Add unified request pipeline stages
This commit is contained in:
@@ -17,11 +17,9 @@ import {
|
||||
AsyncIterableStream,
|
||||
buildSystemInstruction,
|
||||
collectOpenAiResponseCodeInterpreterCalls,
|
||||
collectOpenAiResponseFunctionCalls,
|
||||
collectOpenAiResponseImages,
|
||||
collectOpenAiResponseText,
|
||||
executeToolBatch,
|
||||
getOpenAIResponsesToolsWithImage,
|
||||
MAX_TOOL_ROUNDS,
|
||||
OPENAI_IMAGE_PARTIALS,
|
||||
openAiResponseItemCallId,
|
||||
@@ -42,10 +40,9 @@ import {logError} from "../util/utils";
|
||||
import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files";
|
||||
import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings";
|
||||
import {AiDownloadedFile} from "./telegram-attachments";
|
||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
||||
import {getProviderAdapter} from "./provider-adapters";
|
||||
import {runToolRankStage} from "./tool-rank-stage";
|
||||
|
||||
export async function runOpenAi(
|
||||
msg: Message,
|
||||
@@ -60,16 +57,15 @@ export async function runOpenAi(
|
||||
documentRag?: OpenAiDocumentRagContext,
|
||||
): Promise<void> {
|
||||
const runnerStartedAt = Date.now();
|
||||
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = [...messages] as Array<ResponseInputItem | OpenAiResponseOutputItem>;
|
||||
const openAi = createOpenAiClient(config.openAiChatTarget);
|
||||
const ownsDocumentRag = !documentRag;
|
||||
const preparedDocumentRag = documentRag ?? await prepareOpenAiDocumentRag(openAi, downloads.filter(download => download.kind === "document"));
|
||||
const toolRanker = new ToolRanker(config);
|
||||
const availableTools = getOpenAIResponsesToolsWithImage(
|
||||
config,
|
||||
msg.from?.id === Environment.CREATOR_ID,
|
||||
preparedDocumentRag?.vectorStoreIds ?? [],
|
||||
);
|
||||
const adapter = getProviderAdapter(AiProvider.OPENAI);
|
||||
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = adapter.mapMessages(messages) as unknown as Array<ResponseInputItem | OpenAiResponseOutputItem>;
|
||||
const availableTools = adapter.rankTools(config, {
|
||||
forCreator: msg.from?.id === Environment.CREATOR_ID,
|
||||
vectorStoreIds: preparedDocumentRag?.vectorStoreIds ?? [],
|
||||
});
|
||||
|
||||
const systemPrompt = buildSystemInstruction(
|
||||
config,
|
||||
@@ -93,43 +89,17 @@ export async function runOpenAi(
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
const roundStartedAt = Date.now();
|
||||
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
||||
await streamMessage.flush();
|
||||
const toolRankStartedAt = Date.now();
|
||||
const toolRankStartedAtIso = new Date().toISOString();
|
||||
const rankerSelection = await toolRanker.selectTools({
|
||||
provider: AiProvider.OPENAI,
|
||||
userQuery: latestUserTextFromMessages(messages),
|
||||
availableTools,
|
||||
round,
|
||||
signal,
|
||||
})
|
||||
.catch(async error => {
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.OPENAI,
|
||||
model: config.openAiChatTarget.model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
error,
|
||||
});
|
||||
throw error;
|
||||
});
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
const rankResult = await runToolRankStage({
|
||||
provider: AiProvider.OPENAI,
|
||||
model: config.openAiChatTarget.model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
selectedTools: rankerSelection.toolNames,
|
||||
config,
|
||||
availableTools,
|
||||
messages,
|
||||
streamMessage,
|
||||
signal,
|
||||
});
|
||||
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
|
||||
const filteredTools = rankResult.filteredTools;
|
||||
const requestTools = preparedDocumentRag?.vectorStoreIds.length
|
||||
? (() => {
|
||||
const tools = [...filteredTools];
|
||||
@@ -151,7 +121,7 @@ export async function runOpenAi(
|
||||
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
||||
instructions: systemPrompt,
|
||||
};
|
||||
const response = await openAi.responses.create(request, {signal}) as OpenAiResponseLike;
|
||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
|
||||
|
||||
const responseText = collectOpenAiResponseText(response);
|
||||
streamMessage.append(responseText);
|
||||
@@ -188,12 +158,12 @@ export async function runOpenAi(
|
||||
});
|
||||
}
|
||||
|
||||
const calls = collectOpenAiResponseFunctionCalls(response);
|
||||
const calls = adapter.extractToolCalls(response);
|
||||
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
||||
round,
|
||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||
calls: calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
arguments: safeJsonParseObject(call.argumentsText)
|
||||
})),
|
||||
@@ -201,16 +171,13 @@ export async function runOpenAi(
|
||||
if (!calls.length) return;
|
||||
|
||||
const toolCalls = calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
argumentsText: call.argumentsText,
|
||||
}));
|
||||
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
||||
const toolOutputs = calls.map((call, index) => ({
|
||||
type: "function_call_output" as const,
|
||||
call_id: call.callId,
|
||||
output: toolResults[index] ?? "",
|
||||
}));
|
||||
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
|
||||
adapter.appendToolResults(toolOutputs, calls, toolResults);
|
||||
|
||||
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
||||
if (uploadFilesResult.found) {
|
||||
@@ -243,7 +210,7 @@ export async function runOpenAi(
|
||||
parallel_tool_calls: true,
|
||||
instructions: systemPrompt
|
||||
};
|
||||
const response = await openAi.responses.create(request, {signal}) as AsyncIterableStream<ResponseStreamEvent>;
|
||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
|
||||
|
||||
aiLog("debug", "openai.stream.open", {round});
|
||||
|
||||
@@ -253,7 +220,7 @@ export async function runOpenAi(
|
||||
|
||||
switch (event.type) {
|
||||
case "response.output_text.delta":
|
||||
streamMessage.append(event.delta ?? "");
|
||||
streamMessage.append(adapter.extractTextDelta(event));
|
||||
break;
|
||||
case "response.image_generation_call.in_progress":
|
||||
streamMessage.setStatus(Environment.startingImageGenText);
|
||||
@@ -301,14 +268,11 @@ export async function runOpenAi(
|
||||
case "response.code_interpreter_call_code.done":
|
||||
break;
|
||||
case "response.output_item.added":
|
||||
if (event.item.type === "function_call" && event.item.name) {
|
||||
const item = event.item as OpenAiResponseOutputItem & { id?: string };
|
||||
localToolCalls.push({
|
||||
id: openAiResponseItemCallId(item),
|
||||
name: item.name ?? "",
|
||||
argumentsText: item.arguments ?? "{}",
|
||||
});
|
||||
|
||||
{
|
||||
const streamedCalls = adapter.extractStreamingToolCalls(event);
|
||||
if (streamedCalls.length) {
|
||||
localToolCalls.push(...streamedCalls);
|
||||
}
|
||||
aiLog("info", "openai.stream.tool_call.added", {
|
||||
round,
|
||||
toolCalls: localToolCalls.map(aiLogToolCall)
|
||||
@@ -383,12 +347,12 @@ export async function runOpenAi(
|
||||
});
|
||||
}
|
||||
|
||||
const calls = collectOpenAiResponseFunctionCalls(completedResponse);
|
||||
const calls = adapter.extractToolCalls(completedResponse);
|
||||
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
||||
round,
|
||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||
calls: calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
arguments: safeJsonParseObject(call.argumentsText)
|
||||
})),
|
||||
@@ -396,16 +360,13 @@ export async function runOpenAi(
|
||||
if (!calls.length) return;
|
||||
|
||||
const toolCalls = calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
argumentsText: call.argumentsText,
|
||||
}));
|
||||
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
||||
const toolOutputs = calls.map((call, index) => ({
|
||||
type: "function_call_output",
|
||||
call_id: call.callId,
|
||||
output: toolResults[index] ?? "",
|
||||
}));
|
||||
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
|
||||
adapter.appendToolResults(toolOutputs, calls, toolResults);
|
||||
|
||||
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
||||
if (uploadFilesResult.found) {
|
||||
@@ -431,6 +392,7 @@ export async function runOpenAi(
|
||||
if (ownsDocumentRag) {
|
||||
await preparedDocumentRag?.cleanup().catch(logError);
|
||||
}
|
||||
await adapter.finalize().catch(logError);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user