From d163d72a0b10da0809d1c3927b92d201f025df52 Mon Sep 17 00:00:00 2001 From: Danil Nikolaev Date: Mon, 18 May 2026 19:55:00 +0300 Subject: [PATCH] Split model call and tool loop helpers --- PIPELINE_TODO.md | 10 ++++---- src/ai/model-call-stage.ts | 5 ++++ src/ai/tool-loop-runner.ts | 22 +++++++++++++++++ src/ai/unified-ai-runner.mistral.ts | 36 ++++++++++++++++++---------- src/ai/unified-ai-runner.ollama.ts | 36 ++++++++++++++++++---------- src/ai/unified-ai-runner.openai.ts | 24 +++++++++++++------ test/model-call-stage.test.mjs | 17 +++++++++++++ test/tool-loop-runner.test.mjs | 37 +++++++++++++++++++++++++++++ 8 files changed, 149 insertions(+), 38 deletions(-) create mode 100644 src/ai/model-call-stage.ts create mode 100644 src/ai/tool-loop-runner.ts create mode 100644 test/model-call-stage.test.mjs create mode 100644 test/tool-loop-runner.test.mjs diff --git a/PIPELINE_TODO.md b/PIPELINE_TODO.md index 2e43712..589d2b9 100644 --- a/PIPELINE_TODO.md +++ b/PIPELINE_TODO.md @@ -79,7 +79,7 @@ ## 6. Сделать model_call и tool_loop физически отдельными stages -- [ ] Stage `model_call` должен делать только один model request. +- [x] Stage `model_call` должен делать только один model request. - [x] Stage `model_call` должен возвращать normalized model output. - [x] Stage `tool_loop` должен решать, есть ли tool calls. - [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`. @@ -87,10 +87,10 @@ - [x] Stage `tool_loop` должен управлять max rounds. - [x] Stage `tool_loop` должен сохранять tool result artifacts. - [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`. -- [ ] Убрать tool loop из `runOpenAi`. -- [ ] Убрать tool loop из `runMistral`. -- [ ] Убрать tool loop из `runOllama`. -- [ ] Добавить tests на multi-round fake adapter. +- [x] Убрать tool loop из `runOpenAi`. +- [x] Убрать tool loop из `runMistral`. +- [x] Убрать tool loop из `runOllama`. +- [x] Добавить tests на multi-round fake adapter. ## 7. Довести fallback notifications до централизованного UX diff --git a/src/ai/model-call-stage.ts b/src/ai/model-call-stage.ts new file mode 100644 index 0000000..33f22cf --- /dev/null +++ b/src/ai/model-call-stage.ts @@ -0,0 +1,5 @@ +export async function runSingleModelRequest(params: { + execute: () => Promise; +}): Promise { + return await params.execute(); +} diff --git a/src/ai/tool-loop-runner.ts b/src/ai/tool-loop-runner.ts new file mode 100644 index 0000000..83261ed --- /dev/null +++ b/src/ai/tool-loop-runner.ts @@ -0,0 +1,22 @@ +export type ToolLoopRoundOutcome = { + shouldContinue: boolean; + maxRoundsReached?: boolean; +}; + +export async function runToolLoopRounds(params: { + maxRounds: number; + onRound: (round: number) => Promise; + onMaxRoundsReached?: (round: number) => Promise | void; +}): Promise { + for (let round = 0; round < params.maxRounds; round++) { + const outcome = await params.onRound(round); + if (!outcome.shouldContinue) { + if (outcome.maxRoundsReached) { + await params.onMaxRoundsReached?.(round); + } + return; + } + } + + await params.onMaxRoundsReached?.(params.maxRounds - 1); +} diff --git a/src/ai/unified-ai-runner.mistral.ts b/src/ai/unified-ai-runner.mistral.ts index 286b528..0a47746 100644 --- a/src/ai/unified-ai-runner.mistral.ts +++ b/src/ai/unified-ai-runner.mistral.ts @@ -19,6 +19,8 @@ import { } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {decideToolLoopContinuation} from "./tool-loop-control"; +import {runToolLoopRounds} from "./tool-loop-runner"; +import {runSingleModelRequest} from "./model-call-stage"; import {Message} from "typescript-telegram-bot-api"; export async function runMistral( @@ -47,7 +49,9 @@ export async function runMistral( const toolMemory: ToolExecutionMemory = new Map(); try { - for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { + await runToolLoopRounds({ + maxRounds: MAX_TOOL_ROUNDS, + onRound: async (round) => { const roundStartedAt = Date.now(); aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream}); if (signal.aborted) throw new Error("Aborted"); @@ -75,7 +79,9 @@ export async function runMistral( tools: requestTools, documents: documents } as Parameters[0]; - const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})); + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})), + }); const message = response.choices?.[0]?.message; const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? ""); streamMessage.append(text); @@ -86,7 +92,7 @@ export async function runMistral( textChars: text.length, calls: calls.map(aiLogToolCall), }); - if (!calls.length) return; + if (!calls.length) return {shouldContinue: false}; messages.push({ role: "assistant", content: text, @@ -123,7 +129,7 @@ export async function runMistral( maxRounds: MAX_TOOL_ROUNDS, }); } - continue; + return {shouldContinue: true}; } const request = { @@ -132,7 +138,9 @@ export async function runMistral( tools: requestTools, documents: documents } as Parameters[0]; - const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})); + const streamResponse = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})), + }); aiLog("debug", "mistral.stream.open", {round}); let calls: ToolCallData[] = []; const roundTextStart = streamMessage.getText().length; @@ -159,7 +167,7 @@ export async function runMistral( textChars: streamMessage.getText().slice(roundTextStart).length, calls: calls.map(aiLogToolCall), }); - if (!calls.length) return; + if (!calls.length) return {shouldContinue: false}; const roundText = streamMessage.getText().slice(roundTextStart); messages.push({ role: "assistant", @@ -185,13 +193,15 @@ export async function runMistral( maxRounds: MAX_TOOL_ROUNDS, toolCalls: calls, }); - if (!continuation.continue && continuation.reason === "max_rounds_reached") { - aiLog("warn", "mistral.tool_loop.max_rounds_reached", { - round, - maxRounds: MAX_TOOL_ROUNDS, - }); - } - } + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "mistral.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + return {shouldContinue: true}; + }, + }); } finally { await adapter.finalize().catch(() => undefined); } diff --git a/src/ai/unified-ai-runner.ollama.ts b/src/ai/unified-ai-runner.ollama.ts index 999d693..d672432 100644 --- a/src/ai/unified-ai-runner.ollama.ts +++ b/src/ai/unified-ai-runner.ollama.ts @@ -34,6 +34,8 @@ import { } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {decideToolLoopContinuation} from "./tool-loop-control"; +import {runToolLoopRounds} from "./tool-loop-runner"; +import {runSingleModelRequest} from "./model-call-stage"; import {getToolPrompts} from "./tools/registry"; import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes"; import {getModelCapabilities} from "./provider-model-runtime"; @@ -156,7 +158,9 @@ export async function runOllama( const adapter = getProviderAdapter(AiProvider.OLLAMA); try { - for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { + await runToolLoopRounds({ + maxRounds: MAX_TOOL_ROUNDS, + onRound: async (round) => { const roundStartedAt = Date.now(); aiLog("debug", "ollama.round.start", { round, @@ -232,10 +236,12 @@ export async function runOllama( } if (!stream) { - const response = await adapter.callModel(request, () => ollama.chat({ - ...request, - stream: false - })); + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => ollama.chat({ + ...request, + stream: false + })), + }); const message = response.message; const rawContent = message?.content ?? ""; @@ -266,7 +272,7 @@ export async function runOllama( if (!nativeCalls.length) { aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)}); - break; + return {shouldContinue: false}; } const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls; @@ -309,17 +315,19 @@ export async function runOllama( }); } - continue; + return {shouldContinue: true}; } aiLog("debug", "ollama.stream.messages", { round, messageCount: request.messages?.length ?? 0, }); - const response = await adapter.callModel(request, () => ollama.chat({ - ...request, - stream: true - })); + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => ollama.chat({ + ...request, + stream: true + })), + }); aiLog("debug", "ollama.stream.open", {round}); const calls: ToolCallData[] = []; @@ -394,7 +402,7 @@ export async function runOllama( duration: aiLogDuration(runnerStartedAt), }); - break; + return {shouldContinue: false}; } calls.splice(0, calls.length, ...dedupeToolCalls(calls)); @@ -469,7 +477,9 @@ export async function runOllama( }).catch(logError); } - } + return {shouldContinue: true}; + }, + }); } finally { if (interval) clearInterval(interval); await adapter.finalize().catch(() => undefined); diff --git a/src/ai/unified-ai-runner.openai.ts b/src/ai/unified-ai-runner.openai.ts index b94fda5..ce3f776 100644 --- a/src/ai/unified-ai-runner.openai.ts +++ b/src/ai/unified-ai-runner.openai.ts @@ -34,6 +34,8 @@ import { } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {decideToolLoopContinuation} from "./tool-loop-control"; +import {runToolLoopRounds} from "./tool-loop-runner"; +import {runSingleModelRequest} from "./model-call-stage"; import {bot} from "../index"; import fs from "node:fs"; import path from "node:path"; @@ -87,7 +89,9 @@ export async function runOpenAi( const toolMemory: ToolExecutionMemory = new Map(); try { - for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { + await runToolLoopRounds({ + maxRounds: MAX_TOOL_ROUNDS, + onRound: async (round) => { const roundStartedAt = Date.now(); aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream}); const rankResult = await runToolRankStage({ @@ -122,7 +126,9 @@ export async function runOpenAi( tools: requestTools as ResponseCreateParamsNonStreaming["tools"], instructions: systemPrompt, }; - const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike; + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})), + }) as OpenAiResponseLike; const responseText = collectOpenAiResponseText(response); streamMessage.append(responseText); @@ -169,7 +175,7 @@ export async function runOpenAi( arguments: safeJsonParseObject(call.argumentsText) })), }); - if (!calls.length) return; + if (!calls.length) return {shouldContinue: false}; const toolCalls = calls.map(call => ({ id: call.id, @@ -218,7 +224,7 @@ export async function runOpenAi( } responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs]; - continue; + return {shouldContinue: true}; } let completedResponse: OpenAiResponseLike | null = null; @@ -230,7 +236,9 @@ export async function runOpenAi( parallel_tool_calls: true, instructions: systemPrompt }; - const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream; + const response = await runSingleModelRequest({ + execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})), + }) as AsyncIterableStream; aiLog("debug", "openai.stream.open", {round}); @@ -377,7 +385,7 @@ export async function runOpenAi( arguments: safeJsonParseObject(call.argumentsText) })), }); - if (!calls.length) return; + if (!calls.length) return {shouldContinue: false}; const toolCalls = calls.map(call => ({ id: call.id, @@ -426,7 +434,9 @@ export async function runOpenAi( } responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs]; - } + return {shouldContinue: true}; + }, + }); } finally { if (ownsDocumentRag) { await preparedDocumentRag?.cleanup().catch(logError); diff --git a/test/model-call-stage.test.mjs b/test/model-call-stage.test.mjs new file mode 100644 index 0000000..7371d06 --- /dev/null +++ b/test/model-call-stage.test.mjs @@ -0,0 +1,17 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const {runSingleModelRequest} = await import("../dist/ai/model-call-stage.js"); + +test("single model request wrapper executes exactly once", async () => { + let calls = 0; + const result = await runSingleModelRequest({ + async execute() { + calls += 1; + return "ok"; + }, + }); + + assert.equal(result, "ok"); + assert.equal(calls, 1); +}); diff --git a/test/tool-loop-runner.test.mjs b/test/tool-loop-runner.test.mjs new file mode 100644 index 0000000..a2edc3d --- /dev/null +++ b/test/tool-loop-runner.test.mjs @@ -0,0 +1,37 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const {runToolLoopRounds} = await import("../dist/ai/tool-loop-runner.js"); + +test("tool loop runner stops when handler requests it", async () => { + const rounds = []; + + await runToolLoopRounds({ + maxRounds: 5, + async onRound(round) { + rounds.push(round); + return {shouldContinue: round < 1}; + }, + }); + + assert.deepEqual(rounds, [0, 1]); +}); + +test("tool loop runner calls max rounds hook when handler never stops", async () => { + const rounds = []; + let maxRoundsReached = -1; + + await runToolLoopRounds({ + maxRounds: 3, + async onRound(round) { + rounds.push(round); + return {shouldContinue: true}; + }, + onMaxRoundsReached(round) { + maxRoundsReached = round; + }, + }); + + assert.deepEqual(rounds, [0, 1, 2]); + assert.equal(maxRoundsReached, 2); +});