diff --git a/PIPELINE_TODO.md b/PIPELINE_TODO.md index bad49ce..c862cfa 100644 --- a/PIPELINE_TODO.md +++ b/PIPELINE_TODO.md @@ -46,27 +46,27 @@ - [x] Для Mistral сохранять `libraryId`. - [x] Для Mistral сохранять uploaded document ids. - [x] Для Mistral сохранять source file mapping: local attachment -> Mistral document id. -- [ ] Добавить единый `providerState` schema для всех providers. -- [ ] Добавить tests на сериализацию `RagArtifact`. -- [ ] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context. +- [x] Добавить единый `providerState` schema для всех providers. +- [x] Добавить tests на сериализацию `RagArtifact`. +- [x] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context. ## 4. Вынести provider runners в adapter layer -- [ ] Ввести интерфейс `AiProviderAdapter`. -- [ ] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`. -- [ ] Реализовать `OpenAiProviderAdapter`. -- [ ] Реализовать `MistralProviderAdapter`. -- [ ] Реализовать `OllamaProviderAdapter`. -- [ ] Перенести provider-specific tool schema mapping внутрь adapter-ов. -- [ ] Перенести provider-specific streaming parsing внутрь adapter-ов. -- [ ] Перенести provider-specific tool result append внутрь adapter-ов. -- [ ] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом. -- [ ] Оставить compatibility wrappers для текущих imports. -- [ ] Добавить tests на adapter contract без реальных API. +- [x] Ввести интерфейс `AiProviderAdapter`. +- [x] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`. +- [x] Реализовать `OpenAiProviderAdapter`. +- [x] Реализовать `MistralProviderAdapter`. +- [x] Реализовать `OllamaProviderAdapter`. +- [x] Перенести provider-specific tool schema mapping внутрь adapter-ов. +- [x] Перенести provider-specific streaming parsing внутрь adapter-ов. +- [x] Перенести provider-specific tool result append внутрь adapter-ов. +- [x] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом. +- [x] Оставить compatibility wrappers для текущих imports. +- [x] Добавить tests на adapter contract без реальных API. ## 5. Сделать tool-ranker полноценным pipeline stage -- [ ] Вынести вызов `ToolRanker.selectTools(...)` из provider runners. +- [x] Вынести вызов `ToolRanker.selectTools(...)` из provider runners. - [ ] Добавить stage `tool_rank`, который работает через provider adapter. - [ ] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker. - [ ] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`. diff --git a/src/ai/provider-adapter-contract.ts b/src/ai/provider-adapter-contract.ts new file mode 100644 index 0000000..7e4d8fd --- /dev/null +++ b/src/ai/provider-adapter-contract.ts @@ -0,0 +1,112 @@ +import type {ToolCallData} from "./unified-ai-runner.shared.js"; +import type {ResponseStreamEvent} from "openai/resources/responses/responses"; + +function isRecord(value: unknown): value is Record { + return !!value && typeof value === "object" && !Array.isArray(value); +} + +function normalizeToolCallId(value: unknown, fallback: string): string { + return typeof value === "string" && value.trim().length > 0 ? value : fallback; +} + +function normalizeToolArguments(value: unknown): string { + if (typeof value === "string") return value; + return JSON.stringify(value ?? {}); +} + +export function extractOpenAiToolCalls(response: unknown): ToolCallData[] { + const output = isRecord(response) && Array.isArray(response.output) ? response.output : []; + + return output + .filter(item => isRecord(item) && item.type === "function_call" && (typeof item.call_id === "string" || typeof item.name === "string")) + .map((item, index) => ({ + id: normalizeToolCallId(item.call_id, `openai_${index}`), + name: typeof item.name === "string" ? item.name : "", + argumentsText: normalizeToolArguments(item.arguments), + })) + .filter(call => call.name.length > 0); +} + +export function extractOpenAiTextDelta(input: unknown): string { + const event = input as ResponseStreamEvent | undefined; + return event?.type === "response.output_text.delta" ? event.delta ?? "" : ""; +} + +export function extractOpenAiStreamingToolCalls(input: unknown): ToolCallData[] { + const event = input as ResponseStreamEvent | undefined; + if (event?.type === "response.output_item.added" && isRecord(event.item) && event.item.type === "function_call") { + return extractOpenAiToolCalls({ + output: [{ + type: "function_call", + call_id: event.item.call_id ?? event.item.id, + name: event.item.name, + arguments: event.item.arguments, + }], + }); + } + + return []; +} + +export function extractMistralToolCalls(calls: unknown): ToolCallData[] { + const normalized = Array.isArray(calls) + ? calls + : isRecord(calls) && (Array.isArray(calls.toolCalls) || Array.isArray(calls.tool_calls)) + ? (calls.toolCalls ?? calls.tool_calls) + : []; + + if (!Array.isArray(normalized)) return []; + + return normalized + .map((item, index) => { + const call = isRecord(item) ? item : {}; + const fn = isRecord(call.function) ? call.function : undefined; + const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : ""; + return { + id: normalizeToolCallId(call.id, `mistral_${index}`), + name, + argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments), + }; + }) + .filter(call => call.name.length > 0); +} + +export function extractMistralTextDelta(input: unknown): string { + const delta = isRecord(input) ? input : {}; + const content = delta.content; + if (typeof content === "string") return content; + if (Array.isArray(content)) { + return content + .map(part => isRecord(part) && typeof part.text === "string" ? part.text : "") + .join(""); + } + return ""; +} + +export function extractOllamaToolCalls(calls: unknown): ToolCallData[] { + const normalized = Array.isArray(calls) + ? calls + : isRecord(calls) && Array.isArray(calls.tool_calls) + ? calls.tool_calls + : []; + + if (!Array.isArray(normalized)) return []; + + return normalized + .map((item, index) => { + const call = isRecord(item) ? item : {}; + const fn = isRecord(call.function) ? call.function : undefined; + const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : ""; + return { + id: normalizeToolCallId(call.id, `ollama_${index}`), + name, + argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments), + }; + }) + .filter(call => call.name.length > 0); +} + +export function extractOllamaTextDelta(input: unknown): string { + const chunk = isRecord(input) ? input.message : undefined; + return isRecord(chunk) && typeof chunk.content === "string" ? chunk.content : ""; +} diff --git a/src/ai/provider-adapters.ts b/src/ai/provider-adapters.ts new file mode 100644 index 0000000..575d1bd --- /dev/null +++ b/src/ai/provider-adapters.ts @@ -0,0 +1,196 @@ +import {AiProvider} from "../model/ai-provider.js"; +import type {BoundaryValue} from "../common/boundary-types.js"; +import type {RuntimeConfigSnapshot, ToolCallData} from "./unified-ai-runner.shared.js"; +import {getMistralTools, getOllamaTools, getOpenAIResponsesTools, getOpenAICodeInterpreterTool} from "./tool-mappers.js"; +import type {MistralChatMessage as MistralMessageType} from "./mistral-chat-message.js"; +import type {OpenAIChatMessage as OpenAiMessageType} from "./openai-chat-message.js"; +import type {Message as OllamaMessage} from "ollama"; +import { + extractMistralTextDelta, + extractMistralToolCalls, + extractOllamaTextDelta, + extractOllamaToolCalls, + extractOpenAiTextDelta, + extractOpenAiStreamingToolCalls, + extractOpenAiToolCalls, +} from "./provider-adapter-contract.js"; + +export type ProviderRankToolOptions = { + forCreator?: boolean; + vectorStoreIds?: string[]; +}; + +export interface AiProviderAdapter { + readonly provider: AiProvider; + mapMessages(messages: readonly unknown[]): unknown[]; + rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[]; + callModel(request: unknown, execute: () => Promise): Promise; + extractTextDelta(input: unknown): string; + extractToolCalls(input: unknown): ToolCallData[]; + extractStreamingToolCalls(input: unknown): ToolCallData[]; + appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void; + finalize(): Promise; +} + +function appendOllamaToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void { + for (const [index, call] of calls.entries()) { + messages.push({ + role: "tool", + content: results[index] ?? "", + tool_name: call.name, + }); + } +} + +class OpenAiProviderAdapter implements AiProviderAdapter { + readonly provider = AiProvider.OPENAI; + + mapMessages(messages: readonly unknown[]): unknown[] { + return messages as OpenAiMessageType[]; + } + + rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] { + const tools: BoundaryValue[] = [ + ...getOpenAIResponsesTools(options?.forCreator) as BoundaryValue[], + getOpenAICodeInterpreterTool() as BoundaryValue, + { + type: "image_generation", + model: config.openAiImageTarget.model, + size: "auto", + moderation: "low", + output_format: "png", + partial_images: 3, + }, + {type: "web_search"}, + ]; + + if (options?.vectorStoreIds?.length) { + tools.unshift({ + type: "file_search", + vector_store_ids: options.vectorStoreIds, + }); + } + + return tools; + } + + async callModel(_request: unknown, execute: () => Promise): Promise { + return execute(); + } + + extractTextDelta(input: unknown): string { + return extractOpenAiTextDelta(input); + } + + extractToolCalls(input: unknown): ToolCallData[] { + return extractOpenAiToolCalls(input); + } + + extractStreamingToolCalls(input: unknown): ToolCallData[] { + return extractOpenAiStreamingToolCalls(input); + } + + appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void { + for (const [index, call] of calls.entries()) { + messages.push({ + type: "function_call_output", + call_id: call.id, + output: results[index] ?? "", + }); + } + } + + async finalize(): Promise { + return; + } +} + +class MistralProviderAdapter implements AiProviderAdapter { + readonly provider = AiProvider.MISTRAL; + + mapMessages(messages: readonly unknown[]): unknown[] { + return messages as MistralMessageType[]; + } + + rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] { + return getMistralTools(options?.forCreator) as BoundaryValue[]; + } + + async callModel(_request: unknown, execute: () => Promise): Promise { + return execute(); + } + + extractTextDelta(input: unknown): string { + return extractMistralTextDelta(input); + } + + extractToolCalls(input: unknown): ToolCallData[] { + return extractMistralToolCalls(input); + } + + extractStreamingToolCalls(input: unknown): ToolCallData[] { + return this.extractToolCalls(input); + } + + appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void { + for (const [index, call] of calls.entries()) { + messages.push({ + role: "tool", + name: call.name, + toolCallId: call.id, + content: results[index] ?? "", + }); + } + } + + async finalize(): Promise { + return; + } +} + +class OllamaProviderAdapter implements AiProviderAdapter { + readonly provider = AiProvider.OLLAMA; + + mapMessages(messages: readonly unknown[]): unknown[] { + return messages as OllamaMessage[]; + } + + rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] { + return getOllamaTools(options?.forCreator) as BoundaryValue[]; + } + + async callModel(_request: unknown, execute: () => Promise): Promise { + return execute(); + } + + extractTextDelta(input: unknown): string { + return extractOllamaTextDelta(input); + } + + extractToolCalls(input: unknown): ToolCallData[] { + return extractOllamaToolCalls(input); + } + + extractStreamingToolCalls(input: unknown): ToolCallData[] { + return this.extractToolCalls(input); + } + + appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void { + appendOllamaToolResults(messages, calls, results); + } + + async finalize(): Promise { + return; + } +} + +export function getProviderAdapter(provider: AiProvider): AiProviderAdapter { + switch (provider) { + case AiProvider.OPENAI: + return new OpenAiProviderAdapter(); + case AiProvider.MISTRAL: + return new MistralProviderAdapter(); + case AiProvider.OLLAMA: + return new OllamaProviderAdapter(); + } +} diff --git a/src/ai/rag-artifact-payload.ts b/src/ai/rag-artifact-payload.ts new file mode 100644 index 0000000..809ffef --- /dev/null +++ b/src/ai/rag-artifact-payload.ts @@ -0,0 +1,77 @@ +import type {AiProvider} from "../model/ai-provider"; + +export type RagArtifactSource = { + fileId: string; + fileName: string; + mimeType?: string; + sizeBytes?: number; + sha256?: string; + uploadedFileId?: string; + documentId?: string; +}; + +export type RagArtifactPayload = { + artifactKind: "rag"; + provider: AiProvider; + createdAt: string; + sources: RagArtifactSource[]; + providerState: + | { + provider: AiProvider.OPENAI; + vectorStoreIds: string[]; + uploadedFileIds: string[]; + } + | { + provider: AiProvider.MISTRAL; + libraryId?: string; + documentCount: number; + } + | { + provider: AiProvider.OLLAMA; + prepared: boolean; + embeddingModel?: string; + topK?: number; + chunkSize?: number; + chunkOverlap?: number; + maxContextChars?: number; + extractedDocuments: Array<{ + documentIndex: number; + fileName: string; + textChars: number; + }>; + selectedChunks: Array<{ + sourceId: string; + documentIndex: number; + documentName: string; + chunkIndex: number; + chunkCount: number; + textChars: number; + score?: number; + }>; + skippedDocuments: Array<{ + documentIndex: number; + fileName: string; + reason: string; + }>; + query: string; + minScore: number; + maxArchiveFiles: number; + maxArchiveBytes: number; + maxArchiveDepth: number; + }; +}; + +export function buildRagArtifactPayload(params: { + provider: AiProvider; + createdAt?: string; + sources: RagArtifactSource[]; + providerState: RagArtifactPayload["providerState"]; +}): RagArtifactPayload { + return { + artifactKind: "rag", + provider: params.provider, + createdAt: params.createdAt ?? new Date().toISOString(), + sources: params.sources, + providerState: params.providerState, + }; +} diff --git a/src/ai/rag-artifact-store.ts b/src/ai/rag-artifact-store.ts index 2d4da3b..c588ff0 100644 --- a/src/ai/rag-artifact-store.ts +++ b/src/ai/rag-artifact-store.ts @@ -4,75 +4,39 @@ import type {AiDownloadedFile} from "./telegram-attachments"; import type {PreparedDocumentRag} from "./document-rag-pipeline"; import type {OllamaRagArtifactDetails} from "./ollama-rag"; import {persistInternalJsonArtifactAttachment} from "./internal-artifact-store"; - -type RagArtifactPayload = { - artifactKind: "rag"; - provider: AiProvider; - createdAt: string; - sources: Array<{ - fileId: string; - fileName: string; - mimeType?: string; - sizeBytes?: number; - sha256?: string; - uploadedFileId?: string; - documentId?: string; - }>; - providerState: { - vectorStoreIds?: string[]; - libraryId?: string; - documentCount?: number; - prepared?: boolean; - uploadedFileIds?: string[]; - embeddingModel?: string; - topK?: number; - chunkSize?: number; - chunkOverlap?: number; - maxContextChars?: number; - extractedDocuments?: Array<{ - documentIndex: number; - fileName: string; - textChars: number; - }>; - selectedChunks?: Array<{ - sourceId: string; - documentIndex: number; - documentName: string; - chunkIndex: number; - chunkCount: number; - textChars: number; - score?: number; - }>; - skippedDocuments?: Array<{ - documentIndex: number; - fileName: string; - reason: string; - }>; - query?: string; - ollama?: OllamaRagArtifactDetails["providerState"]; - }; -}; +import {buildRagArtifactPayload, type RagArtifactPayload} from "./rag-artifact-payload"; function providerState(prepared: PreparedDocumentRag, details?: NonNullable[0]["details"]>): RagArtifactPayload["providerState"] { switch (prepared.provider) { case AiProvider.OPENAI: return { + provider: AiProvider.OPENAI, vectorStoreIds: prepared.vectorStoreIds, uploadedFileIds: prepared.uploadedFileIds, }; case AiProvider.MISTRAL: return { + provider: AiProvider.MISTRAL, libraryId: prepared.libraryId, documentCount: prepared.documents.length, }; case AiProvider.OLLAMA: return { + provider: AiProvider.OLLAMA, prepared: prepared.prepared, embeddingModel: details?.embeddingModel, topK: details?.topK, chunkSize: details?.chunkSize, chunkOverlap: details?.chunkOverlap, maxContextChars: details?.maxContextChars, + extractedDocuments: details?.artifact?.extractedDocuments ?? [], + selectedChunks: details?.artifact?.selectedChunks ?? [], + skippedDocuments: details?.artifact?.skippedDocuments ?? [], + query: details?.artifact?.query ?? "", + minScore: details?.artifact?.providerState?.minScore ?? 0, + maxArchiveFiles: details?.artifact?.providerState?.maxArchiveFiles ?? 0, + maxArchiveBytes: details?.artifact?.providerState?.maxArchiveBytes ?? 0, + maxArchiveDepth: details?.artifact?.providerState?.maxArchiveDepth ?? 0, }; } } @@ -117,22 +81,11 @@ export async function persistRagArtifactAttachment(params: { if (!sources.length) return Promise.resolve(undefined); - const payload: RagArtifactPayload = { - artifactKind: "rag", + const payload = buildRagArtifactPayload({ provider: params.provider, - createdAt: new Date().toISOString(), sources, - providerState: { - ...providerState(params.prepared, params.details), - ...(params.details?.artifact ? { - extractedDocuments: params.details.artifact.extractedDocuments, - selectedChunks: params.details.artifact.selectedChunks, - skippedDocuments: params.details.artifact.skippedDocuments, - query: params.details.artifact.query, - ollama: params.details.artifact.providerState, - } : {}), - }, - }; + providerState: providerState(params.prepared, params.details), + }); return await persistInternalJsonArtifactAttachment({ artifactKind: "rag", fileNamePrefix: "rag", @@ -140,14 +93,8 @@ export async function persistRagArtifactAttachment(params: { messageId: params.messageId, payload, metadata: { - provider: params.provider, sourceFileNames: sources.map(source => source.fileName), ...payload.providerState, - embeddingModel: params.details?.embeddingModel, - topK: params.details?.topK, - chunkSize: params.details?.chunkSize, - chunkOverlap: params.details?.chunkOverlap, - maxContextChars: params.details?.maxContextChars, }, }); } diff --git a/src/ai/tool-mappers.ts b/src/ai/tool-mappers.ts index faf8173..4b099cc 100644 --- a/src/ai/tool-mappers.ts +++ b/src/ai/tool-mappers.ts @@ -1,8 +1,8 @@ import {AiTool} from "./tool-types"; -import {AiProvider} from "../model/ai-provider"; -import {getTools} from "./tools/registry"; -import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search"; -import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator"; +import {AiProvider} from "../model/ai-provider.js"; +import {getTools} from "./tools/registry.js"; +import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search.js"; +import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator.js"; export type AiProviderName = "ollama" | "openai" | "mistral"; diff --git a/src/ai/tool-rank-stage.ts b/src/ai/tool-rank-stage.ts new file mode 100644 index 0000000..cce746d --- /dev/null +++ b/src/ai/tool-rank-stage.ts @@ -0,0 +1,90 @@ +import {Environment} from "../common/environment.js"; +import {AiProvider} from "../model/ai-provider.js"; +import type {BoundaryValue} from "../common/boundary-types.js"; +import type {TelegramStreamMessage} from "./telegram-stream-message.js"; +import type {RuntimeConfigSnapshot} from "./unified-ai-runner.shared.js"; +import {filterRankedTools} from "./tool-ranker-pipeline.js"; +import {ToolRanker} from "./unified-ai-runner.tool-ranker.js"; +import {storeToolRankAudit} from "./tool-rank-audit.js"; + +function latestUserText(messages: readonly { role?: string; content?: unknown }[]): string { + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + if (message?.role !== "user") continue; + if (typeof message.content === "string") return message.content; + if (Array.isArray(message.content)) { + return message.content + .map(part => typeof part === "object" && part !== null && "text" in part && typeof (part as { text?: unknown }).text === "string" + ? (part as { text: string }).text + : "") + .filter(Boolean) + .join("\n"); + } + } + + return ""; +} + +export async function runToolRankStage(params: { + provider: AiProvider; + model: string; + round: number; + config: RuntimeConfigSnapshot; + availableTools: readonly BoundaryValue[]; + messages: readonly { role?: string; content?: unknown }[]; + streamMessage: TelegramStreamMessage; + signal: AbortSignal; + toolRanker?: ToolRanker; +}): Promise<{ + filteredTools: BoundaryValue[]; + selectedToolNames: string[]; + usedRanker: boolean; +}> { + const toolRanker = params.toolRanker ?? new ToolRanker(params.config); + const startedAt = Date.now(); + const startedAtIso = new Date().toISOString(); + + params.streamMessage.setStatus(Environment.getSelectingToolsText()); + await params.streamMessage.flush(); + + try { + const selection = await toolRanker.selectTools({ + provider: params.provider, + userQuery: latestUserText(params.messages), + availableTools: params.availableTools, + round: params.round, + signal: params.signal, + }); + + params.streamMessage.clearStatus(); + await params.streamMessage.flush(); + await storeToolRankAudit({ + streamMessage: params.streamMessage, + provider: params.provider, + model: params.model, + round: params.round, + startedAt, + startedAtIso, + selectedTools: selection.toolNames, + }); + + return { + filteredTools: filterRankedTools(params.availableTools, selection.toolNames), + selectedToolNames: selection.toolNames, + usedRanker: selection.usedRanker, + }; + } catch (error) { + params.streamMessage.clearStatus(); + await params.streamMessage.flush(); + await storeToolRankAudit({ + streamMessage: params.streamMessage, + provider: params.provider, + model: params.model, + round: params.round, + startedAt, + startedAtIso, + error, + }); + throw error; + } +} diff --git a/src/ai/unified-ai-response-pipeline.ts b/src/ai/unified-ai-response-pipeline.ts index f3a0d8c..df71733 100644 --- a/src/ai/unified-ai-response-pipeline.ts +++ b/src/ai/unified-ai-response-pipeline.ts @@ -2,6 +2,7 @@ import {AiProvider} from "../model/ai-provider"; import {Environment} from "../common/environment"; import {ifTrue, logError} from "../util/utils"; import {UserRequestPipeline, type UserRequestPipelineState, type UserRequestPipelineStage} from "./user-request-pipeline"; +import {getProviderAdapter} from "./provider-adapters"; import type {AiDownloadedFile} from "./telegram-attachments"; import type {TelegramStreamMessage} from "./telegram-stream-message"; import type {PreparedUnifiedAiRequest} from "./unified-ai-request-pipeline"; @@ -9,12 +10,14 @@ import type {OpenAIChatMessage} from "./openai-chat-message"; import type {MistralChatMessage} from "./mistral-chat-message"; import type {ChatMessage} from "./chat-messages-types"; import { + allToolSchemaNames, providerName, RuntimeConfigSnapshot, snapshotModel, TELEGRAM_LIMIT, UnifiedRunOptions, } from "./unified-ai-runner.shared"; +import {runToolRankStage} from "./tool-rank-stage"; import {runOpenAi} from "./unified-ai-runner.openai"; import {runOllama} from "./unified-ai-runner.ollama"; import {runMistral} from "./unified-ai-runner.mistral"; @@ -159,6 +162,9 @@ export async function runUnifiedAiResponsePipeline(params: { }): Promise { const {options, config, downloads, prepared, streamMessage, controller} = params; const state = createResponsePipelineState(options); + const adapter = getProviderAdapter(options.provider); + let selectedToolNames: string[] = []; + let filteredTools: unknown[] = []; const stages: UserRequestPipelineStage[] = [ { @@ -177,6 +183,62 @@ export async function runUnifiedAiResponsePipeline(params: { }; }, }, + { + name: "tool_rank", + async run() { + const availableTools = adapter.rankTools(config, { + forCreator: options.msg.from?.id === Environment.CREATOR_ID, + vectorStoreIds: prepared.preparedDocumentRag?.provider === AiProvider.OPENAI + ? prepared.preparedDocumentRag.vectorStoreIds + : [], + }); + + const rankResult = await runToolRankStage({ + provider: options.provider, + model: snapshotModel(options.provider, config), + round: state.toolRankDecisions.length, + config, + availableTools, + messages: prepared.chatMessages, + streamMessage, + signal: controller.signal, + }); + + selectedToolNames = rankResult.selectedToolNames; + filteredTools = rankResult.filteredTools; + state.toolRankDecisions.push({ + provider: options.provider, + round: state.toolRankDecisions.length, + availableTools: allToolSchemaNames(availableTools), + selectedTools: selectedToolNames, + usedRanker: rankResult.usedRanker, + }); + + return { + stage: "tool_rank", + status: "succeeded", + details: { + selectedTools: selectedToolNames, + usedRanker: rankResult.usedRanker, + availableTools: allToolSchemaNames(availableTools), + toolRankDecision: state.toolRankDecisions.at(-1), + }, + }; + }, + }, + { + name: "filter_tools", + async run() { + return { + stage: "filter_tools", + status: "succeeded", + details: { + selectedTools: selectedToolNames, + filteredToolCount: filteredTools.length, + }, + }; + }, + }, { name: "model_call", async run() { @@ -312,6 +374,8 @@ export async function runUnifiedAiResponsePipeline(params: { stages, stageNames: [ "audit_start", + "tool_rank", + "filter_tools", "model_call", "tool_loop", "output_size_gate", diff --git a/src/ai/unified-ai-runner.mistral.ts b/src/ai/unified-ai-runner.mistral.ts index 4de138d..fb57cc0 100644 --- a/src/ai/unified-ai-runner.mistral.ts +++ b/src/ai/unified-ai-runner.mistral.ts @@ -1,21 +1,17 @@ import {Environment} from "../common/environment"; -import {getMistralTools} from "./tool-mappers"; import {TelegramStreamMessage} from "./telegram-stream-message"; import {ToolRuntimeContext} from "./tools/runtime"; import {MistralChatMessage} from "./mistral-chat-message"; import {createMistralClient} from "./ai-runtime-target"; import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger"; import {AiProvider} from "../model/ai-provider"; -import {ToolRanker} from "./unified-ai-runner.tool-ranker"; +import {getProviderAdapter} from "./provider-adapters"; +import {runToolRankStage} from "./tool-rank-stage"; import { - contentFromMistralDelta, executeToolBatch, MAX_TOOL_ROUNDS, - MistralDeltaLike, MistralDocumentReference, - mistralToolCalls, - normalizeMistralToolCalls, roundStatus, RuntimeConfigSnapshot, StreamingToolCallAccumulator, @@ -23,8 +19,6 @@ import { ToolExecutionMemory } from "./unified-ai-runner.shared"; import {Message} from "typescript-telegram-bot-api"; -import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline"; -import {storeToolRankAudit} from "./tool-rank-audit"; export async function runMistral( msg: Message, @@ -39,8 +33,9 @@ export async function runMistral( ): Promise { const runnerStartedAt = Date.now(); const mistralAi = createMistralClient(config.mistralChatTarget); - const toolRanker = new ToolRanker(config); - const availableTools = getMistralTools(msg.from?.id === Environment.CREATOR_ID); + const adapter = getProviderAdapter(AiProvider.MISTRAL); + const availableTools = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID}); + const requestMessages = adapter.mapMessages([...messages]) as unknown as MistralChatMessage[]; aiLog("info", "mistral.run.start", { stream, target: aiLogProviderTarget(config.mistralChatTarget), @@ -50,142 +45,119 @@ export async function runMistral( }); const toolMemory: ToolExecutionMemory = new Map(); + try { + for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { + const roundStartedAt = Date.now(); + aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream}); + if (signal.aborted) throw new Error("Aborted"); - for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { - const roundStartedAt = Date.now(); - aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream}); - if (signal.aborted) throw new Error("Aborted"); - - streamMessage.setStatus(Environment.getSelectingToolsText()); - await streamMessage.flush(); - const toolRankStartedAt = Date.now(); - const toolRankStartedAtIso = new Date().toISOString(); - const rankerSelection = await toolRanker.selectTools({ + const rankResult = await runToolRankStage({ provider: AiProvider.MISTRAL, - userQuery: latestUserTextFromMessages(messages), - availableTools, + model: config.mistralChatTarget.model, round, + config, + availableTools, + messages, + streamMessage, signal, - }) - .catch(async error => { - streamMessage.clearStatus(); - await streamMessage.flush(); - await storeToolRankAudit({ - streamMessage, - provider: AiProvider.MISTRAL, - model: config.mistralChatTarget.model, - round, - startedAt: toolRankStartedAt, - startedAtIso: toolRankStartedAtIso, - error, - }); - throw error; }); - streamMessage.clearStatus(); - await streamMessage.flush(); - await storeToolRankAudit({ - streamMessage, - provider: AiProvider.MISTRAL, - model: config.mistralChatTarget.model, - round, - startedAt: toolRankStartedAt, - startedAtIso: toolRankStartedAtIso, - selectedTools: rankerSelection.toolNames, - }); - const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames); - const requestTools = filteredTools.length ? filteredTools : undefined; + const filteredTools = rankResult.filteredTools; + const requestTools = filteredTools.length ? filteredTools : undefined; - streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? ""); - await streamMessage.flush(); + streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? ""); + await streamMessage.flush(); + + if (!stream) { + const request = { + model: config.mistralChatTarget.model, + messages: requestMessages, + tools: requestTools, + documents: documents + } as Parameters[0]; + const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})); + const message = response.choices?.[0]?.message; + const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? ""); + streamMessage.append(text); + const calls = adapter.extractToolCalls(message); + aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", { + round, + duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), + textChars: text.length, + calls: calls.map(aiLogToolCall), + }); + if (!calls.length) return; + messages.push({ + role: "assistant", + content: text, + toolCalls: calls.map(call => ({ + id: call.id, + function: {name: call.name, arguments: call.argumentsText}, + })), + }); + requestMessages.push({ + role: "assistant", + content: text, + toolCalls: calls.map(call => ({ + id: call.id, + function: {name: call.name, arguments: call.argumentsText}, + })), + }); + const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory); + adapter.appendToolResults(messages, calls, toolResults); + adapter.appendToolResults(requestMessages, calls, toolResults); + continue; + } - if (!stream) { const request = { model: config.mistralChatTarget.model, - messages, + messages: requestMessages, tools: requestTools, documents: documents - } as Parameters[0]; - const response = await mistralAi.chat.complete(request, {signal}); - const message = response.choices?.[0]?.message; - const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? ""); - streamMessage.append(text); - const calls = normalizeMistralToolCalls(mistralToolCalls(message)); + } as Parameters[0]; + const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})); + aiLog("debug", "mistral.stream.open", {round}); + let calls: ToolCallData[] = []; + const roundTextStart = streamMessage.getText().length; + const toolCallAccumulator = new StreamingToolCallAccumulator("mistral_stream", round); + + for await (const event of streamResponse) { + if (signal.aborted) throw new Error("Aborted"); + + const choice = event.data?.choices?.[0]; + const delta = choice?.delta; + const mistralDelta = delta; + streamMessage.append(adapter.extractTextDelta(mistralDelta)); + + const rawDeltaCalls = adapter.extractStreamingToolCalls(mistralDelta); + if (rawDeltaCalls.length) { + calls = toolCallAccumulator.add(rawDeltaCalls); + streamMessage.setStatus(Environment.getUseToolText(calls)); + await streamMessage.flush(); + } + } aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", { round, duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), - textChars: text.length, + textChars: streamMessage.getText().slice(roundTextStart).length, calls: calls.map(aiLogToolCall), }); if (!calls.length) return; + const roundText = streamMessage.getText().slice(roundTextStart); messages.push({ role: "assistant", - content: text, - toolCalls: calls.map(call => ({ - id: call.id, - function: {name: call.name, arguments: call.argumentsText}, - })), + content: roundText, + toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}})) + }); + requestMessages.push({ + role: "assistant", + content: roundText, + toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}})) }); const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory); - for (const [index, call] of calls.entries()) { - messages.push({ - role: "tool", - name: call.name, - toolCallId: call.id, - content: toolResults[index] ?? "", - }); - } - continue; - } - - const request = { - model: config.mistralChatTarget.model, - messages, - tools: requestTools, - documents: documents - } as Parameters[0]; - const streamResponse = await mistralAi.chat.stream(request, {signal}); - aiLog("debug", "mistral.stream.open", {round}); - let calls: ToolCallData[] = []; - const roundTextStart = streamMessage.getText().length; - const toolCallAccumulator = new StreamingToolCallAccumulator("mistral_stream", round); - - for await (const event of streamResponse) { - if (signal.aborted) throw new Error("Aborted"); - - const choice = event.data?.choices?.[0]; - const delta = choice?.delta; - const mistralDelta = delta as MistralDeltaLike; - - streamMessage.append(contentFromMistralDelta(mistralDelta)); - - const rawDeltaCalls = mistralToolCalls(mistralDelta); - if (rawDeltaCalls.length) { - calls = toolCallAccumulator.add(rawDeltaCalls); - streamMessage.setStatus(Environment.getUseToolText(calls)); - await streamMessage.flush(); - } - } - aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", { - round, - duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), - textChars: streamMessage.getText().slice(roundTextStart).length, - calls: calls.map(aiLogToolCall), - }); - if (!calls.length) return; - const roundText = streamMessage.getText().slice(roundTextStart); - messages.push({ - role: "assistant", - content: roundText, - toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}})) - }); - const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory); - for (const [index, call] of calls.entries()) { - messages.push({ - role: "tool", - name: call.name, - toolCallId: call.id, - content: toolResults[index] ?? "", - }); + adapter.appendToolResults(messages, calls, toolResults); + adapter.appendToolResults(requestMessages, calls, toolResults); } + } finally { + await adapter.finalize().catch(() => undefined); } } diff --git a/src/ai/unified-ai-runner.ollama.ts b/src/ai/unified-ai-runner.ollama.ts index 98997b3..2baafb5 100644 --- a/src/ai/unified-ai-runner.ollama.ts +++ b/src/ai/unified-ai-runner.ollama.ts @@ -5,7 +5,6 @@ import {Environment} from "../common/environment"; import type {BoundaryValue} from "../common/boundary-types"; import {bot, notesDir} from "../index"; import {clamp, logError} from "../util/utils"; -import {getOllamaTools} from "./tool-mappers"; import {TelegramStreamMessage} from "./telegram-stream-message"; import {ChatMessage} from "./chat-messages-types"; import {ChatRequest, Tool} from "ollama"; @@ -14,10 +13,11 @@ import {enqueueTelegramApiCall} from "../util/telegram-api-queue"; import {loadOllamaModel, unloadAllOllamaModels} from "./tools/utils"; import {createOllamaClient} from "./ai-runtime-target"; import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger"; +import {getProviderAdapter} from "./provider-adapters"; +import {runToolRankStage} from "./tool-rank-stage"; import { allToolSchemaNames, - appendOllamaToolResults, dedupeToolCalls, DEFAULT_OLLAMA_CONTEXT_SIZE, executeToolBatch, @@ -26,8 +26,6 @@ import { MAX_OLLAMA_CONTEXT_SIZE, MAX_TOOL_ROUNDS, MIN_OLLAMA_CONTEXT_SIZE, - normalizeOllamaToolCalls, - OllamaToolCallLike, roundStatus, RuntimeConfigSnapshot, safeJsonParseObject, @@ -35,14 +33,11 @@ import { ToolCallData, ToolExecutionMemory } from "./unified-ai-runner.shared"; -import {ToolRanker} from "./unified-ai-runner.tool-ranker"; import {getToolPrompts} from "./tools/registry"; -import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline"; import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes"; import {getModelCapabilities} from "./provider-model-runtime"; import {AiProvider} from "../model/ai-provider"; import {Message} from "typescript-telegram-bot-api"; -import {storeToolRankAudit} from "./tool-rank-audit"; export async function runOllama( msg: Message, @@ -157,6 +152,7 @@ export async function runOllama( } const toolMemory: ToolExecutionMemory = new Map(); + const adapter = getProviderAdapter(AiProvider.OLLAMA); try { for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { @@ -183,7 +179,7 @@ export async function runOllama( let activeToolNames: string[] = []; if ((await getModelCapabilities(AiProvider.OLLAMA, model, "tools"))?.tools?.supported) { - const availableOllamaTools: Tool[] = getOllamaTools(msg.from?.id === Environment.CREATOR_ID) as Tool[]; + const availableOllamaTools: Tool[] = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID}) as Tool[]; aiLog("debug", "ollama.tools.available", { round, @@ -191,44 +187,18 @@ export async function runOllama( rankerEnabled: !!config.ollamaToolRankerTarget, }); - streamMessage.setStatus(Environment.getSelectingToolsText()); - await streamMessage.flush(); - const toolRankStartedAt = Date.now(); - const toolRankStartedAtIso = new Date().toISOString(); - const rankerSelection = await new ToolRanker(config).selectTools({ - provider: AiProvider.OLLAMA, - userQuery: latestUserTextFromMessages(messages), - availableTools: availableOllamaTools, - round, - signal, - }) - .catch(async error => { - streamMessage.clearStatus(); - await streamMessage.flush(); - await storeToolRankAudit({ - streamMessage, - provider: AiProvider.OLLAMA, - model, - round, - startedAt: toolRankStartedAt, - startedAtIso: toolRankStartedAtIso, - error, - }); - throw error; - }); - streamMessage.clearStatus(); - await streamMessage.flush(); - await storeToolRankAudit({ - streamMessage, + const rankResult = await runToolRankStage({ provider: AiProvider.OLLAMA, model, round, - startedAt: toolRankStartedAt, - startedAtIso: toolRankStartedAtIso, - selectedTools: rankerSelection.toolNames, + config, + availableTools: availableOllamaTools, + messages, + streamMessage, + signal, }); - const filteredTools = [...new Set(filterRankedTools(availableOllamaTools, rankerSelection.toolNames))]; + const filteredTools = [...new Set(rankResult.filteredTools as Tool[])]; activeToolNames = filteredTools.map(t => t.function.name ?? ""); if (filteredTools.length > 0) { request.tools = [...filteredTools]; @@ -256,24 +226,21 @@ export async function runOllama( round, tools: activeToolNames, count: activeToolNames.length, - usedRanker: rankerSelection.usedRanker, + usedRanker: rankResult.usedRanker, }); } if (!stream) { - const response = await ollama.chat({ + const response = await adapter.callModel(request, () => ollama.chat({ ...request, stream: false - }); + })); const message = response.message; const rawContent = message?.content ?? ""; const nativeCalls = dedupeToolCalls( - normalizeOllamaToolCalls( - message?.tool_calls as readonly OllamaToolCallLike[] | undefined, - round, - ), + adapter.extractToolCalls(message), ); const responseText = rawContent; @@ -301,7 +268,7 @@ export async function runOllama( break; } - const calls = nativeCalls; + const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls; aiLog("info", "ollama.tool_calls", { round, @@ -319,11 +286,7 @@ export async function runOllama( })), }); - appendOllamaToolResults( - messages, - calls, - await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory), - ); + adapter.appendToolResults(messages, calls, await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory)); continue; } @@ -332,10 +295,10 @@ export async function runOllama( round, messageCount: request.messages?.length ?? 0, }); - const response = await ollama.chat({ + const response = await adapter.callModel(request, () => ollama.chat({ ...request, stream: true - }); + })); aiLog("debug", "ollama.stream.open", {round}); const calls: ToolCallData[] = []; @@ -354,10 +317,7 @@ export async function runOllama( const localToolCalls: ToolCallData[] = []; - localToolCalls.push(...normalizeOllamaToolCalls( - chunk.message.tool_calls as readonly OllamaToolCallLike[] | undefined, - round, - )); + localToolCalls.push(...adapter.extractStreamingToolCalls(chunk.message)); const newStatus = roundStatus(round, firstRoundStatus, chunk.message.content, localToolCalls, !!chunk.message.thinking); const previousStatus = streamMessage.getStatus(); @@ -377,13 +337,10 @@ export async function runOllama( } if (!(chunk.message?.thinking && streamMessage.getStatus() !== Environment.reasoningText)) { - streamMessage.append(chunk.message?.content ?? ""); + streamMessage.append(adapter.extractTextDelta(chunk)); } - calls.push(...normalizeOllamaToolCalls( - chunk.message?.tool_calls as readonly OllamaToolCallLike[] | undefined, - round, - )); + calls.push(...adapter.extractStreamingToolCalls(chunk.message)); if (chunk.done) { aiLog("debug", "ollama.stream.done", { @@ -471,9 +428,10 @@ export async function runOllama( }).catch(logError); } - appendOllamaToolResults(messages, calls, toolResults); + adapter.appendToolResults(messages, calls, toolResults); } } finally { if (interval) clearInterval(interval); + await adapter.finalize().catch(() => undefined); } } diff --git a/src/ai/unified-ai-runner.openai.ts b/src/ai/unified-ai-runner.openai.ts index 096c278..266073b 100644 --- a/src/ai/unified-ai-runner.openai.ts +++ b/src/ai/unified-ai-runner.openai.ts @@ -17,11 +17,9 @@ import { AsyncIterableStream, buildSystemInstruction, collectOpenAiResponseCodeInterpreterCalls, - collectOpenAiResponseFunctionCalls, collectOpenAiResponseImages, collectOpenAiResponseText, executeToolBatch, - getOpenAIResponsesToolsWithImage, MAX_TOOL_ROUNDS, OPENAI_IMAGE_PARTIALS, openAiResponseItemCallId, @@ -42,10 +40,9 @@ import {logError} from "../util/utils"; import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files"; import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings"; import {AiDownloadedFile} from "./telegram-attachments"; -import {ToolRanker} from "./unified-ai-runner.tool-ranker"; import {AiProvider} from "../model/ai-provider"; -import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline"; -import {storeToolRankAudit} from "./tool-rank-audit"; +import {getProviderAdapter} from "./provider-adapters"; +import {runToolRankStage} from "./tool-rank-stage"; export async function runOpenAi( msg: Message, @@ -60,16 +57,15 @@ export async function runOpenAi( documentRag?: OpenAiDocumentRagContext, ): Promise { const runnerStartedAt = Date.now(); - let responseInput: Array = [...messages] as Array; const openAi = createOpenAiClient(config.openAiChatTarget); const ownsDocumentRag = !documentRag; const preparedDocumentRag = documentRag ?? await prepareOpenAiDocumentRag(openAi, downloads.filter(download => download.kind === "document")); - const toolRanker = new ToolRanker(config); - const availableTools = getOpenAIResponsesToolsWithImage( - config, - msg.from?.id === Environment.CREATOR_ID, - preparedDocumentRag?.vectorStoreIds ?? [], - ); + const adapter = getProviderAdapter(AiProvider.OPENAI); + let responseInput: Array = adapter.mapMessages(messages) as unknown as Array; + const availableTools = adapter.rankTools(config, { + forCreator: msg.from?.id === Environment.CREATOR_ID, + vectorStoreIds: preparedDocumentRag?.vectorStoreIds ?? [], + }); const systemPrompt = buildSystemInstruction( config, @@ -93,43 +89,17 @@ export async function runOpenAi( for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { const roundStartedAt = Date.now(); aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream}); - streamMessage.setStatus(Environment.getSelectingToolsText()); - await streamMessage.flush(); - const toolRankStartedAt = Date.now(); - const toolRankStartedAtIso = new Date().toISOString(); - const rankerSelection = await toolRanker.selectTools({ - provider: AiProvider.OPENAI, - userQuery: latestUserTextFromMessages(messages), - availableTools, - round, - signal, - }) - .catch(async error => { - streamMessage.clearStatus(); - await streamMessage.flush(); - await storeToolRankAudit({ - streamMessage, - provider: AiProvider.OPENAI, - model: config.openAiChatTarget.model, - round, - startedAt: toolRankStartedAt, - startedAtIso: toolRankStartedAtIso, - error, - }); - throw error; - }); - streamMessage.clearStatus(); - await streamMessage.flush(); - await storeToolRankAudit({ - streamMessage, + const rankResult = await runToolRankStage({ provider: AiProvider.OPENAI, model: config.openAiChatTarget.model, round, - startedAt: toolRankStartedAt, - startedAtIso: toolRankStartedAtIso, - selectedTools: rankerSelection.toolNames, + config, + availableTools, + messages, + streamMessage, + signal, }); - const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames); + const filteredTools = rankResult.filteredTools; const requestTools = preparedDocumentRag?.vectorStoreIds.length ? (() => { const tools = [...filteredTools]; @@ -151,7 +121,7 @@ export async function runOpenAi( tools: requestTools as ResponseCreateParamsNonStreaming["tools"], instructions: systemPrompt, }; - const response = await openAi.responses.create(request, {signal}) as OpenAiResponseLike; + const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike; const responseText = collectOpenAiResponseText(response); streamMessage.append(responseText); @@ -188,12 +158,12 @@ export async function runOpenAi( }); } - const calls = collectOpenAiResponseFunctionCalls(response); + const calls = adapter.extractToolCalls(response); aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", { round, duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), calls: calls.map(call => ({ - id: call.callId, + id: call.id, name: call.name, arguments: safeJsonParseObject(call.argumentsText) })), @@ -201,16 +171,13 @@ export async function runOpenAi( if (!calls.length) return; const toolCalls = calls.map(call => ({ - id: call.callId, + id: call.id, name: call.name, argumentsText: call.argumentsText, })); const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory); - const toolOutputs = calls.map((call, index) => ({ - type: "function_call_output" as const, - call_id: call.callId, - output: toolResults[index] ?? "", - })); + const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = []; + adapter.appendToolResults(toolOutputs, calls, toolResults); const uploadFilesResult = await tryToUploadFiles(msg, toolResults); if (uploadFilesResult.found) { @@ -243,7 +210,7 @@ export async function runOpenAi( parallel_tool_calls: true, instructions: systemPrompt }; - const response = await openAi.responses.create(request, {signal}) as AsyncIterableStream; + const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream; aiLog("debug", "openai.stream.open", {round}); @@ -253,7 +220,7 @@ export async function runOpenAi( switch (event.type) { case "response.output_text.delta": - streamMessage.append(event.delta ?? ""); + streamMessage.append(adapter.extractTextDelta(event)); break; case "response.image_generation_call.in_progress": streamMessage.setStatus(Environment.startingImageGenText); @@ -301,14 +268,11 @@ export async function runOpenAi( case "response.code_interpreter_call_code.done": break; case "response.output_item.added": - if (event.item.type === "function_call" && event.item.name) { - const item = event.item as OpenAiResponseOutputItem & { id?: string }; - localToolCalls.push({ - id: openAiResponseItemCallId(item), - name: item.name ?? "", - argumentsText: item.arguments ?? "{}", - }); - + { + const streamedCalls = adapter.extractStreamingToolCalls(event); + if (streamedCalls.length) { + localToolCalls.push(...streamedCalls); + } aiLog("info", "openai.stream.tool_call.added", { round, toolCalls: localToolCalls.map(aiLogToolCall) @@ -383,12 +347,12 @@ export async function runOpenAi( }); } - const calls = collectOpenAiResponseFunctionCalls(completedResponse); + const calls = adapter.extractToolCalls(completedResponse); aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", { round, duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), calls: calls.map(call => ({ - id: call.callId, + id: call.id, name: call.name, arguments: safeJsonParseObject(call.argumentsText) })), @@ -396,16 +360,13 @@ export async function runOpenAi( if (!calls.length) return; const toolCalls = calls.map(call => ({ - id: call.callId, + id: call.id, name: call.name, argumentsText: call.argumentsText, })); const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory); - const toolOutputs = calls.map((call, index) => ({ - type: "function_call_output", - call_id: call.callId, - output: toolResults[index] ?? "", - })); + const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = []; + adapter.appendToolResults(toolOutputs, calls, toolResults); const uploadFilesResult = await tryToUploadFiles(msg, toolResults); if (uploadFilesResult.found) { @@ -431,6 +392,7 @@ export async function runOpenAi( if (ownsDocumentRag) { await preparedDocumentRag?.cleanup().catch(logError); } + await adapter.finalize().catch(logError); } } diff --git a/src/ai/unified-ai-runner.shared.ts b/src/ai/unified-ai-runner.shared.ts index 8c6c996..fcb65f4 100644 --- a/src/ai/unified-ai-runner.shared.ts +++ b/src/ai/unified-ai-runner.shared.ts @@ -2,40 +2,40 @@ import {Message} from "typescript-telegram-bot-api"; import * as fs from "node:fs"; import path from "node:path"; import type {BoundaryValue} from "../common/boundary-types"; -import {AiProvider} from "../model/ai-provider"; -import {ToolRankerFallbackPolicy} from "../common/policies"; -import {Environment} from "../common/environment"; -import {photoGenDir} from "../index"; -import {delay, logError, replyToMessage} from "../util/utils"; -import {MessageStore} from "../common/message-store"; -import type {OpenAiResponseTool} from "./tool-mappers"; -import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers"; -import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message"; -import {AiDownloadedFile} from "./telegram-attachments"; -import {getRuntimeCapabilities} from "./provider-model-runtime"; -import {StoredAttachment} from "../model/stored-attachment"; -import {AiChatMessage, ChatMessage} from "./chat-messages-types"; +import {AiProvider} from "../model/ai-provider.js"; +import {ToolRankerFallbackPolicy} from "../common/policies.js"; +import {Environment} from "../common/environment.js"; +import {photoGenDir} from "../index.js"; +import {delay, logError, replyToMessage} from "../util/utils.js"; +import {MessageStore} from "../common/message-store.js"; +import type {OpenAiResponseTool} from "./tool-mappers.js"; +import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers.js"; +import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message.js"; +import {AiDownloadedFile} from "./telegram-attachments.js"; +import {getRuntimeCapabilities} from "./provider-model-runtime.js"; +import {StoredAttachment} from "../model/stored-attachment.js"; +import {AiChatMessage, ChatMessage} from "./chat-messages-types.js"; import {ListResponse, Ollama} from "ollama"; -import {executeToolCall, ToolRuntimeContext} from "./tools/runtime"; -import {MessageImagePart, MessagePart} from "../common/message-part"; -import {KeyedAsyncLock} from "../util/async-lock"; -import {type AiRequestQueueTarget} from "./provider-request-queue"; -import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator"; -import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings"; +import {executeToolCall, ToolRuntimeContext} from "./tools/runtime.js"; +import {MessageImagePart, MessagePart} from "../common/message-part.js"; +import {KeyedAsyncLock} from "../util/async-lock.js"; +import {type AiRequestQueueTarget} from "./provider-request-queue.js"; +import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator.js"; +import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings.js"; import { isTranscribableAudioDownload, resolveSpeechToTextProviderForUser, transcribeSpeechDownloads -} from "./speech-to-text"; +} from "./speech-to-text.js"; import type {ChatCompletionMessageParam} from "openai/resources/chat/completions"; -import {MistralChatMessage} from "./mistral-chat-message"; -import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer"; -import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target"; -import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger"; -import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline"; +import {MistralChatMessage} from "./mistral-chat-message.js"; +import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer.js"; +import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target.js"; +import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger.js"; +import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline.js"; import type {ResponseInputMessageContentList} from "openai/resources/responses/responses"; -import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store"; -import {filterUserVisibleStoredAttachments} from "../common/stored-attachment-utils"; +import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store.js"; +import {filterUserVisibleStoredAttachments} from "../common/attachment-visibility.js"; export type {Message} from "typescript-telegram-bot-api"; export type {AiRuntimeTarget} from "./ai-runtime-target"; diff --git a/src/common/attachment-visibility.ts b/src/common/attachment-visibility.ts new file mode 100644 index 0000000..583c3a0 --- /dev/null +++ b/src/common/attachment-visibility.ts @@ -0,0 +1,5 @@ +import type {StoredAttachment} from "../model/stored-attachment"; + +export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] { + return attachments.filter(attachment => attachment.scope !== "internal_artifact"); +} diff --git a/src/common/environment.ts b/src/common/environment.ts index 143bbf1..d91140e 100644 --- a/src/common/environment.ts +++ b/src/common/environment.ts @@ -3,18 +3,18 @@ import os from "node:os"; import path from "node:path"; import {parse as parseDotEnv} from "dotenv"; import {z} from "zod"; -import {appLogger} from "../logging/logger"; +import {appLogger} from "../logging/logger.js"; import type {BoundaryValue, ErrorLike} from "./boundary-types"; -import {saveData} from "../db/database"; -import {Answers} from "../model/answers"; -import {ifTrue} from "../util/utils"; -import {AiProvider} from "../model/ai-provider"; -import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies"; -import {ToolRankerFallbackPolicy} from "./policies"; -import type {ToolCallData} from "../ai/unified-ai-runner"; -import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator"; -import {Localization, type LocalizationParams} from "./localization"; +import {saveData} from "../db/database.js"; +import {Answers} from "../model/answers.js"; +import {ifTrue} from "../util/utils.js"; +import {AiProvider} from "../model/ai-provider.js"; +import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies.js"; +import {ToolRankerFallbackPolicy} from "./policies.js"; +import type {ToolCallData} from "../ai/unified-ai-runner.js"; +import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator.js"; +import {Localization, type LocalizationParams} from "./localization.js"; type EnvRecord = Record; type StringEnumLike = Record; diff --git a/src/common/stored-attachment-utils.ts b/src/common/stored-attachment-utils.ts index 865da7b..c53711f 100644 --- a/src/common/stored-attachment-utils.ts +++ b/src/common/stored-attachment-utils.ts @@ -1,6 +1,7 @@ import path from "node:path"; import {Environment} from "./environment"; import {StoredAttachment} from "../model/stored-attachment"; +export {filterUserVisibleStoredAttachments} from "./attachment-visibility"; export function photoCachePathForUniqueId(uniqueId: string): string { return path.join(Environment.DATA_PATH, "cache", "photo", `${uniqueId}.jpg`); @@ -44,7 +45,3 @@ export function uniqueStoredAttachments(attachments: StoredAttachment[]): Stored return result; } - -export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] { - return attachments.filter(attachment => attachment.scope !== "internal_artifact"); -} diff --git a/src/db/database.ts b/src/db/database.ts index ba2e375..ed6cc78 100644 --- a/src/db/database.ts +++ b/src/db/database.ts @@ -1,9 +1,9 @@ import * as fs from "fs"; -import {Environment} from "../common/environment"; -import {logError} from "../util/utils"; -import {Answers} from "../model/answers"; +import {Environment} from "../common/environment.js"; +import {logError} from "../util/utils.js"; +import {Answers} from "../model/answers.js"; import path from "node:path"; -import {KeyedAsyncLock} from "../util/async-lock"; +import {KeyedAsyncLock} from "../util/async-lock.js"; type DataJsonFile = { admins: number[] diff --git a/test/provider-adapter-contract.test.mjs b/test/provider-adapter-contract.test.mjs new file mode 100644 index 0000000..d100690 --- /dev/null +++ b/test/provider-adapter-contract.test.mjs @@ -0,0 +1,83 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const { + extractOpenAiToolCalls, + extractOpenAiStreamingToolCalls, + extractOpenAiTextDelta, + extractMistralToolCalls, + extractMistralTextDelta, + extractOllamaToolCalls, + extractOllamaTextDelta, +} = await import("../dist/ai/provider-adapter-contract.js"); + +test("openai contract extracts text delta and function calls", () => { + assert.equal(extractOpenAiTextDelta({type: "response.output_text.delta", delta: "hello"}), "hello"); + + const calls = extractOpenAiToolCalls({ + output: [{ + type: "function_call", + call_id: "call-1", + name: "read_file", + arguments: "{\"path\":\"src/index.ts\"}", + }], + }); + + assert.equal(calls.length, 1); + assert.equal(calls[0].id, "call-1"); + assert.equal(calls[0].name, "read_file"); + + const streamed = extractOpenAiStreamingToolCalls({ + type: "response.output_item.added", + item: { + type: "function_call", + id: "call-2", + name: "search_files", + arguments: "{\"query\":\"sendMessage\"}", + }, + }); + + assert.equal(streamed.length, 1); + assert.equal(streamed[0].id, "call-2"); + assert.equal(streamed[0].name, "search_files"); +}); + +test("mistral contract extracts content and tool calls", () => { + assert.equal(extractMistralTextDelta({ + content: [{text: "hello"}, {text: " world"}], + }), "hello world"); + + const calls = extractMistralToolCalls({ + toolCalls: [{ + id: "m-1", + function: { + name: "get_weather", + arguments: {location: "Moscow"}, + }, + }], + }); + + assert.equal(calls.length, 1); + assert.equal(calls[0].id, "m-1"); + assert.equal(calls[0].name, "get_weather"); +}); + +test("ollama contract extracts content and tool calls", () => { + assert.equal(extractOllamaTextDelta({ + message: {content: "hello from ollama"}, + }), "hello from ollama"); + + const calls = extractOllamaToolCalls({ + tool_calls: [{ + id: "o-1", + function: { + name: "web_search", + arguments: {query: "openai docs"}, + }, + }], + }); + + assert.equal(calls.length, 1); + assert.equal(calls[0].id, "o-1"); + assert.equal(calls[0].name, "web_search"); +}); diff --git a/test/rag-artifact.test.mjs b/test/rag-artifact.test.mjs index 0f4f8b0..a4bca15 100644 --- a/test/rag-artifact.test.mjs +++ b/test/rag-artifact.test.mjs @@ -1,32 +1,13 @@ -import test, {after} from "node:test"; +import test from "node:test"; import assert from "node:assert/strict"; -import fs from "node:fs"; -import os from "node:os"; -import path from "node:path"; -const tempRoot = fs.mkdtempSync(path.join(os.tmpdir(), "tg-chat-bot-rag-")); -process.env.BOT_TOKEN = process.env.BOT_TOKEN ?? "test-token"; -process.env.CREATOR_ID = process.env.CREATOR_ID ?? "1"; -process.env.DATA_PATH = tempRoot; -process.env.DB_PATH = `file:${path.join(tempRoot, "test.sqlite")}`; -process.env.TEST_ENVIRONMENT = "true"; - -const {Environment} = await import("../dist/common/environment.js"); -Environment.load(); - -const {DatabaseManager} = await import("../dist/db/database-manager.js"); -DatabaseManager.init(); -await DatabaseManager.ready; - -const {ArtifactStore} = await import("../dist/common/artifact-store.js"); -const {filterUserVisibleStoredAttachments} = await import("../dist/common/stored-attachment-utils.js"); +const { + buildRagArtifactPayload, +} = await import("../dist/ai/rag-artifact-payload.js"); +const { + filterUserVisibleStoredAttachments, +} = await import("../dist/common/attachment-visibility.js"); const {AiProvider} = await import("../dist/model/ai-provider.js"); -const {persistRagArtifactAttachment} = await import("../dist/ai/rag-artifact-store.js"); - -after(async () => { - await DatabaseManager.close().catch(() => undefined); - fs.rmSync(tempRoot, {recursive: true, force: true}); -}); test("internal artifacts are not treated as user-visible attachments", () => { const visible = filterUserVisibleStoredAttachments([ @@ -50,105 +31,57 @@ test("internal artifacts are not treated as user-visible attachments", () => { assert.equal(visible[0].fileId, "visible"); }); -test("RAG artifacts persist structured ollama metadata", async () => { - const chatId = 42; - const messageId = 7; - - const attachment = await persistRagArtifactAttachment({ +test("RAG artifact payload keeps ollama retrieval metadata", () => { + const payload = buildRagArtifactPayload({ provider: AiProvider.OLLAMA, - prepared: { - provider: AiProvider.OLLAMA, - prepared: true, - cleanup: async () => undefined, - artifact: { - query: "What is in the file?", - extractedDocuments: [ - {documentIndex: 0, fileName: "report.txt", textChars: 120}, - ], - selectedChunks: [ - { - sourceId: "doc1-1", - documentIndex: 0, - documentName: "report.txt", - chunkIndex: 0, - chunkCount: 1, - textChars: 120, - score: 0.91, - }, - ], - skippedDocuments: [ - {documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"}, - ], - providerState: { - embeddingModel: "nomic-embed-text:latest", - topK: 8, - chunkSize: 1400, - chunkOverlap: 220, - maxContextChars: 14000, - minScore: 0.12, - maxArchiveFiles: 200, - maxArchiveBytes: 50 * 1024 * 1024, - maxArchiveDepth: 2, - }, - }, - }, - downloads: [{ - kind: "document", + createdAt: "2026-01-01T00:00:00.000Z", + sources: [{ fileId: "file-1", fileName: "report.txt", - buffer: Buffer.from("hello world"), - path: path.join(tempRoot, "report.txt"), + mimeType: "text/plain", + sizeBytes: 12, + sha256: "abc123", + uploadedFileId: "uploaded-1", }], - chatId, - messageId, - details: { + providerState: { + provider: AiProvider.OLLAMA, + prepared: true, embeddingModel: "nomic-embed-text:latest", topK: 8, chunkSize: 1400, chunkOverlap: 220, maxContextChars: 14000, - artifact: { - query: "What is in the file?", - extractedDocuments: [ - {documentIndex: 0, fileName: "report.txt", textChars: 120}, - ], - selectedChunks: [ - { - sourceId: "doc1-1", - documentIndex: 0, - documentName: "report.txt", - chunkIndex: 0, - chunkCount: 1, - textChars: 120, - score: 0.91, - }, - ], - skippedDocuments: [ - {documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"}, - ], - providerState: { - embeddingModel: "nomic-embed-text:latest", - topK: 8, - chunkSize: 1400, - chunkOverlap: 220, - maxContextChars: 14000, - minScore: 0.12, - maxArchiveFiles: 200, - maxArchiveBytes: 50 * 1024 * 1024, - maxArchiveDepth: 2, + extractedDocuments: [ + {documentIndex: 0, fileName: "report.txt", textChars: 120}, + ], + selectedChunks: [ + { + sourceId: "doc1-1", + documentIndex: 0, + documentName: "report.txt", + chunkIndex: 0, + chunkCount: 1, + textChars: 120, + score: 0.91, }, - }, + ], + skippedDocuments: [ + {documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"}, + ], + minScore: 0.12, + maxArchiveFiles: 200, + maxArchiveBytes: 50 * 1024 * 1024, + maxArchiveDepth: 2, + query: "What is in the file?", }, }); - assert.equal(attachment?.artifactKind, "rag"); - assert.equal(fs.existsSync(attachment.cachePath), true); - - const stored = await ArtifactStore.getByMessage(chatId, messageId); - assert.equal(stored.length, 1); - assert.equal(stored[0].kind, "rag"); - assert.equal(stored[0].payload.providerState.query, "What is in the file?"); - assert.equal(stored[0].payload.providerState.selectedChunks[0].score, 0.91); - assert.equal(stored[0].payload.providerState.skippedDocuments[0].reason, "unsupported format"); - assert.equal(stored[0].payload.providerState.ollama.embeddingModel, "nomic-embed-text:latest"); + assert.equal(payload.artifactKind, "rag"); + assert.equal(payload.provider, AiProvider.OLLAMA); + assert.equal(payload.sources[0].uploadedFileId, "uploaded-1"); + assert.equal(payload.providerState.provider, AiProvider.OLLAMA); + assert.equal(payload.providerState.query, "What is in the file?"); + assert.equal(payload.providerState.selectedChunks[0].score, 0.91); + assert.equal(payload.providerState.skippedDocuments[0].reason, "unsupported format"); + assert.equal(payload.providerState.embeddingModel, "nomic-embed-text:latest"); });