Add shared tool loop stop policy

This commit is contained in:
2026-05-18 19:24:39 +03:00
parent 13df2a1c23
commit 9a105caf0b
6 changed files with 152 additions and 1 deletions
+1 -1
View File
@@ -84,7 +84,7 @@
- [x] Stage `tool_loop` должен решать, есть ли tool calls. - [x] Stage `tool_loop` должен решать, есть ли tool calls.
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`. - [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
- [x] Stage `tool_loop` должен добавлять tool results в provider adapter. - [x] Stage `tool_loop` должен добавлять tool results в provider adapter.
- [ ] Stage `tool_loop` должен управлять max rounds. - [x] Stage `tool_loop` должен управлять max rounds.
- [ ] Stage `tool_loop` должен сохранять tool result artifacts. - [ ] Stage `tool_loop` должен сохранять tool result artifacts.
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`. - [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
- [ ] Убрать tool loop из `runOpenAi`. - [ ] Убрать tool loop из `runOpenAi`.
+38
View File
@@ -0,0 +1,38 @@
import type {ToolCallData} from "./unified-ai-runner.shared.js";
export type ToolLoopStopReason = "no_tool_calls" | "max_rounds_reached";
export type ToolLoopContinuation = {
continue: boolean;
reason?: ToolLoopStopReason;
remainingRounds: number;
};
export function decideToolLoopContinuation(params: {
round: number;
maxRounds: number;
toolCalls: readonly ToolCallData[];
}): ToolLoopContinuation {
const remainingRounds = Math.max(params.maxRounds - params.round - 1, 0);
if (!params.toolCalls.length) {
return {
continue: false,
reason: "no_tool_calls",
remainingRounds,
};
}
if (remainingRounds === 0) {
return {
continue: false,
reason: "max_rounds_reached",
remainingRounds,
};
}
return {
continue: true,
remainingRounds,
};
}
+23
View File
@@ -18,6 +18,7 @@ import {
ToolExecutionMemory ToolExecutionMemory
} from "./unified-ai-runner.shared"; } from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control";
import {Message} from "typescript-telegram-bot-api"; import {Message} from "typescript-telegram-bot-api";
export async function runMistral( export async function runMistral(
@@ -111,6 +112,17 @@ export async function runMistral(
adapter, adapter,
appendTargets: [messages, requestMessages], appendTargets: [messages, requestMessages],
}); });
const continuation = decideToolLoopContinuation({
round,
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
continue; continue;
} }
@@ -168,6 +180,17 @@ export async function runMistral(
adapter, adapter,
appendTargets: [messages, requestMessages], appendTargets: [messages, requestMessages],
}); });
const continuation = decideToolLoopContinuation({
round,
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
} }
} finally { } finally {
await adapter.finalize().catch(() => undefined); await adapter.finalize().catch(() => undefined);
+25
View File
@@ -33,6 +33,7 @@ import {
ToolExecutionMemory ToolExecutionMemory
} from "./unified-ai-runner.shared"; } from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control";
import {getToolPrompts} from "./tools/registry"; import {getToolPrompts} from "./tools/registry";
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes"; import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
import {getModelCapabilities} from "./provider-model-runtime"; import {getModelCapabilities} from "./provider-model-runtime";
@@ -296,6 +297,18 @@ export async function runOllama(
appendTargets: [messages], appendTargets: [messages],
}); });
const continuation = decideToolLoopContinuation({
round,
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "ollama.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
continue; continue;
} }
@@ -414,6 +427,18 @@ export async function runOllama(
appendTargets: [messages], appendTargets: [messages],
}); });
const continuation = decideToolLoopContinuation({
round,
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "ollama.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
let successGetNoteFileResult: GetNoteFileResult | undefined = undefined; let successGetNoteFileResult: GetNoteFileResult | undefined = undefined;
for (const toolResult of toolResults) { for (const toolResult of toolResults) {
+25
View File
@@ -33,6 +33,7 @@ import {
allToolSchemaNames allToolSchemaNames
} from "./unified-ai-runner.shared"; } from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control";
import {bot} from "../index"; import {bot} from "../index";
import fs from "node:fs"; import fs from "node:fs";
import path from "node:path"; import path from "node:path";
@@ -204,6 +205,18 @@ export async function runOpenAi(
} }
} }
const continuation = decideToolLoopContinuation({
round,
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "openai.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs]; responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
continue; continue;
} }
@@ -400,6 +413,18 @@ export async function runOpenAi(
} }
} }
const continuation = decideToolLoopContinuation({
round,
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "openai.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs]; responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
} }
} finally { } finally {
+40
View File
@@ -0,0 +1,40 @@
import test from "node:test";
import assert from "node:assert/strict";
const {decideToolLoopContinuation} = await import("../dist/ai/tool-loop-control.js");
test("tool loop continuation stops when there are no tool calls", () => {
const decision = decideToolLoopContinuation({
round: 0,
maxRounds: 3,
toolCalls: [],
});
assert.equal(decision.continue, false);
assert.equal(decision.reason, "no_tool_calls");
assert.equal(decision.remainingRounds, 2);
});
test("tool loop continuation stops on the last allowed round", () => {
const decision = decideToolLoopContinuation({
round: 2,
maxRounds: 3,
toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}],
});
assert.equal(decision.continue, false);
assert.equal(decision.reason, "max_rounds_reached");
assert.equal(decision.remainingRounds, 0);
});
test("tool loop continuation allows further rounds when tools remain and rounds are left", () => {
const decision = decideToolLoopContinuation({
round: 1,
maxRounds: 3,
toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}],
});
assert.equal(decision.continue, true);
assert.equal(decision.reason, undefined);
assert.equal(decision.remainingRounds, 1);
});