Extract shared tool batch adapter helper

This commit is contained in:
2026-05-18 19:18:22 +03:00
parent 9352ade19f
commit 13df2a1c23
5 changed files with 312 additions and 243 deletions
+2 -2
View File
@@ -82,8 +82,8 @@
- [ ] Stage `model_call` должен делать только один model request.
- [x] Stage `model_call` должен возвращать normalized model output.
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
- [ ] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
- [ ] Stage `tool_loop` должен добавлять tool results в provider adapter.
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
- [x] Stage `tool_loop` должен добавлять tool results в provider adapter.
- [ ] Stage `tool_loop` должен управлять max rounds.
- [ ] Stage `tool_loop` должен сохранять tool result artifacts.
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
+28
View File
@@ -0,0 +1,28 @@
import type {AiProviderAdapter} from "./provider-adapters.js";
import {executeToolBatch, type ToolCallData, type ToolExecutionMemory} from "./unified-ai-runner.shared.js";
import type {TelegramStreamMessage} from "./telegram-stream-message.js";
import type {ToolRuntimeContext} from "./tools/runtime.js";
export async function executeToolBatchWithAdapter(params: {
userId: number | undefined | null;
toolCalls: ToolCallData[];
streamMessage: TelegramStreamMessage;
toolContext: ToolRuntimeContext;
toolMemory: ToolExecutionMemory;
adapter: AiProviderAdapter;
appendTargets?: unknown[][];
}): Promise<string[]> {
const results = await executeToolBatch(
params.userId,
params.toolCalls,
params.streamMessage,
params.toolContext,
params.toolMemory,
);
for (const target of params.appendTargets ?? []) {
params.adapter.appendToolResults(target, params.toolCalls, results);
}
return results;
}
+19 -7
View File
@@ -9,7 +9,6 @@ import {getProviderAdapter} from "./provider-adapters";
import {runToolRankStage} from "./tool-rank-stage";
import {
executeToolBatch,
MAX_TOOL_ROUNDS,
MistralDocumentReference,
roundStatus,
@@ -18,6 +17,7 @@ import {
ToolCallData,
ToolExecutionMemory
} from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {Message} from "typescript-telegram-bot-api";
export async function runMistral(
@@ -102,9 +102,15 @@ export async function runMistral(
function: {name: call.name, arguments: call.argumentsText},
})),
});
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
adapter.appendToolResults(messages, calls, toolResults);
adapter.appendToolResults(requestMessages, calls, toolResults);
await executeToolBatchWithAdapter({
userId: msg.from?.id,
toolCalls: calls,
streamMessage,
toolContext,
toolMemory,
adapter,
appendTargets: [messages, requestMessages],
});
continue;
}
@@ -153,9 +159,15 @@ export async function runMistral(
content: roundText,
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
});
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
adapter.appendToolResults(messages, calls, toolResults);
adapter.appendToolResults(requestMessages, calls, toolResults);
await executeToolBatchWithAdapter({
userId: msg.from?.id,
toolCalls: calls,
streamMessage,
toolContext,
toolMemory,
adapter,
appendTargets: [messages, requestMessages],
});
}
} finally {
await adapter.finalize().catch(() => undefined);
+19 -4
View File
@@ -20,7 +20,6 @@ import {
allToolSchemaNames,
dedupeToolCalls,
DEFAULT_OLLAMA_CONTEXT_SIZE,
executeToolBatch,
isOllamaModelActive,
isRecord,
MAX_OLLAMA_CONTEXT_SIZE,
@@ -33,6 +32,7 @@ import {
ToolCallData,
ToolExecutionMemory
} from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {getToolPrompts} from "./tools/registry";
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
import {getModelCapabilities} from "./provider-model-runtime";
@@ -286,7 +286,15 @@ export async function runOllama(
})),
});
adapter.appendToolResults(messages, calls, await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory));
await executeToolBatchWithAdapter({
userId: msg.from?.id,
toolCalls: calls,
streamMessage,
toolContext,
toolMemory,
adapter,
appendTargets: [messages],
});
continue;
}
@@ -396,7 +404,15 @@ export async function runOllama(
})),
});
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
const toolResults = await executeToolBatchWithAdapter({
userId: msg.from?.id,
toolCalls: calls,
streamMessage,
toolContext,
toolMemory,
adapter,
appendTargets: [messages],
});
let successGetNoteFileResult: GetNoteFileResult | undefined = undefined;
@@ -428,7 +444,6 @@ export async function runOllama(
}).catch(logError);
}
adapter.appendToolResults(messages, calls, toolResults);
}
} finally {
if (interval) clearInterval(interval);
+244 -230
View File
@@ -19,7 +19,6 @@ import {
collectOpenAiResponseCodeInterpreterCalls,
collectOpenAiResponseImages,
collectOpenAiResponseText,
executeToolBatch,
MAX_TOOL_ROUNDS,
OPENAI_IMAGE_PARTIALS,
openAiResponseItemCallId,
@@ -33,6 +32,7 @@ import {
errorMessage,
allToolSchemaNames
} from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {bot} from "../index";
import fs from "node:fs";
import path from "node:path";
@@ -87,51 +87,247 @@ export async function runOpenAi(
try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
const roundStartedAt = Date.now();
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
const rankResult = await runToolRankStage({
provider: AiProvider.OPENAI,
model: config.openAiChatTarget.model,
round,
config,
availableTools,
messages,
streamMessage,
signal,
});
const filteredTools = rankResult.filteredTools;
const requestTools = preparedDocumentRag?.vectorStoreIds.length
? (() => {
const tools = [...filteredTools];
const hasFileSearch = allToolSchemaNames(tools).includes("file_search");
if (!hasFileSearch) {
const fileSearchTool = availableTools.find(tool => allToolSchemaNames([tool]).includes("file_search"));
if (fileSearchTool) {
tools.unshift(fileSearchTool);
const roundStartedAt = Date.now();
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
const rankResult = await runToolRankStage({
provider: AiProvider.OPENAI,
model: config.openAiChatTarget.model,
round,
config,
availableTools,
messages,
streamMessage,
signal,
});
const filteredTools = rankResult.filteredTools;
const requestTools = preparedDocumentRag?.vectorStoreIds.length
? (() => {
const tools = [...filteredTools];
const hasFileSearch = allToolSchemaNames(tools).includes("file_search");
if (!hasFileSearch) {
const fileSearchTool = availableTools.find(tool => allToolSchemaNames([tool]).includes("file_search"));
if (fileSearchTool) {
tools.unshift(fileSearchTool);
}
}
return tools.length ? tools : undefined;
})()
: (filteredTools.length ? filteredTools : undefined);
if (!stream) {
const request: ResponseCreateParamsNonStreaming = {
model: config.openAiChatTarget.model,
input: responseInput as ResponseInputItem[],
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
instructions: systemPrompt,
};
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
const responseText = collectOpenAiResponseText(response);
streamMessage.append(responseText);
aiLog("debug", "openai.response.received", {
round,
duration: aiLogDuration(roundStartedAt),
textChars: responseText.length,
outputItems: response?.output?.length ?? 0,
});
const images = collectOpenAiResponseImages(response);
if (images.length) {
await showOpenAiGeneratedImage(
streamMessage,
sourceMessage,
images[images.length - 1],
`final_${round}`,
Environment.getImageGenDoneText(config.openAiImageTarget.model),
true,
);
}
const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(response);
if (codeInterpreterCalls.length) {
aiLog("info", "openai.code_interpreter_calls", {
round,
duration: aiLogDuration(roundStartedAt),
calls: codeInterpreterCalls.map(call => ({
id: call.id,
status: call.status,
containerId: call.containerId,
codeChars: call.code?.length ?? 0,
outputItems: call.outputs.length,
})),
});
}
const calls = adapter.extractToolCalls(response);
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
round,
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
calls: calls.map(call => ({
id: call.id,
name: call.name,
arguments: safeJsonParseObject(call.argumentsText)
})),
});
if (!calls.length) return;
const toolCalls = calls.map(call => ({
id: call.id,
name: call.name,
argumentsText: call.argumentsText,
}));
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
const toolResults = await executeToolBatchWithAdapter({
userId: msg.from?.id,
toolCalls,
streamMessage,
toolContext,
toolMemory,
adapter,
appendTargets: [toolOutputs],
});
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
if (uploadFilesResult.found) {
if (!uploadFilesResult.uploaded) {
const old = toolOutputs[uploadFilesResult.toolIndex];
const callId = old?.call_id;
if (uploadFilesResult.toolIndex >= 0) {
delete toolOutputs[uploadFilesResult.toolIndex];
}
if (callId) {
toolOutputs.push({
type: "function_call_output" as const,
call_id: callId,
output: "Error: " + uploadFilesResult.error
});
}
}
}
return tools.length ? tools : undefined;
})()
: (filteredTools.length ? filteredTools : undefined);
if (!stream) {
const request: ResponseCreateParamsNonStreaming = {
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
continue;
}
let completedResponse: OpenAiResponseLike | null = null;
const request: ResponseCreateParamsStreaming = {
model: config.openAiChatTarget.model,
input: responseInput as ResponseInputItem[],
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
instructions: systemPrompt,
stream: true,
tools: requestTools as ResponseCreateParamsStreaming["tools"],
parallel_tool_calls: true,
instructions: systemPrompt
};
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
const responseText = collectOpenAiResponseText(response);
streamMessage.append(responseText);
aiLog("debug", "openai.response.received", {
aiLog("debug", "openai.stream.open", {round});
let localToolCalls: ToolCallData[] = [];
for await (const event of response) {
if (signal.aborted) throw new Error("Aborted");
switch (event.type) {
case "response.output_text.delta":
streamMessage.append(adapter.extractTextDelta(event));
break;
case "response.image_generation_call.in_progress":
streamMessage.setStatus(Environment.startingImageGenText);
await streamMessage.flush();
break;
case "response.image_generation_call.generating":
streamMessage.setStatus(Environment.imageGenText);
await streamMessage.flush();
break;
case "response.image_generation_call.partial_image": {
const iteration = (event.partial_image_index ?? 0) + 1;
await showOpenAiGeneratedImage(
streamMessage,
sourceMessage,
event.partial_image_b64,
`partial_${round}_${iteration}`,
Environment.getPartialImageGenText(iteration, OPENAI_IMAGE_PARTIALS),
false,
);
break;
}
case "response.image_generation_call.completed":
streamMessage.setStatus(Environment.finalizingImageGenText);
await streamMessage.flush();
break;
case "response.file_search_call.in_progress":
case "response.file_search_call.searching":
streamMessage.setStatus(Environment.getUseToolText(["file_search"]));
await streamMessage.flush();
break;
case "response.file_search_call.completed":
streamMessage.clearStatus();
await streamMessage.flush();
break;
case "response.code_interpreter_call.in_progress":
case "response.code_interpreter_call.interpreting":
streamMessage.setStatus(Environment.getUseToolText(["code_interpreter"]));
await streamMessage.flush();
break;
case "response.code_interpreter_call.completed":
streamMessage.clearStatus();
await streamMessage.flush();
break;
case "response.code_interpreter_call_code.delta":
case "response.code_interpreter_call_code.done":
break;
case "response.output_item.added":
{
const streamedCalls = adapter.extractStreamingToolCalls(event);
if (streamedCalls.length) {
localToolCalls.push(...streamedCalls);
}
aiLog("info", "openai.stream.tool_call.added", {
round,
toolCalls: localToolCalls.map(aiLogToolCall)
});
streamMessage.setStatus(Environment.getUseToolText(localToolCalls));
await streamMessage.flush();
}
break;
case "response.output_item.done":
if (event.item.type === "function_call" && event.item.name) {
const item = event.item as OpenAiResponseOutputItem & { id?: string };
const itemId = openAiResponseItemCallId(item);
const index = localToolCalls.findIndex(c => c.id === itemId);
if (index !== -1) {
localToolCalls.splice(index, 1);
if (localToolCalls.length === 0) {
streamMessage.clearStatus();
} else {
streamMessage.setStatus(Environment.getUseToolText(localToolCalls));
}
await streamMessage.flush();
}
}
break;
case "response.function_call_arguments.delta":
break;
case "response.function_call_arguments.done":
break;
case "response.completed":
completedResponse = event.response as OpenAiResponseLike;
break;
case "response.failed":
throw new Error(event.response?.error?.message ?? "OpenAI response failed");
case "error":
throw new Error(event.message ?? event?.message ?? "OpenAI stream error");
}
}
if (!completedResponse) throw new Error("OpenAI did not return the final response.completed event.");
aiLog("debug", "openai.stream.completed", {
round,
duration: aiLogDuration(roundStartedAt),
textChars: responseText.length,
outputItems: response?.output?.length ?? 0,
outputItems: completedResponse?.output?.length ?? 0,
});
const images = collectOpenAiResponseImages(response);
const images = collectOpenAiResponseImages(completedResponse);
if (images.length) {
await showOpenAiGeneratedImage(
streamMessage,
@@ -143,7 +339,7 @@ export async function runOpenAi(
);
}
const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(response);
const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(completedResponse);
if (codeInterpreterCalls.length) {
aiLog("info", "openai.code_interpreter_calls", {
round,
@@ -158,7 +354,7 @@ export async function runOpenAi(
});
}
const calls = adapter.extractToolCalls(response);
const calls = adapter.extractToolCalls(completedResponse);
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
round,
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
@@ -175,9 +371,16 @@ export async function runOpenAi(
name: call.name,
argumentsText: call.argumentsText,
}));
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
adapter.appendToolResults(toolOutputs, calls, toolResults);
const toolResults = await executeToolBatchWithAdapter({
userId: msg.from?.id,
toolCalls,
streamMessage,
toolContext,
toolMemory,
adapter,
appendTargets: [toolOutputs],
});
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
if (uploadFilesResult.found) {
@@ -197,196 +400,7 @@ export async function runOpenAi(
}
}
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
continue;
}
let completedResponse: OpenAiResponseLike | null = null;
const request: ResponseCreateParamsStreaming = {
model: config.openAiChatTarget.model,
input: responseInput as ResponseInputItem[],
stream: true,
tools: requestTools as ResponseCreateParamsStreaming["tools"],
parallel_tool_calls: true,
instructions: systemPrompt
};
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
aiLog("debug", "openai.stream.open", {round});
let localToolCalls: ToolCallData[] = [];
for await (const event of response) {
if (signal.aborted) throw new Error("Aborted");
switch (event.type) {
case "response.output_text.delta":
streamMessage.append(adapter.extractTextDelta(event));
break;
case "response.image_generation_call.in_progress":
streamMessage.setStatus(Environment.startingImageGenText);
await streamMessage.flush();
break;
case "response.image_generation_call.generating":
streamMessage.setStatus(Environment.imageGenText);
await streamMessage.flush();
break;
case "response.image_generation_call.partial_image": {
const iteration = (event.partial_image_index ?? 0) + 1;
await showOpenAiGeneratedImage(
streamMessage,
sourceMessage,
event.partial_image_b64,
`partial_${round}_${iteration}`,
Environment.getPartialImageGenText(iteration, OPENAI_IMAGE_PARTIALS),
false,
);
break;
}
case "response.image_generation_call.completed":
streamMessage.setStatus(Environment.finalizingImageGenText);
await streamMessage.flush();
break;
case "response.file_search_call.in_progress":
case "response.file_search_call.searching":
streamMessage.setStatus(Environment.getUseToolText(["file_search"]));
await streamMessage.flush();
break;
case "response.file_search_call.completed":
streamMessage.clearStatus();
await streamMessage.flush();
break;
case "response.code_interpreter_call.in_progress":
case "response.code_interpreter_call.interpreting":
streamMessage.setStatus(Environment.getUseToolText(["code_interpreter"]));
await streamMessage.flush();
break;
case "response.code_interpreter_call.completed":
streamMessage.clearStatus();
await streamMessage.flush();
break;
case "response.code_interpreter_call_code.delta":
case "response.code_interpreter_call_code.done":
break;
case "response.output_item.added":
{
const streamedCalls = adapter.extractStreamingToolCalls(event);
if (streamedCalls.length) {
localToolCalls.push(...streamedCalls);
}
aiLog("info", "openai.stream.tool_call.added", {
round,
toolCalls: localToolCalls.map(aiLogToolCall)
});
streamMessage.setStatus(Environment.getUseToolText(localToolCalls));
await streamMessage.flush();
}
break;
case "response.output_item.done":
if (event.item.type === "function_call" && event.item.name) {
const item = event.item as OpenAiResponseOutputItem & { id?: string };
const itemId = openAiResponseItemCallId(item);
const index = localToolCalls.findIndex(c => c.id === itemId);
if (index !== -1) {
localToolCalls.splice(index, 1);
if (localToolCalls.length === 0) {
streamMessage.clearStatus();
} else {
streamMessage.setStatus(Environment.getUseToolText(localToolCalls));
}
await streamMessage.flush();
}
}
break;
case "response.function_call_arguments.delta":
break;
case "response.function_call_arguments.done":
break;
case "response.completed":
completedResponse = event.response as OpenAiResponseLike;
break;
case "response.failed":
throw new Error(event.response?.error?.message ?? "OpenAI response failed");
case "error":
throw new Error(event.message ?? event?.message ?? "OpenAI stream error");
}
}
if (!completedResponse) throw new Error("OpenAI did not return the final response.completed event.");
aiLog("debug", "openai.stream.completed", {
round,
duration: aiLogDuration(roundStartedAt),
outputItems: completedResponse?.output?.length ?? 0,
});
const images = collectOpenAiResponseImages(completedResponse);
if (images.length) {
await showOpenAiGeneratedImage(
streamMessage,
sourceMessage,
images[images.length - 1],
`final_${round}`,
Environment.getImageGenDoneText(config.openAiImageTarget.model),
true,
);
}
const codeInterpreterCalls = collectOpenAiResponseCodeInterpreterCalls(completedResponse);
if (codeInterpreterCalls.length) {
aiLog("info", "openai.code_interpreter_calls", {
round,
duration: aiLogDuration(roundStartedAt),
calls: codeInterpreterCalls.map(call => ({
id: call.id,
status: call.status,
containerId: call.containerId,
codeChars: call.code?.length ?? 0,
outputItems: call.outputs.length,
})),
});
}
const calls = adapter.extractToolCalls(completedResponse);
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
round,
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
calls: calls.map(call => ({
id: call.id,
name: call.name,
arguments: safeJsonParseObject(call.argumentsText)
})),
});
if (!calls.length) return;
const toolCalls = calls.map(call => ({
id: call.id,
name: call.name,
argumentsText: call.argumentsText,
}));
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
adapter.appendToolResults(toolOutputs, calls, toolResults);
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
if (uploadFilesResult.found) {
if (!uploadFilesResult.uploaded) {
const old = toolOutputs[uploadFilesResult.toolIndex];
const callId = old?.call_id;
if (uploadFilesResult.toolIndex >= 0) {
delete toolOutputs[uploadFilesResult.toolIndex];
}
if (callId) {
toolOutputs.push({
type: "function_call_output" as const,
call_id: callId,
output: "Error: " + uploadFilesResult.error
});
}
}
}
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
}
} finally {
if (ownsDocumentRag) {