Split model call and tool loop helpers
This commit is contained in:
+5
-5
@@ -79,7 +79,7 @@
|
||||
|
||||
## 6. Сделать model_call и tool_loop физически отдельными stages
|
||||
|
||||
- [ ] Stage `model_call` должен делать только один model request.
|
||||
- [x] Stage `model_call` должен делать только один model request.
|
||||
- [x] Stage `model_call` должен возвращать normalized model output.
|
||||
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
|
||||
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
|
||||
@@ -87,10 +87,10 @@
|
||||
- [x] Stage `tool_loop` должен управлять max rounds.
|
||||
- [x] Stage `tool_loop` должен сохранять tool result artifacts.
|
||||
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
|
||||
- [ ] Убрать tool loop из `runOpenAi`.
|
||||
- [ ] Убрать tool loop из `runMistral`.
|
||||
- [ ] Убрать tool loop из `runOllama`.
|
||||
- [ ] Добавить tests на multi-round fake adapter.
|
||||
- [x] Убрать tool loop из `runOpenAi`.
|
||||
- [x] Убрать tool loop из `runMistral`.
|
||||
- [x] Убрать tool loop из `runOllama`.
|
||||
- [x] Добавить tests на multi-round fake adapter.
|
||||
|
||||
## 7. Довести fallback notifications до централизованного UX
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
export async function runSingleModelRequest<T>(params: {
|
||||
execute: () => Promise<T>;
|
||||
}): Promise<T> {
|
||||
return await params.execute();
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
export type ToolLoopRoundOutcome = {
|
||||
shouldContinue: boolean;
|
||||
maxRoundsReached?: boolean;
|
||||
};
|
||||
|
||||
export async function runToolLoopRounds(params: {
|
||||
maxRounds: number;
|
||||
onRound: (round: number) => Promise<ToolLoopRoundOutcome>;
|
||||
onMaxRoundsReached?: (round: number) => Promise<void> | void;
|
||||
}): Promise<void> {
|
||||
for (let round = 0; round < params.maxRounds; round++) {
|
||||
const outcome = await params.onRound(round);
|
||||
if (!outcome.shouldContinue) {
|
||||
if (outcome.maxRoundsReached) {
|
||||
await params.onMaxRoundsReached?.(round);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
await params.onMaxRoundsReached?.(params.maxRounds - 1);
|
||||
}
|
||||
@@ -19,6 +19,8 @@ import {
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||
import {runToolLoopRounds} from "./tool-loop-runner";
|
||||
import {runSingleModelRequest} from "./model-call-stage";
|
||||
import {Message} from "typescript-telegram-bot-api";
|
||||
|
||||
export async function runMistral(
|
||||
@@ -47,7 +49,9 @@ export async function runMistral(
|
||||
|
||||
const toolMemory: ToolExecutionMemory = new Map();
|
||||
try {
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
await runToolLoopRounds({
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
onRound: async (round) => {
|
||||
const roundStartedAt = Date.now();
|
||||
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
|
||||
if (signal.aborted) throw new Error("Aborted");
|
||||
@@ -75,7 +79,9 @@ export async function runMistral(
|
||||
tools: requestTools,
|
||||
documents: documents
|
||||
} as Parameters<typeof mistralAi.chat.complete>[0];
|
||||
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal}));
|
||||
const response = await runSingleModelRequest({
|
||||
execute: () => adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})),
|
||||
});
|
||||
const message = response.choices?.[0]?.message;
|
||||
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
|
||||
streamMessage.append(text);
|
||||
@@ -86,7 +92,7 @@ export async function runMistral(
|
||||
textChars: text.length,
|
||||
calls: calls.map(aiLogToolCall),
|
||||
});
|
||||
if (!calls.length) return;
|
||||
if (!calls.length) return {shouldContinue: false};
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: text,
|
||||
@@ -123,7 +129,7 @@ export async function runMistral(
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
return {shouldContinue: true};
|
||||
}
|
||||
|
||||
const request = {
|
||||
@@ -132,7 +138,9 @@ export async function runMistral(
|
||||
tools: requestTools,
|
||||
documents: documents
|
||||
} as Parameters<typeof mistralAi.chat.stream>[0];
|
||||
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal}));
|
||||
const streamResponse = await runSingleModelRequest({
|
||||
execute: () => adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})),
|
||||
});
|
||||
aiLog("debug", "mistral.stream.open", {round});
|
||||
let calls: ToolCallData[] = [];
|
||||
const roundTextStart = streamMessage.getText().length;
|
||||
@@ -159,7 +167,7 @@ export async function runMistral(
|
||||
textChars: streamMessage.getText().slice(roundTextStart).length,
|
||||
calls: calls.map(aiLogToolCall),
|
||||
});
|
||||
if (!calls.length) return;
|
||||
if (!calls.length) return {shouldContinue: false};
|
||||
const roundText = streamMessage.getText().slice(roundTextStart);
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
@@ -191,7 +199,9 @@ export async function runMistral(
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
});
|
||||
}
|
||||
}
|
||||
return {shouldContinue: true};
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
await adapter.finalize().catch(() => undefined);
|
||||
}
|
||||
|
||||
@@ -34,6 +34,8 @@ import {
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||
import {runToolLoopRounds} from "./tool-loop-runner";
|
||||
import {runSingleModelRequest} from "./model-call-stage";
|
||||
import {getToolPrompts} from "./tools/registry";
|
||||
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
||||
import {getModelCapabilities} from "./provider-model-runtime";
|
||||
@@ -156,7 +158,9 @@ export async function runOllama(
|
||||
const adapter = getProviderAdapter(AiProvider.OLLAMA);
|
||||
|
||||
try {
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
await runToolLoopRounds({
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
onRound: async (round) => {
|
||||
const roundStartedAt = Date.now();
|
||||
aiLog("debug", "ollama.round.start", {
|
||||
round,
|
||||
@@ -232,10 +236,12 @@ export async function runOllama(
|
||||
}
|
||||
|
||||
if (!stream) {
|
||||
const response = await adapter.callModel(request, () => ollama.chat({
|
||||
const response = await runSingleModelRequest({
|
||||
execute: () => adapter.callModel(request, () => ollama.chat({
|
||||
...request,
|
||||
stream: false
|
||||
}));
|
||||
})),
|
||||
});
|
||||
|
||||
const message = response.message;
|
||||
const rawContent = message?.content ?? "";
|
||||
@@ -266,7 +272,7 @@ export async function runOllama(
|
||||
|
||||
if (!nativeCalls.length) {
|
||||
aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)});
|
||||
break;
|
||||
return {shouldContinue: false};
|
||||
}
|
||||
|
||||
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
|
||||
@@ -309,17 +315,19 @@ export async function runOllama(
|
||||
});
|
||||
}
|
||||
|
||||
continue;
|
||||
return {shouldContinue: true};
|
||||
}
|
||||
|
||||
aiLog("debug", "ollama.stream.messages", {
|
||||
round,
|
||||
messageCount: request.messages?.length ?? 0,
|
||||
});
|
||||
const response = await adapter.callModel(request, () => ollama.chat({
|
||||
const response = await runSingleModelRequest({
|
||||
execute: () => adapter.callModel(request, () => ollama.chat({
|
||||
...request,
|
||||
stream: true
|
||||
}));
|
||||
})),
|
||||
});
|
||||
|
||||
aiLog("debug", "ollama.stream.open", {round});
|
||||
const calls: ToolCallData[] = [];
|
||||
@@ -394,7 +402,7 @@ export async function runOllama(
|
||||
duration: aiLogDuration(runnerStartedAt),
|
||||
});
|
||||
|
||||
break;
|
||||
return {shouldContinue: false};
|
||||
}
|
||||
|
||||
calls.splice(0, calls.length, ...dedupeToolCalls(calls));
|
||||
@@ -469,7 +477,9 @@ export async function runOllama(
|
||||
}).catch(logError);
|
||||
}
|
||||
|
||||
}
|
||||
return {shouldContinue: true};
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
if (interval) clearInterval(interval);
|
||||
await adapter.finalize().catch(() => undefined);
|
||||
|
||||
@@ -34,6 +34,8 @@ import {
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||
import {runToolLoopRounds} from "./tool-loop-runner";
|
||||
import {runSingleModelRequest} from "./model-call-stage";
|
||||
import {bot} from "../index";
|
||||
import fs from "node:fs";
|
||||
import path from "node:path";
|
||||
@@ -87,7 +89,9 @@ export async function runOpenAi(
|
||||
const toolMemory: ToolExecutionMemory = new Map();
|
||||
|
||||
try {
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
await runToolLoopRounds({
|
||||
maxRounds: MAX_TOOL_ROUNDS,
|
||||
onRound: async (round) => {
|
||||
const roundStartedAt = Date.now();
|
||||
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
||||
const rankResult = await runToolRankStage({
|
||||
@@ -122,7 +126,9 @@ export async function runOpenAi(
|
||||
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
||||
instructions: systemPrompt,
|
||||
};
|
||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
|
||||
const response = await runSingleModelRequest({
|
||||
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
|
||||
}) as OpenAiResponseLike;
|
||||
|
||||
const responseText = collectOpenAiResponseText(response);
|
||||
streamMessage.append(responseText);
|
||||
@@ -169,7 +175,7 @@ export async function runOpenAi(
|
||||
arguments: safeJsonParseObject(call.argumentsText)
|
||||
})),
|
||||
});
|
||||
if (!calls.length) return;
|
||||
if (!calls.length) return {shouldContinue: false};
|
||||
|
||||
const toolCalls = calls.map(call => ({
|
||||
id: call.id,
|
||||
@@ -218,7 +224,7 @@ export async function runOpenAi(
|
||||
}
|
||||
|
||||
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
|
||||
continue;
|
||||
return {shouldContinue: true};
|
||||
}
|
||||
|
||||
let completedResponse: OpenAiResponseLike | null = null;
|
||||
@@ -230,7 +236,9 @@ export async function runOpenAi(
|
||||
parallel_tool_calls: true,
|
||||
instructions: systemPrompt
|
||||
};
|
||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
|
||||
const response = await runSingleModelRequest({
|
||||
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
|
||||
}) as AsyncIterableStream<ResponseStreamEvent>;
|
||||
|
||||
aiLog("debug", "openai.stream.open", {round});
|
||||
|
||||
@@ -377,7 +385,7 @@ export async function runOpenAi(
|
||||
arguments: safeJsonParseObject(call.argumentsText)
|
||||
})),
|
||||
});
|
||||
if (!calls.length) return;
|
||||
if (!calls.length) return {shouldContinue: false};
|
||||
|
||||
const toolCalls = calls.map(call => ({
|
||||
id: call.id,
|
||||
@@ -426,7 +434,9 @@ export async function runOpenAi(
|
||||
}
|
||||
|
||||
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
|
||||
}
|
||||
return {shouldContinue: true};
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
if (ownsDocumentRag) {
|
||||
await preparedDocumentRag?.cleanup().catch(logError);
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
|
||||
const {runSingleModelRequest} = await import("../dist/ai/model-call-stage.js");
|
||||
|
||||
test("single model request wrapper executes exactly once", async () => {
|
||||
let calls = 0;
|
||||
const result = await runSingleModelRequest({
|
||||
async execute() {
|
||||
calls += 1;
|
||||
return "ok";
|
||||
},
|
||||
});
|
||||
|
||||
assert.equal(result, "ok");
|
||||
assert.equal(calls, 1);
|
||||
});
|
||||
@@ -0,0 +1,37 @@
|
||||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
|
||||
const {runToolLoopRounds} = await import("../dist/ai/tool-loop-runner.js");
|
||||
|
||||
test("tool loop runner stops when handler requests it", async () => {
|
||||
const rounds = [];
|
||||
|
||||
await runToolLoopRounds({
|
||||
maxRounds: 5,
|
||||
async onRound(round) {
|
||||
rounds.push(round);
|
||||
return {shouldContinue: round < 1};
|
||||
},
|
||||
});
|
||||
|
||||
assert.deepEqual(rounds, [0, 1]);
|
||||
});
|
||||
|
||||
test("tool loop runner calls max rounds hook when handler never stops", async () => {
|
||||
const rounds = [];
|
||||
let maxRoundsReached = -1;
|
||||
|
||||
await runToolLoopRounds({
|
||||
maxRounds: 3,
|
||||
async onRound(round) {
|
||||
rounds.push(round);
|
||||
return {shouldContinue: true};
|
||||
},
|
||||
onMaxRoundsReached(round) {
|
||||
maxRoundsReached = round;
|
||||
},
|
||||
});
|
||||
|
||||
assert.deepEqual(rounds, [0, 1, 2]);
|
||||
assert.equal(maxRoundsReached, 2);
|
||||
});
|
||||
Reference in New Issue
Block a user