diff --git a/PIPELINE_TODO.md b/PIPELINE_TODO.md index dfd89e8..ec7fd82 100644 --- a/PIPELINE_TODO.md +++ b/PIPELINE_TODO.md @@ -82,8 +82,8 @@ - [ ] Stage `model_call` должен делать только один model request. - [x] Stage `model_call` должен возвращать normalized model output. - [x] Stage `tool_loop` должен решать, есть ли tool calls. -- [ ] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`. -- [ ] Stage `tool_loop` должен добавлять tool results в provider adapter. +- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`. +- [x] Stage `tool_loop` должен добавлять tool results в provider adapter. - [ ] Stage `tool_loop` должен управлять max rounds. - [ ] Stage `tool_loop` должен сохранять tool result artifacts. - [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`. diff --git a/src/ai/tool-batch-runner.ts b/src/ai/tool-batch-runner.ts new file mode 100644 index 0000000..fb7dd0c --- /dev/null +++ b/src/ai/tool-batch-runner.ts @@ -0,0 +1,28 @@ +import type {AiProviderAdapter} from "./provider-adapters.js"; +import {executeToolBatch, type ToolCallData, type ToolExecutionMemory} from "./unified-ai-runner.shared.js"; +import type {TelegramStreamMessage} from "./telegram-stream-message.js"; +import type {ToolRuntimeContext} from "./tools/runtime.js"; + +export async function executeToolBatchWithAdapter(params: { + userId: number | undefined | null; + toolCalls: ToolCallData[]; + streamMessage: TelegramStreamMessage; + toolContext: ToolRuntimeContext; + toolMemory: ToolExecutionMemory; + adapter: AiProviderAdapter; + appendTargets?: unknown[][]; +}): Promise { + const results = await executeToolBatch( + params.userId, + params.toolCalls, + params.streamMessage, + params.toolContext, + params.toolMemory, + ); + + for (const target of params.appendTargets ?? []) { + params.adapter.appendToolResults(target, params.toolCalls, results); + } + + return results; +} diff --git a/src/ai/unified-ai-runner.mistral.ts b/src/ai/unified-ai-runner.mistral.ts index fb57cc0..74a0312 100644 --- a/src/ai/unified-ai-runner.mistral.ts +++ b/src/ai/unified-ai-runner.mistral.ts @@ -9,7 +9,6 @@ import {getProviderAdapter} from "./provider-adapters"; import {runToolRankStage} from "./tool-rank-stage"; import { - executeToolBatch, MAX_TOOL_ROUNDS, MistralDocumentReference, roundStatus, @@ -18,6 +17,7 @@ import { ToolCallData, ToolExecutionMemory } from "./unified-ai-runner.shared"; +import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {Message} from "typescript-telegram-bot-api"; export async function runMistral( @@ -102,9 +102,15 @@ export async function runMistral( function: {name: call.name, arguments: call.argumentsText}, })), }); - const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory); - adapter.appendToolResults(messages, calls, toolResults); - adapter.appendToolResults(requestMessages, calls, toolResults); + await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls: calls, + streamMessage, + toolContext, + toolMemory, + adapter, + appendTargets: [messages, requestMessages], + }); continue; } @@ -153,9 +159,15 @@ export async function runMistral( content: roundText, toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}})) }); - const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory); - adapter.appendToolResults(messages, calls, toolResults); - adapter.appendToolResults(requestMessages, calls, toolResults); + await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls: calls, + streamMessage, + toolContext, + toolMemory, + adapter, + appendTargets: [messages, requestMessages], + }); } } finally { await adapter.finalize().catch(() => undefined); diff --git a/src/ai/unified-ai-runner.ollama.ts b/src/ai/unified-ai-runner.ollama.ts index 2baafb5..1807228 100644 --- a/src/ai/unified-ai-runner.ollama.ts +++ b/src/ai/unified-ai-runner.ollama.ts @@ -20,7 +20,6 @@ import { allToolSchemaNames, dedupeToolCalls, DEFAULT_OLLAMA_CONTEXT_SIZE, - executeToolBatch, isOllamaModelActive, isRecord, MAX_OLLAMA_CONTEXT_SIZE, @@ -33,6 +32,7 @@ import { ToolCallData, ToolExecutionMemory } from "./unified-ai-runner.shared"; +import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {getToolPrompts} from "./tools/registry"; import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes"; import {getModelCapabilities} from "./provider-model-runtime"; @@ -286,7 +286,15 @@ export async function runOllama( })), }); - adapter.appendToolResults(messages, calls, await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory)); + await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls: calls, + streamMessage, + toolContext, + toolMemory, + adapter, + appendTargets: [messages], + }); continue; } @@ -396,7 +404,15 @@ export async function runOllama( })), }); - const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory); + const toolResults = await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls: calls, + streamMessage, + toolContext, + toolMemory, + adapter, + appendTargets: [messages], + }); let successGetNoteFileResult: GetNoteFileResult | undefined = undefined; @@ -428,7 +444,6 @@ export async function runOllama( }).catch(logError); } - adapter.appendToolResults(messages, calls, toolResults); } } finally { if (interval) clearInterval(interval); diff --git a/src/ai/unified-ai-runner.openai.ts b/src/ai/unified-ai-runner.openai.ts index 266073b..fd375c9 100644 --- a/src/ai/unified-ai-runner.openai.ts +++ b/src/ai/unified-ai-runner.openai.ts @@ -19,7 +19,6 @@ import { collectOpenAiResponseCodeInterpreterCalls, collectOpenAiResponseImages, collectOpenAiResponseText, - executeToolBatch, MAX_TOOL_ROUNDS, OPENAI_IMAGE_PARTIALS, openAiResponseItemCallId, @@ -33,6 +32,7 @@ import { errorMessage, allToolSchemaNames } from "./unified-ai-runner.shared"; +import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {bot} from "../index"; import fs from "node:fs"; import path from "node:path"; @@ -87,51 +87,247 @@ export async function runOpenAi( try { for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { - const roundStartedAt = Date.now(); - aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream}); - const rankResult = await runToolRankStage({ - provider: AiProvider.OPENAI, - model: config.openAiChatTarget.model, - round, - config, - availableTools, - messages, - streamMessage, - signal, - }); - const filteredTools = rankResult.filteredTools; - const requestTools = preparedDocumentRag?.vectorStoreIds.length - ? (() => { - const tools = [...filteredTools]; - const hasFileSearch = allToolSchemaNames(tools).includes("file_search"); - if (!hasFileSearch) { - const fileSearchTool = availableTools.find(tool => allToolSchemaNames([tool]).includes("file_search")); - if (fileSearchTool) { - tools.unshift(fileSearchTool); + const roundStartedAt = Date.now(); + aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream}); + const rankResult = await runToolRankStage({ + provider: AiProvider.OPENAI, + model: config.openAiChatTarget.model, + round, + config, + availableTools, + messages, + streamMessage, + signal, + }); + const filteredTools = rankResult.filteredTools; + const requestTools = preparedDocumentRag?.vectorStoreIds.length + ? (() => { + const tools = [...filteredTools]; + const hasFileSearch = allToolSchemaNames(tools).includes("file_search"); + if (!hasFileSearch) { + const fileSearchTool = availableTools.find(tool => allToolSchemaNames([tool]).includes("file_search")); + if (fileSearchTool) { + tools.unshift(fileSearchTool); + } + } + return tools.length ? tools : undefined; + })() + : (filteredTools.length ? filteredTools : undefined); + + if (!stream) { + const request: ResponseCreateParamsNonStreaming = { + model: config.openAiChatTarget.model, + input: responseInput as ResponseInputItem[], + tools: requestTools as ResponseCreateParamsNonStreaming["tools"], + instructions: systemPrompt, + }; + const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike; + + const responseText = collectOpenAiResponseText(response); + streamMessage.append(responseText); + aiLog("debug", "openai.response.received", { + round, + duration: aiLogDuration(roundStartedAt), + textChars: responseText.length, + outputItems: response?.output?.length ?? 0, + }); + const images = collectOpenAiResponseImages(response); + if (images.length) { + await showOpenAiGeneratedImage( + streamMessage, + sourceMessage, + images[images.length - 1], + `final_${round}`, + Environment.getImageGenDoneText(config.openAiImageTarget.model), + true, + ); + } + + const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(response); + if (codeInterpreterCalls.length) { + aiLog("info", "openai.code_interpreter_calls", { + round, + duration: aiLogDuration(roundStartedAt), + calls: codeInterpreterCalls.map(call => ({ + id: call.id, + status: call.status, + containerId: call.containerId, + codeChars: call.code?.length ?? 0, + outputItems: call.outputs.length, + })), + }); + } + + const calls = adapter.extractToolCalls(response); + aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", { + round, + duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), + calls: calls.map(call => ({ + id: call.id, + name: call.name, + arguments: safeJsonParseObject(call.argumentsText) + })), + }); + if (!calls.length) return; + + const toolCalls = calls.map(call => ({ + id: call.id, + name: call.name, + argumentsText: call.argumentsText, + })); + const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = []; + const toolResults = await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls, + streamMessage, + toolContext, + toolMemory, + adapter, + appendTargets: [toolOutputs], + }); + + const uploadFilesResult = await tryToUploadFiles(msg, toolResults); + if (uploadFilesResult.found) { + if (!uploadFilesResult.uploaded) { + const old = toolOutputs[uploadFilesResult.toolIndex]; + const callId = old?.call_id; + if (uploadFilesResult.toolIndex >= 0) { + delete toolOutputs[uploadFilesResult.toolIndex]; + } + if (callId) { + toolOutputs.push({ + type: "function_call_output" as const, + call_id: callId, + output: "Error: " + uploadFilesResult.error + }); + } } } - return tools.length ? tools : undefined; - })() - : (filteredTools.length ? filteredTools : undefined); - if (!stream) { - const request: ResponseCreateParamsNonStreaming = { + responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs]; + continue; + } + + let completedResponse: OpenAiResponseLike | null = null; + const request: ResponseCreateParamsStreaming = { model: config.openAiChatTarget.model, input: responseInput as ResponseInputItem[], - tools: requestTools as ResponseCreateParamsNonStreaming["tools"], - instructions: systemPrompt, + stream: true, + tools: requestTools as ResponseCreateParamsStreaming["tools"], + parallel_tool_calls: true, + instructions: systemPrompt }; - const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike; + const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream; - const responseText = collectOpenAiResponseText(response); - streamMessage.append(responseText); - aiLog("debug", "openai.response.received", { + aiLog("debug", "openai.stream.open", {round}); + + let localToolCalls: ToolCallData[] = []; + for await (const event of response) { + if (signal.aborted) throw new Error("Aborted"); + + switch (event.type) { + case "response.output_text.delta": + streamMessage.append(adapter.extractTextDelta(event)); + break; + case "response.image_generation_call.in_progress": + streamMessage.setStatus(Environment.startingImageGenText); + await streamMessage.flush(); + break; + case "response.image_generation_call.generating": + streamMessage.setStatus(Environment.imageGenText); + await streamMessage.flush(); + break; + case "response.image_generation_call.partial_image": { + const iteration = (event.partial_image_index ?? 0) + 1; + await showOpenAiGeneratedImage( + streamMessage, + sourceMessage, + event.partial_image_b64, + `partial_${round}_${iteration}`, + Environment.getPartialImageGenText(iteration, OPENAI_IMAGE_PARTIALS), + false, + ); + break; + } + case "response.image_generation_call.completed": + streamMessage.setStatus(Environment.finalizingImageGenText); + await streamMessage.flush(); + break; + case "response.file_search_call.in_progress": + case "response.file_search_call.searching": + streamMessage.setStatus(Environment.getUseToolText(["file_search"])); + await streamMessage.flush(); + break; + case "response.file_search_call.completed": + streamMessage.clearStatus(); + await streamMessage.flush(); + break; + case "response.code_interpreter_call.in_progress": + case "response.code_interpreter_call.interpreting": + streamMessage.setStatus(Environment.getUseToolText(["code_interpreter"])); + await streamMessage.flush(); + break; + case "response.code_interpreter_call.completed": + streamMessage.clearStatus(); + await streamMessage.flush(); + break; + case "response.code_interpreter_call_code.delta": + case "response.code_interpreter_call_code.done": + break; + case "response.output_item.added": + { + const streamedCalls = adapter.extractStreamingToolCalls(event); + if (streamedCalls.length) { + localToolCalls.push(...streamedCalls); + } + aiLog("info", "openai.stream.tool_call.added", { + round, + toolCalls: localToolCalls.map(aiLogToolCall) + }); + streamMessage.setStatus(Environment.getUseToolText(localToolCalls)); + await streamMessage.flush(); + } + break; + case "response.output_item.done": + if (event.item.type === "function_call" && event.item.name) { + const item = event.item as OpenAiResponseOutputItem & { id?: string }; + const itemId = openAiResponseItemCallId(item); + const index = localToolCalls.findIndex(c => c.id === itemId); + if (index !== -1) { + localToolCalls.splice(index, 1); + if (localToolCalls.length === 0) { + streamMessage.clearStatus(); + } else { + streamMessage.setStatus(Environment.getUseToolText(localToolCalls)); + } + await streamMessage.flush(); + } + } + break; + case "response.function_call_arguments.delta": + break; + case "response.function_call_arguments.done": + break; + + case "response.completed": + completedResponse = event.response as OpenAiResponseLike; + break; + case "response.failed": + throw new Error(event.response?.error?.message ?? "OpenAI response failed"); + case "error": + throw new Error(event.message ?? event?.message ?? "OpenAI stream error"); + } + } + + if (!completedResponse) throw new Error("OpenAI did not return the final response.completed event."); + + aiLog("debug", "openai.stream.completed", { round, duration: aiLogDuration(roundStartedAt), - textChars: responseText.length, - outputItems: response?.output?.length ?? 0, + outputItems: completedResponse?.output?.length ?? 0, }); - const images = collectOpenAiResponseImages(response); + + const images = collectOpenAiResponseImages(completedResponse); if (images.length) { await showOpenAiGeneratedImage( streamMessage, @@ -143,7 +339,7 @@ export async function runOpenAi( ); } - const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(response); + const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(completedResponse); if (codeInterpreterCalls.length) { aiLog("info", "openai.code_interpreter_calls", { round, @@ -158,7 +354,7 @@ export async function runOpenAi( }); } - const calls = adapter.extractToolCalls(response); + const calls = adapter.extractToolCalls(completedResponse); aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", { round, duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), @@ -175,9 +371,16 @@ export async function runOpenAi( name: call.name, argumentsText: call.argumentsText, })); - const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory); const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = []; - adapter.appendToolResults(toolOutputs, calls, toolResults); + const toolResults = await executeToolBatchWithAdapter({ + userId: msg.from?.id, + toolCalls, + streamMessage, + toolContext, + toolMemory, + adapter, + appendTargets: [toolOutputs], + }); const uploadFilesResult = await tryToUploadFiles(msg, toolResults); if (uploadFilesResult.found) { @@ -197,196 +400,7 @@ export async function runOpenAi( } } - responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs]; - continue; - } - - let completedResponse: OpenAiResponseLike | null = null; - const request: ResponseCreateParamsStreaming = { - model: config.openAiChatTarget.model, - input: responseInput as ResponseInputItem[], - stream: true, - tools: requestTools as ResponseCreateParamsStreaming["tools"], - parallel_tool_calls: true, - instructions: systemPrompt - }; - const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream; - - aiLog("debug", "openai.stream.open", {round}); - - let localToolCalls: ToolCallData[] = []; - for await (const event of response) { - if (signal.aborted) throw new Error("Aborted"); - - switch (event.type) { - case "response.output_text.delta": - streamMessage.append(adapter.extractTextDelta(event)); - break; - case "response.image_generation_call.in_progress": - streamMessage.setStatus(Environment.startingImageGenText); - await streamMessage.flush(); - break; - case "response.image_generation_call.generating": - streamMessage.setStatus(Environment.imageGenText); - await streamMessage.flush(); - break; - case "response.image_generation_call.partial_image": { - const iteration = (event.partial_image_index ?? 0) + 1; - await showOpenAiGeneratedImage( - streamMessage, - sourceMessage, - event.partial_image_b64, - `partial_${round}_${iteration}`, - Environment.getPartialImageGenText(iteration, OPENAI_IMAGE_PARTIALS), - false, - ); - break; - } - case "response.image_generation_call.completed": - streamMessage.setStatus(Environment.finalizingImageGenText); - await streamMessage.flush(); - break; - case "response.file_search_call.in_progress": - case "response.file_search_call.searching": - streamMessage.setStatus(Environment.getUseToolText(["file_search"])); - await streamMessage.flush(); - break; - case "response.file_search_call.completed": - streamMessage.clearStatus(); - await streamMessage.flush(); - break; - case "response.code_interpreter_call.in_progress": - case "response.code_interpreter_call.interpreting": - streamMessage.setStatus(Environment.getUseToolText(["code_interpreter"])); - await streamMessage.flush(); - break; - case "response.code_interpreter_call.completed": - streamMessage.clearStatus(); - await streamMessage.flush(); - break; - case "response.code_interpreter_call_code.delta": - case "response.code_interpreter_call_code.done": - break; - case "response.output_item.added": - { - const streamedCalls = adapter.extractStreamingToolCalls(event); - if (streamedCalls.length) { - localToolCalls.push(...streamedCalls); - } - aiLog("info", "openai.stream.tool_call.added", { - round, - toolCalls: localToolCalls.map(aiLogToolCall) - }); - streamMessage.setStatus(Environment.getUseToolText(localToolCalls)); - await streamMessage.flush(); - } - break; - case "response.output_item.done": - if (event.item.type === "function_call" && event.item.name) { - const item = event.item as OpenAiResponseOutputItem & { id?: string }; - const itemId = openAiResponseItemCallId(item); - const index = localToolCalls.findIndex(c => c.id === itemId); - if (index !== -1) { - localToolCalls.splice(index, 1); - if (localToolCalls.length === 0) { - streamMessage.clearStatus(); - } else { - streamMessage.setStatus(Environment.getUseToolText(localToolCalls)); - } - await streamMessage.flush(); - } - } - break; - case "response.function_call_arguments.delta": - break; - case "response.function_call_arguments.done": - break; - - case "response.completed": - completedResponse = event.response as OpenAiResponseLike; - break; - case "response.failed": - throw new Error(event.response?.error?.message ?? "OpenAI response failed"); - case "error": - throw new Error(event.message ?? event?.message ?? "OpenAI stream error"); - } - } - - if (!completedResponse) throw new Error("OpenAI did not return the final response.completed event."); - - aiLog("debug", "openai.stream.completed", { - round, - duration: aiLogDuration(roundStartedAt), - outputItems: completedResponse?.output?.length ?? 0, - }); - - const images = collectOpenAiResponseImages(completedResponse); - if (images.length) { - await showOpenAiGeneratedImage( - streamMessage, - sourceMessage, - images[images.length - 1], - `final_${round}`, - Environment.getImageGenDoneText(config.openAiImageTarget.model), - true, - ); - } - - const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(completedResponse); - if (codeInterpreterCalls.length) { - aiLog("info", "openai.code_interpreter_calls", { - round, - duration: aiLogDuration(roundStartedAt), - calls: codeInterpreterCalls.map(call => ({ - id: call.id, - status: call.status, - containerId: call.containerId, - codeChars: call.code?.length ?? 0, - outputItems: call.outputs.length, - })), - }); - } - - const calls = adapter.extractToolCalls(completedResponse); - aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", { - round, - duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt), - calls: calls.map(call => ({ - id: call.id, - name: call.name, - arguments: safeJsonParseObject(call.argumentsText) - })), - }); - if (!calls.length) return; - - const toolCalls = calls.map(call => ({ - id: call.id, - name: call.name, - argumentsText: call.argumentsText, - })); - const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory); - const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = []; - adapter.appendToolResults(toolOutputs, calls, toolResults); - - const uploadFilesResult = await tryToUploadFiles(msg, toolResults); - if (uploadFilesResult.found) { - if (!uploadFilesResult.uploaded) { - const old = toolOutputs[uploadFilesResult.toolIndex]; - const callId = old?.call_id; - if (uploadFilesResult.toolIndex >= 0) { - delete toolOutputs[uploadFilesResult.toolIndex]; - } - if (callId) { - toolOutputs.push({ - type: "function_call_output" as const, - call_id: callId, - output: "Error: " + uploadFilesResult.error - }); - } - } - } - - responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs]; + responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs]; } } finally { if (ownsDocumentRag) {