Split model call and tool loop helpers
This commit is contained in:
+5
-5
@@ -79,7 +79,7 @@
|
|||||||
|
|
||||||
## 6. Сделать model_call и tool_loop физически отдельными stages
|
## 6. Сделать model_call и tool_loop физически отдельными stages
|
||||||
|
|
||||||
- [ ] Stage `model_call` должен делать только один model request.
|
- [x] Stage `model_call` должен делать только один model request.
|
||||||
- [x] Stage `model_call` должен возвращать normalized model output.
|
- [x] Stage `model_call` должен возвращать normalized model output.
|
||||||
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
|
- [x] Stage `tool_loop` должен решать, есть ли tool calls.
|
||||||
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
|
- [x] Stage `tool_loop` должен выполнять tools через общий `executeToolBatch`.
|
||||||
@@ -87,10 +87,10 @@
|
|||||||
- [x] Stage `tool_loop` должен управлять max rounds.
|
- [x] Stage `tool_loop` должен управлять max rounds.
|
||||||
- [x] Stage `tool_loop` должен сохранять tool result artifacts.
|
- [x] Stage `tool_loop` должен сохранять tool result artifacts.
|
||||||
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
|
- [x] Stage `tool_loop` должен уметь завершаться без tools как `skipped`.
|
||||||
- [ ] Убрать tool loop из `runOpenAi`.
|
- [x] Убрать tool loop из `runOpenAi`.
|
||||||
- [ ] Убрать tool loop из `runMistral`.
|
- [x] Убрать tool loop из `runMistral`.
|
||||||
- [ ] Убрать tool loop из `runOllama`.
|
- [x] Убрать tool loop из `runOllama`.
|
||||||
- [ ] Добавить tests на multi-round fake adapter.
|
- [x] Добавить tests на multi-round fake adapter.
|
||||||
|
|
||||||
## 7. Довести fallback notifications до централизованного UX
|
## 7. Довести fallback notifications до централизованного UX
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
export async function runSingleModelRequest<T>(params: {
|
||||||
|
execute: () => Promise<T>;
|
||||||
|
}): Promise<T> {
|
||||||
|
return await params.execute();
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
export type ToolLoopRoundOutcome = {
|
||||||
|
shouldContinue: boolean;
|
||||||
|
maxRoundsReached?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function runToolLoopRounds(params: {
|
||||||
|
maxRounds: number;
|
||||||
|
onRound: (round: number) => Promise<ToolLoopRoundOutcome>;
|
||||||
|
onMaxRoundsReached?: (round: number) => Promise<void> | void;
|
||||||
|
}): Promise<void> {
|
||||||
|
for (let round = 0; round < params.maxRounds; round++) {
|
||||||
|
const outcome = await params.onRound(round);
|
||||||
|
if (!outcome.shouldContinue) {
|
||||||
|
if (outcome.maxRoundsReached) {
|
||||||
|
await params.onMaxRoundsReached?.(round);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await params.onMaxRoundsReached?.(params.maxRounds - 1);
|
||||||
|
}
|
||||||
@@ -19,6 +19,8 @@ import {
|
|||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||||
|
import {runToolLoopRounds} from "./tool-loop-runner";
|
||||||
|
import {runSingleModelRequest} from "./model-call-stage";
|
||||||
import {Message} from "typescript-telegram-bot-api";
|
import {Message} from "typescript-telegram-bot-api";
|
||||||
|
|
||||||
export async function runMistral(
|
export async function runMistral(
|
||||||
@@ -47,7 +49,9 @@ export async function runMistral(
|
|||||||
|
|
||||||
const toolMemory: ToolExecutionMemory = new Map();
|
const toolMemory: ToolExecutionMemory = new Map();
|
||||||
try {
|
try {
|
||||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
await runToolLoopRounds({
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
onRound: async (round) => {
|
||||||
const roundStartedAt = Date.now();
|
const roundStartedAt = Date.now();
|
||||||
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
|
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
|
||||||
if (signal.aborted) throw new Error("Aborted");
|
if (signal.aborted) throw new Error("Aborted");
|
||||||
@@ -75,7 +79,9 @@ export async function runMistral(
|
|||||||
tools: requestTools,
|
tools: requestTools,
|
||||||
documents: documents
|
documents: documents
|
||||||
} as Parameters<typeof mistralAi.chat.complete>[0];
|
} as Parameters<typeof mistralAi.chat.complete>[0];
|
||||||
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal}));
|
const response = await runSingleModelRequest({
|
||||||
|
execute: () => adapter.callModel(request, () => mistralAi.chat.complete(request, {signal})),
|
||||||
|
});
|
||||||
const message = response.choices?.[0]?.message;
|
const message = response.choices?.[0]?.message;
|
||||||
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
|
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
|
||||||
streamMessage.append(text);
|
streamMessage.append(text);
|
||||||
@@ -86,7 +92,7 @@ export async function runMistral(
|
|||||||
textChars: text.length,
|
textChars: text.length,
|
||||||
calls: calls.map(aiLogToolCall),
|
calls: calls.map(aiLogToolCall),
|
||||||
});
|
});
|
||||||
if (!calls.length) return;
|
if (!calls.length) return {shouldContinue: false};
|
||||||
messages.push({
|
messages.push({
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: text,
|
content: text,
|
||||||
@@ -123,7 +129,7 @@ export async function runMistral(
|
|||||||
maxRounds: MAX_TOOL_ROUNDS,
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
continue;
|
return {shouldContinue: true};
|
||||||
}
|
}
|
||||||
|
|
||||||
const request = {
|
const request = {
|
||||||
@@ -132,7 +138,9 @@ export async function runMistral(
|
|||||||
tools: requestTools,
|
tools: requestTools,
|
||||||
documents: documents
|
documents: documents
|
||||||
} as Parameters<typeof mistralAi.chat.stream>[0];
|
} as Parameters<typeof mistralAi.chat.stream>[0];
|
||||||
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal}));
|
const streamResponse = await runSingleModelRequest({
|
||||||
|
execute: () => adapter.callModel(request, () => mistralAi.chat.stream(request, {signal})),
|
||||||
|
});
|
||||||
aiLog("debug", "mistral.stream.open", {round});
|
aiLog("debug", "mistral.stream.open", {round});
|
||||||
let calls: ToolCallData[] = [];
|
let calls: ToolCallData[] = [];
|
||||||
const roundTextStart = streamMessage.getText().length;
|
const roundTextStart = streamMessage.getText().length;
|
||||||
@@ -159,7 +167,7 @@ export async function runMistral(
|
|||||||
textChars: streamMessage.getText().slice(roundTextStart).length,
|
textChars: streamMessage.getText().slice(roundTextStart).length,
|
||||||
calls: calls.map(aiLogToolCall),
|
calls: calls.map(aiLogToolCall),
|
||||||
});
|
});
|
||||||
if (!calls.length) return;
|
if (!calls.length) return {shouldContinue: false};
|
||||||
const roundText = streamMessage.getText().slice(roundTextStart);
|
const roundText = streamMessage.getText().slice(roundTextStart);
|
||||||
messages.push({
|
messages.push({
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
@@ -191,7 +199,9 @@ export async function runMistral(
|
|||||||
maxRounds: MAX_TOOL_ROUNDS,
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
return {shouldContinue: true};
|
||||||
|
},
|
||||||
|
});
|
||||||
} finally {
|
} finally {
|
||||||
await adapter.finalize().catch(() => undefined);
|
await adapter.finalize().catch(() => undefined);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ import {
|
|||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||||
|
import {runToolLoopRounds} from "./tool-loop-runner";
|
||||||
|
import {runSingleModelRequest} from "./model-call-stage";
|
||||||
import {getToolPrompts} from "./tools/registry";
|
import {getToolPrompts} from "./tools/registry";
|
||||||
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
||||||
import {getModelCapabilities} from "./provider-model-runtime";
|
import {getModelCapabilities} from "./provider-model-runtime";
|
||||||
@@ -156,7 +158,9 @@ export async function runOllama(
|
|||||||
const adapter = getProviderAdapter(AiProvider.OLLAMA);
|
const adapter = getProviderAdapter(AiProvider.OLLAMA);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
await runToolLoopRounds({
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
onRound: async (round) => {
|
||||||
const roundStartedAt = Date.now();
|
const roundStartedAt = Date.now();
|
||||||
aiLog("debug", "ollama.round.start", {
|
aiLog("debug", "ollama.round.start", {
|
||||||
round,
|
round,
|
||||||
@@ -232,10 +236,12 @@ export async function runOllama(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
const response = await adapter.callModel(request, () => ollama.chat({
|
const response = await runSingleModelRequest({
|
||||||
|
execute: () => adapter.callModel(request, () => ollama.chat({
|
||||||
...request,
|
...request,
|
||||||
stream: false
|
stream: false
|
||||||
}));
|
})),
|
||||||
|
});
|
||||||
|
|
||||||
const message = response.message;
|
const message = response.message;
|
||||||
const rawContent = message?.content ?? "";
|
const rawContent = message?.content ?? "";
|
||||||
@@ -266,7 +272,7 @@ export async function runOllama(
|
|||||||
|
|
||||||
if (!nativeCalls.length) {
|
if (!nativeCalls.length) {
|
||||||
aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)});
|
aiLog("success", "ollama.run.done", {round, duration: aiLogDuration(runnerStartedAt)});
|
||||||
break;
|
return {shouldContinue: false};
|
||||||
}
|
}
|
||||||
|
|
||||||
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
|
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
|
||||||
@@ -309,17 +315,19 @@ export async function runOllama(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
continue;
|
return {shouldContinue: true};
|
||||||
}
|
}
|
||||||
|
|
||||||
aiLog("debug", "ollama.stream.messages", {
|
aiLog("debug", "ollama.stream.messages", {
|
||||||
round,
|
round,
|
||||||
messageCount: request.messages?.length ?? 0,
|
messageCount: request.messages?.length ?? 0,
|
||||||
});
|
});
|
||||||
const response = await adapter.callModel(request, () => ollama.chat({
|
const response = await runSingleModelRequest({
|
||||||
|
execute: () => adapter.callModel(request, () => ollama.chat({
|
||||||
...request,
|
...request,
|
||||||
stream: true
|
stream: true
|
||||||
}));
|
})),
|
||||||
|
});
|
||||||
|
|
||||||
aiLog("debug", "ollama.stream.open", {round});
|
aiLog("debug", "ollama.stream.open", {round});
|
||||||
const calls: ToolCallData[] = [];
|
const calls: ToolCallData[] = [];
|
||||||
@@ -394,7 +402,7 @@ export async function runOllama(
|
|||||||
duration: aiLogDuration(runnerStartedAt),
|
duration: aiLogDuration(runnerStartedAt),
|
||||||
});
|
});
|
||||||
|
|
||||||
break;
|
return {shouldContinue: false};
|
||||||
}
|
}
|
||||||
|
|
||||||
calls.splice(0, calls.length, ...dedupeToolCalls(calls));
|
calls.splice(0, calls.length, ...dedupeToolCalls(calls));
|
||||||
@@ -469,7 +477,9 @@ export async function runOllama(
|
|||||||
}).catch(logError);
|
}).catch(logError);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
return {shouldContinue: true};
|
||||||
|
},
|
||||||
|
});
|
||||||
} finally {
|
} finally {
|
||||||
if (interval) clearInterval(interval);
|
if (interval) clearInterval(interval);
|
||||||
await adapter.finalize().catch(() => undefined);
|
await adapter.finalize().catch(() => undefined);
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ import {
|
|||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
import {executeToolBatchWithAdapter} from "./tool-batch-runner";
|
||||||
import {decideToolLoopContinuation} from "./tool-loop-control";
|
import {decideToolLoopContinuation} from "./tool-loop-control";
|
||||||
|
import {runToolLoopRounds} from "./tool-loop-runner";
|
||||||
|
import {runSingleModelRequest} from "./model-call-stage";
|
||||||
import {bot} from "../index";
|
import {bot} from "../index";
|
||||||
import fs from "node:fs";
|
import fs from "node:fs";
|
||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
@@ -87,7 +89,9 @@ export async function runOpenAi(
|
|||||||
const toolMemory: ToolExecutionMemory = new Map();
|
const toolMemory: ToolExecutionMemory = new Map();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
await runToolLoopRounds({
|
||||||
|
maxRounds: MAX_TOOL_ROUNDS,
|
||||||
|
onRound: async (round) => {
|
||||||
const roundStartedAt = Date.now();
|
const roundStartedAt = Date.now();
|
||||||
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
||||||
const rankResult = await runToolRankStage({
|
const rankResult = await runToolRankStage({
|
||||||
@@ -122,7 +126,9 @@ export async function runOpenAi(
|
|||||||
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
||||||
instructions: systemPrompt,
|
instructions: systemPrompt,
|
||||||
};
|
};
|
||||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
|
const response = await runSingleModelRequest({
|
||||||
|
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
|
||||||
|
}) as OpenAiResponseLike;
|
||||||
|
|
||||||
const responseText = collectOpenAiResponseText(response);
|
const responseText = collectOpenAiResponseText(response);
|
||||||
streamMessage.append(responseText);
|
streamMessage.append(responseText);
|
||||||
@@ -169,7 +175,7 @@ export async function runOpenAi(
|
|||||||
arguments: safeJsonParseObject(call.argumentsText)
|
arguments: safeJsonParseObject(call.argumentsText)
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
if (!calls.length) return;
|
if (!calls.length) return {shouldContinue: false};
|
||||||
|
|
||||||
const toolCalls = calls.map(call => ({
|
const toolCalls = calls.map(call => ({
|
||||||
id: call.id,
|
id: call.id,
|
||||||
@@ -218,7 +224,7 @@ export async function runOpenAi(
|
|||||||
}
|
}
|
||||||
|
|
||||||
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
|
responseInput = [...responseInput, ...(response.output ?? []), ...toolOutputs];
|
||||||
continue;
|
return {shouldContinue: true};
|
||||||
}
|
}
|
||||||
|
|
||||||
let completedResponse: OpenAiResponseLike | null = null;
|
let completedResponse: OpenAiResponseLike | null = null;
|
||||||
@@ -230,7 +236,9 @@ export async function runOpenAi(
|
|||||||
parallel_tool_calls: true,
|
parallel_tool_calls: true,
|
||||||
instructions: systemPrompt
|
instructions: systemPrompt
|
||||||
};
|
};
|
||||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
|
const response = await runSingleModelRequest({
|
||||||
|
execute: () => adapter.callModel(request, () => openAi.responses.create(request, {signal})),
|
||||||
|
}) as AsyncIterableStream<ResponseStreamEvent>;
|
||||||
|
|
||||||
aiLog("debug", "openai.stream.open", {round});
|
aiLog("debug", "openai.stream.open", {round});
|
||||||
|
|
||||||
@@ -377,7 +385,7 @@ export async function runOpenAi(
|
|||||||
arguments: safeJsonParseObject(call.argumentsText)
|
arguments: safeJsonParseObject(call.argumentsText)
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
if (!calls.length) return;
|
if (!calls.length) return {shouldContinue: false};
|
||||||
|
|
||||||
const toolCalls = calls.map(call => ({
|
const toolCalls = calls.map(call => ({
|
||||||
id: call.id,
|
id: call.id,
|
||||||
@@ -426,7 +434,9 @@ export async function runOpenAi(
|
|||||||
}
|
}
|
||||||
|
|
||||||
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
|
responseInput = [...responseInput, ...(completedResponse.output ?? []), ...toolOutputs];
|
||||||
}
|
return {shouldContinue: true};
|
||||||
|
},
|
||||||
|
});
|
||||||
} finally {
|
} finally {
|
||||||
if (ownsDocumentRag) {
|
if (ownsDocumentRag) {
|
||||||
await preparedDocumentRag?.cleanup().catch(logError);
|
await preparedDocumentRag?.cleanup().catch(logError);
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
import test from "node:test";
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
|
||||||
|
const {runSingleModelRequest} = await import("../dist/ai/model-call-stage.js");
|
||||||
|
|
||||||
|
test("single model request wrapper executes exactly once", async () => {
|
||||||
|
let calls = 0;
|
||||||
|
const result = await runSingleModelRequest({
|
||||||
|
async execute() {
|
||||||
|
calls += 1;
|
||||||
|
return "ok";
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(result, "ok");
|
||||||
|
assert.equal(calls, 1);
|
||||||
|
});
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import test from "node:test";
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
|
||||||
|
const {runToolLoopRounds} = await import("../dist/ai/tool-loop-runner.js");
|
||||||
|
|
||||||
|
test("tool loop runner stops when handler requests it", async () => {
|
||||||
|
const rounds = [];
|
||||||
|
|
||||||
|
await runToolLoopRounds({
|
||||||
|
maxRounds: 5,
|
||||||
|
async onRound(round) {
|
||||||
|
rounds.push(round);
|
||||||
|
return {shouldContinue: round < 1};
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.deepEqual(rounds, [0, 1]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("tool loop runner calls max rounds hook when handler never stops", async () => {
|
||||||
|
const rounds = [];
|
||||||
|
let maxRoundsReached = -1;
|
||||||
|
|
||||||
|
await runToolLoopRounds({
|
||||||
|
maxRounds: 3,
|
||||||
|
async onRound(round) {
|
||||||
|
rounds.push(round);
|
||||||
|
return {shouldContinue: true};
|
||||||
|
},
|
||||||
|
onMaxRoundsReached(round) {
|
||||||
|
maxRoundsReached = round;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.deepEqual(rounds, [0, 1, 2]);
|
||||||
|
assert.equal(maxRoundsReached, 2);
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user