Add unified request pipeline stages

This commit is contained in:
2026-05-18 15:45:39 +03:00
parent 8cff086a8e
commit 8aede4b053
18 changed files with 905 additions and 509 deletions
+15 -15
View File
@@ -46,27 +46,27 @@
- [x] Для Mistral сохранять `libraryId`.
- [x] Для Mistral сохранять uploaded document ids.
- [x] Для Mistral сохранять source file mapping: local attachment -> Mistral document id.
- [ ] Добавить единый `providerState` schema для всех providers.
- [ ] Добавить tests на сериализацию `RagArtifact`.
- [ ] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context.
- [x] Добавить единый `providerState` schema для всех providers.
- [x] Добавить tests на сериализацию `RagArtifact`.
- [x] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context.
## 4. Вынести provider runners в adapter layer
- [ ] Ввести интерфейс `AiProviderAdapter`.
- [ ] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`.
- [ ] Реализовать `OpenAiProviderAdapter`.
- [ ] Реализовать `MistralProviderAdapter`.
- [ ] Реализовать `OllamaProviderAdapter`.
- [ ] Перенести provider-specific tool schema mapping внутрь adapter-ов.
- [ ] Перенести provider-specific streaming parsing внутрь adapter-ов.
- [ ] Перенести provider-specific tool result append внутрь adapter-ов.
- [ ] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом.
- [ ] Оставить compatibility wrappers для текущих imports.
- [ ] Добавить tests на adapter contract без реальных API.
- [x] Ввести интерфейс `AiProviderAdapter`.
- [x] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`.
- [x] Реализовать `OpenAiProviderAdapter`.
- [x] Реализовать `MistralProviderAdapter`.
- [x] Реализовать `OllamaProviderAdapter`.
- [x] Перенести provider-specific tool schema mapping внутрь adapter-ов.
- [x] Перенести provider-specific streaming parsing внутрь adapter-ов.
- [x] Перенести provider-specific tool result append внутрь adapter-ов.
- [x] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом.
- [x] Оставить compatibility wrappers для текущих imports.
- [x] Добавить tests на adapter contract без реальных API.
## 5. Сделать tool-ranker полноценным pipeline stage
- [ ] Вынести вызов `ToolRanker.selectTools(...)` из provider runners.
- [x] Вынести вызов `ToolRanker.selectTools(...)` из provider runners.
- [ ] Добавить stage `tool_rank`, который работает через provider adapter.
- [ ] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker.
- [ ] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`.
+112
View File
@@ -0,0 +1,112 @@
import type {ToolCallData} from "./unified-ai-runner.shared.js";
import type {ResponseStreamEvent} from "openai/resources/responses/responses";
function isRecord(value: unknown): value is Record<string, unknown> {
return !!value && typeof value === "object" && !Array.isArray(value);
}
function normalizeToolCallId(value: unknown, fallback: string): string {
return typeof value === "string" && value.trim().length > 0 ? value : fallback;
}
function normalizeToolArguments(value: unknown): string {
if (typeof value === "string") return value;
return JSON.stringify(value ?? {});
}
export function extractOpenAiToolCalls(response: unknown): ToolCallData[] {
const output = isRecord(response) && Array.isArray(response.output) ? response.output : [];
return output
.filter(item => isRecord(item) && item.type === "function_call" && (typeof item.call_id === "string" || typeof item.name === "string"))
.map((item, index) => ({
id: normalizeToolCallId(item.call_id, `openai_${index}`),
name: typeof item.name === "string" ? item.name : "",
argumentsText: normalizeToolArguments(item.arguments),
}))
.filter(call => call.name.length > 0);
}
export function extractOpenAiTextDelta(input: unknown): string {
const event = input as ResponseStreamEvent | undefined;
return event?.type === "response.output_text.delta" ? event.delta ?? "" : "";
}
export function extractOpenAiStreamingToolCalls(input: unknown): ToolCallData[] {
const event = input as ResponseStreamEvent | undefined;
if (event?.type === "response.output_item.added" && isRecord(event.item) && event.item.type === "function_call") {
return extractOpenAiToolCalls({
output: [{
type: "function_call",
call_id: event.item.call_id ?? event.item.id,
name: event.item.name,
arguments: event.item.arguments,
}],
});
}
return [];
}
export function extractMistralToolCalls(calls: unknown): ToolCallData[] {
const normalized = Array.isArray(calls)
? calls
: isRecord(calls) && (Array.isArray(calls.toolCalls) || Array.isArray(calls.tool_calls))
? (calls.toolCalls ?? calls.tool_calls)
: [];
if (!Array.isArray(normalized)) return [];
return normalized
.map((item, index) => {
const call = isRecord(item) ? item : {};
const fn = isRecord(call.function) ? call.function : undefined;
const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : "";
return {
id: normalizeToolCallId(call.id, `mistral_${index}`),
name,
argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments),
};
})
.filter(call => call.name.length > 0);
}
export function extractMistralTextDelta(input: unknown): string {
const delta = isRecord(input) ? input : {};
const content = delta.content;
if (typeof content === "string") return content;
if (Array.isArray(content)) {
return content
.map(part => isRecord(part) && typeof part.text === "string" ? part.text : "")
.join("");
}
return "";
}
export function extractOllamaToolCalls(calls: unknown): ToolCallData[] {
const normalized = Array.isArray(calls)
? calls
: isRecord(calls) && Array.isArray(calls.tool_calls)
? calls.tool_calls
: [];
if (!Array.isArray(normalized)) return [];
return normalized
.map((item, index) => {
const call = isRecord(item) ? item : {};
const fn = isRecord(call.function) ? call.function : undefined;
const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : "";
return {
id: normalizeToolCallId(call.id, `ollama_${index}`),
name,
argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments),
};
})
.filter(call => call.name.length > 0);
}
export function extractOllamaTextDelta(input: unknown): string {
const chunk = isRecord(input) ? input.message : undefined;
return isRecord(chunk) && typeof chunk.content === "string" ? chunk.content : "";
}
+196
View File
@@ -0,0 +1,196 @@
import {AiProvider} from "../model/ai-provider.js";
import type {BoundaryValue} from "../common/boundary-types.js";
import type {RuntimeConfigSnapshot, ToolCallData} from "./unified-ai-runner.shared.js";
import {getMistralTools, getOllamaTools, getOpenAIResponsesTools, getOpenAICodeInterpreterTool} from "./tool-mappers.js";
import type {MistralChatMessage as MistralMessageType} from "./mistral-chat-message.js";
import type {OpenAIChatMessage as OpenAiMessageType} from "./openai-chat-message.js";
import type {Message as OllamaMessage} from "ollama";
import {
extractMistralTextDelta,
extractMistralToolCalls,
extractOllamaTextDelta,
extractOllamaToolCalls,
extractOpenAiTextDelta,
extractOpenAiStreamingToolCalls,
extractOpenAiToolCalls,
} from "./provider-adapter-contract.js";
export type ProviderRankToolOptions = {
forCreator?: boolean;
vectorStoreIds?: string[];
};
export interface AiProviderAdapter {
readonly provider: AiProvider;
mapMessages(messages: readonly unknown[]): unknown[];
rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[];
callModel<T>(request: unknown, execute: () => Promise<T>): Promise<T>;
extractTextDelta(input: unknown): string;
extractToolCalls(input: unknown): ToolCallData[];
extractStreamingToolCalls(input: unknown): ToolCallData[];
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void;
finalize(): Promise<void>;
}
function appendOllamaToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
for (const [index, call] of calls.entries()) {
messages.push({
role: "tool",
content: results[index] ?? "",
tool_name: call.name,
});
}
}
class OpenAiProviderAdapter implements AiProviderAdapter {
readonly provider = AiProvider.OPENAI;
mapMessages(messages: readonly unknown[]): unknown[] {
return messages as OpenAiMessageType[];
}
rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
const tools: BoundaryValue[] = [
...getOpenAIResponsesTools(options?.forCreator) as BoundaryValue[],
getOpenAICodeInterpreterTool() as BoundaryValue,
{
type: "image_generation",
model: config.openAiImageTarget.model,
size: "auto",
moderation: "low",
output_format: "png",
partial_images: 3,
},
{type: "web_search"},
];
if (options?.vectorStoreIds?.length) {
tools.unshift({
type: "file_search",
vector_store_ids: options.vectorStoreIds,
});
}
return tools;
}
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
return execute();
}
extractTextDelta(input: unknown): string {
return extractOpenAiTextDelta(input);
}
extractToolCalls(input: unknown): ToolCallData[] {
return extractOpenAiToolCalls(input);
}
extractStreamingToolCalls(input: unknown): ToolCallData[] {
return extractOpenAiStreamingToolCalls(input);
}
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
for (const [index, call] of calls.entries()) {
messages.push({
type: "function_call_output",
call_id: call.id,
output: results[index] ?? "",
});
}
}
async finalize(): Promise<void> {
return;
}
}
class MistralProviderAdapter implements AiProviderAdapter {
readonly provider = AiProvider.MISTRAL;
mapMessages(messages: readonly unknown[]): unknown[] {
return messages as MistralMessageType[];
}
rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
return getMistralTools(options?.forCreator) as BoundaryValue[];
}
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
return execute();
}
extractTextDelta(input: unknown): string {
return extractMistralTextDelta(input);
}
extractToolCalls(input: unknown): ToolCallData[] {
return extractMistralToolCalls(input);
}
extractStreamingToolCalls(input: unknown): ToolCallData[] {
return this.extractToolCalls(input);
}
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
for (const [index, call] of calls.entries()) {
messages.push({
role: "tool",
name: call.name,
toolCallId: call.id,
content: results[index] ?? "",
});
}
}
async finalize(): Promise<void> {
return;
}
}
class OllamaProviderAdapter implements AiProviderAdapter {
readonly provider = AiProvider.OLLAMA;
mapMessages(messages: readonly unknown[]): unknown[] {
return messages as OllamaMessage[];
}
rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
return getOllamaTools(options?.forCreator) as BoundaryValue[];
}
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
return execute();
}
extractTextDelta(input: unknown): string {
return extractOllamaTextDelta(input);
}
extractToolCalls(input: unknown): ToolCallData[] {
return extractOllamaToolCalls(input);
}
extractStreamingToolCalls(input: unknown): ToolCallData[] {
return this.extractToolCalls(input);
}
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
appendOllamaToolResults(messages, calls, results);
}
async finalize(): Promise<void> {
return;
}
}
export function getProviderAdapter(provider: AiProvider): AiProviderAdapter {
switch (provider) {
case AiProvider.OPENAI:
return new OpenAiProviderAdapter();
case AiProvider.MISTRAL:
return new MistralProviderAdapter();
case AiProvider.OLLAMA:
return new OllamaProviderAdapter();
}
}
+77
View File
@@ -0,0 +1,77 @@
import type {AiProvider} from "../model/ai-provider";
export type RagArtifactSource = {
fileId: string;
fileName: string;
mimeType?: string;
sizeBytes?: number;
sha256?: string;
uploadedFileId?: string;
documentId?: string;
};
export type RagArtifactPayload = {
artifactKind: "rag";
provider: AiProvider;
createdAt: string;
sources: RagArtifactSource[];
providerState:
| {
provider: AiProvider.OPENAI;
vectorStoreIds: string[];
uploadedFileIds: string[];
}
| {
provider: AiProvider.MISTRAL;
libraryId?: string;
documentCount: number;
}
| {
provider: AiProvider.OLLAMA;
prepared: boolean;
embeddingModel?: string;
topK?: number;
chunkSize?: number;
chunkOverlap?: number;
maxContextChars?: number;
extractedDocuments: Array<{
documentIndex: number;
fileName: string;
textChars: number;
}>;
selectedChunks: Array<{
sourceId: string;
documentIndex: number;
documentName: string;
chunkIndex: number;
chunkCount: number;
textChars: number;
score?: number;
}>;
skippedDocuments: Array<{
documentIndex: number;
fileName: string;
reason: string;
}>;
query: string;
minScore: number;
maxArchiveFiles: number;
maxArchiveBytes: number;
maxArchiveDepth: number;
};
};
export function buildRagArtifactPayload(params: {
provider: AiProvider;
createdAt?: string;
sources: RagArtifactSource[];
providerState: RagArtifactPayload["providerState"];
}): RagArtifactPayload {
return {
artifactKind: "rag",
provider: params.provider,
createdAt: params.createdAt ?? new Date().toISOString(),
sources: params.sources,
providerState: params.providerState,
};
}
+15 -68
View File
@@ -4,75 +4,39 @@ import type {AiDownloadedFile} from "./telegram-attachments";
import type {PreparedDocumentRag} from "./document-rag-pipeline";
import type {OllamaRagArtifactDetails} from "./ollama-rag";
import {persistInternalJsonArtifactAttachment} from "./internal-artifact-store";
type RagArtifactPayload = {
artifactKind: "rag";
provider: AiProvider;
createdAt: string;
sources: Array<{
fileId: string;
fileName: string;
mimeType?: string;
sizeBytes?: number;
sha256?: string;
uploadedFileId?: string;
documentId?: string;
}>;
providerState: {
vectorStoreIds?: string[];
libraryId?: string;
documentCount?: number;
prepared?: boolean;
uploadedFileIds?: string[];
embeddingModel?: string;
topK?: number;
chunkSize?: number;
chunkOverlap?: number;
maxContextChars?: number;
extractedDocuments?: Array<{
documentIndex: number;
fileName: string;
textChars: number;
}>;
selectedChunks?: Array<{
sourceId: string;
documentIndex: number;
documentName: string;
chunkIndex: number;
chunkCount: number;
textChars: number;
score?: number;
}>;
skippedDocuments?: Array<{
documentIndex: number;
fileName: string;
reason: string;
}>;
query?: string;
ollama?: OllamaRagArtifactDetails["providerState"];
};
};
import {buildRagArtifactPayload, type RagArtifactPayload} from "./rag-artifact-payload";
function providerState(prepared: PreparedDocumentRag, details?: NonNullable<Parameters<typeof persistRagArtifactAttachment>[0]["details"]>): RagArtifactPayload["providerState"] {
switch (prepared.provider) {
case AiProvider.OPENAI:
return {
provider: AiProvider.OPENAI,
vectorStoreIds: prepared.vectorStoreIds,
uploadedFileIds: prepared.uploadedFileIds,
};
case AiProvider.MISTRAL:
return {
provider: AiProvider.MISTRAL,
libraryId: prepared.libraryId,
documentCount: prepared.documents.length,
};
case AiProvider.OLLAMA:
return {
provider: AiProvider.OLLAMA,
prepared: prepared.prepared,
embeddingModel: details?.embeddingModel,
topK: details?.topK,
chunkSize: details?.chunkSize,
chunkOverlap: details?.chunkOverlap,
maxContextChars: details?.maxContextChars,
extractedDocuments: details?.artifact?.extractedDocuments ?? [],
selectedChunks: details?.artifact?.selectedChunks ?? [],
skippedDocuments: details?.artifact?.skippedDocuments ?? [],
query: details?.artifact?.query ?? "",
minScore: details?.artifact?.providerState?.minScore ?? 0,
maxArchiveFiles: details?.artifact?.providerState?.maxArchiveFiles ?? 0,
maxArchiveBytes: details?.artifact?.providerState?.maxArchiveBytes ?? 0,
maxArchiveDepth: details?.artifact?.providerState?.maxArchiveDepth ?? 0,
};
}
}
@@ -117,22 +81,11 @@ export async function persistRagArtifactAttachment(params: {
if (!sources.length) return Promise.resolve(undefined);
const payload: RagArtifactPayload = {
artifactKind: "rag",
const payload = buildRagArtifactPayload({
provider: params.provider,
createdAt: new Date().toISOString(),
sources,
providerState: {
...providerState(params.prepared, params.details),
...(params.details?.artifact ? {
extractedDocuments: params.details.artifact.extractedDocuments,
selectedChunks: params.details.artifact.selectedChunks,
skippedDocuments: params.details.artifact.skippedDocuments,
query: params.details.artifact.query,
ollama: params.details.artifact.providerState,
} : {}),
},
};
providerState: providerState(params.prepared, params.details),
});
return await persistInternalJsonArtifactAttachment({
artifactKind: "rag",
fileNamePrefix: "rag",
@@ -140,14 +93,8 @@ export async function persistRagArtifactAttachment(params: {
messageId: params.messageId,
payload,
metadata: {
provider: params.provider,
sourceFileNames: sources.map(source => source.fileName),
...payload.providerState,
embeddingModel: params.details?.embeddingModel,
topK: params.details?.topK,
chunkSize: params.details?.chunkSize,
chunkOverlap: params.details?.chunkOverlap,
maxContextChars: params.details?.maxContextChars,
},
});
}
+4 -4
View File
@@ -1,8 +1,8 @@
import {AiTool} from "./tool-types";
import {AiProvider} from "../model/ai-provider";
import {getTools} from "./tools/registry";
import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search";
import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator";
import {AiProvider} from "../model/ai-provider.js";
import {getTools} from "./tools/registry.js";
import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search.js";
import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator.js";
export type AiProviderName = "ollama" | "openai" | "mistral";
+90
View File
@@ -0,0 +1,90 @@
import {Environment} from "../common/environment.js";
import {AiProvider} from "../model/ai-provider.js";
import type {BoundaryValue} from "../common/boundary-types.js";
import type {TelegramStreamMessage} from "./telegram-stream-message.js";
import type {RuntimeConfigSnapshot} from "./unified-ai-runner.shared.js";
import {filterRankedTools} from "./tool-ranker-pipeline.js";
import {ToolRanker} from "./unified-ai-runner.tool-ranker.js";
import {storeToolRankAudit} from "./tool-rank-audit.js";
function latestUserText(messages: readonly { role?: string; content?: unknown }[]): string {
for (let i = messages.length - 1; i >= 0; i--) {
const message = messages[i];
if (message?.role !== "user") continue;
if (typeof message.content === "string") return message.content;
if (Array.isArray(message.content)) {
return message.content
.map(part => typeof part === "object" && part !== null && "text" in part && typeof (part as { text?: unknown }).text === "string"
? (part as { text: string }).text
: "")
.filter(Boolean)
.join("\n");
}
}
return "";
}
export async function runToolRankStage(params: {
provider: AiProvider;
model: string;
round: number;
config: RuntimeConfigSnapshot;
availableTools: readonly BoundaryValue[];
messages: readonly { role?: string; content?: unknown }[];
streamMessage: TelegramStreamMessage;
signal: AbortSignal;
toolRanker?: ToolRanker;
}): Promise<{
filteredTools: BoundaryValue[];
selectedToolNames: string[];
usedRanker: boolean;
}> {
const toolRanker = params.toolRanker ?? new ToolRanker(params.config);
const startedAt = Date.now();
const startedAtIso = new Date().toISOString();
params.streamMessage.setStatus(Environment.getSelectingToolsText());
await params.streamMessage.flush();
try {
const selection = await toolRanker.selectTools({
provider: params.provider,
userQuery: latestUserText(params.messages),
availableTools: params.availableTools,
round: params.round,
signal: params.signal,
});
params.streamMessage.clearStatus();
await params.streamMessage.flush();
await storeToolRankAudit({
streamMessage: params.streamMessage,
provider: params.provider,
model: params.model,
round: params.round,
startedAt,
startedAtIso,
selectedTools: selection.toolNames,
});
return {
filteredTools: filterRankedTools(params.availableTools, selection.toolNames),
selectedToolNames: selection.toolNames,
usedRanker: selection.usedRanker,
};
} catch (error) {
params.streamMessage.clearStatus();
await params.streamMessage.flush();
await storeToolRankAudit({
streamMessage: params.streamMessage,
provider: params.provider,
model: params.model,
round: params.round,
startedAt,
startedAtIso,
error,
});
throw error;
}
}
+64
View File
@@ -2,6 +2,7 @@ import {AiProvider} from "../model/ai-provider";
import {Environment} from "../common/environment";
import {ifTrue, logError} from "../util/utils";
import {UserRequestPipeline, type UserRequestPipelineState, type UserRequestPipelineStage} from "./user-request-pipeline";
import {getProviderAdapter} from "./provider-adapters";
import type {AiDownloadedFile} from "./telegram-attachments";
import type {TelegramStreamMessage} from "./telegram-stream-message";
import type {PreparedUnifiedAiRequest} from "./unified-ai-request-pipeline";
@@ -9,12 +10,14 @@ import type {OpenAIChatMessage} from "./openai-chat-message";
import type {MistralChatMessage} from "./mistral-chat-message";
import type {ChatMessage} from "./chat-messages-types";
import {
allToolSchemaNames,
providerName,
RuntimeConfigSnapshot,
snapshotModel,
TELEGRAM_LIMIT,
UnifiedRunOptions,
} from "./unified-ai-runner.shared";
import {runToolRankStage} from "./tool-rank-stage";
import {runOpenAi} from "./unified-ai-runner.openai";
import {runOllama} from "./unified-ai-runner.ollama";
import {runMistral} from "./unified-ai-runner.mistral";
@@ -159,6 +162,9 @@ export async function runUnifiedAiResponsePipeline(params: {
}): Promise<void> {
const {options, config, downloads, prepared, streamMessage, controller} = params;
const state = createResponsePipelineState(options);
const adapter = getProviderAdapter(options.provider);
let selectedToolNames: string[] = [];
let filteredTools: unknown[] = [];
const stages: UserRequestPipelineStage[] = [
{
@@ -177,6 +183,62 @@ export async function runUnifiedAiResponsePipeline(params: {
};
},
},
{
name: "tool_rank",
async run() {
const availableTools = adapter.rankTools(config, {
forCreator: options.msg.from?.id === Environment.CREATOR_ID,
vectorStoreIds: prepared.preparedDocumentRag?.provider === AiProvider.OPENAI
? prepared.preparedDocumentRag.vectorStoreIds
: [],
});
const rankResult = await runToolRankStage({
provider: options.provider,
model: snapshotModel(options.provider, config),
round: state.toolRankDecisions.length,
config,
availableTools,
messages: prepared.chatMessages,
streamMessage,
signal: controller.signal,
});
selectedToolNames = rankResult.selectedToolNames;
filteredTools = rankResult.filteredTools;
state.toolRankDecisions.push({
provider: options.provider,
round: state.toolRankDecisions.length,
availableTools: allToolSchemaNames(availableTools),
selectedTools: selectedToolNames,
usedRanker: rankResult.usedRanker,
});
return {
stage: "tool_rank",
status: "succeeded",
details: {
selectedTools: selectedToolNames,
usedRanker: rankResult.usedRanker,
availableTools: allToolSchemaNames(availableTools),
toolRankDecision: state.toolRankDecisions.at(-1),
},
};
},
},
{
name: "filter_tools",
async run() {
return {
stage: "filter_tools",
status: "succeeded",
details: {
selectedTools: selectedToolNames,
filteredToolCount: filteredTools.length,
},
};
},
},
{
name: "model_call",
async run() {
@@ -312,6 +374,8 @@ export async function runUnifiedAiResponsePipeline(params: {
stages,
stageNames: [
"audit_start",
"tool_rank",
"filter_tools",
"model_call",
"tool_loop",
"output_size_gate",
+40 -68
View File
@@ -1,21 +1,17 @@
import {Environment} from "../common/environment";
import {getMistralTools} from "./tool-mappers";
import {TelegramStreamMessage} from "./telegram-stream-message";
import {ToolRuntimeContext} from "./tools/runtime";
import {MistralChatMessage} from "./mistral-chat-message";
import {createMistralClient} from "./ai-runtime-target";
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
import {AiProvider} from "../model/ai-provider";
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
import {getProviderAdapter} from "./provider-adapters";
import {runToolRankStage} from "./tool-rank-stage";
import {
contentFromMistralDelta,
executeToolBatch,
MAX_TOOL_ROUNDS,
MistralDeltaLike,
MistralDocumentReference,
mistralToolCalls,
normalizeMistralToolCalls,
roundStatus,
RuntimeConfigSnapshot,
StreamingToolCallAccumulator,
@@ -23,8 +19,6 @@ import {
ToolExecutionMemory
} from "./unified-ai-runner.shared";
import {Message} from "typescript-telegram-bot-api";
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
import {storeToolRankAudit} from "./tool-rank-audit";
export async function runMistral(
msg: Message,
@@ -39,8 +33,9 @@ export async function runMistral(
): Promise<void> {
const runnerStartedAt = Date.now();
const mistralAi = createMistralClient(config.mistralChatTarget);
const toolRanker = new ToolRanker(config);
const availableTools = getMistralTools(msg.from?.id === Environment.CREATOR_ID);
const adapter = getProviderAdapter(AiProvider.MISTRAL);
const availableTools = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID});
const requestMessages = adapter.mapMessages([...messages]) as unknown as MistralChatMessage[];
aiLog("info", "mistral.run.start", {
stream,
target: aiLogProviderTarget(config.mistralChatTarget),
@@ -50,49 +45,23 @@ export async function runMistral(
});
const toolMemory: ToolExecutionMemory = new Map();
try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
const roundStartedAt = Date.now();
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
if (signal.aborted) throw new Error("Aborted");
streamMessage.setStatus(Environment.getSelectingToolsText());
await streamMessage.flush();
const toolRankStartedAt = Date.now();
const toolRankStartedAtIso = new Date().toISOString();
const rankerSelection = await toolRanker.selectTools({
const rankResult = await runToolRankStage({
provider: AiProvider.MISTRAL,
userQuery: latestUserTextFromMessages(messages),
model: config.mistralChatTarget.model,
round,
config,
availableTools,
round,
messages,
streamMessage,
signal,
})
.catch(async error => {
streamMessage.clearStatus();
await streamMessage.flush();
await storeToolRankAudit({
streamMessage,
provider: AiProvider.MISTRAL,
model: config.mistralChatTarget.model,
round,
startedAt: toolRankStartedAt,
startedAtIso: toolRankStartedAtIso,
error,
});
throw error;
});
streamMessage.clearStatus();
await streamMessage.flush();
await storeToolRankAudit({
streamMessage,
provider: AiProvider.MISTRAL,
model: config.mistralChatTarget.model,
round,
startedAt: toolRankStartedAt,
startedAtIso: toolRankStartedAtIso,
selectedTools: rankerSelection.toolNames,
});
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
const filteredTools = rankResult.filteredTools;
const requestTools = filteredTools.length ? filteredTools : undefined;
streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? "");
@@ -101,15 +70,15 @@ export async function runMistral(
if (!stream) {
const request = {
model: config.mistralChatTarget.model,
messages,
messages: requestMessages,
tools: requestTools,
documents: documents
} as Parameters<typeof mistralAi.chat.complete>[0];
const response = await mistralAi.chat.complete(request, {signal});
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal}));
const message = response.choices?.[0]?.message;
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
streamMessage.append(text);
const calls = normalizeMistralToolCalls(mistralToolCalls(message));
const calls = adapter.extractToolCalls(message);
aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", {
round,
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
@@ -125,25 +94,27 @@ export async function runMistral(
function: {name: call.name, arguments: call.argumentsText},
})),
});
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
for (const [index, call] of calls.entries()) {
messages.push({
role: "tool",
name: call.name,
toolCallId: call.id,
content: toolResults[index] ?? "",
requestMessages.push({
role: "assistant",
content: text,
toolCalls: calls.map(call => ({
id: call.id,
function: {name: call.name, arguments: call.argumentsText},
})),
});
}
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
adapter.appendToolResults(messages, calls, toolResults);
adapter.appendToolResults(requestMessages, calls, toolResults);
continue;
}
const request = {
model: config.mistralChatTarget.model,
messages,
messages: requestMessages,
tools: requestTools,
documents: documents
} as Parameters<typeof mistralAi.chat.stream>[0];
const streamResponse = await mistralAi.chat.stream(request, {signal});
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal}));
aiLog("debug", "mistral.stream.open", {round});
let calls: ToolCallData[] = [];
const roundTextStart = streamMessage.getText().length;
@@ -154,11 +125,10 @@ export async function runMistral(
const choice = event.data?.choices?.[0];
const delta = choice?.delta;
const mistralDelta = delta as MistralDeltaLike;
const mistralDelta = delta;
streamMessage.append(adapter.extractTextDelta(mistralDelta));
streamMessage.append(contentFromMistralDelta(mistralDelta));
const rawDeltaCalls = mistralToolCalls(mistralDelta);
const rawDeltaCalls = adapter.extractStreamingToolCalls(mistralDelta);
if (rawDeltaCalls.length) {
calls = toolCallAccumulator.add(rawDeltaCalls);
streamMessage.setStatus(Environment.getUseToolText(calls));
@@ -178,14 +148,16 @@ export async function runMistral(
content: roundText,
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
});
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
for (const [index, call] of calls.entries()) {
messages.push({
role: "tool",
name: call.name,
toolCallId: call.id,
content: toolResults[index] ?? "",
requestMessages.push({
role: "assistant",
content: roundText,
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
});
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
adapter.appendToolResults(messages, calls, toolResults);
adapter.appendToolResults(requestMessages, calls, toolResults);
}
} finally {
await adapter.finalize().catch(() => undefined);
}
}
+24 -66
View File
@@ -5,7 +5,6 @@ import {Environment} from "../common/environment";
import type {BoundaryValue} from "../common/boundary-types";
import {bot, notesDir} from "../index";
import {clamp, logError} from "../util/utils";
import {getOllamaTools} from "./tool-mappers";
import {TelegramStreamMessage} from "./telegram-stream-message";
import {ChatMessage} from "./chat-messages-types";
import {ChatRequest, Tool} from "ollama";
@@ -14,10 +13,11 @@ import {enqueueTelegramApiCall} from "../util/telegram-api-queue";
import {loadOllamaModel, unloadAllOllamaModels} from "./tools/utils";
import {createOllamaClient} from "./ai-runtime-target";
import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
import {getProviderAdapter} from "./provider-adapters";
import {runToolRankStage} from "./tool-rank-stage";
import {
allToolSchemaNames,
appendOllamaToolResults,
dedupeToolCalls,
DEFAULT_OLLAMA_CONTEXT_SIZE,
executeToolBatch,
@@ -26,8 +26,6 @@ import {
MAX_OLLAMA_CONTEXT_SIZE,
MAX_TOOL_ROUNDS,
MIN_OLLAMA_CONTEXT_SIZE,
normalizeOllamaToolCalls,
OllamaToolCallLike,
roundStatus,
RuntimeConfigSnapshot,
safeJsonParseObject,
@@ -35,14 +33,11 @@ import {
ToolCallData,
ToolExecutionMemory
} from "./unified-ai-runner.shared";
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
import {getToolPrompts} from "./tools/registry";
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
import {getModelCapabilities} from "./provider-model-runtime";
import {AiProvider} from "../model/ai-provider";
import {Message} from "typescript-telegram-bot-api";
import {storeToolRankAudit} from "./tool-rank-audit";
export async function runOllama(
msg: Message,
@@ -157,6 +152,7 @@ export async function runOllama(
}
const toolMemory: ToolExecutionMemory = new Map();
const adapter = getProviderAdapter(AiProvider.OLLAMA);
try {
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
@@ -183,7 +179,7 @@ export async function runOllama(
let activeToolNames: string[] = [];
if ((await getModelCapabilities(AiProvider.OLLAMA, model, "tools"))?.tools?.supported) {
const availableOllamaTools: Tool[] = getOllamaTools(msg.from?.id === Environment.CREATOR_ID) as Tool[];
const availableOllamaTools: Tool[] = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID}) as Tool[];
aiLog("debug", "ollama.tools.available", {
round,
@@ -191,44 +187,18 @@ export async function runOllama(
rankerEnabled: !!config.ollamaToolRankerTarget,
});
streamMessage.setStatus(Environment.getSelectingToolsText());
await streamMessage.flush();
const toolRankStartedAt = Date.now();
const toolRankStartedAtIso = new Date().toISOString();
const rankerSelection = await new ToolRanker(config).selectTools({
const rankResult = await runToolRankStage({
provider: AiProvider.OLLAMA,
userQuery: latestUserTextFromMessages(messages),
model,
round,
config,
availableTools: availableOllamaTools,
round,
messages,
streamMessage,
signal,
})
.catch(async error => {
streamMessage.clearStatus();
await streamMessage.flush();
await storeToolRankAudit({
streamMessage,
provider: AiProvider.OLLAMA,
model,
round,
startedAt: toolRankStartedAt,
startedAtIso: toolRankStartedAtIso,
error,
});
throw error;
});
streamMessage.clearStatus();
await streamMessage.flush();
await storeToolRankAudit({
streamMessage,
provider: AiProvider.OLLAMA,
model,
round,
startedAt: toolRankStartedAt,
startedAtIso: toolRankStartedAtIso,
selectedTools: rankerSelection.toolNames,
});
const filteredTools = [...new Set(filterRankedTools(availableOllamaTools, rankerSelection.toolNames))];
const filteredTools = [...new Set(rankResult.filteredTools as Tool[])];
activeToolNames = filteredTools.map(t => t.function.name ?? "");
if (filteredTools.length > 0) {
request.tools = [...filteredTools];
@@ -256,24 +226,21 @@ export async function runOllama(
round,
tools: activeToolNames,
count: activeToolNames.length,
usedRanker: rankerSelection.usedRanker,
usedRanker: rankResult.usedRanker,
});
}
if (!stream) {
const response = await ollama.chat({
const response = await adapter.callModel(request, () => ollama.chat({
...request,
stream: false
});
}));
const message = response.message;
const rawContent = message?.content ?? "";
const nativeCalls = dedupeToolCalls(
normalizeOllamaToolCalls(
message?.tool_calls as readonly OllamaToolCallLike[] | undefined,
round,
),
adapter.extractToolCalls(message),
);
const responseText = rawContent;
@@ -301,7 +268,7 @@ export async function runOllama(
break;
}
const calls = nativeCalls;
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
aiLog("info", "ollama.tool_calls", {
round,
@@ -319,11 +286,7 @@ export async function runOllama(
})),
});
appendOllamaToolResults(
messages,
calls,
await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory),
);
adapter.appendToolResults(messages, calls, await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory));
continue;
}
@@ -332,10 +295,10 @@ export async function runOllama(
round,
messageCount: request.messages?.length ?? 0,
});
const response = await ollama.chat({
const response = await adapter.callModel(request, () => ollama.chat({
...request,
stream: true
});
}));
aiLog("debug", "ollama.stream.open", {round});
const calls: ToolCallData[] = [];
@@ -354,10 +317,7 @@ export async function runOllama(
const localToolCalls: ToolCallData[] = [];
localToolCalls.push(...normalizeOllamaToolCalls(
chunk.message.tool_calls as readonly OllamaToolCallLike[] | undefined,
round,
));
localToolCalls.push(...adapter.extractStreamingToolCalls(chunk.message));
const newStatus = roundStatus(round, firstRoundStatus, chunk.message.content, localToolCalls, !!chunk.message.thinking);
const previousStatus = streamMessage.getStatus();
@@ -377,13 +337,10 @@ export async function runOllama(
}
if (!(chunk.message?.thinking && streamMessage.getStatus() !== Environment.reasoningText)) {
streamMessage.append(chunk.message?.content ?? "");
streamMessage.append(adapter.extractTextDelta(chunk));
}
calls.push(...normalizeOllamaToolCalls(
chunk.message?.tool_calls as readonly OllamaToolCallLike[] | undefined,
round,
));
calls.push(...adapter.extractStreamingToolCalls(chunk.message));
if (chunk.done) {
aiLog("debug", "ollama.stream.done", {
@@ -471,9 +428,10 @@ export async function runOllama(
}).catch(logError);
}
appendOllamaToolResults(messages, calls, toolResults);
adapter.appendToolResults(messages, calls, toolResults);
}
} finally {
if (interval) clearInterval(interval);
await adapter.finalize().catch(() => undefined);
}
}
+34 -72
View File
@@ -17,11 +17,9 @@ import {
AsyncIterableStream,
buildSystemInstruction,
collectOpenAiResponseCodeInterpreterCalls,
collectOpenAiResponseFunctionCalls,
collectOpenAiResponseImages,
collectOpenAiResponseText,
executeToolBatch,
getOpenAIResponsesToolsWithImage,
MAX_TOOL_ROUNDS,
OPENAI_IMAGE_PARTIALS,
openAiResponseItemCallId,
@@ -42,10 +40,9 @@ import {logError} from "../util/utils";
import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files";
import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings";
import {AiDownloadedFile} from "./telegram-attachments";
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
import {AiProvider} from "../model/ai-provider";
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
import {storeToolRankAudit} from "./tool-rank-audit";
import {getProviderAdapter} from "./provider-adapters";
import {runToolRankStage} from "./tool-rank-stage";
export async function runOpenAi(
msg: Message,
@@ -60,16 +57,15 @@ export async function runOpenAi(
documentRag?: OpenAiDocumentRagContext,
): Promise<void> {
const runnerStartedAt = Date.now();
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = [...messages] as Array<ResponseInputItem | OpenAiResponseOutputItem>;
const openAi = createOpenAiClient(config.openAiChatTarget);
const ownsDocumentRag = !documentRag;
const preparedDocumentRag = documentRag ?? await prepareOpenAiDocumentRag(openAi, downloads.filter(download => download.kind === "document"));
const toolRanker = new ToolRanker(config);
const availableTools = getOpenAIResponsesToolsWithImage(
config,
msg.from?.id === Environment.CREATOR_ID,
preparedDocumentRag?.vectorStoreIds ?? [],
);
const adapter = getProviderAdapter(AiProvider.OPENAI);
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = adapter.mapMessages(messages) as unknown as Array<ResponseInputItem | OpenAiResponseOutputItem>;
const availableTools = adapter.rankTools(config, {
forCreator: msg.from?.id === Environment.CREATOR_ID,
vectorStoreIds: preparedDocumentRag?.vectorStoreIds ?? [],
});
const systemPrompt = buildSystemInstruction(
config,
@@ -93,43 +89,17 @@ export async function runOpenAi(
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
const roundStartedAt = Date.now();
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
streamMessage.setStatus(Environment.getSelectingToolsText());
await streamMessage.flush();
const toolRankStartedAt = Date.now();
const toolRankStartedAtIso = new Date().toISOString();
const rankerSelection = await toolRanker.selectTools({
const rankResult = await runToolRankStage({
provider: AiProvider.OPENAI,
userQuery: latestUserTextFromMessages(messages),
model: config.openAiChatTarget.model,
round,
config,
availableTools,
round,
messages,
streamMessage,
signal,
})
.catch(async error => {
streamMessage.clearStatus();
await streamMessage.flush();
await storeToolRankAudit({
streamMessage,
provider: AiProvider.OPENAI,
model: config.openAiChatTarget.model,
round,
startedAt: toolRankStartedAt,
startedAtIso: toolRankStartedAtIso,
error,
});
throw error;
});
streamMessage.clearStatus();
await streamMessage.flush();
await storeToolRankAudit({
streamMessage,
provider: AiProvider.OPENAI,
model: config.openAiChatTarget.model,
round,
startedAt: toolRankStartedAt,
startedAtIso: toolRankStartedAtIso,
selectedTools: rankerSelection.toolNames,
});
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
const filteredTools = rankResult.filteredTools;
const requestTools = preparedDocumentRag?.vectorStoreIds.length
? (() => {
const tools = [...filteredTools];
@@ -151,7 +121,7 @@ export async function runOpenAi(
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
instructions: systemPrompt,
};
const response = await openAi.responses.create(request, {signal}) as OpenAiResponseLike;
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
const responseText = collectOpenAiResponseText(response);
streamMessage.append(responseText);
@@ -188,12 +158,12 @@ export async function runOpenAi(
});
}
const calls = collectOpenAiResponseFunctionCalls(response);
const calls = adapter.extractToolCalls(response);
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
round,
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
calls: calls.map(call => ({
id: call.callId,
id: call.id,
name: call.name,
arguments: safeJsonParseObject(call.argumentsText)
})),
@@ -201,16 +171,13 @@ export async function runOpenAi(
if (!calls.length) return;
const toolCalls = calls.map(call => ({
id: call.callId,
id: call.id,
name: call.name,
argumentsText: call.argumentsText,
}));
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
const toolOutputs = calls.map((call, index) => ({
type: "function_call_output" as const,
call_id: call.callId,
output: toolResults[index] ?? "",
}));
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
adapter.appendToolResults(toolOutputs, calls, toolResults);
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
if (uploadFilesResult.found) {
@@ -243,7 +210,7 @@ export async function runOpenAi(
parallel_tool_calls: true,
instructions: systemPrompt
};
const response = await openAi.responses.create(request, {signal}) as AsyncIterableStream<ResponseStreamEvent>;
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
aiLog("debug", "openai.stream.open", {round});
@@ -253,7 +220,7 @@ export async function runOpenAi(
switch (event.type) {
case "response.output_text.delta":
streamMessage.append(event.delta ?? "");
streamMessage.append(adapter.extractTextDelta(event));
break;
case "response.image_generation_call.in_progress":
streamMessage.setStatus(Environment.startingImageGenText);
@@ -301,14 +268,11 @@ export async function runOpenAi(
case "response.code_interpreter_call_code.done":
break;
case "response.output_item.added":
if (event.item.type === "function_call" && event.item.name) {
const item = event.item as OpenAiResponseOutputItem & { id?: string };
localToolCalls.push({
id: openAiResponseItemCallId(item),
name: item.name ?? "",
argumentsText: item.arguments ?? "{}",
});
{
const streamedCalls = adapter.extractStreamingToolCalls(event);
if (streamedCalls.length) {
localToolCalls.push(...streamedCalls);
}
aiLog("info", "openai.stream.tool_call.added", {
round,
toolCalls: localToolCalls.map(aiLogToolCall)
@@ -383,12 +347,12 @@ export async function runOpenAi(
});
}
const calls = collectOpenAiResponseFunctionCalls(completedResponse);
const calls = adapter.extractToolCalls(completedResponse);
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
round,
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
calls: calls.map(call => ({
id: call.callId,
id: call.id,
name: call.name,
arguments: safeJsonParseObject(call.argumentsText)
})),
@@ -396,16 +360,13 @@ export async function runOpenAi(
if (!calls.length) return;
const toolCalls = calls.map(call => ({
id: call.callId,
id: call.id,
name: call.name,
argumentsText: call.argumentsText,
}));
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
const toolOutputs = calls.map((call, index) => ({
type: "function_call_output",
call_id: call.callId,
output: toolResults[index] ?? "",
}));
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
adapter.appendToolResults(toolOutputs, calls, toolResults);
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
if (uploadFilesResult.found) {
@@ -431,6 +392,7 @@ export async function runOpenAi(
if (ownsDocumentRag) {
await preparedDocumentRag?.cleanup().catch(logError);
}
await adapter.finalize().catch(logError);
}
}
+27 -27
View File
@@ -2,40 +2,40 @@ import {Message} from "typescript-telegram-bot-api";
import * as fs from "node:fs";
import path from "node:path";
import type {BoundaryValue} from "../common/boundary-types";
import {AiProvider} from "../model/ai-provider";
import {ToolRankerFallbackPolicy} from "../common/policies";
import {Environment} from "../common/environment";
import {photoGenDir} from "../index";
import {delay, logError, replyToMessage} from "../util/utils";
import {MessageStore} from "../common/message-store";
import type {OpenAiResponseTool} from "./tool-mappers";
import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers";
import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message";
import {AiDownloadedFile} from "./telegram-attachments";
import {getRuntimeCapabilities} from "./provider-model-runtime";
import {StoredAttachment} from "../model/stored-attachment";
import {AiChatMessage, ChatMessage} from "./chat-messages-types";
import {AiProvider} from "../model/ai-provider.js";
import {ToolRankerFallbackPolicy} from "../common/policies.js";
import {Environment} from "../common/environment.js";
import {photoGenDir} from "../index.js";
import {delay, logError, replyToMessage} from "../util/utils.js";
import {MessageStore} from "../common/message-store.js";
import type {OpenAiResponseTool} from "./tool-mappers.js";
import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers.js";
import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message.js";
import {AiDownloadedFile} from "./telegram-attachments.js";
import {getRuntimeCapabilities} from "./provider-model-runtime.js";
import {StoredAttachment} from "../model/stored-attachment.js";
import {AiChatMessage, ChatMessage} from "./chat-messages-types.js";
import {ListResponse, Ollama} from "ollama";
import {executeToolCall, ToolRuntimeContext} from "./tools/runtime";
import {MessageImagePart, MessagePart} from "../common/message-part";
import {KeyedAsyncLock} from "../util/async-lock";
import {type AiRequestQueueTarget} from "./provider-request-queue";
import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator";
import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings";
import {executeToolCall, ToolRuntimeContext} from "./tools/runtime.js";
import {MessageImagePart, MessagePart} from "../common/message-part.js";
import {KeyedAsyncLock} from "../util/async-lock.js";
import {type AiRequestQueueTarget} from "./provider-request-queue.js";
import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator.js";
import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings.js";
import {
isTranscribableAudioDownload,
resolveSpeechToTextProviderForUser,
transcribeSpeechDownloads
} from "./speech-to-text";
} from "./speech-to-text.js";
import type {ChatCompletionMessageParam} from "openai/resources/chat/completions";
import {MistralChatMessage} from "./mistral-chat-message";
import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer";
import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target";
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline";
import {MistralChatMessage} from "./mistral-chat-message.js";
import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer.js";
import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target.js";
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger.js";
import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline.js";
import type {ResponseInputMessageContentList} from "openai/resources/responses/responses";
import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store";
import {filterUserVisibleStoredAttachments} from "../common/stored-attachment-utils";
import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store.js";
import {filterUserVisibleStoredAttachments} from "../common/attachment-visibility.js";
export type {Message} from "typescript-telegram-bot-api";
export type {AiRuntimeTarget} from "./ai-runtime-target";
+5
View File
@@ -0,0 +1,5 @@
import type {StoredAttachment} from "../model/stored-attachment";
export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] {
return attachments.filter(attachment => attachment.scope !== "internal_artifact");
}
+10 -10
View File
@@ -3,18 +3,18 @@ import os from "node:os";
import path from "node:path";
import {parse as parseDotEnv} from "dotenv";
import {z} from "zod";
import {appLogger} from "../logging/logger";
import {appLogger} from "../logging/logger.js";
import type {BoundaryValue, ErrorLike} from "./boundary-types";
import {saveData} from "../db/database";
import {Answers} from "../model/answers";
import {ifTrue} from "../util/utils";
import {AiProvider} from "../model/ai-provider";
import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies";
import {ToolRankerFallbackPolicy} from "./policies";
import type {ToolCallData} from "../ai/unified-ai-runner";
import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator";
import {Localization, type LocalizationParams} from "./localization";
import {saveData} from "../db/database.js";
import {Answers} from "../model/answers.js";
import {ifTrue} from "../util/utils.js";
import {AiProvider} from "../model/ai-provider.js";
import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies.js";
import {ToolRankerFallbackPolicy} from "./policies.js";
import type {ToolCallData} from "../ai/unified-ai-runner.js";
import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator.js";
import {Localization, type LocalizationParams} from "./localization.js";
type EnvRecord = Record<string, string>;
type StringEnumLike = Record<string, string>;
+1 -4
View File
@@ -1,6 +1,7 @@
import path from "node:path";
import {Environment} from "./environment";
import {StoredAttachment} from "../model/stored-attachment";
export {filterUserVisibleStoredAttachments} from "./attachment-visibility";
export function photoCachePathForUniqueId(uniqueId: string): string {
return path.join(Environment.DATA_PATH, "cache", "photo", `${uniqueId}.jpg`);
@@ -44,7 +45,3 @@ export function uniqueStoredAttachments(attachments: StoredAttachment[]): Stored
return result;
}
export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] {
return attachments.filter(attachment => attachment.scope !== "internal_artifact");
}
+4 -4
View File
@@ -1,9 +1,9 @@
import * as fs from "fs";
import {Environment} from "../common/environment";
import {logError} from "../util/utils";
import {Answers} from "../model/answers";
import {Environment} from "../common/environment.js";
import {logError} from "../util/utils.js";
import {Answers} from "../model/answers.js";
import path from "node:path";
import {KeyedAsyncLock} from "../util/async-lock";
import {KeyedAsyncLock} from "../util/async-lock.js";
type DataJsonFile = {
admins: number[]
+83
View File
@@ -0,0 +1,83 @@
import test from "node:test";
import assert from "node:assert/strict";
const {
extractOpenAiToolCalls,
extractOpenAiStreamingToolCalls,
extractOpenAiTextDelta,
extractMistralToolCalls,
extractMistralTextDelta,
extractOllamaToolCalls,
extractOllamaTextDelta,
} = await import("../dist/ai/provider-adapter-contract.js");
test("openai contract extracts text delta and function calls", () => {
assert.equal(extractOpenAiTextDelta({type: "response.output_text.delta", delta: "hello"}), "hello");
const calls = extractOpenAiToolCalls({
output: [{
type: "function_call",
call_id: "call-1",
name: "read_file",
arguments: "{\"path\":\"src/index.ts\"}",
}],
});
assert.equal(calls.length, 1);
assert.equal(calls[0].id, "call-1");
assert.equal(calls[0].name, "read_file");
const streamed = extractOpenAiStreamingToolCalls({
type: "response.output_item.added",
item: {
type: "function_call",
id: "call-2",
name: "search_files",
arguments: "{\"query\":\"sendMessage\"}",
},
});
assert.equal(streamed.length, 1);
assert.equal(streamed[0].id, "call-2");
assert.equal(streamed[0].name, "search_files");
});
test("mistral contract extracts content and tool calls", () => {
assert.equal(extractMistralTextDelta({
content: [{text: "hello"}, {text: " world"}],
}), "hello world");
const calls = extractMistralToolCalls({
toolCalls: [{
id: "m-1",
function: {
name: "get_weather",
arguments: {location: "Moscow"},
},
}],
});
assert.equal(calls.length, 1);
assert.equal(calls[0].id, "m-1");
assert.equal(calls[0].name, "get_weather");
});
test("ollama contract extracts content and tool calls", () => {
assert.equal(extractOllamaTextDelta({
message: {content: "hello from ollama"},
}), "hello from ollama");
const calls = extractOllamaToolCalls({
tool_calls: [{
id: "o-1",
function: {
name: "web_search",
arguments: {query: "openai docs"},
},
}],
});
assert.equal(calls.length, 1);
assert.equal(calls[0].id, "o-1");
assert.equal(calls[0].name, "web_search");
});
+27 -94
View File
@@ -1,32 +1,13 @@
import test, {after} from "node:test";
import test from "node:test";
import assert from "node:assert/strict";
import fs from "node:fs";
import os from "node:os";
import path from "node:path";
const tempRoot = fs.mkdtempSync(path.join(os.tmpdir(), "tg-chat-bot-rag-"));
process.env.BOT_TOKEN = process.env.BOT_TOKEN ?? "test-token";
process.env.CREATOR_ID = process.env.CREATOR_ID ?? "1";
process.env.DATA_PATH = tempRoot;
process.env.DB_PATH = `file:${path.join(tempRoot, "test.sqlite")}`;
process.env.TEST_ENVIRONMENT = "true";
const {Environment} = await import("../dist/common/environment.js");
Environment.load();
const {DatabaseManager} = await import("../dist/db/database-manager.js");
DatabaseManager.init();
await DatabaseManager.ready;
const {ArtifactStore} = await import("../dist/common/artifact-store.js");
const {filterUserVisibleStoredAttachments} = await import("../dist/common/stored-attachment-utils.js");
const {
buildRagArtifactPayload,
} = await import("../dist/ai/rag-artifact-payload.js");
const {
filterUserVisibleStoredAttachments,
} = await import("../dist/common/attachment-visibility.js");
const {AiProvider} = await import("../dist/model/ai-provider.js");
const {persistRagArtifactAttachment} = await import("../dist/ai/rag-artifact-store.js");
after(async () => {
await DatabaseManager.close().catch(() => undefined);
fs.rmSync(tempRoot, {recursive: true, force: true});
});
test("internal artifacts are not treated as user-visible attachments", () => {
const visible = filterUserVisibleStoredAttachments([
@@ -50,65 +31,26 @@ test("internal artifacts are not treated as user-visible attachments", () => {
assert.equal(visible[0].fileId, "visible");
});
test("RAG artifacts persist structured ollama metadata", async () => {
const chatId = 42;
const messageId = 7;
const attachment = await persistRagArtifactAttachment({
test("RAG artifact payload keeps ollama retrieval metadata", () => {
const payload = buildRagArtifactPayload({
provider: AiProvider.OLLAMA,
prepared: {
provider: AiProvider.OLLAMA,
prepared: true,
cleanup: async () => undefined,
artifact: {
query: "What is in the file?",
extractedDocuments: [
{documentIndex: 0, fileName: "report.txt", textChars: 120},
],
selectedChunks: [
{
sourceId: "doc1-1",
documentIndex: 0,
documentName: "report.txt",
chunkIndex: 0,
chunkCount: 1,
textChars: 120,
score: 0.91,
},
],
skippedDocuments: [
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
],
providerState: {
embeddingModel: "nomic-embed-text:latest",
topK: 8,
chunkSize: 1400,
chunkOverlap: 220,
maxContextChars: 14000,
minScore: 0.12,
maxArchiveFiles: 200,
maxArchiveBytes: 50 * 1024 * 1024,
maxArchiveDepth: 2,
},
},
},
downloads: [{
kind: "document",
createdAt: "2026-01-01T00:00:00.000Z",
sources: [{
fileId: "file-1",
fileName: "report.txt",
buffer: Buffer.from("hello world"),
path: path.join(tempRoot, "report.txt"),
mimeType: "text/plain",
sizeBytes: 12,
sha256: "abc123",
uploadedFileId: "uploaded-1",
}],
chatId,
messageId,
details: {
providerState: {
provider: AiProvider.OLLAMA,
prepared: true,
embeddingModel: "nomic-embed-text:latest",
topK: 8,
chunkSize: 1400,
chunkOverlap: 220,
maxContextChars: 14000,
artifact: {
query: "What is in the file?",
extractedDocuments: [
{documentIndex: 0, fileName: "report.txt", textChars: 120},
],
@@ -126,29 +68,20 @@ test("RAG artifacts persist structured ollama metadata", async () => {
skippedDocuments: [
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
],
providerState: {
embeddingModel: "nomic-embed-text:latest",
topK: 8,
chunkSize: 1400,
chunkOverlap: 220,
maxContextChars: 14000,
minScore: 0.12,
maxArchiveFiles: 200,
maxArchiveBytes: 50 * 1024 * 1024,
maxArchiveDepth: 2,
},
},
query: "What is in the file?",
},
});
assert.equal(attachment?.artifactKind, "rag");
assert.equal(fs.existsSync(attachment.cachePath), true);
const stored = await ArtifactStore.getByMessage(chatId, messageId);
assert.equal(stored.length, 1);
assert.equal(stored[0].kind, "rag");
assert.equal(stored[0].payload.providerState.query, "What is in the file?");
assert.equal(stored[0].payload.providerState.selectedChunks[0].score, 0.91);
assert.equal(stored[0].payload.providerState.skippedDocuments[0].reason, "unsupported format");
assert.equal(stored[0].payload.providerState.ollama.embeddingModel, "nomic-embed-text:latest");
assert.equal(payload.artifactKind, "rag");
assert.equal(payload.provider, AiProvider.OLLAMA);
assert.equal(payload.sources[0].uploadedFileId, "uploaded-1");
assert.equal(payload.providerState.provider, AiProvider.OLLAMA);
assert.equal(payload.providerState.query, "What is in the file?");
assert.equal(payload.providerState.selectedChunks[0].score, 0.91);
assert.equal(payload.providerState.skippedDocuments[0].reason, "unsupported format");
assert.equal(payload.providerState.embeddingModel, "nomic-embed-text:latest");
});