diff --git a/.env.example b/.env.example index d92b8aa..dd80c2e 100644 --- a/.env.example +++ b/.env.example @@ -103,6 +103,10 @@ OLLAMA_MAX_CONCURRENT_REQUESTS=1 # OpenAI OPENAI_API_KEY= OPENAI_BASE_URL= +# Backend mode: +# official = OpenAI responses API +# compatible = OpenAI-compatible chat.completions servers like llama.cpp +OPENAI_BACKEND=official OPENAI_MODEL=gpt-4.1-nano OPENAI_IMAGE_MODEL=gpt-image-1-mini OPENAI_TRANSCRIPTION_MODEL=gpt-4o-mini-transcribe @@ -133,6 +137,7 @@ MCP_SERVERS= # OLLAMA_ADDRESS or OLLAMA_BASE_URL. # Capability aliases are also supported: IMAGE, THINK, RAG, EMBEDDING, # TRANSCRIPTION, STT, TTS. +# Backend override: OPENAI_BACKEND=official|compatible. # # Examples: # OPENAI_SPEECH_TO_TEXT_MODEL=gpt-4o-mini-transcribe diff --git a/OPENAI_COMPATIBLE_TARGET_IMPLEMENTATION.md b/OPENAI_COMPATIBLE_TARGET_IMPLEMENTATION.md new file mode 100644 index 0000000..1b84f2e --- /dev/null +++ b/OPENAI_COMPATIBLE_TARGET_IMPLEMENTATION.md @@ -0,0 +1,314 @@ +# OPENAI Compatible Target Implementation + +## Purpose + +Add a separate execution path for OpenAI-compatible backends such as `llama.cpp`, while keeping the current official OpenAI path unchanged. + +## Checklist + +- [x] Add explicit OpenAI backend mode in config +- [x] Route OpenAI requests to separate official and compatible runners +- [x] Keep official OpenAI on `responses.create(...)` +- [x] Add compatible `chat.completions.create(...)` runner +- [x] Add compatible tool-call extractors +- [x] Add backend selection tests +- [x] Add basic memory/config regression coverage +- [x] Normalize compatible streaming tool-call assembly +- [x] Preserve file upload behavior in compatible backend +- [x] Guard unsupported OpenAI-only tools for compatible backend +- [x] Add environment docs and example config entries +- [x] Add real-server integration coverage for compatible backend +- [x] Revisit shared orchestration extraction for further deduplication + +## Non-Goals + +1. Do not change the current official OpenAI `responses.create(...)` behavior. +2. Do not auto-switch behavior only because `OPENAI_BASE_URL` is set. +3. Do not merge compatible backend quirks into the official OpenAI runner. +4. Do not remove or weaken existing tool ranking, memory, RAG, logging, or upload behavior. + +## Current State + +1. `src/ai/unified-ai-runner.openai.ts` currently uses the official `responses` API. +2. `src/ai/provider-adapters.ts` already has provider-specific adapters and tool/result mapping. +3. `src/ai/provider-adapter-contract.ts` already contains `responses`-style extractors. +4. `src/ai/openai-chat-message.ts` currently models `responses`-style messages, not `chat.completions` tool messages. +5. `src/ai/unified-ai-request-pipeline.ts` prepares chat context and runtime state before the model call. +6. `src/ai/ai-runtime-target.ts` resolves provider targets, base URLs, models, and keys. +7. `src/ai/unified-ai-runner.tool-ranker.ts` already uses a `chat.completions`-style call path, which is closer to compatible backends. + +## Target Architecture + +1. Official OpenAI backend stays on `responses.create(...)`. +2. Compatible OpenAI backend uses `chat.completions.create(...)`. +3. Backend selection is explicit through config, for example `OPENAI_BACKEND=official|compatible`. +4. Shared preparation logic remains common. +5. Transport-specific request formatting and response parsing are split. + +## Configuration Design + +1. Add a new config value `OPENAI_BACKEND`. +2. Allowed values should be `official` and `compatible`. +3. Default must be `official`. +4. Keep `OPENAI_BASE_URL` as a transport setting only. +5. `OPENAI_BASE_URL` must not imply compatible mode by itself. +6. Extend environment schema and runtime config to expose this value. +7. Update env docs and example env files. + +## Step 1: Config and Target Selection + +1. Update `src/common/environment.ts`. +2. Add a new environment field for backend mode. +3. Add setters if the codebase uses runtime env mutation in tests. +4. Update the startup schema and runtime snapshot. +5. Add tests for default value and explicit `compatible` selection. + +Expected result: +- Official OpenAI stays unchanged by default. +- Explicit `OPENAI_BACKEND=compatible` selects the new execution path. + +## Step 2: Split Runner Selection + +1. Update the unified AI execution entry point. +2. Add a small backend selector for OpenAI targets. +3. Route official mode to the current runner. +4. Route compatible mode to a new compatible runner. +5. Keep other providers untouched. + +Expected result: +- One codepath for official OpenAI. +- One codepath for OpenAI-compatible servers. + +## Step 3: Shared Orchestration Extraction + +1. Identify logic that is identical for both OpenAI branches. +2. Extract common orchestration into a shared helper where possible. +3. Keep these pieces shared: + - memory prompt injection + - tool ranking + - tool loop control + - logging and timing + - cancellation handling + - file upload post-processing + - document RAG preparation and cleanup +4. Keep transport-specific pieces separate: + - request shape + - response parsing + - tool result message shape + - streaming event parsing + +Expected result: +- Less duplicate logic. +- Cleaner separation between official and compatible behavior. + +## Step 4: Compatible Message Model + +1. Update `src/ai/openai-chat-message.ts` or create a sibling type file for compatible chat messages. +2. Model `system`, `user`, `assistant`, and `tool` roles explicitly. +3. Support `tool_calls` on assistant messages. +4. Support `tool_call_id` on tool result messages. +5. Preserve support for text and multimodal user content where the backend supports it. +6. Avoid forcing `responses` output types into `chat.completions`. + +Expected result: +- Compatible runner can build valid `chat.completions` message arrays. + +## Step 5: Compatible Contract Extractors + +1. Extend `src/ai/provider-adapter-contract.ts`. +2. Add extractors for `chat.completions` tool calls. +3. Add extractors for `chat.completions` streaming tool call deltas. +4. Keep existing `responses` extractors intact. +5. Normalize tool call IDs, names, and argument text the same way as existing extractors. +6. Ensure arguments are always represented as JSON text for the tool loop. + +Expected result: +- Compatible runner can parse tool calls from both normal and streaming responses. + +## Step 6: Compatible Provider Adapter + +1. Update `src/ai/provider-adapters.ts`. +2. Add a separate adapter or branch for OpenAI-compatible chat.completions behavior. +3. Reuse existing tool ranking where safe. +4. Make `appendToolResults(...)` emit `role: "tool"` messages with `tool_call_id`. +5. Keep official OpenAI adapter outputting `function_call_output`. +6. Keep Mistral and Ollama unchanged. + +Expected result: +- Each backend uses the tool result shape it expects. + +## Step 7: Compatible Runner Implementation + +1. Create a new file such as `src/ai/unified-ai-runner.openai-compatible.ts`. +2. Use `openai.chat.completions.create(...)`. +3. Pass `messages`, `tools`, `model`, `stream`, and `signal`. +4. Map system prompt and memory prompt into the `messages` array correctly. +5. Keep the tool loop structure from the current runner. +6. Append assistant tool-call messages and tool result messages between rounds. +7. Continue until no tool calls remain or max rounds is reached. + +Expected result: +- Compatible backends can complete multi-round tool flows. + +## Step 8: Tool Call Loop Semantics + +1. Preserve `MAX_TOOL_ROUNDS`. +2. Preserve tool ranking before each round. +3. Preserve memory tool selection. +4. Preserve file search injection when document RAG is active. +5. Preserve file upload post-processing. +6. Preserve max-rounds warnings and continuation decisions. +7. Keep the final text visible in the stream message exactly as today. + +Expected result: +- Compatible backend behaves like the current runner from the user’s perspective. + +## Step 9: Streaming Behavior + +1. Implement streaming event handling for `chat.completions`. +2. Parse text deltas and append them to `TelegramStreamMessage`. +3. Parse `delta.tool_calls` and keep incremental tool-call state. +4. Update status text when tool usage starts and ends. +5. Keep image generation and file-search status handling if the backend emits compatible signals. +6. Finalize the stream only after the terminal completion event. + +Expected result: +- Streaming works without losing tool call state. + +## Step 10: Tool Result Handling + +1. After each tool execution round, append tool results using the compatible message format. +2. Ensure each tool result keeps the correct `tool_call_id`. +3. Preserve the existing file upload hook. +4. If upload fails, convert the failure into a tool result error string. +5. Preserve the same tool memory map behavior. + +Expected result: +- The backend receives a valid message history for the next round. + +## Step 11: Prompt and Memory Injection + +1. Keep `buildSystemInstruction(...)` as the source of system prompt assembly. +2. Keep `buildUserMemoryPrompt(...)` injected as a separate block. +3. Preserve the explicit separation between assistant memory and user memory. +4. Preserve the `user.md` and `system.md` memory layout. +5. Ensure compatible backend receives the same semantic prompt content. + +Expected result: +- Memory behavior stays identical across official and compatible backends. + +## Step 12: Tool Ranking Compatibility + +1. Review `src/ai/unified-ai-runner.tool-ranker.ts`. +2. Verify whether the current JSON response handling is safe for compatible backends. +3. If a backend cannot guarantee strict JSON mode, add a fallback parser. +4. Keep ranking inputs and outputs consistent across both branches. +5. Do not weaken tool selection heuristics. + +Expected result: +- Tool ranking remains deterministic enough for both branches. + +## Step 13: File Search and RAG + +1. Keep document RAG preparation in the request pipeline. +2. Keep vector store preparation for official OpenAI. +3. Decide whether compatible backend supports file search or needs a no-op fallback. +4. If unsupported, guard the tool list so the compatible backend never receives unsupported tools. +5. Keep cleanup behavior for temporary artifacts. + +Expected result: +- Compatible backend does not receive tools it cannot execute. + +## Step 14: Error Handling + +1. Preserve abort handling. +2. Preserve response failure handling. +3. Preserve stream error handling. +4. Surface backend-specific incompatibilities as explicit errors. +5. Do not silently fall back from compatible to official mode. +6. Keep logs actionable. + +Expected result: +- Failures are obvious and debuggable. + +## Step 15: Logging and Observability + +1. Keep the current AI logs and duration tracking. +2. Add backend mode to log metadata. +3. Log tool calls, tool outputs, and round transitions in both branches. +4. Preserve existing observability hooks. +5. Add explicit labels for official vs compatible runs. + +Expected result: +- Debugging remains easy after the split. + +## Step 16: Tests + +1. Add unit tests for backend selection. +2. Add unit tests for compatible message conversion. +3. Add unit tests for compatible tool call extraction. +4. Add integration tests for a tool-call round trip using mocked `chat.completions`. +5. Add tests proving the official `responses` path is unchanged. +6. Add tests for streaming tool call parsing if the backend supports it. +7. Add tests for fallback behavior in the tool ranker if needed. + +Expected result: +- Both branches are covered and regressions are visible quickly. + +## Step 17: Suggested File Changes + +1. `src/common/environment.ts` +2. `src/ai/ai-runtime-target.ts` +3. `src/ai/unified-ai-request-pipeline.ts` +4. `src/ai/unified-ai-runner.openai.ts` +5. `src/ai/unified-ai-runner.openai-compatible.ts` +6. `src/ai/provider-adapter-contract.ts` +7. `src/ai/provider-adapters.ts` +8. `src/ai/openai-chat-message.ts` +9. `src/ai/unified-ai-runner.tool-ranker.ts` +10. `test/*.test.mjs` +11. `.env.example` +12. Documentation files for backend selection + +## Implementation Order + +1. [x] Add config flag and wire it through environment parsing. +2. [x] Add backend selection logic. +3. [x] Add compatible message and extractor support. +4. [x] Create the compatible runner. +5. [x] Reuse shared orchestration where possible. +6. [x] Wire tests. +7. [x] Verify official behavior is unchanged. +8. [x] Verify compatible backend works with a real OpenAI-compatible server. + +## Verification Plan + +1. Run unit tests. +2. Run integration tests. +3. Verify official OpenAI path still uses `responses.create(...)`. +4. Verify compatible path uses `chat.completions.create(...)`. +5. Verify a `llama.cpp`-style server can complete a tool loop. +6. Verify memory tools still work. +7. Verify document RAG and file upload behavior do not regress. + +## Risks + +1. Some OpenAI-compatible servers do not support every official OpenAI feature. +2. Streaming tool call deltas may differ across providers. +3. JSON-mode assumptions in the ranker may not hold for all compatible servers. +4. Tool schema filtering may need backend-specific allowlists. +5. Message conversion mistakes can break tool loops silently if not tested. + +## Acceptance Criteria + +1. Official OpenAI behavior is unchanged. +2. Compatible backend can run a full chat loop with tools. +3. Tool calls are correctly extracted and executed. +4. Tool results are appended in the correct format. +5. Memory injection still works. +6. Document RAG and file upload behavior remain functional or fail explicitly. +7. Tests cover both branches. + +## Final Note + +The key design rule is simple: keep official OpenAI `responses` behavior intact, and introduce OpenAI-compatible `chat.completions` behavior as a separate backend mode with its own parsing and message shape. diff --git a/README.md b/README.md index 5675abd..5286aa6 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Bot for Telegram with a lot of commands and AI (Ollama/Mistral/OpenAI) written i ```bash cp .env.example .env # Edit .env: add BOT_TOKEN, CREATOR_ID and configure optional AI models (MISTRAL_API_KEY, OPENAI_API_KEY, OLLAMA_ADDRESS) +# For OpenAI-compatible servers (llama.cpp, etc.), set OPENAI_BACKEND=compatible and OPENAI_BASE_URL. # Optional: set DATABASE_URL to postgres://... for PostgreSQL or :memory: for ephemeral SQLite. # Optional: set DATA_PATH if you want to override the default local storage directory. ``` diff --git a/src/ai/ai-runtime-target.ts b/src/ai/ai-runtime-target.ts index 173aa91..da6cd09 100644 --- a/src/ai/ai-runtime-target.ts +++ b/src/ai/ai-runtime-target.ts @@ -6,7 +6,7 @@ import {AiModelCapabilities} from "../model/ai-model-capabilities.js"; import {AiProvider} from "../model/ai-provider.js"; export type AiCapabilityName = keyof AiModelCapabilities; -export type AiRuntimePurpose = AiCapabilityName | "chat"; +export type AiRuntimePurpose = AiCapabilityName | "chat" | "memoryCompress"; export type AiRuntimeTarget = { provider: AiProvider; @@ -24,6 +24,7 @@ const PURPOSE_SUFFIXES: Record = { thinking: ["THINKING", "THINK"], extendedThinking: ["EXTENDED_THINKING", "THINKING", "THINK"], tools: ["TOOLS", "CHAT"], + memoryCompress: ["MEMORY_COMPRESS"], toolRank: ["TOOL_RANK", "TOOL_RANKER"], audio: ["AUDIO"], documents: ["DOCUMENTS", "RAG", "EMBEDDING"], @@ -155,6 +156,25 @@ export function resolveAiRuntimeTarget( return {provider, purpose, model, baseUrl, apiKey, systemPromptAdditions}; } +function hasExplicitTargetConfig(provider: AiProvider, purpose: AiRuntimePurpose): boolean { + const prefix = providerPrefix(provider); + return [ + ...endpointEnvNames(provider, purpose), + ...apiKeyEnvNames(provider, purpose), + ...modelEnvNames(provider, purpose), + ...systemPromptEnvNames(provider, purpose), + ].some(name => !!env(name)) || !!env(`${prefix}_${PURPOSE_SUFFIXES[purpose][0]}_MODEL`); +} + +export function resolveOptionalAiRuntimeTarget( + provider: AiProvider, + purpose: AiRuntimePurpose, + modelOverride?: string, +): AiRuntimeTarget | undefined { + if (!hasExplicitTargetConfig(provider, purpose)) return undefined; + return resolveAiRuntimeTarget(provider, purpose, modelOverride); +} + export function sameRuntimeEndpoint(left: AiRuntimeTarget, right: AiRuntimeTarget): boolean { return left.provider === right.provider && (left.baseUrl ?? "") === (right.baseUrl ?? "") diff --git a/src/ai/conversation-pipeline.ts b/src/ai/conversation-pipeline.ts index fd4ea44..e3b0db4 100644 --- a/src/ai/conversation-pipeline.ts +++ b/src/ai/conversation-pipeline.ts @@ -16,6 +16,7 @@ import type {AttachmentKind, AiRuntimeTarget, RuntimeConfigSnapshot} from "./uni import type {OpenAIChatMessage} from "./openai-chat-message"; import type {MistralChatMessage} from "./mistral-chat-message"; import type {OllamaChatMessage} from "./ollama-chat-message"; +import {buildUserMemoryPrompt} from "./tools/user-memory.js"; export type ConversationAttachment = { kind: AttachmentKind; @@ -267,11 +268,13 @@ function buildSystemInstruction( responseLanguage: UserAiResponseLanguage, includePythonToolPrompt: boolean, additions?: string | null, + memoryInstruction?: string | null, ): string { return [ config.useSystemPrompt ? getResponseLanguageInstruction(responseLanguage) : null, config.systemPrompt && config.useSystemPrompt ? config.systemPrompt : null, additions?.trim() ? additions.trim() : null, + memoryInstruction?.trim() ? memoryInstruction.trim() : null, includePythonToolPrompt ? pythonInterpreterToolPrompt : null, ].filter(Boolean).join("\n\n"); } @@ -310,11 +313,12 @@ export async function buildConversationSnapshot( if (turn.bot) return sum; return sum + turn.attachments.filter(attachment => attachment.kind === "image").length; }, 0); + const memoryInstruction = await buildUserMemoryPrompt(msg.from?.id); return { turns, imageCount, - systemInstruction: buildSystemInstruction(config, responseLanguage, includePythonToolPrompt, runtimeTarget.systemPromptAdditions), + systemInstruction: buildSystemInstruction(config, responseLanguage, includePythonToolPrompt, runtimeTarget.systemPromptAdditions, memoryInstruction), }; } diff --git a/src/ai/document-rag-pipeline.ts b/src/ai/document-rag-pipeline.ts index f8b706c..5afdd29 100644 --- a/src/ai/document-rag-pipeline.ts +++ b/src/ai/document-rag-pipeline.ts @@ -42,6 +42,10 @@ export async function prepareDocumentRag( const documents = downloads.filter(download => download.kind === "document"); if (!documents.length) return undefined; + if (provider === AiProvider.OPENAI && config.openAiBackend === "compatible") { + return undefined; + } + switch (provider) { case AiProvider.OPENAI: { const openAi = createOpenAiClient(config.openAiChatTarget); diff --git a/src/ai/openai-chat-completions.ts b/src/ai/openai-chat-completions.ts new file mode 100644 index 0000000..d667d73 --- /dev/null +++ b/src/ai/openai-chat-completions.ts @@ -0,0 +1,66 @@ +import {isRecord} from "./unified-ai-runner.shared.js"; +import type {OpenAIChatMessage, OpenAICompatibleChatMessage} from "./openai-chat-message.js"; +import type {ToolCallData} from "./unified-ai-runner.shared.js"; + +export function responseContentToText(content: unknown): string { + if (typeof content === "string") return content; + if (!Array.isArray(content)) return ""; + + return content + .map(part => isRecord(part) && typeof part.text === "string" ? part.text : "") + .join(""); +} + +export function openAiResponseMessagesToChatCompletions(messages: OpenAIChatMessage[]): OpenAICompatibleChatMessage[] { + return messages.map((message): OpenAICompatibleChatMessage => { + if (message.role === "system") { + return {role: "system", content: responseContentToText(message.content)}; + } + + if (message.role === "assistant") { + const text = responseContentToText(message.content); + return text.length + ? {role: "assistant", content: text} + : {role: "assistant", content: null}; + } + + const content = Array.isArray(message.content) + ? (() => { + const parts = message.content.map((part): {type: "text"; text: string} | {type: "image_url"; image_url: {url: string}} => { + if (isRecord(part) && part.type === "input_image") { + return { + type: "image_url", + image_url: {url: String(part.image_url ?? "")}, + }; + } + + return { + type: "text", + text: isRecord(part) && typeof part.text === "string" ? part.text : "", + }; + }); + + return parts.every(part => part.type === "text") + ? parts.map(part => part.text).join("") + : parts; + })() + : message.content; + + return {role: "user", content}; + }); +} + +export function buildAssistantToolMessage(calls: ToolCallData[], text: string): OpenAICompatibleChatMessage { + return { + role: "assistant", + content: text, + tool_calls: calls.map(call => ({ + id: call.id, + type: "function", + function: { + name: call.name, + arguments: call.argumentsText, + }, + })), + }; +} diff --git a/src/ai/openai-chat-message.ts b/src/ai/openai-chat-message.ts index d1c4f71..94513a3 100644 --- a/src/ai/openai-chat-message.ts +++ b/src/ai/openai-chat-message.ts @@ -2,6 +2,7 @@ import type { ResponseInputMessageContentList, ResponseOutputMessage, } from "openai/resources/responses/responses"; +import type {ChatCompletionMessageParam} from "openai/resources/chat/completions"; type OpenAIInputChatMessage = { type: "message"; @@ -17,3 +18,5 @@ type OpenAIOutputChatMessage = { } & Pick; export type OpenAIChatMessage = OpenAIInputChatMessage | OpenAIOutputChatMessage; + +export type OpenAICompatibleChatMessage = ChatCompletionMessageParam; diff --git a/src/ai/openai-upload-files.ts b/src/ai/openai-upload-files.ts new file mode 100644 index 0000000..c9a1416 --- /dev/null +++ b/src/ai/openai-upload-files.ts @@ -0,0 +1,74 @@ +import {Message} from "typescript-telegram-bot-api"; +import fs from "node:fs"; +import path from "node:path"; +import {bot} from "../index.js"; +import {Environment} from "../common/environment.js"; +import {logError} from "../util/utils.js"; +import {errorMessage} from "./unified-ai-runner.shared.js"; +import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files.js"; + +export async function tryToUploadFiles( + msg: Message, + toolResults: string[] +): Promise< + | { found: false } + | { found: true, uploaded: true } + | { found: boolean, uploaded: false, error: string, toolIndex: number } +> { + let sendFileAttachment: { + result: SendFileAttachmentResult & { success: true }, + toolIndex: number + } | null = null; + + let found = false; + + try { + for (const [index, toolResult] of toolResults.entries()) { + const raw = JSON.parse(toolResult); + const res = SendFileAttachmentResultSchema.safeParse(raw); + + if (res.success) { + found = true; + + if (res.data.success) { + sendFileAttachment = {result: res.data, toolIndex: index}; + } + } + } + + if (!found) { + return {found: false}; + } + + const attachmentRoot = Environment.FILE_TOOLS_ROOT_DIR; + const attachmentPath = attachmentRoot + ? path.join( + attachmentRoot, + String(msg.from?.id), + sendFileAttachment?.result?.attachment?.relativePath ?? "", + ) + : ""; + + if (!fs.existsSync(attachmentPath)) { + throw new Error(`Attachment file does not exist: ${attachmentPath}`); + } + + await bot.sendDocument({ + chat_id: msg.chat.id, + reply_parameters: { + message_id: msg.message_id, + }, + document: fs.createReadStream(attachmentPath), + }); + + return {found: true, uploaded: true}; + } catch (e) { + logError(e instanceof Error ? e : String(e)); + return { + found: found, + uploaded: false, + error: errorMessage(e instanceof Error ? e : String(e)), + toolIndex: sendFileAttachment?.toolIndex ?? -1 + }; + } +} diff --git a/src/ai/provider-adapter-contract.ts b/src/ai/provider-adapter-contract.ts index 7e4d8fd..d7984eb 100644 --- a/src/ai/provider-adapter-contract.ts +++ b/src/ai/provider-adapter-contract.ts @@ -14,6 +14,12 @@ function normalizeToolArguments(value: unknown): string { return JSON.stringify(value ?? {}); } +function normalizeToolArgumentsChunk(value: unknown): string { + if (typeof value === "string") return value; + if (value === undefined || value === null) return ""; + return JSON.stringify(value); +} + export function extractOpenAiToolCalls(response: unknown): ToolCallData[] { const output = isRecord(response) && Array.isArray(response.output) ? response.output : []; @@ -32,6 +38,86 @@ export function extractOpenAiTextDelta(input: unknown): string { return event?.type === "response.output_text.delta" ? event.delta ?? "" : ""; } +export function extractOpenAiChatTextDelta(input: unknown): string { + const event = isRecord(input) ? input : undefined; + const choice = event && Array.isArray(event.choices) && isRecord(event.choices[0]) ? event.choices[0] : undefined; + const delta = isRecord(choice?.delta) ? choice.delta : undefined; + const content = delta && typeof delta.content === "string" ? delta.content : ""; + return content; +} + +export function normalizeStreamingTextDelta(existingText: string, deltaText: string): string { + if (!deltaText) return ""; + if (!existingText) return deltaText; + + if (deltaText.startsWith(existingText)) { + return deltaText.slice(existingText.length); + } + + return deltaText; +} + +export function extractOpenAiChatToolCalls(response: unknown): ToolCallData[] { + const record = isRecord(response) ? response : undefined; + const choice = record && Array.isArray(record.choices) && isRecord(record.choices[0]) ? record.choices[0] : undefined; + const message = isRecord(choice?.message) ? choice.message : undefined; + const toolCalls = message && Array.isArray(message.tool_calls) ? message.tool_calls : []; + + return toolCalls + .filter((item, index) => isRecord(item) && ((typeof item.id === "string") || typeof item.index === "number" || index >= 0)) + .map((item, index) => { + const call = isRecord(item) ? item : {}; + const fn = isRecord(call.function) ? call.function : undefined; + const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : ""; + return { + id: normalizeToolCallId(call.id, `openai_chat_${typeof call.index === "number" ? call.index : index}`), + name, + argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments), + }; + }) + .filter(call => call.name.length > 0); +} + +export function extractOpenAiChatStreamingToolCalls(input: unknown): ToolCallData[] { + const event = isRecord(input) ? input : undefined; + const choice = event && Array.isArray(event.choices) && isRecord(event.choices[0]) ? event.choices[0] : undefined; + const delta = isRecord(choice?.delta) ? choice.delta : undefined; + const toolCalls = Array.isArray(delta?.tool_calls) ? delta.tool_calls : []; + + return toolCalls + .map((item, index) => { + const call = isRecord(item) ? item : {}; + const fn = isRecord(call.function) ? call.function : undefined; + const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : ""; + return { + id: normalizeToolCallId(call.id, `openai_chat_${typeof call.index === "number" ? call.index : index}`), + name, + argumentsText: normalizeToolArgumentsChunk(fn?.arguments ?? call.arguments), + }; + }) + .filter(call => call.id.length > 0); +} + +export function mergeToolCallChunks(existing: ToolCallData[], chunks: ToolCallData[]): ToolCallData[] { + const merged = new Map(existing.map(call => [call.id, {...call}])); + + for (const chunk of chunks) { + const current = merged.get(chunk.id); + if (!current) { + merged.set(chunk.id, {...chunk}); + continue; + } + + merged.set(chunk.id, { + id: current.id, + name: current.name || chunk.name, + argumentsText: current.argumentsText + (chunk.argumentsText ?? ""), + }); + } + + return [...merged.values()]; +} + export function extractOpenAiStreamingToolCalls(input: unknown): ToolCallData[] { const event = input as ResponseStreamEvent | undefined; if (event?.type === "response.output_item.added" && isRecord(event.item) && event.item.type === "function_call") { diff --git a/src/ai/provider-model-runtime.ts b/src/ai/provider-model-runtime.ts index 60ddc6e..00b2faf 100644 --- a/src/ai/provider-model-runtime.ts +++ b/src/ai/provider-model-runtime.ts @@ -196,7 +196,8 @@ export async function getRuntimeCapabilities( target?: AiRuntimeTarget ): Promise { const runtimeTarget = target ?? resolveAiRuntimeTarget(provider, "chat", model ?? getRuntimeModel(provider)); - const result = await getModelCapabilities(provider, runtimeTarget.model, target?.purpose ?? "chat") ?? buildCapabilities({}); + const targetPurpose = target?.purpose && target.purpose !== "memoryCompress" ? target.purpose : "chat"; + const result = await getModelCapabilities(provider, runtimeTarget.model, targetPurpose) ?? buildCapabilities({}); for (const capabilityName of CAPABILITY_NAMES) { if (provider === AiProvider.OPENAI && (capabilityName === "vision" || capabilityName === "ocr")) { diff --git a/src/ai/tool-mappers.ts b/src/ai/tool-mappers.ts index 4b099cc..842438f 100644 --- a/src/ai/tool-mappers.ts +++ b/src/ai/tool-mappers.ts @@ -3,6 +3,7 @@ import {AiProvider} from "../model/ai-provider.js"; import {getTools} from "./tools/registry.js"; import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search.js"; import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator.js"; +import {toolSchemaNames} from "./tool-schema-utils.js"; export type AiProviderName = "ollama" | "openai" | "mistral"; @@ -26,6 +27,11 @@ export function getOpenAITools(forCreator?: boolean): AiTool[] { })); } +export function getOpenAICompatibleTools(forCreator?: boolean): AiTool[] { + // The compatible chat.completions backend only accepts plain function tools. + return getOpenAITools(forCreator); +} + export type OpenAiResponseTool = { type: "function"; name: string; @@ -79,3 +85,20 @@ export function getProviderTools(provider: AiProvider, forCreator?: boolean): Ai return getOpenAITools(forCreator); } } + +export function ensureToolsSelected(availableTools: readonly T[], selectedTools: readonly T[], toolNames: readonly string[]): T[] { + const selected = [...selectedTools]; + const selectedNames = new Set(selected.flatMap(tool => toolSchemaNames(tool as never))); + + for (const toolName of toolNames) { + if (selectedNames.has(toolName)) continue; + + const extraTool = availableTools.find(tool => toolSchemaNames(tool as never).includes(toolName)); + if (extraTool) { + selected.unshift(extraTool); + selectedNames.add(toolName); + } + } + + return selected; +} diff --git a/src/ai/tool-ranker-metadata.ts b/src/ai/tool-ranker-metadata.ts index 7cf47f0..765ee5a 100644 --- a/src/ai/tool-ranker-metadata.ts +++ b/src/ai/tool-ranker-metadata.ts @@ -102,6 +102,100 @@ export const TOOL_RANKER_TOOL_INFOS = { example("где определён BotService?", ["search_files"]), ], ), + read_user_info: tool( + "read_user_info", + "Read persistent user memory from user.md.", + "Use before editing or when the user asks what you remember about them.", + [ + example("что ты помнишь обо мне?", ["read_user_info"]), + example("покажи мою память", ["read_user_info"]), + ], + ), + read_system_info: tool( + "read_system_info", + "Read persistent assistant memory from system.md.", + "Use before editing or when the user asks what instructions you remember about yourself.", + [ + example("что ты помнишь о себе?", ["read_system_info"]), + example("покажи память о тебе", ["read_system_info"]), + ], + ), + add_user_info: tool( + "add_user_info", + "Append a durable fact about the user to persistent memory.", + "Use when the user asks to remember a new fact, preference, identity detail, or profile information about themselves.", + [ + example("запомни, что меня зовут Иван", ["add_user_info"]), + example("запомни, что я люблю чай", ["add_user_info"]), + example("remember that I like short answers", ["add_user_info"]), + ], + ), + add_system_info: tool( + "add_system_info", + "Append a durable instruction about the assistant to persistent memory.", + "Use when the user asks to remember a new assistant identity, style, or behavior instruction.", + [ + example("тебя зовут Евлампий", ["add_system_info"]), + example("ты ИИ помощник", ["add_system_info"]), + example("remember you are a concise assistant", ["add_system_info"]), + ], + ), + remove_user_info: tool( + "remove_user_info", + "Remove a specific user fact from persistent memory.", + "Use when the user asks to forget, delete, or remove a specific fact about themselves.", + [ + example("забудь, что я люблю кофе", ["remove_user_info"]), + example("удали из памяти, что я живу в Москве", ["remove_user_info"]), + example("forget that I work at ACME", ["remove_user_info"]), + ], + ), + remove_system_info: tool( + "remove_system_info", + "Remove a specific assistant instruction from persistent memory.", + "Use when the user asks to forget or remove a specific instruction about the assistant.", + [ + example("забудь, что тебя зовут Евлампий", ["remove_system_info"]), + example("убери правило отвечать коротко", ["remove_system_info"]), + example("forget that you are a concise assistant", ["remove_system_info"]), + ], + ), + replace_user_info: tool( + "replace_user_info", + "Replace the full user memory with a new compact version.", + "Use when the user wants to overwrite all remembered user info, for example when they say to forget everything and keep only the new fact.", + [ + example("забудь всё обо мне и запиши только это: меня зовут Иван", ["replace_user_info"]), + example("замени всю память обо мне на: люблю чай и короткие ответы", ["replace_user_info"]), + ], + ), + replace_system_info: tool( + "replace_system_info", + "Replace the full assistant memory with a new compact version.", + "Use when the user wants to overwrite all remembered assistant info or instructions.", + [ + example("забудь всё о себе и запиши только это: тебя зовут Евлампий", ["replace_system_info"]), + example("замени инструкцию о себе на: ты краткий ИИ помощник", ["replace_system_info"]), + ], + ), + delete_user_info: tool( + "delete_user_info", + "Delete user.md entirely.", + "Use when the user explicitly asks to delete all remembered user info, not just a fragment.", + [ + example("удали всю память обо мне", ["delete_user_info"]), + example("forget all user memory", ["delete_user_info"]), + ], + ), + delete_system_info: tool( + "delete_system_info", + "Delete system.md entirely.", + "Use when the user explicitly asks to delete all remembered assistant info, not just a fragment.", + [ + example("удали всю память о себе", ["delete_system_info"]), + example("forget all assistant memory", ["delete_system_info"]), + ], + ), create_file: tool( "create_file", "Create a new small file.", @@ -443,6 +537,16 @@ function buildPriorityLines(tools: readonly ToolRankerToolInfo[]): string[] { pushIfAvailable("read_file", "known local file path -> read_file"); pushIfAvailable("list_directory", "project structure or directory listing -> list_directory"); pushIfAvailable("search_files", "local file/content search or unknown file path -> search_files"); + pushIfAvailable("read_user_info", "inspect remembered user info -> read_user_info"); + pushIfAvailable("read_system_info", "inspect remembered assistant info -> read_system_info"); + pushIfAvailable("add_user_info", "remember a new user fact -> add_user_info"); + pushIfAvailable("add_system_info", "remember a new assistant instruction -> add_system_info"); + pushIfAvailable("remove_user_info", "forget a user fact -> remove_user_info"); + pushIfAvailable("remove_system_info", "forget an assistant instruction -> remove_system_info"); + pushIfAvailable("replace_user_info", "overwrite all user memory -> replace_user_info"); + pushIfAvailable("replace_system_info", "overwrite all assistant memory -> replace_system_info"); + pushIfAvailable("delete_user_info", "delete all user memory -> delete_user_info"); + pushIfAvailable("delete_system_info", "delete all assistant memory -> delete_system_info"); pushIfAvailable("edit_file_patch", "targeted existing file edit -> edit_file_patch"); pushIfAvailable("update_file", "full existing file replacement -> update_file"); pushIfAvailable("create_file", "small new file -> create_file"); diff --git a/src/ai/tools/create-note.ts b/src/ai/tools/create-note.ts index 79281e8..5e7ed00 100644 --- a/src/ai/tools/create-note.ts +++ b/src/ai/tools/create-note.ts @@ -1,11 +1,11 @@ -import {AiTool} from "../tool-types"; +import {AiTool} from "../tool-types.js"; import path from "node:path"; import {readFile, writeFile} from "node:fs/promises"; -import {NOTES_HEADER, notesDir, notesRootFile} from "../../index"; -import {asNonEmptyString} from "./utils"; +import {NOTES_HEADER, notesDir, notesRootFile} from "../../index.js"; +import {asNonEmptyString} from "./utils.js"; import fs from "node:fs"; -import {toolsLogger} from "./tool-logger"; -import {AiJsonObject} from "../tool-types"; +import {toolsLogger} from "./tool-logger.js"; +import {AiJsonObject} from "../tool-types.js"; const logger = toolsLogger.child("create-note"); diff --git a/src/ai/tools/datetime.ts b/src/ai/tools/datetime.ts index 27b972d..608b3a0 100644 --- a/src/ai/tools/datetime.ts +++ b/src/ai/tools/datetime.ts @@ -1,5 +1,5 @@ import {AiTool} from "../tool-types"; -import {asNonEmptyString} from "./utils"; +import {asNonEmptyString} from "./utils.js"; import {AiJsonObject} from "../tool-types"; export const getCurrentDateTimeTool = { diff --git a/src/ai/tools/files.ts b/src/ai/tools/files.ts index 7874a15..8772ca8 100644 --- a/src/ai/tools/files.ts +++ b/src/ai/tools/files.ts @@ -3,8 +3,8 @@ import fs from "node:fs"; import path from "node:path"; import {z} from "zod"; -import {Environment} from "../../common/environment"; -import {AiJsonObject, AiJsonValue, AiTool} from "../tool-types"; +import {Environment} from "../../common/environment.js"; +import {AiJsonObject, AiJsonValue, AiTool} from "../tool-types.js"; import { MAX_COPY_ENTRIES, MAX_COPY_TOTAL_BYTES, @@ -23,8 +23,8 @@ import { MAX_PATCH_SEARCH_BYTES, MAX_STREAM_WRITE_IDLE_MS, MAX_STREAM_WRITE_SESSIONS, -} from "./limits"; -import {asBoolean, asNonEmptyString, asPositiveInt, asString} from "./utils"; +} from "./limits.js"; +import {asBoolean, asNonEmptyString, asPositiveInt, asString} from "./utils.js"; // ============================================================================= // Public types and schemas diff --git a/src/ai/tools/market-rates.ts b/src/ai/tools/market-rates.ts index 0fd7369..de0fd31 100644 --- a/src/ai/tools/market-rates.ts +++ b/src/ai/tools/market-rates.ts @@ -1,7 +1,7 @@ -import {AiTool} from "../tool-types"; +import {AiTool} from "../tool-types.js"; import axios from "axios"; -import {toolsLogger} from "./tool-logger"; -import {AiJsonObject} from "../tool-types"; +import {toolsLogger} from "./tool-logger.js"; +import {AiJsonObject} from "../tool-types.js"; const logger = toolsLogger.child("market-rates"); diff --git a/src/ai/tools/notes.ts b/src/ai/tools/notes.ts index 096aa66..85cdc8b 100644 --- a/src/ai/tools/notes.ts +++ b/src/ai/tools/notes.ts @@ -1,11 +1,11 @@ -import {AiTool} from "../tool-types"; +import {AiTool} from "../tool-types.js"; import path from "node:path"; import {readdir, readFile, stat, unlink, writeFile} from "node:fs/promises"; -import {notesDir, notesRootFile} from "../../index"; -import {asNonEmptyString} from "./utils"; -import {toolsLogger} from "./tool-logger"; +import {notesDir, notesRootFile} from "../../index.js"; +import {asNonEmptyString} from "./utils.js"; +import {toolsLogger} from "./tool-logger.js"; import {z} from "zod"; -import {AiJsonObject} from "../tool-types"; +import {AiJsonObject} from "../tool-types.js"; const logger = toolsLogger.child("notes"); diff --git a/src/ai/tools/registry.ts b/src/ai/tools/registry.ts index 4d89019..a7b47fc 100644 --- a/src/ai/tools/registry.ts +++ b/src/ai/tools/registry.ts @@ -1,17 +1,17 @@ -import {Environment} from "../../common/environment"; -import {AiTool} from "../tool-types"; -import {WEB_SEARCH_TOOL_NAME, webSearch, webSearchTool, webSearchToolPrompt} from "./web-search"; -import {getCurrentDateTime, getCurrentDateTimeTool} from "./datetime"; -import {shellExecute, shellExecuteTool} from "./shell"; -import {ToolHandler} from "./types"; -import {getWeather, getWeatherTool} from "./weather"; +import {Environment} from "../../common/environment.js"; +import {AiTool} from "../tool-types.js"; +import {WEB_SEARCH_TOOL_NAME, webSearch, webSearchTool, webSearchToolPrompt} from "./web-search.js"; +import {getCurrentDateTime, getCurrentDateTimeTool} from "./datetime.js"; +import {shellExecute, shellExecuteTool} from "./shell.js"; +import {ToolHandler} from "./types.js"; +import {getWeather, getWeatherTool} from "./weather.js"; import { GET_FINANCIAL_MARKET_DATA_TOOL_NAME, getFinancialMarketData, getFinancialMarketDataToolPrompt, getMarketRates -} from "./market-rates"; -import {pythonInterpreterTool, runPythonInterpreter} from "./python-interpretator"; +} from "./market-rates.js"; +import {pythonInterpreterTool, runPythonInterpreter} from "./python-interpretator.js"; import { beginFileWrite, beginFileWriteTool, @@ -44,12 +44,14 @@ import { updateFileTool, writeFileChunk, writeFileChunkTool -} from "./files"; +} from "./files.js"; +import {executeMemoryTool, memoryToolPrompt, memoryTools, type MemoryToolName} from "./user-memory.js"; import {getMcpToolHandlers, getMcpToolPrompts, getMcpTools} from "../mcp/mcp-registry.js"; export const defaultTools: AiTool[] = [ getCurrentDateTimeTool, getFinancialMarketData, + ...memoryTools, ]; export const fileTools = [ @@ -169,6 +171,20 @@ export const getToolHandlers = () => { if (isLocalToolEnabled("get_datetime")) handlers.get_datetime = getCurrentDateTime; if (isLocalToolEnabled("get_financial_market_data")) handlers.get_financial_market_data = getMarketRates; + for (const tool of memoryTools) { + if (!isLocalToolEnabled(tool.function.name)) continue; + handlers[tool.function.name] = async (args, context) => { + const userId = typeof args?.userId === "number" ? args.userId : undefined; + if (!userId) { + return {success: false, error: "Missing userId"}; + } + + return executeMemoryTool(tool.function.name as MemoryToolName, { + userId, + content: typeof args?.content === "string" ? args.content : undefined, + }, context); + }; + } if (isLocalToolEnabled("read_file")) handlers.read_file = readFile; if (isLocalToolEnabled("list_directory")) handlers.list_directory = listDirectory; @@ -186,7 +202,7 @@ export const getToolHandlers = () => { if (isLocalToolEnabled("rename_path")) handlers.rename_path = renamePath; if (isLocalToolEnabled("delete_path")) handlers.delete_path = deletePath; - if (isLocalToolEnabled("python_interpreter")) handlers.python_interpreter = runPythonInterpreter; + if (isLocalToolEnabled("python_interpreter")) handlers.python_interpreter = (args, _context) => runPythonInterpreter(args); if (isLocalToolEnabled("shell_execute")) handlers.shell_execute = shellExecute; if (isLocalToolEnabled("web_search")) handlers.web_search = webSearch; if (isLocalToolEnabled("get_weather")) handlers.get_weather = getWeather; @@ -200,6 +216,8 @@ export function getToolPrompts(toolNames: string[]): string[] { } const prompts: string[] = []; + const memoryToolNames = new Set(memoryTools.map(tool => tool.function.name)); + let memoryPromptAdded = false; for (const toolName of toolNames) { if (!isLocalToolEnabled(toolName)) { @@ -212,6 +230,14 @@ export function getToolPrompts(toolNames: string[]): string[] { continue; } + if (memoryToolNames.has(toolName)) { + if (!memoryPromptAdded) { + prompts.push(memoryToolPrompt); + memoryPromptAdded = true; + } + continue; + } + switch (toolName) { case GET_FINANCIAL_MARKET_DATA_TOOL_NAME: prompts.push(getFinancialMarketDataToolPrompt); diff --git a/src/ai/tools/runtime.ts b/src/ai/tools/runtime.ts index 913e208..d7be454 100644 --- a/src/ai/tools/runtime.ts +++ b/src/ai/tools/runtime.ts @@ -1,14 +1,19 @@ -import {getToolHandlers} from "./registry"; -import {normalizeToolArguments} from "./utils"; -import {PYTHON_INTERPRETER_TOOL_NAME, PythonInterpreterInputFile, runPythonInterpreter} from "./python-interpretator"; -import {toolsLogger} from "./tool-logger"; -import {AiJsonObject, AiJsonValue} from "../tool-types"; +import {getToolHandlers} from "./registry.js"; +import {normalizeToolArguments} from "./utils.js"; +import {PYTHON_INTERPRETER_TOOL_NAME, PythonInterpreterInputFile, runPythonInterpreter} from "./python-interpretator.js"; +import {toolsLogger} from "./tool-logger.js"; +import {AiJsonObject, AiJsonValue} from "../tool-types.js"; +import type {MemoryRuntimeContext} from "./user-memory.js"; +import type {AiRuntimeTarget} from "../ai-runtime-target.js"; +import type {AiProvider} from "../../model/ai-provider.js"; const logger = toolsLogger.child("runtime"); export type ToolRuntimeContext = { pythonInputFiles?: PythonInterpreterInputFile[]; -}; + provider?: AiProvider; + runtimeTarget?: AiRuntimeTarget; +} & MemoryRuntimeContext; function stringifyToolResult(result: AiJsonValue): string { if (typeof result === "string") return result; @@ -48,7 +53,7 @@ export async function executeToolCall( } const arguments1 = normalizeToolArguments(args, userId); - const result = await handler(arguments1); + const result = await handler(arguments1, context); const s = stringifyToolResult(result); logger.debug("execute.done", {name, chars: s.length, duration: logger.duration(startedAt)}); return s; diff --git a/src/ai/tools/search-notes.ts b/src/ai/tools/search-notes.ts index b8da252..56af2dd 100644 --- a/src/ai/tools/search-notes.ts +++ b/src/ai/tools/search-notes.ts @@ -1,10 +1,10 @@ -import {AiTool} from "../tool-types"; +import {AiTool} from "../tool-types.js"; import path from "node:path"; import {readdir, readFile} from "node:fs/promises"; -import {notesDir, notesRootFile} from "../../index"; -import {asNonEmptyString} from "./utils"; -import {toolsLogger} from "./tool-logger"; -import {AiJsonObject, AiJsonValue} from "../tool-types"; +import {notesDir, notesRootFile} from "../../index.js"; +import {asNonEmptyString} from "./utils.js"; +import {toolsLogger} from "./tool-logger.js"; +import {AiJsonObject, AiJsonValue} from "../tool-types.js"; const logger = toolsLogger.child("search-notes"); diff --git a/src/ai/tools/shell.ts b/src/ai/tools/shell.ts index cb56221..b124ebe 100644 --- a/src/ai/tools/shell.ts +++ b/src/ai/tools/shell.ts @@ -1,6 +1,6 @@ import {AiTool} from "../tool-types"; -import {runCommand} from "../../util/utils"; -import {asNonEmptyString} from "./utils"; +import {runCommand} from "../../util/utils.js"; +import {asNonEmptyString} from "./utils.js"; import {AiJsonObject} from "../tool-types"; export const shellExecuteTool = { diff --git a/src/ai/tools/types.ts b/src/ai/tools/types.ts index 221bd20..8d14e3b 100644 --- a/src/ai/tools/types.ts +++ b/src/ai/tools/types.ts @@ -1,3 +1,4 @@ import {AiJsonObject, AiJsonValue} from "../tool-types"; +import type {ToolRuntimeContext} from "./runtime.js"; -export type ToolHandler = (args?: AiJsonObject) => Promise | AiJsonValue | string | null | undefined; +export type ToolHandler = (args?: AiJsonObject, context?: ToolRuntimeContext) => Promise | AiJsonValue | string | null | undefined; diff --git a/src/ai/tools/user-memory.ts b/src/ai/tools/user-memory.ts new file mode 100644 index 0000000..735f641 --- /dev/null +++ b/src/ai/tools/user-memory.ts @@ -0,0 +1,582 @@ +import path from "node:path"; +import {readFile, rename, writeFile, mkdir, rm} from "node:fs/promises"; +import {AiProvider} from "../../model/ai-provider.js"; +import {Environment} from "../../common/environment.js"; +import {createMistralClient, createOllamaClient, createOpenAiClient, resolveOptionalAiRuntimeTarget, type AiRuntimeTarget} from "../ai-runtime-target.js"; +import {AiTool} from "../tool-types.js"; +import {toolsLogger} from "./tool-logger.js"; +import {asNonEmptyString} from "./utils.js"; + +const logger = toolsLogger.child("user-memory"); + +function memoryDir(): string { + return path.join(Environment.DATA_PATH, "memory"); +} + +export const USER_MEMORY_MAX_CHARS = 1000; + +export type MemoryScope = "user" | "system"; +export type MemoryAction = "add" | "replace" | "remove"; + +export type MemoryRuntimeContext = { + provider?: AiProvider; + runtimeTarget?: AiRuntimeTarget; +}; + +export type MemoryOperationResult = + | {success: true; scope: MemoryScope; filePath: string; content: string; chars: number; compressed: boolean} + | {success: false; scope: MemoryScope; error: string}; + +type CompressionRunResult = { + content: string; +}; + +export type MemoryCompressionRunner = (params: { + target: AiRuntimeTarget; + scope: MemoryScope; + currentText: string; + limit: number; +}) => Promise; + +function extractMistralText(content: unknown): string { + if (typeof content === "string") return content; + if (!Array.isArray(content)) return ""; + + return content + .map(part => { + if (typeof part === "string") return part; + if (part && typeof part === "object" && "text" in part && typeof (part as {text?: unknown}).text === "string") { + return (part as {text: string}).text; + } + return ""; + }) + .join(""); +} + +export type MemoryToolName = + | "read_user_info" + | "read_system_info" + | "add_user_info" + | "add_system_info" + | "remove_user_info" + | "remove_system_info" + | "replace_user_info" + | "replace_system_info" + | "delete_user_info" + | "delete_system_info"; + +export const MEMORY_TOOL_NAMES: MemoryToolName[] = [ + "read_user_info", + "read_system_info", + "add_user_info", + "add_system_info", + "remove_user_info", + "remove_system_info", + "replace_user_info", + "replace_system_info", + "delete_user_info", + "delete_system_info", +]; + +type MemoryToolSpec = { + name: MemoryToolName; + scope: MemoryScope; + kind: "read" | "write" | "delete"; + action?: MemoryAction; + description: string; + prompt: string; +}; + +const MEMORY_TOOL_SPECS: MemoryToolSpec[] = [ + { + name: "read_user_info", + scope: "user", + kind: "read", + description: "Read persistent user memory from user.md.", + prompt: `Use when you need to inspect remembered user facts before editing or answering.`, + }, + { + name: "read_system_info", + scope: "system", + kind: "read", + description: "Read persistent assistant memory from system.md.", + prompt: `Use when you need to inspect remembered assistant instructions before editing or answering.`, + }, + { + name: "add_user_info", + scope: "user", + kind: "write", + action: "add", + description: "Append a durable fact about the user to user.md.", + prompt: `Use for new user facts, preferences, identity details, and profile information. Keep the result at or below ${USER_MEMORY_MAX_CHARS} characters.`, + }, + { + name: "add_system_info", + scope: "system", + kind: "write", + action: "add", + description: "Append a durable instruction about the assistant to system.md.", + prompt: `Use for new assistant identity, style, or behavior instructions. Keep the result at or below ${USER_MEMORY_MAX_CHARS} characters.`, + }, + { + name: "remove_user_info", + scope: "user", + kind: "write", + action: "remove", + description: "Remove a specific user fact or fragment from user.md.", + prompt: `Use when the user asks to forget something about themselves. Keep the result at or below ${USER_MEMORY_MAX_CHARS} characters.`, + }, + { + name: "remove_system_info", + scope: "system", + kind: "write", + action: "remove", + description: "Remove a specific assistant instruction or fragment from system.md.", + prompt: `Use when the user asks to forget something about the assistant. Keep the result at or below ${USER_MEMORY_MAX_CHARS} characters.`, + }, + { + name: "replace_user_info", + scope: "user", + kind: "write", + action: "replace", + description: "Replace user.md completely with a new compact version.", + prompt: `Use when the user wants to overwrite all remembered user info, such as "forget everything about me and remember only this". Keep the result at or below ${USER_MEMORY_MAX_CHARS} characters.`, + }, + { + name: "replace_system_info", + scope: "system", + kind: "write", + action: "replace", + description: "Replace system.md completely with a new compact version.", + prompt: `Use when the user wants to overwrite all remembered assistant info or instructions. Keep the result at or below ${USER_MEMORY_MAX_CHARS} characters.`, + }, + { + name: "delete_user_info", + scope: "user", + kind: "delete", + description: "Delete the user memory file user.md.", + prompt: `Use when the user asks to delete all remembered user info and remove the memory file entirely.`, + }, + { + name: "delete_system_info", + scope: "system", + kind: "delete", + description: "Delete the assistant memory file system.md.", + prompt: `Use when the user asks to delete all remembered assistant info and remove the memory file entirely.`, + }, +]; + +export const memoryToolPrompt = [ + "Use the memory tools to manage persistent per-user memory.", + "- `read_*` shows the current file content before editing.", + "- `user.md` stores durable facts about the user.", + "- `system.md` stores durable facts/instructions about the assistant itself.", + "- `add_*` appends a new fact or instruction.", + "- `remove_*` removes a specific fact or fragment.", + "- `replace_*` rewrites the whole file when the user wants to overwrite memory.", + "- `delete_*` removes the file entirely.", + `- Keep each file at or below ${USER_MEMORY_MAX_CHARS} characters.`, +].join("\n"); + +function createMemoryTool(spec: MemoryToolSpec): AiTool { + return { + type: "function", + function: { + name: spec.name, + description: spec.description, + parameters: { + type: "object", + properties: spec.kind === "read" || spec.kind === "delete" ? {} : { + content: { + type: "string", + description: spec.action === "remove" + ? "Exact text or fragment to remove from memory." + : "Text to append or replace in memory.", + }, + }, + required: spec.kind === "read" || spec.kind === "delete" ? [] : ["content"], + }, + }, + } satisfies AiTool; +} + +export const memoryTools = MEMORY_TOOL_SPECS.map(createMemoryTool); + +function normalizeUserId(userId: number): number | null { + return Number.isSafeInteger(userId) && userId > 0 ? userId : null; +} + +function normalizeMemoryText(value: string): string { + return value.replaceAll("\r\n", "\n"); +} + +function getMemoryUserDir(userId: number): string { + return path.join(memoryDir(), String(userId)); +} + +export function getMemoryFilePath(userId: number, scope: MemoryScope): string { + return path.join(getMemoryUserDir(userId), `${scope}.md`); +} + +async function ensureMemoryDir(userId: number): Promise { + const dir = getMemoryUserDir(userId); + await mkdir(dir, {recursive: true}); + return dir; +} + +async function readMemoryFile(userId: number, scope: MemoryScope): Promise { + const filePath = getMemoryFilePath(userId, scope); + try { + return normalizeMemoryText(await readFile(filePath, "utf-8")); + } catch (error) { + if (error instanceof Error && "code" in error && (error as NodeJS.ErrnoException).code === "ENOENT") { + return ""; + } + + throw error; + } +} + +async function writeMemoryFile(userId: number, scope: MemoryScope, content: string): Promise { + const normalized = normalizeMemoryText(content); + const filePath = getMemoryFilePath(userId, scope); + await ensureMemoryDir(userId); + + const tempPath = `${filePath}.tmp-${process.pid}-${Date.now()}`; + await writeFile(tempPath, normalized, "utf-8"); + await rename(tempPath, filePath); + return filePath; +} + +function trimToLimit(content: string, limit = USER_MEMORY_MAX_CHARS): string { + if (content.length <= limit) return content; + return content.slice(0, limit).trimEnd(); +} + +function stripCodeFences(content: string): string { + const trimmed = content.trim(); + const fenced = trimmed.match(/^```(?:markdown|md)?\s*([\s\S]*?)\s*```$/i); + if (fenced?.[1]) return fenced[1].trim(); + return trimmed; +} + +function sameTarget(left: AiRuntimeTarget | undefined, right: AiRuntimeTarget | undefined): boolean { + if (!left || !right) return false; + return left.provider === right.provider + && left.model === right.model + && (left.baseUrl ?? "") === (right.baseUrl ?? "") + && (left.apiKey ?? "") === (right.apiKey ?? ""); +} + +async function compressWithTarget(params: { + target: AiRuntimeTarget; + scope: MemoryScope; + currentText: string; + limit: number; +}): Promise { + const {target, scope, currentText, limit} = params; + + const systemPrompt = [ + "You compress persistent memory for a chat bot.", + "Return only the rewritten Markdown text.", + "Preserve important facts, preferences, identities, instructions, and durable context.", + "Remove noise, duplication, stale details, and low-value filler.", + `Keep the result at or below ${limit} characters.`, + "Do not add explanations, bullet labels, or code fences.", + ].join("\n"); + + const userPrompt = [ + `Memory scope: ${scope}`, + `Character limit: ${limit}`, + "Current memory:", + currentText.trim() || "(empty)", + "", + "Rewrite it as compact Markdown only.", + ].join("\n"); + + logger.info("compress.start", {provider: target.provider, model: target.model, scope, chars: currentText.length}); + + switch (target.provider) { + case AiProvider.OPENAI: { + const client = createOpenAiClient(target); + const response = await client.chat.completions.create({ + model: target.model, + temperature: 0, + messages: [ + {role: "system", content: systemPrompt}, + {role: "user", content: userPrompt}, + ], + }); + const text = response.choices[0]?.message?.content ?? ""; + return {content: stripCodeFences(text)}; + } + case AiProvider.MISTRAL: { + const client = createMistralClient(target); + const response = await client.chat.complete({ + model: target.model, + temperature: 0, + messages: [ + {role: "system", content: systemPrompt}, + {role: "user", content: userPrompt}, + ], + } as Parameters[0]); + const text = extractMistralText(response.choices?.[0]?.message?.content); + return {content: stripCodeFences(text)}; + } + case AiProvider.OLLAMA: { + const client = createOllamaClient(target); + const response = await client.chat({ + model: target.model, + stream: false, + options: {temperature: 0}, + messages: [ + {role: "system", content: systemPrompt}, + {role: "user", content: userPrompt}, + ], + }); + const text = typeof response.message?.content === "string" ? response.message.content : ""; + return {content: stripCodeFences(text)}; + } + } +} + +export async function compressMemoryWithFallback(params: { + provider?: AiProvider; + currentTarget?: AiRuntimeTarget; + scope: MemoryScope; + currentText: string; + limit?: number; +}, runner: MemoryCompressionRunner = async (input) => (await compressWithTarget(input)).content): Promise<{content: string; compressed: boolean; usedTarget?: AiRuntimeTarget}> { + const limit = params.limit ?? USER_MEMORY_MAX_CHARS; + const trimmed = normalizeMemoryText(params.currentText); + if (trimmed.length <= limit) { + return {content: trimmed, compressed: false}; + } + + const explicitTarget = params.provider ? resolveOptionalAiRuntimeTarget(params.provider, "memoryCompress") : undefined; + const targets = [explicitTarget, params.currentTarget].filter((target, index, list): target is AiRuntimeTarget => !!target && list.findIndex(item => sameTarget(item, target)) === index); + + for (const target of targets) { + try { + const content = trimToLimit(await runner({target, scope: params.scope, currentText: trimmed, limit}), limit); + if (content.length <= limit) { + return {content, compressed: true, usedTarget: target}; + } + } catch (error) { + logger.warn("compress.failed", { + provider: params.provider, + scope: params.scope, + target: target.model, + error: error instanceof Error ? error.message : String(error), + }); + } + } + + return {content: trimToLimit(trimmed, limit), compressed: true}; +} + +async function compressMemoryIfNeeded(params: { + userId: number; + scope: MemoryScope; + content: string; + context?: MemoryRuntimeContext; + limit?: number; +}): Promise<{content: string; compressed: boolean}> { + const {scope, context, limit = USER_MEMORY_MAX_CHARS} = params; + const result = await compressMemoryWithFallback({ + provider: context?.provider, + currentTarget: context?.runtimeTarget, + scope, + currentText: params.content, + limit, + }); + + if (!result.compressed) { + return result; + } + + if (result.content.length > limit) { + return {content: trimToLimit(result.content, limit), compressed: true}; + } + + return {content: result.content, compressed: true}; +} + +async function finalizeMemoryWrite(params: { + userId: number; + scope: MemoryScope; + content: string; + context?: MemoryRuntimeContext; +}): Promise<{filePath: string; content: string; compressed: boolean}> { + const {userId, scope, context} = params; + const compressed = await compressMemoryIfNeeded({userId, scope, content: params.content, context}); + const filePath = await writeMemoryFile(userId, scope, compressed.content); + return {filePath, content: compressed.content, compressed: compressed.compressed}; +} + +function findMemoryToolSpec(toolName: string): MemoryToolSpec | undefined { + return MEMORY_TOOL_SPECS.find(spec => spec.name === toolName); +} + +function isMemoryWriteTool(spec: MemoryToolSpec): spec is MemoryToolSpec & {kind: "write"; action: MemoryAction} { + return spec.kind === "write"; +} + +export async function buildUserMemoryPrompt(userId: number | undefined | null): Promise { + const normalizedUserId = typeof userId === "number" ? normalizeUserId(userId) : null; + if (!normalizedUserId) return undefined; + + const [userMemoryResult, systemMemoryResult] = await Promise.all([ + readUserMemory(normalizedUserId, "user"), + readUserMemory(normalizedUserId, "system"), + ]); + + const userMemory = userMemoryResult.success ? userMemoryResult.content : ""; + const systemMemory = systemMemoryResult.success ? systemMemoryResult.content : ""; + + const blocks: string[] = []; + if (systemMemory.trim()) { + blocks.push([ + "## Assistant memory (system.md)", + "This is information about the assistant and its behavior.", + systemMemory.trim(), + ].join("\n")); + } + if (userMemory.trim()) { + blocks.push([ + "## User memory (user.md)", + "This is information about the user.", + userMemory.trim(), + ].join("\n")); + } + + return blocks.length ? blocks.join("\n\n") : undefined; +} + +export async function readUserMemory(userId: number, scope: MemoryScope): Promise { + const normalizedUserId = normalizeUserId(userId); + if (!normalizedUserId) { + return {success: false, scope, error: "Invalid userId"}; + } + + try { + const content = await readMemoryFile(normalizedUserId, scope); + return { + success: true, + scope, + filePath: getMemoryFilePath(normalizedUserId, scope), + content, + chars: content.length, + compressed: false, + }; + } catch (error) { + return {success: false, scope, error: error instanceof Error ? error.message : String(error)}; + } +} + +export async function updateUserMemory(args: { + userId: number; + scope: MemoryScope; + action: MemoryAction; + content?: string; + context?: MemoryRuntimeContext; +}): Promise { + const normalizedUserId = normalizeUserId(args.userId); + if (!normalizedUserId) { + return {success: false, scope: args.scope, error: "Invalid userId"}; + } + + try { + const current = await readMemoryFile(normalizedUserId, args.scope); + let next = current; + + switch (args.action) { + case "add": { + const content = normalizeMemoryText(asNonEmptyString(args.content) ?? ""); + if (!content.trim()) { + return {success: false, scope: args.scope, error: "No content provided"}; + } + next = [current.trimEnd(), content.trim()].filter(Boolean).join(current.trim() ? "\n\n" : ""); + break; + } + case "replace": { + const content = normalizeMemoryText(asNonEmptyString(args.content) ?? ""); + next = content; + break; + } + case "remove": { + const needle = normalizeMemoryText(asNonEmptyString(args.content) ?? ""); + if (!needle.trim()) { + return {success: false, scope: args.scope, error: "No text to remove provided"}; + } + if (!current.includes(needle)) { + return {success: false, scope: args.scope, error: "Text not found in memory"}; + } + next = current.split(needle).join("").trim(); + break; + } + } + + const finalized = await finalizeMemoryWrite({userId: normalizedUserId, scope: args.scope, content: next, context: args.context}); + logger.debug("write.done", { + userId: normalizedUserId, + scope: args.scope, + chars: finalized.content.length, + compressed: finalized.compressed, + filePath: finalized.filePath, + }); + + return { + success: true, + scope: args.scope, + filePath: finalized.filePath, + content: finalized.content, + chars: finalized.content.length, + compressed: finalized.compressed, + }; + } catch (error) { + return {success: false, scope: args.scope, error: error instanceof Error ? error.message : String(error)}; + } +} + +export async function executeMemoryTool(toolName: MemoryToolName, args: {userId: number; content?: string}, context?: MemoryRuntimeContext): Promise { + const spec = findMemoryToolSpec(toolName); + if (!spec) { + return {success: false, scope: "user", error: `Unknown memory tool: ${toolName}`}; + } + + if (spec.kind === "read") { + return readUserMemory(args.userId, spec.scope); + } + + if (spec.kind === "delete") { + return deleteUserMemory(args.userId, spec.scope); + } + + if (!isMemoryWriteTool(spec)) { + return {success: false, scope: spec.scope, error: `Unsupported memory tool: ${toolName}`}; + } + + return updateUserMemory({ + userId: args.userId, + scope: spec.scope, + action: spec.action, + content: args.content, + context, + }); +} + +export async function deleteUserMemory(userId: number, scope: MemoryScope): Promise { + const normalizedUserId = normalizeUserId(userId); + if (!normalizedUserId) { + return {success: false, scope, error: "Invalid userId"}; + } + + const filePath = getMemoryFilePath(normalizedUserId, scope); + try { + await rm(filePath, {force: true}); + return {success: true, scope, filePath, content: "", chars: 0, compressed: false}; + } catch (error) { + return {success: false, scope, error: error instanceof Error ? error.message : String(error)}; + } +} diff --git a/src/ai/tools/utils.ts b/src/ai/tools/utils.ts index e48332a..38a787c 100644 --- a/src/ai/tools/utils.ts +++ b/src/ai/tools/utils.ts @@ -1,5 +1,5 @@ import {Ollama} from "ollama"; -import {toolsLogger} from "./tool-logger"; +import {toolsLogger} from "./tool-logger.js"; import {AiJsonObject, AiJsonValue} from "../tool-types"; import type {BoundaryValue} from "../../common/boundary-types"; diff --git a/src/ai/tools/weather.ts b/src/ai/tools/weather.ts index b4907d0..beb8b3d 100644 --- a/src/ai/tools/weather.ts +++ b/src/ai/tools/weather.ts @@ -1,11 +1,11 @@ import axios from "axios"; -import {toolsLogger} from "./tool-logger"; +import {toolsLogger} from "./tool-logger.js"; const logger = toolsLogger.child("weather"); -import {Environment} from "../../common/environment"; -import {logError} from "../../util/utils"; -import {AiJsonObject, AiTool} from "../tool-types"; -import {asNonEmptyString} from "./utils"; +import {Environment} from "../../common/environment.js"; +import {logError} from "../../util/utils.js"; +import {AiJsonObject, AiTool} from "../tool-types.js"; +import {asNonEmptyString} from "./utils.js"; export const getWeatherTool = { type: "function", diff --git a/src/ai/tools/web-search.ts b/src/ai/tools/web-search.ts index e50f5b2..90e50a6 100644 --- a/src/ai/tools/web-search.ts +++ b/src/ai/tools/web-search.ts @@ -1,11 +1,11 @@ import axios from "axios"; -import {toolsLogger} from "./tool-logger"; +import {toolsLogger} from "./tool-logger.js"; const logger = toolsLogger.child("brave-search"); -import {Environment} from "../../common/environment"; -import {logError} from "../../util/utils"; -import {AiJsonObject, AiJsonValue, AiTool} from "../tool-types"; -import {asBoolean, asNonEmptyString} from "./utils"; +import {Environment} from "../../common/environment.js"; +import {logError} from "../../util/utils.js"; +import {AiJsonObject, AiJsonValue, AiTool} from "../tool-types.js"; +import {asBoolean, asNonEmptyString} from "./utils.js"; type BraveSearchProfile = { name?: string; diff --git a/src/ai/unified-ai-response-pipeline.ts b/src/ai/unified-ai-response-pipeline.ts index 1e6b015..6347e7a 100644 --- a/src/ai/unified-ai-response-pipeline.ts +++ b/src/ai/unified-ai-response-pipeline.ts @@ -19,6 +19,7 @@ import { } from "./unified-ai-runner.shared"; import {runToolRankStage} from "./tool-rank-stage"; import {runOpenAi} from "./unified-ai-runner.openai"; +import {runOpenAiCompatible} from "./unified-ai-runner.openai-compatible"; import {runOllama} from "./unified-ai-runner.ollama"; import {runMistral} from "./unified-ai-runner.mistral"; import {summarizeModelOutput} from "./response-model-output"; @@ -80,6 +81,21 @@ async function runProviderModelCall(params: { switch (options.provider) { case AiProvider.OPENAI: + if (config.openAiBackend === "compatible") { + await runOpenAiCompatible( + options.msg, + prepared.chatMessages as OpenAIChatMessage[], + streamMessage, + signal, + options.stream ?? true, + options.msg, + config, + prepared.toolContext, + downloads, + ); + return; + } + await runOpenAi( options.msg, prepared.chatMessages as OpenAIChatMessage[], diff --git a/src/ai/unified-ai-runner.mistral.ts b/src/ai/unified-ai-runner.mistral.ts index 0a47746..7732584 100644 --- a/src/ai/unified-ai-runner.mistral.ts +++ b/src/ai/unified-ai-runner.mistral.ts @@ -7,6 +7,8 @@ import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../loggi import {AiProvider} from "../model/ai-provider"; import {getProviderAdapter} from "./provider-adapters"; import {runToolRankStage} from "./tool-rank-stage"; +import {ensureToolsSelected} from "./tool-mappers.js"; +import {MEMORY_TOOL_NAMES} from "./tools/user-memory.js"; import { MAX_TOOL_ROUNDS, @@ -66,7 +68,7 @@ export async function runMistral( streamMessage, signal, }); - const filteredTools = rankResult.filteredTools; + const filteredTools = ensureToolsSelected(availableTools, rankResult.filteredTools, MEMORY_TOOL_NAMES); const requestTools = filteredTools.length ? filteredTools : undefined; streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? ""); @@ -113,7 +115,11 @@ export async function runMistral( userId: msg.from?.id, toolCalls: calls, streamMessage, - toolContext, + toolContext: { + ...toolContext, + provider: AiProvider.MISTRAL, + runtimeTarget: config.mistralChatTarget, + }, toolMemory, adapter, appendTargets: [messages, requestMessages], @@ -183,7 +189,11 @@ export async function runMistral( userId: msg.from?.id, toolCalls: calls, streamMessage, - toolContext, + toolContext: { + ...toolContext, + provider: AiProvider.MISTRAL, + runtimeTarget: config.mistralChatTarget, + }, toolMemory, adapter, appendTargets: [messages, requestMessages], diff --git a/src/ai/unified-ai-runner.ollama.ts b/src/ai/unified-ai-runner.ollama.ts index d672432..1ceb4b0 100644 --- a/src/ai/unified-ai-runner.ollama.ts +++ b/src/ai/unified-ai-runner.ollama.ts @@ -15,6 +15,8 @@ import {createOllamaClient} from "./ai-runtime-target"; import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger"; import {getProviderAdapter} from "./provider-adapters"; import {runToolRankStage} from "./tool-rank-stage"; +import {ensureToolsSelected} from "./tool-mappers.js"; +import {MEMORY_TOOL_NAMES} from "./tools/user-memory.js"; import { allToolSchemaNames, @@ -203,7 +205,7 @@ export async function runOllama( signal, }); - const filteredTools = [...new Set(rankResult.filteredTools as Tool[])]; + const filteredTools = [...new Set(ensureToolsSelected(availableOllamaTools, rankResult.filteredTools as Tool[], MEMORY_TOOL_NAMES) as Tool[])]; activeToolNames = filteredTools.map(t => t.function.name ?? ""); if (filteredTools.length > 0) { request.tools = [...filteredTools]; @@ -297,7 +299,11 @@ export async function runOllama( userId: msg.from?.id, toolCalls: calls, streamMessage, - toolContext, + toolContext: { + ...toolContext, + provider: AiProvider.OLLAMA, + runtimeTarget: target, + }, toolMemory, adapter, appendTargets: [messages], @@ -429,7 +435,11 @@ export async function runOllama( userId: msg.from?.id, toolCalls: calls, streamMessage, - toolContext, + toolContext: { + ...toolContext, + provider: AiProvider.OLLAMA, + runtimeTarget: target, + }, toolMemory, adapter, appendTargets: [messages], diff --git a/src/ai/unified-ai-runner.openai-compatible.ts b/src/ai/unified-ai-runner.openai-compatible.ts new file mode 100644 index 0000000..11101b9 --- /dev/null +++ b/src/ai/unified-ai-runner.openai-compatible.ts @@ -0,0 +1,419 @@ +import {Message} from "typescript-telegram-bot-api"; +import type { + ChatCompletionCreateParamsNonStreaming, + ChatCompletionCreateParamsStreaming, + ChatCompletionTool, +} from "openai/resources/chat/completions"; +import {Environment} from "../common/environment.js"; +import {TelegramStreamMessage} from "./telegram-stream-message"; +import {ToolRuntimeContext} from "./tools/runtime"; +import {OpenAIChatMessage, OpenAICompatibleChatMessage} from "./openai-chat-message"; +import {createOpenAiClient} from "./ai-runtime-target"; +import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger"; +import type {BoundaryValue} from "../common/boundary-types.js"; +import { + AsyncIterableStream, + buildSystemInstruction, + MAX_TOOL_ROUNDS, + OpenAiChatCompletionResponseLike, + OpenAiChatCompletionStreamChunkLike, + RuntimeConfigSnapshot, + safeJsonParseObject, + ToolCallData, + ToolExecutionMemory, +} from "./unified-ai-runner.shared"; +import {mergeToolCallChunks, normalizeStreamingTextDelta} from "./provider-adapter-contract.js"; +import {buildUserMemoryPrompt} from "./tools/user-memory.js"; +import {executeToolBatchWithAdapter} from "./tool-batch-runner"; +import {decideToolLoopContinuation} from "./tool-loop-control"; +import {runToolLoopRounds} from "./tool-loop-runner"; +import {runSingleModelRequest} from "./model-call-stage"; +import {ensureToolsSelected, getOpenAICompatibleTools} from "./tool-mappers.js"; +import {MEMORY_TOOL_NAMES} from "./tools/user-memory.js"; +import {logError} from "../util/utils"; +import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings"; +import {AiDownloadedFile} from "./telegram-attachments"; +import {AiProvider} from "../model/ai-provider"; +import {getProviderAdapter} from "./provider-adapters"; +import {runToolRankStage} from "./tool-rank-stage"; +import type {AiProviderAdapter} from "./provider-adapters.js"; +import {tryToUploadFiles} from "./openai-upload-files.js"; +import {buildAssistantToolMessage, openAiResponseMessagesToChatCompletions} from "./openai-chat-completions.js"; + +function describeOpenAiCompatibleError(error: unknown): Record { + const err = error as { + message?: unknown; + status?: unknown; + code?: unknown; + type?: unknown; + error?: unknown; + } | undefined; + + return { + errorSummary: typeof err?.message === "string" ? err.message : String(error), + httpStatus: err?.status, + errorCode: err?.code, + errorType: err?.type, + }; +} + +async function executeChatCompletionWithOptionalToolFallback(params: { + openAi: ReturnType; + request: ChatCompletionCreateParamsNonStreaming | ChatCompletionCreateParamsStreaming; + signal: AbortSignal; + stream: boolean; +}): Promise { + try { + return await params.openAi.chat.completions.create(params.request as never, {signal: params.signal}) as T; + } catch (error) { + const requestWithTools = params.request as {tools?: unknown[]}; + if (!requestWithTools.tools || !Array.isArray(requestWithTools.tools) || requestWithTools.tools.length === 0) { + aiLog("error", "openai_compatible.request.failed", { + stream: params.stream, + hasTools: false, + error: describeOpenAiCompatibleError(error), + }); + throw error; + } + + aiLog("warn", "openai_compatible.tools.retry_without_tools", { + stream: params.stream, + error: describeOpenAiCompatibleError(error), + }); + + const retryRequest = {...params.request} as ChatCompletionCreateParamsNonStreaming | ChatCompletionCreateParamsStreaming & {tools?: unknown[]}; + delete retryRequest.tools; + + try { + return await params.openAi.chat.completions.create(retryRequest as never, {signal: params.signal}) as T; + } catch (retryError) { + aiLog("error", "openai_compatible.request.retry_without_tools.failed", { + stream: params.stream, + hasTools: true, + error: describeOpenAiCompatibleError(retryError), + }); + throw retryError; + } + } +} + +function makeChatCompletionAdapter(): AiProviderAdapter { + const baseAdapter = getProviderAdapter(AiProvider.OPENAI); + + return { + ...baseAdapter, + callModel: baseAdapter.callModel.bind(baseAdapter), + mapMessages(messages: readonly unknown[]): unknown[] { + return openAiResponseMessagesToChatCompletions(messages as OpenAIChatMessage[]); + }, + rankTools(config: RuntimeConfigSnapshot, options?: {forCreator?: boolean; vectorStoreIds?: string[]}): readonly BoundaryValue[] { + void config; + void options?.vectorStoreIds; + return getOpenAICompatibleTools(options?.forCreator) as BoundaryValue[]; + }, + extractTextDelta(input: unknown): string { + const chunk = input as OpenAiChatCompletionStreamChunkLike | undefined; + return chunk?.choices?.[0]?.delta?.content ?? ""; + }, + extractToolCalls(input: unknown): ToolCallData[] { + const response = input as OpenAiChatCompletionResponseLike | undefined; + const toolCalls = response?.choices?.[0]?.message?.tool_calls ?? []; + + return toolCalls + .map((call, index) => ({ + id: typeof call?.id === "string" && call.id.trim().length > 0 ? call.id : `openai_chat_${index}`, + name: typeof call?.function?.name === "string" ? call.function.name : typeof call?.name === "string" ? call.name : "", + argumentsText: typeof call?.function?.arguments === "string" + ? call.function.arguments + : JSON.stringify(call?.function?.arguments ?? call?.arguments ?? {}), + })) + .filter(call => call.name.length > 0); + }, + extractStreamingToolCalls(input: unknown): ToolCallData[] { + const chunk = input as OpenAiChatCompletionStreamChunkLike | undefined; + const toolCalls = chunk?.choices?.[0]?.delta?.tool_calls ?? []; + + return toolCalls + .map((call, index) => ({ + id: typeof call?.id === "string" && call.id.trim().length > 0 + ? call.id + : `openai_chat_${typeof call?.index === "number" ? call.index : index}`, + name: typeof call?.function?.name === "string" ? call.function.name : typeof call?.name === "string" ? call.name : "", + argumentsText: typeof call?.function?.arguments === "string" + ? call.function.arguments + : call?.function?.arguments + ? JSON.stringify(call.function.arguments) + : typeof call?.arguments === "string" + ? call.arguments + : "", + })) + .filter(call => call.id.length > 0); + }, + appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void { + for (const [index, call] of calls.entries()) { + messages.push({ + role: "tool", + tool_call_id: call.id, + content: results[index] ?? "", + }); + } + }, + finalize: baseAdapter.finalize.bind(baseAdapter), + }; +} + +export async function runOpenAiCompatible( + msg: Message, + messages: OpenAIChatMessage[], + streamMessage: TelegramStreamMessage, + signal: AbortSignal, + stream: boolean, + sourceMessage: Message, + config: RuntimeConfigSnapshot, + toolContext: ToolRuntimeContext, + downloads: AiDownloadedFile[] = [], +): Promise { + void downloads; + const runnerStartedAt = Date.now(); + const openAi = createOpenAiClient(config.openAiChatTarget); + const adapter = makeChatCompletionAdapter(); + const systemPrompt = buildSystemInstruction( + config, + DEFAULT_AI_RESPONSE_LANGUAGE, + false, + config.openAiChatTarget.systemPromptAdditions, + await buildUserMemoryPrompt(msg.from?.id), + ); + let conversationMessages = [...openAiResponseMessagesToChatCompletions(messages)]; + + if (systemPrompt.trim().length) { + conversationMessages.unshift({role: "system", content: systemPrompt}); + } + + const availableTools = getOpenAICompatibleTools(msg.from?.id === Environment.CREATOR_ID) as ChatCompletionTool[]; + + aiLog("info", "openai_compatible.run.start", { + stream, + target: aiLogProviderTarget(config.openAiChatTarget), + inputMessages: messages.length, + sourceMessage: aiLogMessageIdentity(sourceMessage), + hasToolInputFiles: !!toolContext.pythonInputFiles?.length, + backend: config.openAiBackend, + }); + + const toolMemory: ToolExecutionMemory = new Map(); + + try { + await runToolLoopRounds({ + maxRounds: MAX_TOOL_ROUNDS, + onRound: async (round) => { + const roundStartedAt = Date.now(); + aiLog("debug", "openai_compatible.round.start", {round, inputMessages: conversationMessages.length, stream}); + + const rankResult = await runToolRankStage({ + provider: AiProvider.OPENAI, + model: config.openAiChatTarget.model, + round, + config, + availableTools: availableTools as readonly BoundaryValue[], + messages, + streamMessage, + signal, + }); + + const requestTools = ensureToolsSelected( + availableTools, + rankResult.filteredTools as ChatCompletionTool[], + MEMORY_TOOL_NAMES, + ); + + if (!stream) { + const request: ChatCompletionCreateParamsNonStreaming = { + model: config.openAiChatTarget.model, + messages: conversationMessages, + tools: requestTools.length ? requestTools : undefined, + }; + + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => executeChatCompletionWithOptionalToolFallback({ + openAi, + request, + signal, + stream: false, + })), + }) as OpenAiChatCompletionResponseLike; + + const message = response.choices?.[0]?.message; + const responseText = typeof message?.content === "string" ? message.content : ""; + streamMessage.append(responseText); + aiLog("debug", "openai_compatible.response.received", { + round, + duration: aiLogDuration(roundStartedAt), + textChars: responseText.length, + hasToolCalls: !!message?.tool_calls?.length, + }); + + const calls = adapter.extractToolCalls(response); + aiLog(calls.length ? "info" : "success", calls.length ? "openai_compatible.tool_calls" : "openai_compatible.run.done", { + round, + duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), + calls: calls.map(call => ({ + id: call.id, + name: call.name, + arguments: safeJsonParseObject(call.argumentsText) + })), + }); + if (!calls.length) return {shouldContinue: false}; + + const toolCalls = calls.map(call => ({ + id: call.id, + name: call.name, + argumentsText: call.argumentsText, + })); + const toolMessages: OpenAICompatibleChatMessage[] = []; + const toolResults = await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls, + streamMessage, + toolContext: { + ...toolContext, + provider: AiProvider.OPENAI, + runtimeTarget: config.openAiChatTarget, + }, + toolMemory, + adapter, + appendTargets: [toolMessages], + }); + + const uploadFilesResult = await tryToUploadFiles(msg, toolResults); + if (uploadFilesResult.found && !uploadFilesResult.uploaded && uploadFilesResult.toolIndex >= 0) { + const toolMessage = toolMessages[uploadFilesResult.toolIndex]; + if (toolMessage && toolMessage.role === "tool") { + toolMessage.content = "Error: " + uploadFilesResult.error; + } + } + + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "openai_compatible.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + + conversationMessages = [...conversationMessages, buildAssistantToolMessage(calls, responseText), ...toolMessages]; + return {shouldContinue: true}; + } + + const request: ChatCompletionCreateParamsStreaming = { + model: config.openAiChatTarget.model, + messages: conversationMessages, + stream: true, + tools: requestTools.length ? requestTools : undefined, + }; + + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => executeChatCompletionWithOptionalToolFallback>({ + openAi, + request, + signal, + stream: true, + })), + }) as AsyncIterableStream; + + aiLog("debug", "openai_compatible.stream.open", {round}); + + let responseText = ""; + let toolCallState: ToolCallData[] = []; + for await (const chunk of response) { + if (signal.aborted) throw new Error("Aborted"); + + const deltaText = adapter.extractTextDelta(chunk); + if (deltaText) { + const appendedText = normalizeStreamingTextDelta(responseText, deltaText); + responseText += appendedText; + streamMessage.append(appendedText); + } + + const streamedCalls = adapter.extractStreamingToolCalls(chunk); + if (streamedCalls.length) { + toolCallState = mergeToolCallChunks(toolCallState, streamedCalls); + const activeCalls = toolCallState.filter(call => call.name.length > 0); + aiLog("info", "openai_compatible.stream.tool_call.added", { + round, + toolCalls: activeCalls.map(aiLogToolCall), + }); + streamMessage.setStatus(Environment.getUseToolText(activeCalls)); + await streamMessage.flush(); + } + } + + const calls = toolCallState.filter(call => call.name.length > 0); + aiLog(calls.length ? "info" : "success", calls.length ? "openai_compatible.tool_calls" : "openai_compatible.stream.done", { + round, + duration: aiLogDuration(roundStartedAt), + textChars: responseText.length, + calls: calls.map(call => ({ + id: call.id, + name: call.name, + arguments: safeJsonParseObject(call.argumentsText) + })), + }); + if (!calls.length) return {shouldContinue: false}; + + streamMessage.clearStatus(); + await streamMessage.flush(); + + const toolMessages: OpenAICompatibleChatMessage[] = []; + const toolResults = await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls: calls, + streamMessage, + toolContext: { + ...toolContext, + provider: AiProvider.OPENAI, + runtimeTarget: config.openAiChatTarget, + }, + toolMemory, + adapter, + appendTargets: [toolMessages], + }); + + const uploadFilesResult = await tryToUploadFiles(msg, toolResults); + if (uploadFilesResult.found && !uploadFilesResult.uploaded && uploadFilesResult.toolIndex >= 0) { + const toolMessage = toolMessages[uploadFilesResult.toolIndex]; + if (toolMessage && toolMessage.role === "tool") { + toolMessage.content = "Error: " + uploadFilesResult.error; + } + } + + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "openai_compatible.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + + conversationMessages = [...conversationMessages, buildAssistantToolMessage(calls, responseText), ...toolMessages]; + return {shouldContinue: true}; + }, + }); + } catch (error) { + aiLog("error", "openai_compatible.run.failed", { + duration: aiLogDuration(runnerStartedAt), + error: describeOpenAiCompatibleError(error), + }); + throw error; + } finally { + await adapter.finalize().catch(logError); + } +} diff --git a/src/ai/unified-ai-runner.openai.ts b/src/ai/unified-ai-runner.openai.ts index ce3f776..0a94037 100644 --- a/src/ai/unified-ai-runner.openai.ts +++ b/src/ai/unified-ai-runner.openai.ts @@ -12,6 +12,7 @@ import type { } from "openai/resources/responses/responses"; import {createOpenAiClient} from "./ai-runtime-target"; import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger"; +import {buildUserMemoryPrompt} from "./tools/user-memory.js"; import { AsyncIterableStream, @@ -29,23 +30,21 @@ import { showOpenAiGeneratedImage, ToolCallData, ToolExecutionMemory, - errorMessage, allToolSchemaNames } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {decideToolLoopContinuation} from "./tool-loop-control"; import {runToolLoopRounds} from "./tool-loop-runner"; import {runSingleModelRequest} from "./model-call-stage"; -import {bot} from "../index"; -import fs from "node:fs"; -import path from "node:path"; +import {ensureToolsSelected} from "./tool-mappers.js"; +import {MEMORY_TOOL_NAMES} from "./tools/user-memory.js"; import {logError} from "../util/utils"; -import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files"; import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings"; import {AiDownloadedFile} from "./telegram-attachments"; import {AiProvider} from "../model/ai-provider"; import {getProviderAdapter} from "./provider-adapters"; import {runToolRankStage} from "./tool-rank-stage"; +import {tryToUploadFiles} from "./openai-upload-files.js"; export async function runOpenAi( msg: Message, @@ -75,6 +74,7 @@ export async function runOpenAi( DEFAULT_AI_RESPONSE_LANGUAGE, false, config.openAiChatTarget.systemPromptAdditions, + await buildUserMemoryPrompt(msg.from?.id), ); aiLog("info", "openai.run.start", { @@ -115,9 +115,13 @@ export async function runOpenAi( tools.unshift(fileSearchTool); } } - return tools.length ? tools : undefined; + const withMemory = ensureToolsSelected(availableTools, tools, MEMORY_TOOL_NAMES); + return withMemory.length ? withMemory : undefined; })() - : (filteredTools.length ? filteredTools : undefined); + : (() => { + const withMemory = ensureToolsSelected(availableTools, filteredTools, MEMORY_TOOL_NAMES); + return withMemory.length ? withMemory : undefined; + })(); if (!stream) { const request: ResponseCreateParamsNonStreaming = { @@ -187,7 +191,11 @@ export async function runOpenAi( userId: msg.from?.id, toolCalls, streamMessage, - toolContext, + toolContext: { + ...toolContext, + provider: AiProvider.OPENAI, + runtimeTarget: config.openAiChatTarget, + }, toolMemory, adapter, appendTargets: [toolOutputs], @@ -397,7 +405,11 @@ export async function runOpenAi( userId: msg.from?.id, toolCalls, streamMessage, - toolContext, + toolContext: { + ...toolContext, + provider: AiProvider.OPENAI, + runtimeTarget: config.openAiChatTarget, + }, toolMemory, adapter, appendTargets: [toolOutputs], @@ -504,72 +516,6 @@ async function cleanupOpenAiDocumentRag(openAi: OpenAI, vectorStoreId: string, f } } -async function tryToUploadFiles( - msg: Message, - toolResults: string[] -): Promise< - | { found: false } - | { found: true, uploaded: true } - | { found: boolean, uploaded: false, error: string, toolIndex: number } -> { - let sendFileAttachment: { - result: SendFileAttachmentResult & { success: true }, - toolIndex: number - } | null = null; - - let found = false; - - try { - for (const [index, toolResult] of toolResults.entries()) { - const raw = JSON.parse(toolResult); - const res = SendFileAttachmentResultSchema.safeParse(raw); - - if (res.success) { - found = true; - - if (res.data.success) { - sendFileAttachment = {result: res.data, toolIndex: index}; - } - } - } - - if (!found) { - return {found: false}; - } - - const attachmentRoot = Environment.FILE_TOOLS_ROOT_DIR; - const attachmentPath = attachmentRoot - ? path.join( - attachmentRoot, - String(msg.from?.id), - sendFileAttachment?.result?.attachment?.relativePath ?? "", - ) - : ""; - - if (!fs.existsSync(attachmentPath)) { - throw new Error(`Attachment file does not exist: ${attachmentPath}`); - } - - await bot.sendDocument({ - chat_id: msg.chat.id, - reply_parameters: { - message_id: msg.message_id, - }, - document: fs.createReadStream(attachmentPath), - }); - - return {found: true, uploaded: true}; - } catch (e) { - logError(e instanceof Error ? e : String(e)); - return { - found: found, - uploaded: false, - error: errorMessage(e instanceof Error ? e : String(e)), - toolIndex: sendFileAttachment?.toolIndex ?? -1 - }; - } -} - // function openAiResponseContentToText(content: string | readonly { text?: string; refusal?: string }[]): string { // if (typeof content === "string") return content; // if (!Array.isArray(content)) return ""; diff --git a/src/ai/unified-ai-runner.shared.ts b/src/ai/unified-ai-runner.shared.ts index 4649b6e..79eadaf 100644 --- a/src/ai/unified-ai-runner.shared.ts +++ b/src/ai/unified-ai-runner.shared.ts @@ -4,7 +4,7 @@ import path from "node:path"; import type {BoundaryValue} from "../common/boundary-types"; import {AiProvider} from "../model/ai-provider.js"; import {ToolRankerFallbackPolicy} from "../common/policies.js"; -import {Environment} from "../common/environment.js"; +import {Environment, type OpenAiBackend} from "../common/environment.js"; import {delay, logError, replyToMessage} from "../util/utils.js"; import {MessageStore} from "../common/message-store.js"; import type {OpenAiResponseTool} from "./tool-mappers.js"; @@ -274,6 +274,7 @@ export type RuntimeConfigSnapshot = { openAiChatTarget: AiRuntimeTarget; openAiImageTarget: AiRuntimeTarget; openAiToolRankerTarget?: AiRuntimeTarget; + openAiBackend: OpenAiBackend; }; export function snapshotRuntimeConfig(): RuntimeConfigSnapshot { @@ -307,9 +308,14 @@ export function snapshotRuntimeConfig(): RuntimeConfigSnapshot { openAiChatTarget: resolveAiRuntimeTarget(AiProvider.OPENAI, "chat"), openAiImageTarget: resolveAiRuntimeTarget(AiProvider.OPENAI, "outputImages"), openAiToolRankerTarget: resolveAiRuntimeTarget(AiProvider.OPENAI, "toolRank"), + openAiBackend: Environment.OPENAI_BACKEND, }; } +export function isOpenAiCompatibleBackend(config: RuntimeConfigSnapshot): boolean { + return config.openAiBackend === "compatible"; +} + export function getMessageImageParts(part: MessagePart): MessageImagePart[] { if (part.imageParts?.length) return part.imageParts; return (part.images ?? []).map(data => ({data, mimeType: "image/jpeg"})); @@ -382,11 +388,13 @@ export function buildSystemInstruction( responseLanguage: UserAiResponseLanguage, includePythonToolPrompt: boolean, additions?: string | null, + memoryInstruction?: string | null, ): string { return [ config.useSystemPrompt ? getResponseLanguageInstruction(responseLanguage) : null, config.systemPrompt && config.useSystemPrompt ? config.systemPrompt : null, additions?.trim() ? additions.trim() : null, + memoryInstruction?.trim() ? memoryInstruction.trim() : null, includePythonToolPrompt ? pythonInterpreterToolPrompt : null, ].filter(Boolean).join("\n\n"); } @@ -1117,19 +1125,31 @@ export async function executeTool( } } -export function toolResourceKeys(toolCall: ToolCallData): string[] { +export function toolResourceKeys(toolCall: ToolCallData, userId?: number | undefined | null): string[] { const args = safeJsonParseObject(toolCall.argumentsText); const pathValue = typeof args.path === "string" ? args.path : undefined; const sourcePath = typeof args.sourcePath === "string" ? args.sourcePath : undefined; const targetPath = typeof args.targetPath === "string" ? args.targetPath : undefined; + const memoryScope = toolCall.name.endsWith("_user_info") ? "user" + : toolCall.name.endsWith("_system_info") ? "system" + : undefined; switch (toolCall.name) { + case "read_user_info": + case "read_system_info": case "get_datetime": case "web_search": case "get_weather": case "read_file": case "list_directory": return []; + case "add_user_info": + case "add_system_info": + case "remove_user_info": + case "remove_system_info": + case "replace_user_info": + case "replace_system_info": + return userId && memoryScope ? [`memory:${userId}:${memoryScope}`] : []; case "create_file": case "create_directory": case "update_file": @@ -1162,7 +1182,7 @@ export async function executeScheduledTool( message: TelegramStreamMessage, context: ToolRuntimeContext, ): Promise { - const keys = toolResourceKeys(toolCall); + const keys = toolResourceKeys(toolCall, userId); if (!keys.length) return executeTool(userId, toolCall, message, context); return runWithToolLocks(keys, () => executeTool(userId, toolCall, message, context)); } diff --git a/src/ai/unified-ai-runner.tool-ranker.ts b/src/ai/unified-ai-runner.tool-ranker.ts index c74b28d..e144aa6 100644 --- a/src/ai/unified-ai-runner.tool-ranker.ts +++ b/src/ai/unified-ai-runner.tool-ranker.ts @@ -1,4 +1,4 @@ -import {ChatCompletionMessageParam} from "openai/resources/chat/completions"; +import type {ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam} from "openai/resources/chat/completions"; import {ChatRequest} from "ollama"; import {BoundaryValue} from "../common/boundary-types.js"; import {ToolRankerFallbackPolicy} from "../common/policies.js"; @@ -107,7 +107,7 @@ export class ToolRanker { target: aiLogProviderTarget(target), fallbackTarget: aiLogProviderTarget(mainModelTarget), duration: aiLogDuration(startedAt), - error: failureMessage, + errorSummary: failureMessage, }); const fallbackRanker = buildToolRankerPrompt( @@ -142,7 +142,7 @@ export class ToolRanker { target: aiLogProviderTarget(target), fallbackTarget: aiLogProviderTarget(mainModelTarget), duration: aiLogDuration(startedAt), - error: fallbackErrorMessage, + errorSummary: fallbackErrorMessage, }); failureMessage = fallbackErrorMessage; @@ -155,7 +155,7 @@ export class ToolRanker { target: aiLogProviderTarget(target), fallbackPolicy, duration: aiLogDuration(startedAt), - error: failureMessage, + errorSummary: failureMessage, }); return resolveToolRankerFallbackSelection({ @@ -227,12 +227,19 @@ export class ToolRanker { {role: "user", content: userQuery}, ] satisfies ChatCompletionMessageParam[]; - // gpt-5 family ranker targets reject temperature=0; use the model default instead. - const response = await openAi.chat.completions.create({ + // OpenAI-compatible servers often reject `response_format`, so keep JSON mode + // only for official OpenAI endpoints. + const request: ChatCompletionCreateParamsNonStreaming = { model: target.model, messages, - response_format: {type: "json_object"}, - }); + }; + + if (!target.baseUrl) { + // gpt-5 family ranker targets reject temperature=0; use the model default instead. + request.response_format = {type: "json_object"}; + } + + const response = await openAi.chat.completions.create(request); return response.choices[0]?.message?.content?.trim() ?? ""; } diff --git a/src/common/environment.ts b/src/common/environment.ts index a998b47..dcc1f44 100644 --- a/src/common/environment.ts +++ b/src/common/environment.ts @@ -14,6 +14,13 @@ import type {ToolCallData} from "../ai/unified-ai-runner.js"; import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator.js"; import {Localization, type LocalizationParams} from "./localization.js"; +export const OpenAiBackendModes = { + OFFICIAL: "official", + COMPATIBLE: "compatible", +} as const; + +export type OpenAiBackend = typeof OpenAiBackendModes[keyof typeof OpenAiBackendModes]; + function parseBooleanLike(value: string): boolean { const normalized = value.trim().toLowerCase(); return ["true", "t", "y", "1"].includes(normalized); @@ -245,6 +252,10 @@ const RuntimeEnvSchema = z.object({ OPENAI_BASE_URL: optionalStringSchema, OPENAI_API_KEY: optionalStringSchema, + OPENAI_BACKEND: enumWithDefaultSchema( + OpenAiBackendModes, + OpenAiBackendModes.OFFICIAL, + ), OPENAI_MODEL: stringWithDefaultSchema("gpt-4.1-nano"), OPENAI_IMAGE_MODEL: stringWithDefaultSchema("gpt-image-1-mini"), OPENAI_TRANSCRIPTION_MODEL: stringWithDefaultSchema("gpt-4o-mini-transcribe"), @@ -343,6 +354,7 @@ export class Environment { static OPENAI_BASE_URL?: string; static OPENAI_API_KEY?: string; + static OPENAI_BACKEND: OpenAiBackend = OpenAiBackendModes.OFFICIAL; static OPENAI_MODEL: string = ""; static OPENAI_IMAGE_MODEL: string = ""; static OPENAI_TRANSCRIPTION_MODEL: string = ""; @@ -1881,6 +1893,7 @@ export class Environment { Environment.OPENAI_BASE_URL = env.OPENAI_BASE_URL; Environment.OPENAI_API_KEY = env.OPENAI_API_KEY; + Environment.OPENAI_BACKEND = env.OPENAI_BACKEND; Environment.OPENAI_MODEL = env.OPENAI_MODEL; Environment.OPENAI_IMAGE_MODEL = env.OPENAI_IMAGE_MODEL; Environment.OPENAI_TRANSCRIPTION_MODEL = env.OPENAI_TRANSCRIPTION_MODEL; @@ -2081,6 +2094,10 @@ export class Environment { this.OPENAI_API_KEY = newAIApiKey; } + static setOpenAIBackend(newBackend: OpenAiBackend): void { + this.OPENAI_BACKEND = newBackend; + } + static setOpenAIModel(newModel: string): void { this.OPENAI_MODEL = newModel; } diff --git a/src/index.ts b/src/index.ts index 7473b84..3e90ecd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -194,6 +194,7 @@ export const filesDir = path.join(Environment.DATA_PATH, "files"); export const NOTES_HEADER = "## Notes\n"; export const notesDir = path.join(Environment.DATA_PATH, "notes"); export const notesRootFile = path.join(notesDir, "index.md"); +export const memoryDir = path.join(Environment.DATA_PATH, "memory"); const logger = appLogger.child("main"); @@ -262,7 +263,7 @@ async function main() { }); await measureStartupStep("environment.load", () => Environment.load()); - const dirsToCheck = [cacheDir, photoDir, photoGenDir, documentDir, audioDir, videoDir, videoNotesDir, videoTempDir, notesDir, filesDir]; + const dirsToCheck = [cacheDir, photoDir, photoGenDir, documentDir, audioDir, videoDir, videoNotesDir, videoTempDir, notesDir, memoryDir, filesDir]; await measureStartupStep("prepare_directories", () => { const created: string[] = []; for (const dir of dirsToCheck) { diff --git a/test/openai-backend.test.mjs b/test/openai-backend.test.mjs new file mode 100644 index 0000000..d3eb8fc --- /dev/null +++ b/test/openai-backend.test.mjs @@ -0,0 +1,14 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const {Environment} = await import("../dist/common/environment.js"); + +test("openai backend defaults to official", () => { + assert.equal(Environment.OPENAI_BACKEND, "official"); +}); + +test("openai backend setter updates runtime config", () => { + Environment.setOpenAIBackend("compatible"); + assert.equal(Environment.OPENAI_BACKEND, "compatible"); + Environment.setOpenAIBackend("official"); +}); diff --git a/test/openai-compatible-integration.test.mjs b/test/openai-compatible-integration.test.mjs new file mode 100644 index 0000000..af5f7f9 --- /dev/null +++ b/test/openai-compatible-integration.test.mjs @@ -0,0 +1,43 @@ +import test from "node:test"; +import assert from "node:assert/strict"; +import {OpenAI} from "openai"; + +const {extractOpenAiChatToolCalls} = await import("../dist/ai/provider-adapter-contract.js"); + +const baseURL = process.env.OPENAI_COMPATIBLE_TEST_BASE_URL; +const model = process.env.OPENAI_COMPATIBLE_TEST_MODEL; +const apiKey = process.env.OPENAI_COMPATIBLE_TEST_API_KEY ?? process.env.OPENAI_API_KEY ?? "test"; + +test("openai-compatible chat.completions tool loop works on a real server", {skip: !baseURL || !model}, async () => { + const client = new OpenAI({baseURL, apiKey}); + + const response = await client.chat.completions.create({ + model, + temperature: 0, + messages: [ + {role: "system", content: "You must call the ping tool exactly once. Do not answer in plain text."}, + {role: "user", content: "ping"}, + ], + tools: [{ + type: "function", + function: { + name: "ping", + description: "Return a ping token.", + parameters: { + type: "object", + properties: {}, + additionalProperties: false, + }, + }, + }], + tool_choice: { + type: "function", + function: {name: "ping"}, + }, + }); + + const calls = extractOpenAiChatToolCalls(response); + + assert.equal(calls.length, 1); + assert.equal(calls[0].name, "ping"); +}); diff --git a/test/provider-adapter-contract.test.mjs b/test/provider-adapter-contract.test.mjs index d100690..b8423b7 100644 --- a/test/provider-adapter-contract.test.mjs +++ b/test/provider-adapter-contract.test.mjs @@ -5,6 +5,11 @@ const { extractOpenAiToolCalls, extractOpenAiStreamingToolCalls, extractOpenAiTextDelta, + extractOpenAiChatToolCalls, + extractOpenAiChatStreamingToolCalls, + extractOpenAiChatTextDelta, + mergeToolCallChunks, + normalizeStreamingTextDelta, extractMistralToolCalls, extractMistralTextDelta, extractOllamaToolCalls, @@ -42,6 +47,62 @@ test("openai contract extracts text delta and function calls", () => { assert.equal(streamed[0].name, "search_files"); }); +test("openai chat contract extracts text delta and tool calls", () => { + assert.equal(extractOpenAiChatTextDelta({choices: [{delta: {content: "hello chat"}}]}), "hello chat"); + assert.equal(normalizeStreamingTextDelta("hel", "hello"), "lo"); + assert.equal(normalizeStreamingTextDelta("hel", "lo"), "lo"); + + const calls = extractOpenAiChatToolCalls({ + choices: [{ + message: { + tool_calls: [{ + id: "chat-1", + function: { + name: "read_user_info", + arguments: "{\"userId\":123}", + }, + }], + }, + }], + }); + + assert.equal(calls.length, 1); + assert.equal(calls[0].id, "chat-1"); + assert.equal(calls[0].name, "read_user_info"); + + const streamed = extractOpenAiChatStreamingToolCalls({ + choices: [{ + delta: { + tool_calls: [{ + index: 0, + id: "chat-2", + function: { + name: "write_note", + arguments: "{\"text\":", + }, + }], + }, + }], + }); + + assert.equal(streamed.length, 1); + assert.equal(streamed[0].id, "chat-2"); + assert.equal(streamed[0].name, "write_note"); + assert.equal(streamed[0].argumentsText, "{\"text\":"); + + const merged = mergeToolCallChunks([ + {id: "chat-2", name: "", argumentsText: "{\"text\":"}, + ], [{ + id: "chat-2", + name: "write_note", + argumentsText: "\"hello\"}", + }]); + + assert.equal(merged.length, 1); + assert.equal(merged[0].name, "write_note"); + assert.equal(merged[0].argumentsText, "{\"text\":\"hello\"}"); +}); + test("mistral contract extracts content and tool calls", () => { assert.equal(extractMistralTextDelta({ content: [{text: "hello"}, {text: " world"}], diff --git a/test/tool-ranker.test.mjs b/test/tool-ranker.test.mjs index 37fb28e..76265ea 100644 --- a/test/tool-ranker.test.mjs +++ b/test/tool-ranker.test.mjs @@ -86,6 +86,19 @@ test("prompt includes search files routing example for usage search", () => { assert.ok(prompt.includes(JSON.stringify({toolNames: ["search_files"]}))); }); +test("prompt includes memory routing examples for remember requests", () => { + const prompt = promptFor("no_tool", "read_user_info", "add_user_info", "remove_user_info", "replace_user_info", "delete_user_info"); + + assert.ok(prompt.includes("что ты помнишь обо мне?")); + assert.ok(prompt.includes("запомни, что меня зовут Иван")); + assert.ok(prompt.includes("забудь, что я люблю кофе")); + assert.ok(prompt.includes("забудь всё обо мне и запиши только это")); + assert.ok(prompt.includes("удали всю память обо мне")); + assert.ok(prompt.includes("inspect remembered user info -> read_user_info")); + assert.ok(prompt.includes("remember a new user fact -> add_user_info")); + assert.ok(prompt.includes(JSON.stringify({toolNames: ["add_user_info"]}))); +}); + test("prompt includes edit file patch routing example for targeted edits", () => { const prompt = promptFor("no_tool", "edit_file_patch"); diff --git a/test/user-memory.test.mjs b/test/user-memory.test.mjs new file mode 100644 index 0000000..85a81e3 --- /dev/null +++ b/test/user-memory.test.mjs @@ -0,0 +1,197 @@ +import test from "node:test"; +import assert from "node:assert/strict"; +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; + +const {Environment} = await import("../dist/common/environment.js"); +const { + buildUserMemoryPrompt, + compressMemoryWithFallback, + deleteUserMemory, + getMemoryFilePath, + readUserMemory, + updateUserMemory, +} = await import("../dist/ai/tools/user-memory.js"); +const {AiProvider} = await import("../dist/model/ai-provider.js"); + +function makeTempDataPath() { + return fs.mkdtempSync(path.join(os.tmpdir(), "tg-chat-bot-memory-")); +} + +function withEnv(vars, fn) { + const snapshot = new Map(); + for (const [key, value] of Object.entries(vars)) { + snapshot.set(key, process.env[key]); + if (value === undefined) { + delete process.env[key]; + } else { + process.env[key] = value; + } + } + + return Promise.resolve(fn()).finally(() => { + for (const [key, value] of snapshot.entries()) { + if (value === undefined) { + delete process.env[key]; + } else { + process.env[key] = value; + } + } + }); +} + +test("memory storage supports append replace and remove", async () => { + const oldDataPath = Environment.DATA_PATH; + Environment.DATA_PATH = makeTempDataPath(); + + try { + const userId = 475823381; + + let result = await updateUserMemory({ + userId, + scope: "user", + action: "replace", + content: "# Profile\nLikes tea", + }); + assert.equal(result.success, true); + + result = await updateUserMemory({ + userId, + scope: "user", + action: "add", + content: "Prefers concise answers", + }); + assert.equal(result.success, true); + assert.match(result.content, /Likes tea/); + assert.match(result.content, /Prefers concise answers/); + + result = await updateUserMemory({ + userId, + scope: "user", + action: "remove", + content: "Prefers concise answers", + }); + assert.equal(result.success, true); + assert.doesNotMatch(result.content, /Prefers concise answers/); + + const readback = await readUserMemory(userId, "user"); + assert.equal(readback.success, true); + assert.equal(readback.filePath, getMemoryFilePath(userId, "user")); + assert.match(readback.content, /Likes tea/); + } finally { + Environment.DATA_PATH = oldDataPath; + } +}); + +test("memory delete removes the file", async () => { + const oldDataPath = Environment.DATA_PATH; + Environment.DATA_PATH = makeTempDataPath(); + + try { + const userId = 999; + await updateUserMemory({userId, scope: "user", action: "replace", content: "hello"}); + const deleted = await deleteUserMemory(userId, "user"); + assert.equal(deleted.success, true); + const readback = await readUserMemory(userId, "user"); + assert.equal(readback.success, true); + assert.equal(readback.content, ""); + } finally { + Environment.DATA_PATH = oldDataPath; + } +}); + +test("memory prompt combines system and user files", async () => { + const oldDataPath = Environment.DATA_PATH; + Environment.DATA_PATH = makeTempDataPath(); + + try { + const userId = 1234; + + await updateUserMemory({ + userId, + scope: "system", + action: "replace", + content: "Ты зовешься Евлампий.", + }); + await updateUserMemory({ + userId, + scope: "user", + action: "replace", + content: "Пользователь любит короткие ответы.", + }); + + const prompt = await buildUserMemoryPrompt(userId); + assert(prompt); + assert.equal(prompt?.includes("## Assistant memory (system.md)"), true); + assert.equal(prompt?.includes("This is information about the assistant and its behavior."), true); + assert.equal(prompt?.includes("## User memory (user.md)"), true); + assert.equal(prompt?.includes("This is information about the user."), true); + assert(prompt.indexOf("## Assistant memory (system.md)") < prompt.indexOf("## User memory (user.md)")); + } finally { + Environment.DATA_PATH = oldDataPath; + } +}); + +test("memory compression falls back to current target when explicit target fails", async () => { + await withEnv({ + OLLAMA_MEMORY_COMPRESS_MODEL: "memory-compress-model", + OLLAMA_CHAT_MODEL: "chat-model", + }, async () => { + const calls = []; + const result = await compressMemoryWithFallback( + { + provider: AiProvider.OLLAMA, + currentTarget: { + provider: AiProvider.OLLAMA, + purpose: "chat", + model: "chat-model", + }, + scope: "system", + currentText: "x".repeat(1200), + limit: 1000, + }, + async ({target}) => { + calls.push(target.model); + if (target.model === "memory-compress-model") { + throw new Error("boom"); + } + return "short summary"; + }, + ); + + assert.deepEqual(calls, ["memory-compress-model", "chat-model"]); + assert.equal(result.content, "short summary"); + assert.equal(result.compressed, true); + }); +}); + +test("memory compression uses current target when no separate target exists", async () => { + await withEnv({ + OLLAMA_MEMORY_COMPRESS_MODEL: undefined, + OLLAMA_CHAT_MODEL: "chat-model", + }, async () => { + const calls = []; + const result = await compressMemoryWithFallback( + { + provider: AiProvider.OLLAMA, + currentTarget: { + provider: AiProvider.OLLAMA, + purpose: "chat", + model: "chat-model", + }, + scope: "user", + currentText: "x".repeat(1200), + limit: 1000, + }, + async ({target}) => { + calls.push(target.model); + return "summary"; + }, + ); + + assert.deepEqual(calls, ["chat-model"]); + assert.equal(result.content, "summary"); + assert.equal(result.compressed, true); + }); +});