Add shared tool loop stop policy
This commit is contained in:
+1
-1
@@ -84,7 +84,7 @@
|
|||||||
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
|
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
|
||||||
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
|
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
|
||||||
- [x] Stage `tool_loop` должен добавлять tool results в provider adapter.
|
- [x] Stage `tool_loop` должен добавлять tool results в provider adapter.
|
||||||
- [ ] Stage `tool_loop` должен управлять max rounds.
|
- [x] Stage `tool_loop` должен управлять max rounds.
|
||||||
- [ ] Stage `tool_loop` должен сохранять tool result artifacts.
|
- [ ] Stage `tool_loop` должен сохранять tool result artifacts.
|
||||||
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
|
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
|
||||||
- [ ] Убрать tool loop из `runOpenAi`.
|
- [ ] Убрать tool loop из `runOpenAi`.
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
import type {ToolCallData} from "./unified-ai-runner.shared.js";
|
||||||
|
|
||||||
|
export type ToolLoopStopReason = "no_tool_calls" | "max_rounds_reached";
|
||||||
|
|
||||||
|
export type ToolLoopContinuation = {
|
||||||
|
continue: boolean;
|
||||||
|
reason?: ToolLoopStopReason;
|
||||||
|
remainingRounds: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function decideToolLoopContinuation(params: {
|
||||||
|
round: number;
|
||||||
|
maxRounds: number;
|
||||||
|
toolCalls: readonly ToolCallData[];
|
||||||
|
}): ToolLoopContinuation {
|
||||||
|
const remainingRounds = Math.max(params.maxRounds - params.round - 1, 0);
|
||||||
|
|
||||||
|
if (!params.toolCalls.length) {
|
||||||
|
return {
|
||||||
|
continue: false,
|
||||||
|
reason: "no_tool_calls",
|
||||||
|
remainingRounds,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remainingRounds === 0) {
|
||||||
|
return {
|
||||||
|
continue: false,
|
||||||
|
reason: "max_rounds_reached",
|
||||||
|
remainingRounds,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
continue: true,
|
||||||
|
remainingRounds,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ import {
|
|||||||
ToolExecutionMemory
|
ToolExecutionMemory
|
||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||||
|
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||||
import {Message} from "typescript-telegram-bot-api";
|
import {Message} from "typescript-telegram-bot-api";
|
||||||
|
|
||||||
export async function runMistral(
|
export async function runMistral(
|
||||||
@@ -111,6 +112,17 @@ export async function runMistral(
|
|||||||
adapter,
|
adapter,
|
||||||
appendTargets: [messages, requestMessages],
|
appendTargets: [messages, requestMessages],
|
||||||
});
|
});
|
||||||
|
const continuation = decideToolLoopContinuation({
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
toolCalls: calls,
|
||||||
|
});
|
||||||
|
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||||
|
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
});
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,6 +180,17 @@ export async function runMistral(
|
|||||||
adapter,
|
adapter,
|
||||||
appendTargets: [messages, requestMessages],
|
appendTargets: [messages, requestMessages],
|
||||||
});
|
});
|
||||||
|
const continuation = decideToolLoopContinuation({
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
toolCalls: calls,
|
||||||
|
});
|
||||||
|
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||||
|
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
await adapter.finalize().catch(() => undefined);
|
await adapter.finalize().catch(() => undefined);
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ import {
|
|||||||
ToolExecutionMemory
|
ToolExecutionMemory
|
||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||||
|
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||||
import {getToolPrompts} from "./tools/registry";
|
import {getToolPrompts} from "./tools/registry";
|
||||||
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
||||||
import {getModelCapabilities} from "./provider-model-runtime";
|
import {getModelCapabilities} from "./provider-model-runtime";
|
||||||
@@ -296,6 +297,18 @@ export async function runOllama(
|
|||||||
appendTargets: [messages],
|
appendTargets: [messages],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const continuation = decideToolLoopContinuation({
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
toolCalls: calls,
|
||||||
|
});
|
||||||
|
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||||
|
aiLog("warn", "ollama.tool_loop.max_rounds_reached", {
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,6 +427,18 @@ export async function runOllama(
|
|||||||
appendTargets: [messages],
|
appendTargets: [messages],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const continuation = decideToolLoopContinuation({
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
toolCalls: calls,
|
||||||
|
});
|
||||||
|
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||||
|
aiLog("warn", "ollama.tool_loop.max_rounds_reached", {
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let successGetNoteFileResult: GetNoteFileResult | undefined = undefined;
|
let successGetNoteFileResult: GetNoteFileResult | undefined = undefined;
|
||||||
|
|
||||||
for (const toolResult of toolResults) {
|
for (const toolResult of toolResults) {
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ import {
|
|||||||
allToolSchemaNames
|
allToolSchemaNames
|
||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||||
|
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||||
import {bot} from "../index";
|
import {bot} from "../index";
|
||||||
import fs from "node:fs";
|
import fs from "node:fs";
|
||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
@@ -204,6 +205,18 @@ export async function runOpenAi(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const continuation = decideToolLoopContinuation({
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
toolCalls: calls,
|
||||||
|
});
|
||||||
|
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||||
|
aiLog("warn", "openai.tool_loop.max_rounds_reached", {
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
|
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -400,6 +413,18 @@ export async function runOpenAi(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const continuation = decideToolLoopContinuation({
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
toolCalls: calls,
|
||||||
|
});
|
||||||
|
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
|
||||||
|
aiLog("warn", "openai.tool_loop.max_rounds_reached", {
|
||||||
|
round,
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
|
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import test from "node:test";
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
|
||||||
|
const {decideToolLoopContinuation} = await import("../dist/ai/tool-loop-control.js");
|
||||||
|
|
||||||
|
test("tool loop continuation stops when there are no tool calls", () => {
|
||||||
|
const decision = decideToolLoopContinuation({
|
||||||
|
round: 0,
|
||||||
|
maxRounds: 3,
|
||||||
|
toolCalls: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(decision.continue, false);
|
||||||
|
assert.equal(decision.reason, "no_tool_calls");
|
||||||
|
assert.equal(decision.remainingRounds, 2);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("tool loop continuation stops on the last allowed round", () => {
|
||||||
|
const decision = decideToolLoopContinuation({
|
||||||
|
round: 2,
|
||||||
|
maxRounds: 3,
|
||||||
|
toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}],
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(decision.continue, false);
|
||||||
|
assert.equal(decision.reason, "max_rounds_reached");
|
||||||
|
assert.equal(decision.remainingRounds, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("tool loop continuation allows further rounds when tools remain and rounds are left", () => {
|
||||||
|
const decision = decideToolLoopContinuation({
|
||||||
|
round: 1,
|
||||||
|
maxRounds: 3,
|
||||||
|
toolCalls: [{id: "call-1", name: "read_file", argumentsText: "{}"}],
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(decision.continue, true);
|
||||||
|
assert.equal(decision.reason, undefined);
|
||||||
|
assert.equal(decision.remainingRounds, 1);
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user