From 58f5a645fd983f59a8cf55efa8669f68e5d79998 Mon Sep 17 00:00:00 2001 From: Danil Nikolaev Date: Mon, 18 May 2026 16:23:32 +0300 Subject: [PATCH] Add tool ranker fallback policy tests --- PIPELINE_TODO.md | 2 +- src/ai/ai-runtime-target.ts | 6 +-- src/ai/tool-ranker-fallback.ts | 23 +++++++++++ src/ai/tool-ranker-pipeline.ts | 10 ++--- src/ai/tools/python-interpretator.ts | 8 ++-- src/ai/tools/tool-logger.ts | 2 +- src/ai/unified-ai-runner.tool-ranker.ts | 51 +++++++++++++------------ src/common/environment.ts | 15 ++++++-- src/common/localization.ts | 2 +- src/logging/ai-logger.ts | 2 +- src/model/ai-capability-info.ts | 2 +- src/model/ai-model-capabilities.ts | 2 +- test/tool-ranker-fallback.test.mjs | 46 ++++++++++++++++++++++ 13 files changed, 125 insertions(+), 46 deletions(-) create mode 100644 src/ai/tool-ranker-fallback.ts create mode 100644 test/tool-ranker-fallback.test.mjs diff --git a/PIPELINE_TODO.md b/PIPELINE_TODO.md index 54a4c7c..f4f79a1 100644 --- a/PIPELINE_TODO.md +++ b/PIPELINE_TODO.md @@ -75,7 +75,7 @@ - [x] Сохранить status UX: `🧩 Выбираю подходящие инструменты...`. - [x] Гарантировать `clearStatus()` после ranker success/failure. - [ ] Добавить fallback через `PipelineFallbackExecutor`: main model, all tools, no tools. -- [ ] Добавить tests на fallback ranker policy. +- [x] Добавить tests на fallback ranker policy. ## 6. Сделать model_call и tool_loop физически отдельными stages diff --git a/src/ai/ai-runtime-target.ts b/src/ai/ai-runtime-target.ts index ba45d47..173aa91 100644 --- a/src/ai/ai-runtime-target.ts +++ b/src/ai/ai-runtime-target.ts @@ -1,9 +1,9 @@ import {Mistral} from "@mistralai/mistralai"; import {Ollama} from "ollama"; import {OpenAI} from "openai"; -import {Environment} from "../common/environment"; -import {AiModelCapabilities} from "../model/ai-model-capabilities"; -import {AiProvider} from "../model/ai-provider"; +import {Environment} from "../common/environment.js"; +import {AiModelCapabilities} from "../model/ai-model-capabilities.js"; +import {AiProvider} from "../model/ai-provider.js"; export type AiCapabilityName = keyof AiModelCapabilities; export type AiRuntimePurpose = AiCapabilityName | "chat"; diff --git a/src/ai/tool-ranker-fallback.ts b/src/ai/tool-ranker-fallback.ts new file mode 100644 index 0000000..696f971 --- /dev/null +++ b/src/ai/tool-ranker-fallback.ts @@ -0,0 +1,23 @@ +import {ToolRankerFallbackPolicy} from "../common/policies.js"; + +export type ToolRankerFallbackSelection = { + toolNames: string[]; + usedRanker: boolean; +}; + +export function resolveToolRankerFallbackSelection(params: { + fallbackPolicy: ToolRankerFallbackPolicy; + availableToolNames: readonly string[]; +}): ToolRankerFallbackSelection { + if (params.fallbackPolicy === ToolRankerFallbackPolicy.NO_TOOLS) { + return { + toolNames: [], + usedRanker: false, + }; + } + + return { + toolNames: [...params.availableToolNames], + usedRanker: false, + }; +} diff --git a/src/ai/tool-ranker-pipeline.ts b/src/ai/tool-ranker-pipeline.ts index 97d2e46..8ab849b 100644 --- a/src/ai/tool-ranker-pipeline.ts +++ b/src/ai/tool-ranker-pipeline.ts @@ -1,12 +1,12 @@ -import type {BoundaryValue} from "../common/boundary-types"; -import type {AiRuntimeTarget} from "./ai-runtime-target"; -import {AiProvider} from "../model/ai-provider"; -import {RuntimeConfigSnapshot, toolSchemaNames} from "./unified-ai-runner.shared"; +import type {BoundaryValue} from "../common/boundary-types.js"; +import type {AiRuntimeTarget} from "./ai-runtime-target.js"; +import {AiProvider} from "../model/ai-provider.js"; +import {RuntimeConfigSnapshot, toolSchemaNames} from "./unified-ai-runner.shared.js"; import { buildToolRankerSystemPrompt, getToolRankerAvailableToolInfos, type ToolRankerToolInfo, -} from "./tool-ranker-metadata"; +} from "./tool-ranker-metadata.js"; export type ToolRankerMessage = { role?: string; diff --git a/src/ai/tools/python-interpretator.ts b/src/ai/tools/python-interpretator.ts index de0c3f9..ef08c44 100644 --- a/src/ai/tools/python-interpretator.ts +++ b/src/ai/tools/python-interpretator.ts @@ -2,11 +2,11 @@ import {spawn} from "node:child_process"; import {copyFile, lstat, mkdir, readdir, rm, writeFile} from "node:fs/promises"; import os from "node:os"; import path from "node:path"; -import {AiTool} from "../tool-types"; -import {Environment} from "../../common/environment"; -import {toolsLogger} from "./tool-logger"; +import {AiTool} from "../tool-types.js"; +import {Environment} from "../../common/environment.js"; +import {toolsLogger} from "./tool-logger.js"; import {randomUUID} from "node:crypto"; -import {AiJsonObject} from "../tool-types"; +import {AiJsonObject} from "../tool-types.js"; const logger = toolsLogger.child("python-interpreter"); diff --git a/src/ai/tools/tool-logger.ts b/src/ai/tools/tool-logger.ts index a056778..5d27d41 100644 --- a/src/ai/tools/tool-logger.ts +++ b/src/ai/tools/tool-logger.ts @@ -1,3 +1,3 @@ -import {appLogger} from "../../logging/logger"; +import {appLogger} from "../../logging/logger.js"; export const toolsLogger = appLogger.child("ai-tools"); diff --git a/src/ai/unified-ai-runner.tool-ranker.ts b/src/ai/unified-ai-runner.tool-ranker.ts index 4223978..c74b28d 100644 --- a/src/ai/unified-ai-runner.tool-ranker.ts +++ b/src/ai/unified-ai-runner.tool-ranker.ts @@ -1,20 +1,21 @@ import {ChatCompletionMessageParam} from "openai/resources/chat/completions"; import {ChatRequest} from "ollama"; -import {BoundaryValue} from "../common/boundary-types"; -import {ToolRankerFallbackPolicy} from "../common/policies"; -import {AiProvider} from "../model/ai-provider"; -import {createMistralClient, createOllamaClient, createOpenAiClient, sameRuntimeEndpoint} from "./ai-runtime-target"; -import {aiLog, aiLogDuration, aiLogProviderTarget} from "../logging/ai-logger"; -import {providerChatTarget, RuntimeConfigSnapshot} from "./unified-ai-runner.shared"; +import {BoundaryValue} from "../common/boundary-types.js"; +import {ToolRankerFallbackPolicy} from "../common/policies.js"; +import {AiProvider} from "../model/ai-provider.js"; +import {createMistralClient, createOllamaClient, createOpenAiClient, sameRuntimeEndpoint} from "./ai-runtime-target.js"; +import {aiLog, aiLogDuration, aiLogProviderTarget} from "../logging/ai-logger.js"; +import {providerChatTarget, RuntimeConfigSnapshot} from "./unified-ai-runner.shared.js"; import { buildRankerContext, buildRankerTarget, buildToolRankerPrompt, filterRankedTools, ToolRankerSelection, -} from "./tool-ranker-pipeline"; -import {allToolSchemaNames} from "./unified-ai-runner.shared"; -import {sanitizeToolRankerResult} from "./tool-ranker-metadata"; +} from "./tool-ranker-pipeline.js"; +import {allToolSchemaNames} from "./unified-ai-runner.shared.js"; +import {sanitizeToolRankerResult} from "./tool-ranker-metadata.js"; +import {resolveToolRankerFallbackSelection} from "./tool-ranker-fallback.js"; export class ToolRanker { constructor(private readonly config: RuntimeConfigSnapshot) { @@ -27,8 +28,15 @@ export class ToolRanker { round: number; signal: AbortSignal; messages?: readonly { role?: string; content?: string | readonly { text?: string }[] }[]; + runRanker?: ( + provider: AiProvider, + target: NonNullable>, + prompt: string, + userQuery: string, + ) => Promise; }): Promise { const {availableTools, provider, round, signal, userQuery} = args; + const runRanker = args.runRanker ?? this.runRanker.bind(this); const availableNames = allToolSchemaNames(availableTools); const fallbackPolicy = this.config.toolRankerFallbackPolicy; const configuredTarget = buildRankerTarget(this.config, provider); @@ -41,11 +49,10 @@ export class ToolRanker { const target = configuredTarget ?? (fallbackPolicy === ToolRankerFallbackPolicy.MAIN_MODEL ? mainModelTarget : undefined); if (!target) { - if (fallbackPolicy === ToolRankerFallbackPolicy.NO_TOOLS) { - return {toolNames: [], usedRanker: false}; - } - - return {toolNames: availableNames, usedRanker: false}; + return resolveToolRankerFallbackSelection({ + fallbackPolicy, + availableToolNames: availableNames, + }); } const startedAt = Date.now(); @@ -63,7 +70,7 @@ export class ToolRanker { try { if (signal.aborted) throw new Error("Aborted"); - const raw = await this.runRanker(provider, target, ranker.prompt, userQuery); + const raw = await runRanker(provider, target, ranker.prompt, userQuery); if (signal.aborted) throw new Error("Aborted"); const selectedNames = sanitizeToolRankerResult({ raw, @@ -106,7 +113,7 @@ export class ToolRanker { const fallbackRanker = buildToolRankerPrompt( buildRankerContext(this.config, provider, mainModelTarget, round, userQuery, availableTools), ); - const raw = await this.runRanker(provider, mainModelTarget, fallbackRanker.prompt, userQuery); + const raw = await runRanker(provider, mainModelTarget, fallbackRanker.prompt, userQuery); const selectedNames = sanitizeToolRankerResult({ raw, availableToolNames: availableNames, @@ -151,14 +158,10 @@ export class ToolRanker { error: failureMessage, }); - if (fallbackPolicy === ToolRankerFallbackPolicy.NO_TOOLS) { - return {toolNames: [], usedRanker: false}; - } - - return { - toolNames: availableNames, - usedRanker: false, - }; + return resolveToolRankerFallbackSelection({ + fallbackPolicy, + availableToolNames: availableNames, + }); } } diff --git a/src/common/environment.ts b/src/common/environment.ts index d91140e..dc45976 100644 --- a/src/common/environment.ts +++ b/src/common/environment.ts @@ -6,9 +6,7 @@ import {z} from "zod"; import {appLogger} from "../logging/logger.js"; import type {BoundaryValue, ErrorLike} from "./boundary-types"; -import {saveData} from "../db/database.js"; import {Answers} from "../model/answers.js"; -import {ifTrue} from "../util/utils.js"; import {AiProvider} from "../model/ai-provider.js"; import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies.js"; import {ToolRankerFallbackPolicy} from "./policies.js"; @@ -16,6 +14,11 @@ import type {ToolCallData} from "../ai/unified-ai-runner.js"; import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator.js"; import {Localization, type LocalizationParams} from "./localization.js"; +function parseBooleanLike(value: string): boolean { + const normalized = value.trim().toLowerCase(); + return ["true", "t", "y", "1"].includes(normalized); +} + type EnvRecord = Record; type StringEnumLike = Record; type StringEnumValue = T[keyof T]; @@ -53,7 +56,7 @@ function booleanWithDefaultSchema(defaultValue: boolean) { return defaultValue; } - return ifTrue(normalized); + return parseBooleanLike(normalized); }, z.boolean()) .default(defaultValue) .catch(defaultValue); @@ -62,7 +65,7 @@ function booleanWithDefaultSchema(defaultValue: boolean) { const optionalBooleanSchema = z .preprocess(value => { const normalized = normalizeString(value as BoundaryValue); - return normalized === undefined ? undefined : ifTrue(normalized); + return normalized === undefined ? undefined : parseBooleanLike(normalized); }, z.boolean().optional()) .optional() .catch(undefined); @@ -1939,6 +1942,7 @@ export class Environment { if (!has) { this.ADMIN_IDS.add(id); + const {saveData} = await import("../db/database.js"); await saveData(); } @@ -1950,6 +1954,7 @@ export class Environment { if (has) { this.ADMIN_IDS.delete(id); + const {saveData} = await import("../db/database.js"); await saveData(); } @@ -1966,6 +1971,7 @@ export class Environment { } this.MUTED_IDS.add(id); + const {saveData} = await import("../db/database.js"); await saveData(); return true; } @@ -1976,6 +1982,7 @@ export class Environment { } this.MUTED_IDS.delete(id); + const {saveData} = await import("../db/database.js"); await saveData(); return true; } diff --git a/src/common/localization.ts b/src/common/localization.ts index cab74d0..edf84a4 100644 --- a/src/common/localization.ts +++ b/src/common/localization.ts @@ -1,7 +1,7 @@ import {AsyncLocalStorage} from "node:async_hooks"; import fs from "node:fs"; import path from "node:path"; -import {appLogger} from "../logging/logger"; +import {appLogger} from "../logging/logger.js"; const logger = appLogger.child("localization"); diff --git a/src/logging/ai-logger.ts b/src/logging/ai-logger.ts index f6be159..d0b5348 100644 --- a/src/logging/ai-logger.ts +++ b/src/logging/ai-logger.ts @@ -1,5 +1,5 @@ import {Message} from "typescript-telegram-bot-api"; -import {createLogger, formatDuration, LogDetails, LogLevel} from "./logger"; +import {createLogger, formatDuration, LogDetails, LogLevel} from "./logger.js"; export type AiRunnerLogLevel = LogLevel; export type AiRunnerLogDetails = LogDetails; diff --git a/src/model/ai-capability-info.ts b/src/model/ai-capability-info.ts index 291374a..19afc5d 100644 --- a/src/model/ai-capability-info.ts +++ b/src/model/ai-capability-info.ts @@ -1,4 +1,4 @@ -import {AiProvider} from "./ai-provider"; +import {AiProvider} from "./ai-provider.js"; export type AiEndpointInfo = { provider?: AiProvider; diff --git a/src/model/ai-model-capabilities.ts b/src/model/ai-model-capabilities.ts index 175171d..c3b8db4 100644 --- a/src/model/ai-model-capabilities.ts +++ b/src/model/ai-model-capabilities.ts @@ -1,4 +1,4 @@ -import {AiCapabilityInfo} from "./ai-capability-info"; +import {AiCapabilityInfo} from "./ai-capability-info.js"; export class AiModelCapabilities { chat: AiCapabilityInfo | undefined; diff --git a/test/tool-ranker-fallback.test.mjs b/test/tool-ranker-fallback.test.mjs new file mode 100644 index 0000000..2662c6c --- /dev/null +++ b/test/tool-ranker-fallback.test.mjs @@ -0,0 +1,46 @@ +import test from "node:test"; +import assert from "node:assert/strict"; + +const {ToolRankerFallbackPolicy} = await import("../dist/common/policies.js"); +const {resolveToolRankerFallbackSelection} = await import("../dist/ai/tool-ranker-fallback.js"); + +const availableToolNames = ["read_file", "search_files"]; + +test("tool ranker fallback returns no tools when policy is NO_TOOLS", () => { + assert.deepEqual( + resolveToolRankerFallbackSelection({ + fallbackPolicy: ToolRankerFallbackPolicy.NO_TOOLS, + availableToolNames, + }), + { + toolNames: [], + usedRanker: false, + }, + ); +}); + +test("tool ranker fallback returns all tools when policy is ALL_TOOLS", () => { + assert.deepEqual( + resolveToolRankerFallbackSelection({ + fallbackPolicy: ToolRankerFallbackPolicy.ALL_TOOLS, + availableToolNames, + }), + { + toolNames: ["read_file", "search_files"], + usedRanker: false, + }, + ); +}); + +test("tool ranker fallback keeps all tools when policy is MAIN_MODEL", () => { + assert.deepEqual( + resolveToolRankerFallbackSelection({ + fallbackPolicy: ToolRankerFallbackPolicy.MAIN_MODEL, + availableToolNames, + }), + { + toolNames: ["read_file", "search_files"], + usedRanker: false, + }, + ); +});