Split model call and tool loop helpers

This commit is contained in:
2026-05-18 19:55:00 +03:00
parent 57985ce87b
commit d163d72a0b
8 changed files with 149 additions and 38 deletions
+5 -5
View File
@@ -79,7 +79,7 @@
## 6. Сделать model_call и tool_loop физически отдельными stages ## 6. Сделать model_call и tool_loop физически отдельными stages
- [ ] Stage `model_call` должен делать только один model request. - [x] Stage `model_call` должен делать только один model request.
- [x] Stage `model_call` должен возвращать normalized model output. - [x] Stage `model_call` должен возвращать normalized model output.
- [x] Stage `tool_loop` должен решать, есть ли tool calls. - [x] Stage `tool_loop` должен решать, есть ли tool calls.
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`. - [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
@@ -87,10 +87,10 @@
- [x] Stage `tool_loop` должен управлять max rounds. - [x] Stage `tool_loop` должен управлять max rounds.
- [x] Stage `tool_loop` должен сохранять tool result artifacts. - [x] Stage `tool_loop` должен сохранять tool result artifacts.
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`. - [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
- [ ] Убрать tool loop из `runOpenAi`. - [x] Убрать tool loop из `runOpenAi`.
- [ ] Убрать tool loop из `runMistral`. - [x] Убрать tool loop из `runMistral`.
- [ ] Убрать tool loop из `runOllama`. - [x] Убрать tool loop из `runOllama`.
- [ ] Добавить tests на multi-round fake adapter. - [x] Добавить tests на multi-round fake adapter.
## 7. Довести fallback notifications до централизованного UX ## 7. Довести fallback notifications до централизованного UX
+5
View File
@@ -0,0 +1,5 @@
export async function runSingleModelRequest<T>(params: {
execute: () => Promise<T>;
}): Promise<T> {
return await params.execute();
}
+22
View File
@@ -0,0 +1,22 @@
export type ToolLoopRoundOutcome = {
shouldContinue: boolean;
maxRoundsReached?: boolean;
};
export async function runToolLoopRounds(params: {
maxRounds: number;
onRound: (round: number) => Promise<ToolLoopRoundOutcome>;
onMaxRoundsReached?: (round: number) => Promise<void> | void;
}): Promise<void> {
for (let round = 0; round < params.maxRounds; round++) {
const outcome = await params.onRound(round);
if (!outcome.shouldContinue) {
if (outcome.maxRoundsReached) {
await params.onMaxRoundsReached?.(round);
}
return;
}
}
await params.onMaxRoundsReached?.(params.maxRounds - 1);
}
+17 -7
View File
@@ -19,6 +19,8 @@ import {
} from "./unified-ai-runner.shared"; } from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control"; import {decideToolLoopContinuation} from "./tool-loop-control";
import {runToolLoopRounds} from "./tool-loop-runner";
import {runSingleModelRequest} from "./model-call-stage";
import {Message} from "typescript-telegram-bot-api"; import {Message} from "typescript-telegram-bot-api";
export async function runMistral( export async function runMistral(
@@ -47,7 +49,9 @@ export async function runMistral(
const toolMemory: ToolExecutionMemory = new Map(); const toolMemory: ToolExecutionMemory = new Map();
try { try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { await runToolLoopRounds({
maxRounds: MAX_TOOL_ROUNDS,
onRound: async (round) => {
const roundStartedAt = Date.now(); const roundStartedAt = Date.now();
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream}); aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
if (signal.aborted) throw new Error("Aborted"); if (signal.aborted) throw new Error("Aborted");
@@ -75,7 +79,9 @@ export async function runMistral(
tools: requestTools, tools: requestTools,
documents: documents documents: documents
} as Parameters<typeof mistralAi.chat.complete>[0]; } as Parameters<typeof mistralAi.chat.complete>[0];
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})); const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})),
});
const message = response.choices?.[0]?.message; const message = response.choices?.[0]?.message;
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? ""); const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
streamMessage.append(text); streamMessage.append(text);
@@ -86,7 +92,7 @@ export async function runMistral(
textChars: text.length, textChars: text.length,
calls: calls.map(aiLogToolCall), calls: calls.map(aiLogToolCall),
}); });
if (!calls.length) return; if (!calls.length) return {shouldContinue: false};
messages.push({ messages.push({
role: "assistant", role: "assistant",
content: text, content: text,
@@ -123,7 +129,7 @@ export async function runMistral(
maxRounds: MAX_TOOL_ROUNDS, maxRounds: MAX_TOOL_ROUNDS,
}); });
} }
continue; return {shouldContinue: true};
} }
const request = { const request = {
@@ -132,7 +138,9 @@ export async function runMistral(
tools: requestTools, tools: requestTools,
documents: documents documents: documents
} as Parameters<typeof mistralAi.chat.stream>[0]; } as Parameters<typeof mistralAi.chat.stream>[0];
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})); const streamResponse = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})),
});
aiLog("debug", "mistral.stream.open", {round}); aiLog("debug", "mistral.stream.open", {round});
let calls: ToolCallData[] = []; let calls: ToolCallData[] = [];
const roundTextStart = streamMessage.getText().length; const roundTextStart = streamMessage.getText().length;
@@ -159,7 +167,7 @@ export async function runMistral(
textChars: streamMessage.getText().slice(roundTextStart).length, textChars: streamMessage.getText().slice(roundTextStart).length,
calls: calls.map(aiLogToolCall), calls: calls.map(aiLogToolCall),
}); });
if (!calls.length) return; if (!calls.length) return {shouldContinue: false};
const roundText = streamMessage.getText().slice(roundTextStart); const roundText = streamMessage.getText().slice(roundTextStart);
messages.push({ messages.push({
role: "assistant", role: "assistant",
@@ -191,7 +199,9 @@ export async function runMistral(
maxRounds: MAX_TOOL_ROUNDS, maxRounds: MAX_TOOL_ROUNDS,
}); });
} }
} return {shouldContinue: true};
},
});
} finally { } finally {
await adapter.finalize().catch(() => undefined); await adapter.finalize().catch(() => undefined);
} }
+19 -9
View File
@@ -34,6 +34,8 @@ import {
} from "./unified-ai-runner.shared"; } from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control"; import {decideToolLoopContinuation} from "./tool-loop-control";
import {runToolLoopRounds} from "./tool-loop-runner";
import {runSingleModelRequest} from "./model-call-stage";
import {getToolPrompts} from "./tools/registry"; import {getToolPrompts} from "./tools/registry";
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes"; import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
import {getModelCapabilities} from "./provider-model-runtime"; import {getModelCapabilities} from "./provider-model-runtime";
@@ -156,7 +158,9 @@ export async function runOllama(
const adapter = getProviderAdapter(AiProvider.OLLAMA); const adapter = getProviderAdapter(AiProvider.OLLAMA);
try { try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { await runToolLoopRounds({
maxRounds: MAX_TOOL_ROUNDS,
onRound: async (round) => {
const roundStartedAt = Date.now(); const roundStartedAt = Date.now();
aiLog("debug", "ollama.round.start", { aiLog("debug", "ollama.round.start", {
round, round,
@@ -232,10 +236,12 @@ export async function runOllama(
} }
if (!stream) { if (!stream) {
const response = await adapter.callModel(request, () => ollama.chat({ const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => ollama.chat({
...request, ...request,
stream: false stream: false
})); })),
});
const message = response.message; const message = response.message;
const rawContent = message?.content ?? ""; const rawContent = message?.content ?? "";
@@ -266,7 +272,7 @@ export async function runOllama(
if (!nativeCalls.length) { if (!nativeCalls.length) {
aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)}); aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)});
break; return {shouldContinue: false};
} }
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls; const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
@@ -309,17 +315,19 @@ export async function runOllama(
}); });
} }
continue; return {shouldContinue: true};
} }
aiLog("debug", "ollama.stream.messages", { aiLog("debug", "ollama.stream.messages", {
round, round,
messageCount: request.messages?.length ?? 0, messageCount: request.messages?.length ?? 0,
}); });
const response = await adapter.callModel(request, () => ollama.chat({ const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => ollama.chat({
...request, ...request,
stream: true stream: true
})); })),
});
aiLog("debug", "ollama.stream.open", {round}); aiLog("debug", "ollama.stream.open", {round});
const calls: ToolCallData[] = []; const calls: ToolCallData[] = [];
@@ -394,7 +402,7 @@ export async function runOllama(
duration: aiLogDuration(runnerStartedAt), duration: aiLogDuration(runnerStartedAt),
}); });
break; return {shouldContinue: false};
} }
calls.splice(0, calls.length, ...dedupeToolCalls(calls)); calls.splice(0, calls.length, ...dedupeToolCalls(calls));
@@ -469,7 +477,9 @@ export async function runOllama(
}).catch(logError); }).catch(logError);
} }
} return {shouldContinue: true};
},
});
} finally { } finally {
if (interval) clearInterval(interval); if (interval) clearInterval(interval);
await adapter.finalize().catch(() => undefined); await adapter.finalize().catch(() => undefined);
+17 -7
View File
@@ -34,6 +34,8 @@ import {
} from "./unified-ai-runner.shared"; } from "./unified-ai-runner.shared";
import {executeToolBatchWithAdapter} from "./tool-batch-runner"; import {executeToolBatchWithAdapter} from "./tool-batch-runner";
import {decideToolLoopContinuation} from "./tool-loop-control"; import {decideToolLoopContinuation} from "./tool-loop-control";
import {runToolLoopRounds} from "./tool-loop-runner";
import {runSingleModelRequest} from "./model-call-stage";
import {bot} from "../index"; import {bot} from "../index";
import fs from "node:fs"; import fs from "node:fs";
import path from "node:path"; import path from "node:path";
@@ -87,7 +89,9 @@ export async function runOpenAi(
const toolMemory: ToolExecutionMemory = new Map(); const toolMemory: ToolExecutionMemory = new Map();
try { try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) { await runToolLoopRounds({
maxRounds: MAX_TOOL_ROUNDS,
onRound: async (round) => {
const roundStartedAt = Date.now(); const roundStartedAt = Date.now();
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream}); aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
const rankResult = await runToolRankStage({ const rankResult = await runToolRankStage({
@@ -122,7 +126,9 @@ export async function runOpenAi(
tools: requestTools as ResponseCreateParamsNonStreaming["tools"], tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
instructions: systemPrompt, instructions: systemPrompt,
}; };
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike; const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
}) as OpenAiResponseLike;
const responseText = collectOpenAiResponseText(response); const responseText = collectOpenAiResponseText(response);
streamMessage.append(responseText); streamMessage.append(responseText);
@@ -169,7 +175,7 @@ export async function runOpenAi(
arguments: safeJsonParseObject(call.argumentsText) arguments: safeJsonParseObject(call.argumentsText)
})), })),
}); });
if (!calls.length) return; if (!calls.length) return {shouldContinue: false};
const toolCalls = calls.map(call => ({ const toolCalls = calls.map(call => ({
id: call.id, id: call.id,
@@ -218,7 +224,7 @@ export async function runOpenAi(
} }
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs]; responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
continue; return {shouldContinue: true};
} }
let completedResponse: OpenAiResponseLike | null = null; let completedResponse: OpenAiResponseLike | null = null;
@@ -230,7 +236,9 @@ export async function runOpenAi(
parallel_tool_calls: true, parallel_tool_calls: true,
instructions: systemPrompt instructions: systemPrompt
}; };
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>; const response = await runSingleModelRequest({
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
}) as AsyncIterableStream<ResponseStreamEvent>;
aiLog("debug", "openai.stream.open", {round}); aiLog("debug", "openai.stream.open", {round});
@@ -377,7 +385,7 @@ export async function runOpenAi(
arguments: safeJsonParseObject(call.argumentsText) arguments: safeJsonParseObject(call.argumentsText)
})), })),
}); });
if (!calls.length) return; if (!calls.length) return {shouldContinue: false};
const toolCalls = calls.map(call => ({ const toolCalls = calls.map(call => ({
id: call.id, id: call.id,
@@ -426,7 +434,9 @@ export async function runOpenAi(
} }
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs]; responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
} return {shouldContinue: true};
},
});
} finally { } finally {
if (ownsDocumentRag) { if (ownsDocumentRag) {
await preparedDocumentRag?.cleanup().catch(logError); await preparedDocumentRag?.cleanup().catch(logError);
+17
View File
@@ -0,0 +1,17 @@
import test from "node:test";
import assert from "node:assert/strict";
const {runSingleModelRequest} = await import("../dist/ai/model-call-stage.js");
test("single model request wrapper executes exactly once", async () => {
let calls = 0;
const result = await runSingleModelRequest({
async execute() {
calls += 1;
return "ok";
},
});
assert.equal(result, "ok");
assert.equal(calls, 1);
});
+37
View File
@@ -0,0 +1,37 @@
import test from "node:test";
import assert from "node:assert/strict";
const {runToolLoopRounds} = await import("../dist/ai/tool-loop-runner.js");
test("tool loop runner stops when handler requests it", async () => {
const rounds = [];
await runToolLoopRounds({
maxRounds: 5,
async onRound(round) {
rounds.push(round);
return {shouldContinue: round < 1};
},
});
assert.deepEqual(rounds, [0, 1]);
});
test("tool loop runner calls max rounds hook when handler never stops", async () => {
const rounds = [];
let maxRoundsReached = -1;
await runToolLoopRounds({
maxRounds: 3,
async onRound(round) {
rounds.push(round);
return {shouldContinue: true};
},
onMaxRoundsReached(round) {
maxRoundsReached = round;
},
});
assert.deepEqual(rounds, [0, 1, 2]);
assert.equal(maxRoundsReached, 2);
});