From 9a105caf0bbd22a0f82fe20ee780c02183ed029f Mon Sep 17 00:00:00 2001 From: Danil Nikolaev Date: Mon, 18 May 2026 19:24:39 +0300 Subject: [PATCH] Add shared tool loop stop policy --- PIPELINE_TODO.md | 2 +- src/ai/tool-loop-control.ts | 38 +++++++++++++++++++++++++++ src/ai/unified-ai-runner.mistral.ts | 23 +++++++++++++++++ src/ai/unified-ai-runner.ollama.ts | 25 ++++++++++++++++++ src/ai/unified-ai-runner.openai.ts | 25 ++++++++++++++++++ test/tool-loop-control.test.mjs | 40 +++++++++++++++++++++++++++++ 6 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 src/ai/tool-loop-control.ts create mode 100644 test/tool-loop-control.test.mjs diff --git a/PIPELINE_TODO.md b/PIPELINE_TODO.md index ec7fd82..9a98198 100644 --- a/PIPELINE_TODO.md +++ b/PIPELINE_TODO.md @@ -84,7 +84,7 @@ - [x] Stage `tool_loop` должен решать, есть ли tool calls. - [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`. - [x] Stage `tool_loop` должен добавлять tool results в provider adapter. -- [ ] Stage `tool_loop` должен управлять max rounds. +- [x] Stage `tool_loop` должен управлять max rounds. - [ ] Stage `tool_loop` должен сохранять tool result artifacts. - [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`. - [ ] Убрать tool loop из `runOpenAi`. diff --git a/src/ai/tool-loop-control.ts b/src/ai/tool-loop-control.ts new file mode 100644 index 0000000..ae7a280 --- /dev/null +++ b/src/ai/tool-loop-control.ts @@ -0,0 +1,38 @@ +import type {ToolCallData} from "./unified-ai-runner.shared.js"; + +export type ToolLoopStopReason = "no_tool_calls" | "max_rounds_reached"; + +export type ToolLoopContinuation = { + continue: boolean; + reason?: ToolLoopStopReason; + remainingRounds: number; +}; + +export function decideToolLoopContinuation(params: { + round: number; + maxRounds: number; + toolCalls: readonly ToolCallData[]; +}): ToolLoopContinuation { + const remainingRounds = Math.max(params.maxRounds - params.round - 1, 0); + + if (!params.toolCalls.length) { + return { + continue: false, + reason: "no_tool_calls", + remainingRounds, + }; + } + + if (remainingRounds === 0) { + return { + continue: false, + reason: "max_rounds_reached", + remainingRounds, + }; + } + + return { + continue: true, + remainingRounds, + }; +} diff --git a/src/ai/unified-ai-runner.mistral.ts b/src/ai/unified-ai-runner.mistral.ts index 74a0312..286b528 100644 --- a/src/ai/unified-ai-runner.mistral.ts +++ b/src/ai/unified-ai-runner.mistral.ts @@ -18,6 +18,7 @@ import { ToolExecutionMemory } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; +import {decideToolLoopContinuation} from "./tool-loop-control"; import {Message} from "typescript-telegram-bot-api"; export async function runMistral( @@ -111,6 +112,17 @@ export async function runMistral( adapter, appendTargets: [messages, requestMessages], }); + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "mistral.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } continue; } @@ -168,6 +180,17 @@ export async function runMistral( adapter, appendTargets: [messages, requestMessages], }); + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "mistral.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } } } finally { await adapter.finalize().catch(() => undefined); diff --git a/src/ai/unified-ai-runner.ollama.ts b/src/ai/unified-ai-runner.ollama.ts index 1807228..999d693 100644 --- a/src/ai/unified-ai-runner.ollama.ts +++ b/src/ai/unified-ai-runner.ollama.ts @@ -33,6 +33,7 @@ import { ToolExecutionMemory } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; +import {decideToolLoopContinuation} from "./tool-loop-control"; import {getToolPrompts} from "./tools/registry"; import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes"; import {getModelCapabilities} from "./provider-model-runtime"; @@ -296,6 +297,18 @@ export async function runOllama( appendTargets: [messages], }); + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "ollama.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + continue; } @@ -414,6 +427,18 @@ export async function runOllama( appendTargets: [messages], }); + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "ollama.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + let successGetNoteFileResult: GetNoteFileResult | undefined = undefined; for (const toolResult of toolResults) { diff --git a/src/ai/unified-ai-runner.openai.ts b/src/ai/unified-ai-runner.openai.ts index fd375c9..b94fda5 100644 --- a/src/ai/unified-ai-runner.openai.ts +++ b/src/ai/unified-ai-runner.openai.ts @@ -33,6 +33,7 @@ import { allToolSchemaNames } from "./unified-ai-runner.shared"; import {executeToolBatchWithAdapter} from "./tool-batch-runner"; +import {decideToolLoopContinuation} from "./tool-loop-control"; import {bot} from "../index"; import fs from "node:fs"; import path from "node:path"; @@ -204,6 +205,18 @@ export async function runOpenAi( } } + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "openai.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs]; continue; } @@ -400,6 +413,18 @@ export async function runOpenAi( } } + const continuation = decideToolLoopContinuation({ + round, + maxRounds: MAX_TOOL_ROUNDS, + toolCalls: calls, + }); + if (!continuation.continue && continuation.reason === "max_rounds_reached") { + aiLog("warn", "openai.tool_loop.max_rounds_reached", { + round, + maxRounds: MAX_TOOL_ROUNDS, + }); + } + responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs]; } } finally { diff --git a/test/tool-loop-control.test.mjs b/test/tool-loop-control.test.mjs new file mode 100644 index 0000000..2feed12 --- /dev/null +++ b/test/tool-loop-control.test.mjs @@ -0,0 +1,40 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const {decideToolLoopContinuation} = await import("../dist/ai/tool-loop-control.js"); + +test("tool loop continuation stops when there are no tool calls", () => { + const decision = decideToolLoopContinuation({ + round: 0, + maxRounds: 3, + toolCalls: [], + }); + + assert.equal(decision.continue, false); + assert.equal(decision.reason, "no_tool_calls"); + assert.equal(decision.remainingRounds, 2); +}); + +test("tool loop continuation stops on the last allowed round", () => { + const decision = decideToolLoopContinuation({ + round: 2, + maxRounds: 3, + toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}], + }); + + assert.equal(decision.continue, false); + assert.equal(decision.reason, "max_rounds_reached"); + assert.equal(decision.remainingRounds, 0); +}); + +test("tool loop continuation allows further rounds when tools remain and rounds are left", () => { + const decision = decideToolLoopContinuation({ + round: 1, + maxRounds: 3, + toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}], + }); + + assert.equal(decision.continue, true); + assert.equal(decision.reason, undefined); + assert.equal(decision.remainingRounds, 1); +});