Split model call and tool loop helpers

This commit is contained in:
2026-05-18 19:55:00 +03:00
parent 57985ce87b
commit d163d72a0b
8 changed files with 149 additions and 38 deletions
+5 -5
View File
@@ -79,7 +79,7 @@
## 6. Сделать model_call и tool_loop физически отдельными stages
- [ ] Stage `model_call` должен делать только один model request.
- [x] Stage `model_call` должен делать только один model request.
- [x] Stage `model_call` должен возвращать normalized model output.
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
@@ -87,10 +87,10 @@
- [x] Stage `tool_loop` должен управлять max rounds.
- [x] Stage `tool_loop` должен сохранять tool result artifacts.
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
- [ ] Убрать tool loop из `runOpenAi`.
- [ ] Убрать tool loop из `runMistral`.
- [ ] Убрать tool loop из `runOllama`.
- [ ] Добавить tests на multi-round fake adapter.
- [x] Убрать tool loop из `runOpenAi`.
- [x] Убрать tool loop из `runMistral`.
- [x] Убрать tool loop из `runOllama`.
- [x] Добавить tests на multi-round fake adapter.
## 7. Довести fallback notifications до централизованного UX
+5
View File
@@ -0,0 +1,5 @@
export async function runSingleModelRequest<T>(params: {
execute: () => Promise<T>;
}): Promise<T> {
return await params.execute();
}
+22
View File
@@ -0,0 +1,22 @@
export type ToolLoopRoundOutcome = {
shouldContinue: boolean;
maxRoundsReached?: boolean;
};
export async function runToolLoopRounds(params: {
maxRounds: number;
onRound: (round: number) => Promise<ToolLoopRoundOutcome>;
onMaxRoundsReached?: (round: number) => Promise<void> | void;
}): Promise<void> {
for (let round = 0; round < params.maxRounds; round++) {
const outcome = await params.onRound(round);
if (!outcome.shouldContinue) {
if (outcome.maxRoundsReached) {
await params.onMaxRoundsReached?.(round);
}
return;
}
}
await params.onMaxRoundsReached?.(params.maxRounds - 1);
}
+23 -13
View File
@@ -19,6 +19,8 @@ import {
} from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control";
import {runToolLoopRounds} from "./tool-loop-runner";
import {runSingleModelRequest} from "./model-call-stage";
import {Message} from "typescript-telegram-bot-api";
export async function runMistral(
@@ -47,7 +49,9 @@ export async function runMistral(
const toolMemory: ToolExecutionMemory = new Map();
try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
await runToolLoopRounds({
maxRounds: MAX_TOOL_ROUNDS,
onRound: async (round) => {
const roundStartedAt = Date.now();
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
if (signal.aborted) throw new Error("Aborted");
@@ -75,7 +79,9 @@ export async function runMistral(
tools: requestTools,
documents: documents
} as Parameters<typeof mistralAi.chat.complete>[0];
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal}));
const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})),
});
const message = response.choices?.[0]?.message;
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
streamMessage.append(text);
@@ -86,7 +92,7 @@ export async function runMistral(
textChars: text.length,
calls: calls.map(aiLogToolCall),
});
if (!calls.length) return;
if (!calls.length) return {shouldContinue: false};
messages.push({
role: "assistant",
content: text,
@@ -123,7 +129,7 @@ export async function runMistral(
maxRounds: MAX_TOOL_ROUNDS,
});
}
continue;
return {shouldContinue: true};
}
const request = {
@@ -132,7 +138,9 @@ export async function runMistral(
tools: requestTools,
documents: documents
} as Parameters<typeof mistralAi.chat.stream>[0];
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal}));
const streamResponse = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})),
});
aiLog("debug", "mistral.stream.open", {round});
let calls: ToolCallData[] = [];
const roundTextStart = streamMessage.getText().length;
@@ -159,7 +167,7 @@ export async function runMistral(
textChars: streamMessage.getText().slice(roundTextStart).length,
calls: calls.map(aiLogToolCall),
});
if (!calls.length) return;
if (!calls.length) return {shouldContinue: false};
const roundText = streamMessage.getText().slice(roundTextStart);
messages.push({
role: "assistant",
@@ -185,13 +193,15 @@ export async function runMistral(
maxRounds: MAX_TOOL_ROUNDS,
toolCalls: calls,
});
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
}
if (!continuation.continue && continuation.reason === "max_rounds_reached") {
aiLog("warn", "mistral.tool_loop.max_rounds_reached", {
round,
maxRounds: MAX_TOOL_ROUNDS,
});
}
return {shouldContinue: true};
},
});
} finally {
await adapter.finalize().catch(() => undefined);
}
+23 -13
View File
@@ -34,6 +34,8 @@ import {
} from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control";
import {runToolLoopRounds} from "./tool-loop-runner";
import {runSingleModelRequest} from "./model-call-stage";
import {getToolPrompts} from "./tools/registry";
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
import {getModelCapabilities} from "./provider-model-runtime";
@@ -156,7 +158,9 @@ export async function runOllama(
const adapter = getProviderAdapter(AiProvider.OLLAMA);
try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
await runToolLoopRounds({
maxRounds: MAX_TOOL_ROUNDS,
onRound: async (round) => {
const roundStartedAt = Date.now();
aiLog("debug", "ollama.round.start", {
round,
@@ -232,10 +236,12 @@ export async function runOllama(
}
if (!stream) {
const response = await adapter.callModel(request, () => ollama.chat({
...request,
stream: false
}));
const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => ollama.chat({
...request,
stream: false
})),
});
const message = response.message;
const rawContent = message?.content ?? "";
@@ -266,7 +272,7 @@ export async function runOllama(
if (!nativeCalls.length) {
aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)});
break;
return {shouldContinue: false};
}
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
@@ -309,17 +315,19 @@ export async function runOllama(
});
}
continue;
return {shouldContinue: true};
}
aiLog("debug", "ollama.stream.messages", {
round,
messageCount: request.messages?.length ?? 0,
});
const response = await adapter.callModel(request, () => ollama.chat({
...request,
stream: true
}));
const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => ollama.chat({
...request,
stream: true
})),
});
aiLog("debug", "ollama.stream.open", {round});
const calls: ToolCallData[] = [];
@@ -394,7 +402,7 @@ export async function runOllama(
duration: aiLogDuration(runnerStartedAt),
});
break;
return {shouldContinue: false};
}
calls.splice(0, calls.length, ...dedupeToolCalls(calls));
@@ -469,7 +477,9 @@ export async function runOllama(
}).catch(logError);
}
}
return {shouldContinue: true};
},
});
} finally {
if (interval) clearInterval(interval);
await adapter.finalize().catch(() => undefined);
+17 -7
View File
@@ -34,6 +34,8 @@ import {
} from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control";
import {runToolLoopRounds} from "./tool-loop-runner";
import {runSingleModelRequest} from "./model-call-stage";
import {bot} from "../index";
import fs from "node:fs";
import path from "node:path";
@@ -87,7 +89,9 @@ export async function runOpenAi(
const toolMemory: ToolExecutionMemory = new Map();
try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
await runToolLoopRounds({
maxRounds: MAX_TOOL_ROUNDS,
onRound: async (round) => {
const roundStartedAt = Date.now();
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
const rankResult = await runToolRankStage({
@@ -122,7 +126,9 @@ export async function runOpenAi(
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
instructions: systemPrompt,
};
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
}) as OpenAiResponseLike;
const responseText = collectOpenAiResponseText(response);
streamMessage.append(responseText);
@@ -169,7 +175,7 @@ export async function runOpenAi(
arguments: safeJsonParseObject(call.argumentsText)
})),
});
if (!calls.length) return;
if (!calls.length) return {shouldContinue: false};
const toolCalls = calls.map(call => ({
id: call.id,
@@ -218,7 +224,7 @@ export async function runOpenAi(
}
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
continue;
return {shouldContinue: true};
}
let completedResponse: OpenAiResponseLike | null = null;
@@ -230,7 +236,9 @@ export async function runOpenAi(
parallel_tool_calls: true,
instructions: systemPrompt
};
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
}) as AsyncIterableStream<ResponseStreamEvent>;
aiLog("debug", "openai.stream.open", {round});
@@ -377,7 +385,7 @@ export async function runOpenAi(
arguments: safeJsonParseObject(call.argumentsText)
})),
});
if (!calls.length) return;
if (!calls.length) return {shouldContinue: false};
const toolCalls = calls.map(call => ({
id: call.id,
@@ -426,7 +434,9 @@ export async function runOpenAi(
}
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
}
return {shouldContinue: true};
},
});
} finally {
if (ownsDocumentRag) {
await preparedDocumentRag?.cleanup().catch(logError);
+17
View File
@@ -0,0 +1,17 @@
import test from "node:test";
import assert from "node:assert/strict";
const {runSingleModelRequest} = await import("../dist/ai/model-call-stage.js");
test("single model request wrapper executes exactly once", async () => {
let calls = 0;
const result = await runSingleModelRequest({
async execute() {
calls += 1;
return "ok";
},
});
assert.equal(result, "ok");
assert.equal(calls, 1);
});
+37
View File
@@ -0,0 +1,37 @@
import test from "node:test";
import assert from "node:assert/strict";
const {runToolLoopRounds} = await import("../dist/ai/tool-loop-runner.js");
test("tool loop runner stops when handler requests it", async () => {
const rounds = [];
await runToolLoopRounds({
maxRounds: 5,
async onRound(round) {
rounds.push(round);
return {shouldContinue: round < 1};
},
});
assert.deepEqual(rounds, [0, 1]);
});
test("tool loop runner calls max rounds hook when handler never stops", async () => {
const rounds = [];
let maxRoundsReached = -1;
await runToolLoopRounds({
maxRounds: 3,
async onRound(round) {
rounds.push(round);
return {shouldContinue: true};
},
onMaxRoundsReached(round) {
maxRoundsReached = round;
},
});
assert.deepEqual(rounds, [0, 1, 2]);
assert.equal(maxRoundsReached, 2);
});