Add shared tool loop stop policy
This commit is contained in:
+1
-1
@@ -84,7 +84,7 @@
|
||||
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
|
||||
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
|
||||
- [x] Stage `tool_loop` должен добавлять tool results в provider adapter.
|
||||
- [ ] Stage `tool_loop` должен управлять max rounds.
|
||||
- [x] Stage `tool_loop` должен управлять max rounds.
|
||||
- [ ] Stage `tool_loop` должен сохранять tool result artifacts.
|
||||
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
|
||||
- [ ] Убрать tool loop из `runOpenAi`.
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import type {ToolCallData} from "./unified-ai-runner.shared.js";
|
||||
|
||||
export type ToolLoopStopReason = "no_tool_calls" | "max_rounds_reached";
|
||||
|
||||
export type ToolLoopContinuation = {
|
||||
continue: boolean;
|
||||
reason?: ToolLoopStopReason;
|
||||
remainingRounds: number;
|
||||
};
|
||||
|
||||
export function decideToolLoopContinuation(params: {
|
||||
round: number;
|
||||
maxRounds: number;
|
||||
toolCalls: readonly ToolCallData[];
|
||||
}): ToolLoopContinuation {
|
||||
const remainingRounds = Math.max(params.maxRounds - params.round - 1, 0);
|
||||
|
||||
if (!params.toolCalls.length) {
|
||||
return {
|
||||
continue: false,
|
||||
reason: "no_tool_calls",
|
||||
remainingRounds,
|
||||
};
|
||||
}
|
||||
|
||||
if (remainingRounds === 0) {
|
||||
return {
|
||||
continue: false,
|
||||
reason: "max_rounds_reached",
|
||||
remainingRounds,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
continue: true,
|
||||
remainingRounds,
|
||||
};
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
ToolExecutionMemory
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||
import {Message} from "typescript-telegram-bot-api";
|
||||
|
||||
export async function runMistral(
|
||||
@@ -111,6 +112,17 @@ export async function runMistral(
|
||||
adapter,
|
||||
appendTargets: [messages, requestMessages],
|
||||
});
|
||||
const continuation = decideToolLoopContinuation({
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
toolCalls: calls,
|
||||
});
|
||||
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -168,6 +180,17 @@ export async function runMistral(
|
||||
adapter,
|
||||
appendTargets: [messages, requestMessages],
|
||||
});
|
||||
const continuation = decideToolLoopContinuation({
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
toolCalls: calls,
|
||||
});
|
||||
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
await adapter.finalize().catch(() => undefined);
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
ToolExecutionMemory
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||
import {getToolPrompts} from "./tools/registry";
|
||||
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
||||
import {getModelCapabilities} from "./provider-model-runtime";
|
||||
@@ -296,6 +297,18 @@ export async function runOllama(
|
||||
appendTargets: [messages],
|
||||
});
|
||||
|
||||
const continuation = decideToolLoopContinuation({
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
toolCalls: calls,
|
||||
});
|
||||
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||
aiLog("warn", "ollama.tool_loop.max_rounds_reached", {
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -414,6 +427,18 @@ export async function runOllama(
|
||||
appendTargets: [messages],
|
||||
});
|
||||
|
||||
const continuation = decideToolLoopContinuation({
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
toolCalls: calls,
|
||||
});
|
||||
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||
aiLog("warn", "ollama.tool_loop.max_rounds_reached", {
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
|
||||
let successGetNoteFileResult: GetNoteFileResult | undefined = undefined;
|
||||
|
||||
for (const toolResult of toolResults) {
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
allToolSchemaNames
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||
import {bot} from "../index";
|
||||
import fs from "node:fs";
|
||||
import path from "node:path";
|
||||
@@ -204,6 +205,18 @@ export async function runOpenAi(
|
||||
}
|
||||
}
|
||||
|
||||
const continuation = decideToolLoopContinuation({
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
toolCalls: calls,
|
||||
});
|
||||
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||
aiLog("warn", "openai.tool_loop.max_rounds_reached", {
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
|
||||
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
|
||||
continue;
|
||||
}
|
||||
@@ -400,6 +413,18 @@ export async function runOpenAi(
|
||||
}
|
||||
}
|
||||
|
||||
const continuation = decideToolLoopContinuation({
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
toolCalls: calls,
|
||||
});
|
||||
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||
aiLog("warn", "openai.tool_loop.max_rounds_reached", {
|
||||
round,
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
|
||||
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
|
||||
}
|
||||
} finally {
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
|
||||
const {decideToolLoopContinuation} = await import("../dist/ai/tool-loop-control.js");
|
||||
|
||||
test("tool loop continuation stops when there are no tool calls", () => {
|
||||
const decision = decideToolLoopContinuation({
|
||||
round: 0,
|
||||
maxRounds: 3,
|
||||
toolCalls: [],
|
||||
});
|
||||
|
||||
assert.equal(decision.continue, false);
|
||||
assert.equal(decision.reason, "no_tool_calls");
|
||||
assert.equal(decision.remainingRounds, 2);
|
||||
});
|
||||
|
||||
test("tool loop continuation stops on the last allowed round", () => {
|
||||
const decision = decideToolLoopContinuation({
|
||||
round: 2,
|
||||
maxRounds: 3,
|
||||
toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}],
|
||||
});
|
||||
|
||||
assert.equal(decision.continue, false);
|
||||
assert.equal(decision.reason, "max_rounds_reached");
|
||||
assert.equal(decision.remainingRounds, 0);
|
||||
});
|
||||
|
||||
test("tool loop continuation allows further rounds when tools remain and rounds are left", () => {
|
||||
const decision = decideToolLoopContinuation({
|
||||
round: 1,
|
||||
maxRounds: 3,
|
||||
toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}],
|
||||
});
|
||||
|
||||
assert.equal(decision.continue, true);
|
||||
assert.equal(decision.reason, undefined);
|
||||
assert.equal(decision.remainingRounds, 1);
|
||||
});
|
||||
Reference in New Issue
Block a user