diff --git a/PIPELINE_TODO.md b/PIPELINE_TODO.md index c862cfa..9f7d9bc 100644 --- a/PIPELINE_TODO.md +++ b/PIPELINE_TODO.md @@ -67,13 +67,13 @@ ## 5. Сделать tool-ranker полноценным pipeline stage - [x] Вынести вызов `ToolRanker.selectTools(...)` из provider runners. -- [ ] Добавить stage `tool_rank`, который работает через provider adapter. -- [ ] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker. -- [ ] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`. -- [ ] Сохранять `ToolRankDecision` в `request_audit.details`. +- [x] Добавить stage `tool_rank`, который работает через provider adapter. +- [x] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker. +- [x] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`. +- [x] Сохранять `ToolRankDecision` в `request_audit.details`. - [ ] Убрать дублирующий ручной `tool-rank-audit.ts`, если stage полностью заменит его. -- [ ] Сохранить status UX: `🧩 Выбираю подходящие инструменты...`. -- [ ] Гарантировать `clearStatus()` после ranker success/failure. +- [x] Сохранить status UX: `🧩 Выбираю подходящие инструменты...`. +- [x] Гарантировать `clearStatus()` после ranker success/failure. - [ ] Добавить fallback через `PipelineFallbackExecutor`: main model, all tools, no tools. - [ ] Добавить tests на fallback ranker policy. diff --git a/src/ai/tool-rank-audit.ts b/src/ai/tool-rank-audit.ts index 9d22e81..595fce0 100644 --- a/src/ai/tool-rank-audit.ts +++ b/src/ai/tool-rank-audit.ts @@ -10,7 +10,9 @@ export async function storeToolRankAudit(params: { round: number; startedAt: number; startedAtIso: string; + availableTools: string[]; selectedTools?: string[]; + usedRanker?: boolean; error?: unknown; }): Promise { const event: PipelineAuditEvent = { @@ -23,7 +25,16 @@ export async function storeToolRankAudit(params: { model: params.model, details: { round: params.round, + availableTools: params.availableTools, selectedTools: params.selectedTools ?? [], + usedRanker: params.usedRanker ?? false, + toolRankDecision: { + provider: params.provider, + round: params.round, + availableTools: params.availableTools, + selectedTools: params.selectedTools ?? [], + usedRanker: params.usedRanker ?? false, + }, }, error: params.error instanceof Error ? params.error.message : params.error ? String(params.error) : undefined, }; diff --git a/src/ai/tool-rank-stage.ts b/src/ai/tool-rank-stage.ts index cce746d..dc3a795 100644 --- a/src/ai/tool-rank-stage.ts +++ b/src/ai/tool-rank-stage.ts @@ -1,11 +1,9 @@ -import {Environment} from "../common/environment.js"; import {AiProvider} from "../model/ai-provider.js"; import type {BoundaryValue} from "../common/boundary-types.js"; import type {TelegramStreamMessage} from "./telegram-stream-message.js"; import type {RuntimeConfigSnapshot} from "./unified-ai-runner.shared.js"; -import {filterRankedTools} from "./tool-ranker-pipeline.js"; -import {ToolRanker} from "./unified-ai-runner.tool-ranker.js"; -import {storeToolRankAudit} from "./tool-rank-audit.js"; +import {allToolSchemaNames, toolSchemaNames} from "./tool-schema-utils.js"; +import type {ToolRanker} from "./unified-ai-runner.tool-ranker.js"; function latestUserText(messages: readonly { role?: string; content?: unknown }[]): string { for (let i = messages.length - 1; i >= 0; i--) { @@ -35,16 +33,33 @@ export async function runToolRankStage(params: { streamMessage: TelegramStreamMessage; signal: AbortSignal; toolRanker?: ToolRanker; + storeAudit?: (params: { + streamMessage: TelegramStreamMessage; + provider: AiProvider; + model: string; + round: number; + startedAt: number; + startedAtIso: string; + availableTools: string[]; + selectedTools?: string[]; + usedRanker?: boolean; + error?: unknown; + }) => Promise; }): Promise<{ filteredTools: BoundaryValue[]; selectedToolNames: string[]; usedRanker: boolean; }> { - const toolRanker = params.toolRanker ?? new ToolRanker(params.config); + const toolRanker = params.toolRanker ?? new (await import("./unified-ai-runner.tool-ranker.js")).ToolRanker(params.config); const startedAt = Date.now(); const startedAtIso = new Date().toISOString(); + const storeAudit = params.storeAudit ?? (await import("./tool-rank-audit.js")).storeToolRankAudit; + const filterSelectedTools = (selectedToolNames: readonly string[]): BoundaryValue[] => { + const selected = new Set(selectedToolNames); + return params.availableTools.filter(tool => toolSchemaNames(tool).some(name => selected.has(name))); + }; - params.streamMessage.setStatus(Environment.getSelectingToolsText()); + params.streamMessage.setStatus("🧩 Выбираю подходящие инструменты..."); await params.streamMessage.flush(); try { @@ -58,31 +73,34 @@ export async function runToolRankStage(params: { params.streamMessage.clearStatus(); await params.streamMessage.flush(); - await storeToolRankAudit({ + await storeAudit({ streamMessage: params.streamMessage, provider: params.provider, model: params.model, round: params.round, startedAt, startedAtIso, + availableTools: allToolSchemaNames(params.availableTools), selectedTools: selection.toolNames, + usedRanker: selection.usedRanker, }); return { - filteredTools: filterRankedTools(params.availableTools, selection.toolNames), + filteredTools: filterSelectedTools(selection.toolNames), selectedToolNames: selection.toolNames, usedRanker: selection.usedRanker, }; } catch (error) { params.streamMessage.clearStatus(); await params.streamMessage.flush(); - await storeToolRankAudit({ + await storeAudit({ streamMessage: params.streamMessage, provider: params.provider, model: params.model, round: params.round, startedAt, startedAtIso, + availableTools: allToolSchemaNames(params.availableTools), error, }); throw error; diff --git a/src/ai/tool-schema-utils.ts b/src/ai/tool-schema-utils.ts new file mode 100644 index 0000000..bb6ea8a --- /dev/null +++ b/src/ai/tool-schema-utils.ts @@ -0,0 +1,33 @@ +import type {BoundaryValue} from "../common/boundary-types.js"; + +function isRecord(value: BoundaryValue): value is Record { + return value !== null && typeof value === "object" && !Array.isArray(value); +} + +function asOptionalString(value: BoundaryValue): string | undefined { + return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined; +} + +export function toolSchemaName(tool: BoundaryValue): string | undefined { + if (!isRecord(tool)) return undefined; + const fn = isRecord(tool.function) ? tool.function : undefined; + const directName = fn?.name ?? tool.name ?? (typeof tool.type === "string" && tool.type !== "function" ? tool.type : undefined); + return asOptionalString(directName); +} + +export function toolSchemaNames(tool: BoundaryValue): string[] { + if (!isRecord(tool)) return []; + + if (Array.isArray(tool.functionDeclarations)) { + return tool.functionDeclarations + .map(declaration => isRecord(declaration) ? asOptionalString(declaration.name) : undefined) + .filter((name): name is string => !!name); + } + + const name = toolSchemaName(tool); + return name ? [name] : []; +} + +export function allToolSchemaNames(tools: readonly BoundaryValue[]): string[] { + return [...new Set(tools.flatMap(toolSchemaNames))]; +} diff --git a/src/ai/unified-ai-runner.shared.ts b/src/ai/unified-ai-runner.shared.ts index fcb65f4..8c34a08 100644 --- a/src/ai/unified-ai-runner.shared.ts +++ b/src/ai/unified-ai-runner.shared.ts @@ -5,7 +5,6 @@ import type {BoundaryValue} from "../common/boundary-types"; import {AiProvider} from "../model/ai-provider.js"; import {ToolRankerFallbackPolicy} from "../common/policies.js"; import {Environment} from "../common/environment.js"; -import {photoGenDir} from "../index.js"; import {delay, logError, replyToMessage} from "../util/utils.js"; import {MessageStore} from "../common/message-store.js"; import type {OpenAiResponseTool} from "./tool-mappers.js"; @@ -72,6 +71,10 @@ export const MAX_OLLAMA_CONTEXT_SIZE = 262144; export const DEFAULT_OLLAMA_CONTEXT_SIZE = 32768; export const toolResourceLocks = new KeyedAsyncLock(); +function photoGenDir(): string { + return path.join(Environment.DATA_PATH, "cache", "photo", "gen"); +} + export type UnifiedRunOptions = { provider: AiProvider; msg: Message; @@ -1523,7 +1526,7 @@ export function writeOpenAiGeneratedImage(sourceMessage: Message, b64: string, l } { const buffer = Buffer.from(b64, "base64"); const fileName = `${sourceMessage.chat.id}_${sourceMessage.message_id}_${Date.now()}_${label}.png`; - const cachePath = path.join(photoGenDir, fileName); + const cachePath = path.join(photoGenDir(), fileName); fs.writeFileSync(cachePath, buffer); return {buffer, cachePath, fileName}; } diff --git a/src/index.ts b/src/index.ts index 69d93f7..91145a5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,9 +1,9 @@ import "dotenv/config"; -import {appLogger} from "./logging/logger"; -import {Environment} from "./common/environment"; +import {appLogger} from "./logging/logger.js"; +import {Environment} from "./common/environment.js"; import {BotCommand, TelegramBot, User} from "typescript-telegram-bot-api"; -import {Command} from "./base/command"; -import type {LogDetails} from "./logging/logger"; +import {Command} from "./base/command.js"; +import type {LogDetails} from "./logging/logger.js"; import { initSystemSpecs, logError, @@ -13,68 +13,68 @@ import { processInlineQuery, processMyChatMember, processNewMessage -} from "./util/utils"; -import {Ae} from "./commands/ae"; -import {Help} from "./commands/help"; -import {Ignore} from "./commands/ignore"; -import {Unignore} from "./commands/unignore"; -import {Ping} from "./commands/ping"; -import {RandomString} from "./commands/random-string"; -import {SystemInfo} from "./commands/system-info"; -import {Test} from "./commands/test"; -import {readData, retrieveAnswers} from "./db/database"; -import {Uptime} from "./commands/uptime"; -import {WhatBetter} from "./commands/what-better"; -import {When} from "./commands/when"; -import {RandomInt} from "./commands/random-int"; -import {Ban} from "./commands/ban"; -import {Quote} from "./commands/quote"; -import {OllamaSearch} from "./commands/ollama-search"; -import {Id} from "./commands/id"; -import {AdminsAdd} from "./commands/admins-add"; -import {AdminsRemove} from "./commands/admins-remove"; -import {Shutdown} from "./commands/shutdown"; -import {Leave} from "./commands/leave"; -import {OllamaChat} from "./commands/ollama-chat"; -import {Start} from "./commands/start"; -import {Choice} from "./commands/choice"; -import {Coin} from "./commands/coin"; -import {Qr} from "./commands/qr"; -import {Distort} from "./commands/distort"; -import {Dice} from "./commands/dice"; -import {Unban} from "./commands/unban"; -import {Title} from "./commands/title"; -import {MessageDao} from "./db/message-dao"; -import {DatabaseManager} from "./db/database-manager"; -import {UserDao} from "./db/user-dao"; -import {UserStore} from "./common/user-store"; -import {CallbackCommand} from "./base/callback-command"; -import {AiCancel} from "./callback_commands/ai-cancel"; -import {AiRegenerate} from "./callback_commands/ai-regenerate"; -import {MistralChat} from "./commands/mistral-chat"; -import {Transliteration} from "./commands/transliteration"; -import {OllamaListModels} from "./commands/ollama-list-models"; -import {OllamaGetModel} from "./commands/ollama-get-model"; -import {OllamaSetModel} from "./commands/ollama-set-model"; -import {MistralGetModel} from "./commands/mistral-get-model"; -import {MistralSetModel} from "./commands/mistral-set-model"; -import {MistralListModels} from "./commands/mistral-list-models"; -import {Debug} from "./commands/debug"; +} from "./util/utils.js"; +import {Ae} from "./commands/ae.js"; +import {Help} from "./commands/help.js"; +import {Ignore} from "./commands/ignore.js"; +import {Unignore} from "./commands/unignore.js"; +import {Ping} from "./commands/ping.js"; +import {RandomString} from "./commands/random-string.js"; +import {SystemInfo} from "./commands/system-info.js"; +import {Test} from "./commands/test.js"; +import {readData, retrieveAnswers} from "./db/database.js"; +import {Uptime} from "./commands/uptime.js"; +import {WhatBetter} from "./commands/what-better.js"; +import {When} from "./commands/when.js"; +import {RandomInt} from "./commands/random-int.js"; +import {Ban} from "./commands/ban.js"; +import {Quote} from "./commands/quote.js"; +import {OllamaSearch} from "./commands/ollama-search.js"; +import {Id} from "./commands/id.js"; +import {AdminsAdd} from "./commands/admins-add.js"; +import {AdminsRemove} from "./commands/admins-remove.js"; +import {Shutdown} from "./commands/shutdown.js"; +import {Leave} from "./commands/leave.js"; +import {OllamaChat} from "./commands/ollama-chat.js"; +import {Start} from "./commands/start.js"; +import {Choice} from "./commands/choice.js"; +import {Coin} from "./commands/coin.js"; +import {Qr} from "./commands/qr.js"; +import {Distort} from "./commands/distort.js"; +import {Dice} from "./commands/dice.js"; +import {Unban} from "./commands/unban.js"; +import {Title} from "./commands/title.js"; +import {MessageDao} from "./db/message-dao.js"; +import {DatabaseManager} from "./db/database-manager.js"; +import {UserDao} from "./db/user-dao.js"; +import {UserStore} from "./common/user-store.js"; +import {CallbackCommand} from "./base/callback-command.js"; +import {AiCancel} from "./callback_commands/ai-cancel.js"; +import {AiRegenerate} from "./callback_commands/ai-regenerate.js"; +import {MistralChat} from "./commands/mistral-chat.js"; +import {Transliteration} from "./commands/transliteration.js"; +import {OllamaListModels} from "./commands/ollama-list-models.js"; +import {OllamaGetModel} from "./commands/ollama-get-model.js"; +import {OllamaSetModel} from "./commands/ollama-set-model.js"; +import {MistralGetModel} from "./commands/mistral-get-model.js"; +import {MistralSetModel} from "./commands/mistral-set-model.js"; +import {MistralListModels} from "./commands/mistral-list-models.js"; +import {Debug} from "./commands/debug.js"; import fs from "node:fs"; import path from "node:path"; -import {OpenAIChat} from "./commands/openai-chat"; -import {OpenAIListModels} from "./commands/openai-list-models"; -import {OpenAIGetModel} from "./commands/openai-get-model"; -import {OpenAISetModel} from "./commands/openai-set-model"; -import {Info} from "./commands/info"; -import {AdminsList} from "./commands/admins-list"; -import {ExportDb} from "./commands/export-db"; -import {ImportDb} from "./commands/import-db"; -import {Settings} from "./commands/settings"; -import {UserSettingsCallback} from "./callback_commands/user-settings"; -import {TextToSpeech} from "./commands/text-to-speech"; -import {SpeechToText} from "./commands/speech-to-text"; -import {cleanupInternalArtifactCache} from "./ai/internal-artifact-store"; +import {OpenAIChat} from "./commands/openai-chat.js"; +import {OpenAIListModels} from "./commands/openai-list-models.js"; +import {OpenAIGetModel} from "./commands/openai-get-model.js"; +import {OpenAISetModel} from "./commands/openai-set-model.js"; +import {Info} from "./commands/info.js"; +import {AdminsList} from "./commands/admins-list.js"; +import {ExportDb} from "./commands/export-db.js"; +import {ImportDb} from "./commands/import-db.js"; +import {Settings} from "./commands/settings.js"; +import {UserSettingsCallback} from "./callback_commands/user-settings.js"; +import {TextToSpeech} from "./commands/text-to-speech.js"; +import {SpeechToText} from "./commands/speech-to-text.js"; +import {cleanupInternalArtifactCache} from "./ai/internal-artifact-store.js"; process.setUncaughtExceptionCaptureCallback(logError); diff --git a/src/util/utils.ts b/src/util/utils.ts index 2a1e8c9..bbbc31c 100644 --- a/src/util/utils.ts +++ b/src/util/utils.ts @@ -1,7 +1,7 @@ import * as si from "systeminformation"; -import {appLogger} from "../logging/logger"; -import {Command} from "../base/command"; -import {CallbackCommand} from "../base/callback-command"; +import {appLogger} from "../logging/logger.js"; +import {Command} from "../base/command.js"; +import {CallbackCommand} from "../base/callback-command.js"; import { CallbackQuery, ChatMember, @@ -15,39 +15,39 @@ import { TelegramBot, User } from "typescript-telegram-bot-api"; -import {Environment} from "../common/environment"; -import {TelegramError} from "typescript-telegram-bot-api/dist/errors"; -import {bot, botUser, callbackCommands, commands, messageDao, photoDir} from "../index"; +import {Environment} from "../common/environment.js"; +import {TelegramError} from "typescript-telegram-bot-api/dist/errors.js"; +import {bot, botUser, callbackCommands, commands, messageDao, photoDir} from "../index.js"; import os from "os"; import axios from "axios"; -import {MessageAudioPart, MessageImagePart, MessagePart} from "../common/message-part"; -import {StoredMessage} from "../model/stored-message"; +import {MessageAudioPart, MessageImagePart, MessagePart} from "../common/message-part.js"; +import {StoredMessage} from "../model/stored-message.js"; import sharp from "sharp"; -import {UserStore} from "../common/user-store"; +import {UserStore} from "../common/user-store.js"; import fs from "node:fs"; import path from "node:path"; -import {MessageStore} from "../common/message-store"; -import {SystemInfo} from "../commands/system-info"; -import {PrefixResponse} from "../commands/prefix-response"; -import {ChatCommand} from "../base/chat-command"; -import {AiProvider} from "../model/ai-provider"; -import {SendOptions} from "../model/send-options"; -import {EditOptions} from "../model/edit-options"; -import {StoredUser} from "../model/stored-user"; -import {StoredAttachment} from "../model/stored-attachment"; -import {AiDownloadedFile} from "../ai/telegram-attachments"; -import {runUnifiedAi} from "../ai/unified-ai-runner"; -import {enqueueTelegramApiCall} from "./telegram-api-queue"; -import {AsyncSemaphore, KeyedAsyncLock} from "./async-lock"; -import {resolveEffectiveAiProviderForUser, resolveInterfaceLocaleForUser} from "../common/user-ai-settings"; -import {Localization} from "../common/localization"; -import {createOllamaClient, resolveAiRuntimeTarget} from "../ai/ai-runtime-target"; -import {RandomUtils} from "./random-utils"; -import {HtmlUtils} from "./html-utils"; -import {ShellCommandResult, ShellCommandRunner} from "./shell-command-runner"; -import type {BoundaryValue, ErrorLike} from "../common/boundary-types"; -import {createStoredImageAttachment, photoCachePathForUniqueId, uniqueStoredAttachments} from "../common/stored-attachment-utils"; -import {runTelegramMessageAttachmentPipeline} from "../ai/user-request-pipeline"; +import {MessageStore} from "../common/message-store.js"; +import {SystemInfo} from "../commands/system-info.js"; +import {PrefixResponse} from "../commands/prefix-response.js"; +import {ChatCommand} from "../base/chat-command.js"; +import {AiProvider} from "../model/ai-provider.js"; +import {SendOptions} from "../model/send-options.js"; +import {EditOptions} from "../model/edit-options.js"; +import {StoredUser} from "../model/stored-user.js"; +import {StoredAttachment} from "../model/stored-attachment.js"; +import {AiDownloadedFile} from "../ai/telegram-attachments.js"; +import {runUnifiedAi} from "../ai/unified-ai-runner.js"; +import {enqueueTelegramApiCall} from "./telegram-api-queue.js"; +import {AsyncSemaphore, KeyedAsyncLock} from "./async-lock.js"; +import {resolveEffectiveAiProviderForUser, resolveInterfaceLocaleForUser} from "../common/user-ai-settings.js"; +import {Localization} from "../common/localization.js"; +import {createOllamaClient, resolveAiRuntimeTarget} from "../ai/ai-runtime-target.js"; +import {RandomUtils} from "./random-utils.js"; +import {HtmlUtils} from "./html-utils.js"; +import {ShellCommandResult, ShellCommandRunner} from "./shell-command-runner.js"; +import type {BoundaryValue, ErrorLike} from "../common/boundary-types.js"; +import {createStoredImageAttachment, photoCachePathForUniqueId, uniqueStoredAttachments} from "../common/stored-attachment-utils.js"; +import {runTelegramMessageAttachmentPipeline} from "../ai/user-request-pipeline/index.js"; const imageProcessingSemaphore = new AsyncSemaphore(2); const fileWriteLocks = new KeyedAsyncLock(); diff --git a/test/tool-rank-stage.test.mjs b/test/tool-rank-stage.test.mjs new file mode 100644 index 0000000..d121372 --- /dev/null +++ b/test/tool-rank-stage.test.mjs @@ -0,0 +1,126 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const {runToolRankStage} = await import("../dist/ai/tool-rank-stage.js"); + +function createStreamMessage() { + const events = []; + const state = { + status: "", + events, + setStatus(value) { + state.status = value; + }, + clearStatus() { + state.status = ""; + }, + async flush() {}, + async storePipelineAudit(batch) { + events.push(...batch); + }, + }; + + return state; +} + +function createAuditRecorder() { + const events = []; + return { + events, + async storeAudit(params) { + events.push({ + stage: "tool_rank", + status: params.error ? "failed" : "succeeded", + details: { + round: params.round, + availableTools: params.availableTools, + selectedTools: params.selectedTools ?? [], + usedRanker: params.usedRanker ?? false, + toolRankDecision: { + provider: params.provider, + round: params.round, + availableTools: params.availableTools, + selectedTools: params.selectedTools ?? [], + usedRanker: params.usedRanker ?? false, + }, + }, + }); + }, + }; +} + +test("tool rank stage clears status after success and stores decision audit", async () => { + const streamMessage = createStreamMessage(); + const audit = createAuditRecorder(); + const result = await runToolRankStage({ + provider: "OLLAMA", + model: "test-model", + round: 0, + config: { + toolRankerFallbackPolicy: "NO_TOOLS", + }, + availableTools: [{name: "read_file"}], + messages: [{role: "user", content: "прочитай src/index.ts"}], + streamMessage, + signal: new AbortController().signal, + storeAudit: audit.storeAudit, + toolRanker: { + async selectTools() { + return { + toolNames: ["read_file"], + usedRanker: true, + }; + }, + }, + }); + + assert.deepEqual(result.selectedToolNames, ["read_file"]); + assert.deepEqual(result.filteredTools, [{name: "read_file"}]); + assert.equal(result.usedRanker, true); + assert.equal(streamMessage.status, ""); + assert.equal(audit.events.length, 1); + assert.equal(audit.events[0].stage, "tool_rank"); + assert.equal(audit.events[0].status, "succeeded"); + assert.deepEqual(audit.events[0].details.toolRankDecision, { + provider: "OLLAMA", + round: 0, + availableTools: ["read_file"], + selectedTools: ["read_file"], + usedRanker: true, + }); +}); + +test("tool rank stage clears status after failure", async () => { + const streamMessage = createStreamMessage(); + const audit = createAuditRecorder(); + await assert.rejects(() => runToolRankStage({ + provider: "OLLAMA", + model: "test-model", + round: 1, + config: { + toolRankerFallbackPolicy: "NO_TOOLS", + }, + availableTools: [{name: "read_file"}], + messages: [{role: "user", content: "прочитай src/index.ts"}], + streamMessage, + signal: new AbortController().signal, + storeAudit: audit.storeAudit, + toolRanker: { + async selectTools() { + throw new Error("ranker failed"); + }, + }, + }), /ranker failed/); + + assert.equal(streamMessage.status, ""); + assert.equal(audit.events.length, 1); + assert.equal(audit.events[0].stage, "tool_rank"); + assert.equal(audit.events[0].status, "failed"); + assert.deepEqual(audit.events[0].details.toolRankDecision, { + provider: "OLLAMA", + round: 1, + availableTools: ["read_file"], + selectedTools: [], + usedRanker: false, + }); +});