Add unified request pipeline stages
This commit is contained in:
+15
-15
@@ -46,27 +46,27 @@
|
||||
- [x] Для Mistral сохранять `libraryId`.
|
||||
- [x] Для Mistral сохранять uploaded document ids.
|
||||
- [x] Для Mistral сохранять source file mapping: local attachment -> Mistral document id.
|
||||
- [ ] Добавить единый `providerState` schema для всех providers.
|
||||
- [ ] Добавить tests на сериализацию `RagArtifact`.
|
||||
- [ ] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context.
|
||||
- [x] Добавить единый `providerState` schema для всех providers.
|
||||
- [x] Добавить tests на сериализацию `RagArtifact`.
|
||||
- [x] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context.
|
||||
|
||||
## 4. Вынести provider runners в adapter layer
|
||||
|
||||
- [ ] Ввести интерфейс `AiProviderAdapter`.
|
||||
- [ ] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`.
|
||||
- [ ] Реализовать `OpenAiProviderAdapter`.
|
||||
- [ ] Реализовать `MistralProviderAdapter`.
|
||||
- [ ] Реализовать `OllamaProviderAdapter`.
|
||||
- [ ] Перенести provider-specific tool schema mapping внутрь adapter-ов.
|
||||
- [ ] Перенести provider-specific streaming parsing внутрь adapter-ов.
|
||||
- [ ] Перенести provider-specific tool result append внутрь adapter-ов.
|
||||
- [ ] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом.
|
||||
- [ ] Оставить compatibility wrappers для текущих imports.
|
||||
- [ ] Добавить tests на adapter contract без реальных API.
|
||||
- [x] Ввести интерфейс `AiProviderAdapter`.
|
||||
- [x] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`.
|
||||
- [x] Реализовать `OpenAiProviderAdapter`.
|
||||
- [x] Реализовать `MistralProviderAdapter`.
|
||||
- [x] Реализовать `OllamaProviderAdapter`.
|
||||
- [x] Перенести provider-specific tool schema mapping внутрь adapter-ов.
|
||||
- [x] Перенести provider-specific streaming parsing внутрь adapter-ов.
|
||||
- [x] Перенести provider-specific tool result append внутрь adapter-ов.
|
||||
- [x] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом.
|
||||
- [x] Оставить compatibility wrappers для текущих imports.
|
||||
- [x] Добавить tests на adapter contract без реальных API.
|
||||
|
||||
## 5. Сделать tool-ranker полноценным pipeline stage
|
||||
|
||||
- [ ] Вынести вызов `ToolRanker.selectTools(...)` из provider runners.
|
||||
- [x] Вынести вызов `ToolRanker.selectTools(...)` из provider runners.
|
||||
- [ ] Добавить stage `tool_rank`, который работает через provider adapter.
|
||||
- [ ] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker.
|
||||
- [ ] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`.
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import type {ToolCallData} from "./unified-ai-runner.shared.js";
|
||||
import type {ResponseStreamEvent} from "openai/resources/responses/responses";
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return !!value && typeof value === "object" && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function normalizeToolCallId(value: unknown, fallback: string): string {
|
||||
return typeof value === "string" && value.trim().length > 0 ? value : fallback;
|
||||
}
|
||||
|
||||
function normalizeToolArguments(value: unknown): string {
|
||||
if (typeof value === "string") return value;
|
||||
return JSON.stringify(value ?? {});
|
||||
}
|
||||
|
||||
export function extractOpenAiToolCalls(response: unknown): ToolCallData[] {
|
||||
const output = isRecord(response) && Array.isArray(response.output) ? response.output : [];
|
||||
|
||||
return output
|
||||
.filter(item => isRecord(item) && item.type === "function_call" && (typeof item.call_id === "string" || typeof item.name === "string"))
|
||||
.map((item, index) => ({
|
||||
id: normalizeToolCallId(item.call_id, `openai_${index}`),
|
||||
name: typeof item.name === "string" ? item.name : "",
|
||||
argumentsText: normalizeToolArguments(item.arguments),
|
||||
}))
|
||||
.filter(call => call.name.length > 0);
|
||||
}
|
||||
|
||||
export function extractOpenAiTextDelta(input: unknown): string {
|
||||
const event = input as ResponseStreamEvent | undefined;
|
||||
return event?.type === "response.output_text.delta" ? event.delta ?? "" : "";
|
||||
}
|
||||
|
||||
export function extractOpenAiStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||
const event = input as ResponseStreamEvent | undefined;
|
||||
if (event?.type === "response.output_item.added" && isRecord(event.item) && event.item.type === "function_call") {
|
||||
return extractOpenAiToolCalls({
|
||||
output: [{
|
||||
type: "function_call",
|
||||
call_id: event.item.call_id ?? event.item.id,
|
||||
name: event.item.name,
|
||||
arguments: event.item.arguments,
|
||||
}],
|
||||
});
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
export function extractMistralToolCalls(calls: unknown): ToolCallData[] {
|
||||
const normalized = Array.isArray(calls)
|
||||
? calls
|
||||
: isRecord(calls) && (Array.isArray(calls.toolCalls) || Array.isArray(calls.tool_calls))
|
||||
? (calls.toolCalls ?? calls.tool_calls)
|
||||
: [];
|
||||
|
||||
if (!Array.isArray(normalized)) return [];
|
||||
|
||||
return normalized
|
||||
.map((item, index) => {
|
||||
const call = isRecord(item) ? item : {};
|
||||
const fn = isRecord(call.function) ? call.function : undefined;
|
||||
const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : "";
|
||||
return {
|
||||
id: normalizeToolCallId(call.id, `mistral_${index}`),
|
||||
name,
|
||||
argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments),
|
||||
};
|
||||
})
|
||||
.filter(call => call.name.length > 0);
|
||||
}
|
||||
|
||||
export function extractMistralTextDelta(input: unknown): string {
|
||||
const delta = isRecord(input) ? input : {};
|
||||
const content = delta.content;
|
||||
if (typeof content === "string") return content;
|
||||
if (Array.isArray(content)) {
|
||||
return content
|
||||
.map(part => isRecord(part) && typeof part.text === "string" ? part.text : "")
|
||||
.join("");
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
export function extractOllamaToolCalls(calls: unknown): ToolCallData[] {
|
||||
const normalized = Array.isArray(calls)
|
||||
? calls
|
||||
: isRecord(calls) && Array.isArray(calls.tool_calls)
|
||||
? calls.tool_calls
|
||||
: [];
|
||||
|
||||
if (!Array.isArray(normalized)) return [];
|
||||
|
||||
return normalized
|
||||
.map((item, index) => {
|
||||
const call = isRecord(item) ? item : {};
|
||||
const fn = isRecord(call.function) ? call.function : undefined;
|
||||
const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : "";
|
||||
return {
|
||||
id: normalizeToolCallId(call.id, `ollama_${index}`),
|
||||
name,
|
||||
argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments),
|
||||
};
|
||||
})
|
||||
.filter(call => call.name.length > 0);
|
||||
}
|
||||
|
||||
export function extractOllamaTextDelta(input: unknown): string {
|
||||
const chunk = isRecord(input) ? input.message : undefined;
|
||||
return isRecord(chunk) && typeof chunk.content === "string" ? chunk.content : "";
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
import {AiProvider} from "../model/ai-provider.js";
|
||||
import type {BoundaryValue} from "../common/boundary-types.js";
|
||||
import type {RuntimeConfigSnapshot, ToolCallData} from "./unified-ai-runner.shared.js";
|
||||
import {getMistralTools, getOllamaTools, getOpenAIResponsesTools, getOpenAICodeInterpreterTool} from "./tool-mappers.js";
|
||||
import type {MistralChatMessage as MistralMessageType} from "./mistral-chat-message.js";
|
||||
import type {OpenAIChatMessage as OpenAiMessageType} from "./openai-chat-message.js";
|
||||
import type {Message as OllamaMessage} from "ollama";
|
||||
import {
|
||||
extractMistralTextDelta,
|
||||
extractMistralToolCalls,
|
||||
extractOllamaTextDelta,
|
||||
extractOllamaToolCalls,
|
||||
extractOpenAiTextDelta,
|
||||
extractOpenAiStreamingToolCalls,
|
||||
extractOpenAiToolCalls,
|
||||
} from "./provider-adapter-contract.js";
|
||||
|
||||
export type ProviderRankToolOptions = {
|
||||
forCreator?: boolean;
|
||||
vectorStoreIds?: string[];
|
||||
};
|
||||
|
||||
export interface AiProviderAdapter {
|
||||
readonly provider: AiProvider;
|
||||
mapMessages(messages: readonly unknown[]): unknown[];
|
||||
rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[];
|
||||
callModel<T>(request: unknown, execute: () => Promise<T>): Promise<T>;
|
||||
extractTextDelta(input: unknown): string;
|
||||
extractToolCalls(input: unknown): ToolCallData[];
|
||||
extractStreamingToolCalls(input: unknown): ToolCallData[];
|
||||
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void;
|
||||
finalize(): Promise<void>;
|
||||
}
|
||||
|
||||
function appendOllamaToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||
for (const [index, call] of calls.entries()) {
|
||||
messages.push({
|
||||
role: "tool",
|
||||
content: results[index] ?? "",
|
||||
tool_name: call.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
class OpenAiProviderAdapter implements AiProviderAdapter {
|
||||
readonly provider = AiProvider.OPENAI;
|
||||
|
||||
mapMessages(messages: readonly unknown[]): unknown[] {
|
||||
return messages as OpenAiMessageType[];
|
||||
}
|
||||
|
||||
rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
|
||||
const tools: BoundaryValue[] = [
|
||||
...getOpenAIResponsesTools(options?.forCreator) as BoundaryValue[],
|
||||
getOpenAICodeInterpreterTool() as BoundaryValue,
|
||||
{
|
||||
type: "image_generation",
|
||||
model: config.openAiImageTarget.model,
|
||||
size: "auto",
|
||||
moderation: "low",
|
||||
output_format: "png",
|
||||
partial_images: 3,
|
||||
},
|
||||
{type: "web_search"},
|
||||
];
|
||||
|
||||
if (options?.vectorStoreIds?.length) {
|
||||
tools.unshift({
|
||||
type: "file_search",
|
||||
vector_store_ids: options.vectorStoreIds,
|
||||
});
|
||||
}
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
|
||||
return execute();
|
||||
}
|
||||
|
||||
extractTextDelta(input: unknown): string {
|
||||
return extractOpenAiTextDelta(input);
|
||||
}
|
||||
|
||||
extractToolCalls(input: unknown): ToolCallData[] {
|
||||
return extractOpenAiToolCalls(input);
|
||||
}
|
||||
|
||||
extractStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||
return extractOpenAiStreamingToolCalls(input);
|
||||
}
|
||||
|
||||
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||
for (const [index, call] of calls.entries()) {
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: call.id,
|
||||
output: results[index] ?? "",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async finalize(): Promise<void> {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
class MistralProviderAdapter implements AiProviderAdapter {
|
||||
readonly provider = AiProvider.MISTRAL;
|
||||
|
||||
mapMessages(messages: readonly unknown[]): unknown[] {
|
||||
return messages as MistralMessageType[];
|
||||
}
|
||||
|
||||
rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
|
||||
return getMistralTools(options?.forCreator) as BoundaryValue[];
|
||||
}
|
||||
|
||||
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
|
||||
return execute();
|
||||
}
|
||||
|
||||
extractTextDelta(input: unknown): string {
|
||||
return extractMistralTextDelta(input);
|
||||
}
|
||||
|
||||
extractToolCalls(input: unknown): ToolCallData[] {
|
||||
return extractMistralToolCalls(input);
|
||||
}
|
||||
|
||||
extractStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||
return this.extractToolCalls(input);
|
||||
}
|
||||
|
||||
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||
for (const [index, call] of calls.entries()) {
|
||||
messages.push({
|
||||
role: "tool",
|
||||
name: call.name,
|
||||
toolCallId: call.id,
|
||||
content: results[index] ?? "",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async finalize(): Promise<void> {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
class OllamaProviderAdapter implements AiProviderAdapter {
|
||||
readonly provider = AiProvider.OLLAMA;
|
||||
|
||||
mapMessages(messages: readonly unknown[]): unknown[] {
|
||||
return messages as OllamaMessage[];
|
||||
}
|
||||
|
||||
rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
|
||||
return getOllamaTools(options?.forCreator) as BoundaryValue[];
|
||||
}
|
||||
|
||||
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
|
||||
return execute();
|
||||
}
|
||||
|
||||
extractTextDelta(input: unknown): string {
|
||||
return extractOllamaTextDelta(input);
|
||||
}
|
||||
|
||||
extractToolCalls(input: unknown): ToolCallData[] {
|
||||
return extractOllamaToolCalls(input);
|
||||
}
|
||||
|
||||
extractStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||
return this.extractToolCalls(input);
|
||||
}
|
||||
|
||||
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||
appendOllamaToolResults(messages, calls, results);
|
||||
}
|
||||
|
||||
async finalize(): Promise<void> {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
export function getProviderAdapter(provider: AiProvider): AiProviderAdapter {
|
||||
switch (provider) {
|
||||
case AiProvider.OPENAI:
|
||||
return new OpenAiProviderAdapter();
|
||||
case AiProvider.MISTRAL:
|
||||
return new MistralProviderAdapter();
|
||||
case AiProvider.OLLAMA:
|
||||
return new OllamaProviderAdapter();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
import type {AiProvider} from "../model/ai-provider";
|
||||
|
||||
export type RagArtifactSource = {
|
||||
fileId: string;
|
||||
fileName: string;
|
||||
mimeType?: string;
|
||||
sizeBytes?: number;
|
||||
sha256?: string;
|
||||
uploadedFileId?: string;
|
||||
documentId?: string;
|
||||
};
|
||||
|
||||
export type RagArtifactPayload = {
|
||||
artifactKind: "rag";
|
||||
provider: AiProvider;
|
||||
createdAt: string;
|
||||
sources: RagArtifactSource[];
|
||||
providerState:
|
||||
| {
|
||||
provider: AiProvider.OPENAI;
|
||||
vectorStoreIds: string[];
|
||||
uploadedFileIds: string[];
|
||||
}
|
||||
| {
|
||||
provider: AiProvider.MISTRAL;
|
||||
libraryId?: string;
|
||||
documentCount: number;
|
||||
}
|
||||
| {
|
||||
provider: AiProvider.OLLAMA;
|
||||
prepared: boolean;
|
||||
embeddingModel?: string;
|
||||
topK?: number;
|
||||
chunkSize?: number;
|
||||
chunkOverlap?: number;
|
||||
maxContextChars?: number;
|
||||
extractedDocuments: Array<{
|
||||
documentIndex: number;
|
||||
fileName: string;
|
||||
textChars: number;
|
||||
}>;
|
||||
selectedChunks: Array<{
|
||||
sourceId: string;
|
||||
documentIndex: number;
|
||||
documentName: string;
|
||||
chunkIndex: number;
|
||||
chunkCount: number;
|
||||
textChars: number;
|
||||
score?: number;
|
||||
}>;
|
||||
skippedDocuments: Array<{
|
||||
documentIndex: number;
|
||||
fileName: string;
|
||||
reason: string;
|
||||
}>;
|
||||
query: string;
|
||||
minScore: number;
|
||||
maxArchiveFiles: number;
|
||||
maxArchiveBytes: number;
|
||||
maxArchiveDepth: number;
|
||||
};
|
||||
};
|
||||
|
||||
export function buildRagArtifactPayload(params: {
|
||||
provider: AiProvider;
|
||||
createdAt?: string;
|
||||
sources: RagArtifactSource[];
|
||||
providerState: RagArtifactPayload["providerState"];
|
||||
}): RagArtifactPayload {
|
||||
return {
|
||||
artifactKind: "rag",
|
||||
provider: params.provider,
|
||||
createdAt: params.createdAt ?? new Date().toISOString(),
|
||||
sources: params.sources,
|
||||
providerState: params.providerState,
|
||||
};
|
||||
}
|
||||
@@ -4,75 +4,39 @@ import type {AiDownloadedFile} from "./telegram-attachments";
|
||||
import type {PreparedDocumentRag} from "./document-rag-pipeline";
|
||||
import type {OllamaRagArtifactDetails} from "./ollama-rag";
|
||||
import {persistInternalJsonArtifactAttachment} from "./internal-artifact-store";
|
||||
|
||||
type RagArtifactPayload = {
|
||||
artifactKind: "rag";
|
||||
provider: AiProvider;
|
||||
createdAt: string;
|
||||
sources: Array<{
|
||||
fileId: string;
|
||||
fileName: string;
|
||||
mimeType?: string;
|
||||
sizeBytes?: number;
|
||||
sha256?: string;
|
||||
uploadedFileId?: string;
|
||||
documentId?: string;
|
||||
}>;
|
||||
providerState: {
|
||||
vectorStoreIds?: string[];
|
||||
libraryId?: string;
|
||||
documentCount?: number;
|
||||
prepared?: boolean;
|
||||
uploadedFileIds?: string[];
|
||||
embeddingModel?: string;
|
||||
topK?: number;
|
||||
chunkSize?: number;
|
||||
chunkOverlap?: number;
|
||||
maxContextChars?: number;
|
||||
extractedDocuments?: Array<{
|
||||
documentIndex: number;
|
||||
fileName: string;
|
||||
textChars: number;
|
||||
}>;
|
||||
selectedChunks?: Array<{
|
||||
sourceId: string;
|
||||
documentIndex: number;
|
||||
documentName: string;
|
||||
chunkIndex: number;
|
||||
chunkCount: number;
|
||||
textChars: number;
|
||||
score?: number;
|
||||
}>;
|
||||
skippedDocuments?: Array<{
|
||||
documentIndex: number;
|
||||
fileName: string;
|
||||
reason: string;
|
||||
}>;
|
||||
query?: string;
|
||||
ollama?: OllamaRagArtifactDetails["providerState"];
|
||||
};
|
||||
};
|
||||
import {buildRagArtifactPayload, type RagArtifactPayload} from "./rag-artifact-payload";
|
||||
|
||||
function providerState(prepared: PreparedDocumentRag, details?: NonNullable<Parameters<typeof persistRagArtifactAttachment>[0]["details"]>): RagArtifactPayload["providerState"] {
|
||||
switch (prepared.provider) {
|
||||
case AiProvider.OPENAI:
|
||||
return {
|
||||
provider: AiProvider.OPENAI,
|
||||
vectorStoreIds: prepared.vectorStoreIds,
|
||||
uploadedFileIds: prepared.uploadedFileIds,
|
||||
};
|
||||
case AiProvider.MISTRAL:
|
||||
return {
|
||||
provider: AiProvider.MISTRAL,
|
||||
libraryId: prepared.libraryId,
|
||||
documentCount: prepared.documents.length,
|
||||
};
|
||||
case AiProvider.OLLAMA:
|
||||
return {
|
||||
provider: AiProvider.OLLAMA,
|
||||
prepared: prepared.prepared,
|
||||
embeddingModel: details?.embeddingModel,
|
||||
topK: details?.topK,
|
||||
chunkSize: details?.chunkSize,
|
||||
chunkOverlap: details?.chunkOverlap,
|
||||
maxContextChars: details?.maxContextChars,
|
||||
extractedDocuments: details?.artifact?.extractedDocuments ?? [],
|
||||
selectedChunks: details?.artifact?.selectedChunks ?? [],
|
||||
skippedDocuments: details?.artifact?.skippedDocuments ?? [],
|
||||
query: details?.artifact?.query ?? "",
|
||||
minScore: details?.artifact?.providerState?.minScore ?? 0,
|
||||
maxArchiveFiles: details?.artifact?.providerState?.maxArchiveFiles ?? 0,
|
||||
maxArchiveBytes: details?.artifact?.providerState?.maxArchiveBytes ?? 0,
|
||||
maxArchiveDepth: details?.artifact?.providerState?.maxArchiveDepth ?? 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -117,22 +81,11 @@ export async function persistRagArtifactAttachment(params: {
|
||||
|
||||
if (!sources.length) return Promise.resolve(undefined);
|
||||
|
||||
const payload: RagArtifactPayload = {
|
||||
artifactKind: "rag",
|
||||
const payload = buildRagArtifactPayload({
|
||||
provider: params.provider,
|
||||
createdAt: new Date().toISOString(),
|
||||
sources,
|
||||
providerState: {
|
||||
...providerState(params.prepared, params.details),
|
||||
...(params.details?.artifact ? {
|
||||
extractedDocuments: params.details.artifact.extractedDocuments,
|
||||
selectedChunks: params.details.artifact.selectedChunks,
|
||||
skippedDocuments: params.details.artifact.skippedDocuments,
|
||||
query: params.details.artifact.query,
|
||||
ollama: params.details.artifact.providerState,
|
||||
} : {}),
|
||||
},
|
||||
};
|
||||
providerState: providerState(params.prepared, params.details),
|
||||
});
|
||||
return await persistInternalJsonArtifactAttachment({
|
||||
artifactKind: "rag",
|
||||
fileNamePrefix: "rag",
|
||||
@@ -140,14 +93,8 @@ export async function persistRagArtifactAttachment(params: {
|
||||
messageId: params.messageId,
|
||||
payload,
|
||||
metadata: {
|
||||
provider: params.provider,
|
||||
sourceFileNames: sources.map(source => source.fileName),
|
||||
...payload.providerState,
|
||||
embeddingModel: params.details?.embeddingModel,
|
||||
topK: params.details?.topK,
|
||||
chunkSize: params.details?.chunkSize,
|
||||
chunkOverlap: params.details?.chunkOverlap,
|
||||
maxContextChars: params.details?.maxContextChars,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import {AiTool} from "./tool-types";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {getTools} from "./tools/registry";
|
||||
import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search";
|
||||
import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator";
|
||||
import {AiProvider} from "../model/ai-provider.js";
|
||||
import {getTools} from "./tools/registry.js";
|
||||
import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search.js";
|
||||
import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator.js";
|
||||
|
||||
export type AiProviderName = "ollama" | "openai" | "mistral";
|
||||
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import {Environment} from "../common/environment.js";
|
||||
import {AiProvider} from "../model/ai-provider.js";
|
||||
import type {BoundaryValue} from "../common/boundary-types.js";
|
||||
import type {TelegramStreamMessage} from "./telegram-stream-message.js";
|
||||
import type {RuntimeConfigSnapshot} from "./unified-ai-runner.shared.js";
|
||||
import {filterRankedTools} from "./tool-ranker-pipeline.js";
|
||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker.js";
|
||||
import {storeToolRankAudit} from "./tool-rank-audit.js";
|
||||
|
||||
function latestUserText(messages: readonly { role?: string; content?: unknown }[]): string {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const message = messages[i];
|
||||
if (message?.role !== "user") continue;
|
||||
if (typeof message.content === "string") return message.content;
|
||||
if (Array.isArray(message.content)) {
|
||||
return message.content
|
||||
.map(part => typeof part === "object" && part !== null && "text" in part && typeof (part as { text?: unknown }).text === "string"
|
||||
? (part as { text: string }).text
|
||||
: "")
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
}
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
export async function runToolRankStage(params: {
|
||||
provider: AiProvider;
|
||||
model: string;
|
||||
round: number;
|
||||
config: RuntimeConfigSnapshot;
|
||||
availableTools: readonly BoundaryValue[];
|
||||
messages: readonly { role?: string; content?: unknown }[];
|
||||
streamMessage: TelegramStreamMessage;
|
||||
signal: AbortSignal;
|
||||
toolRanker?: ToolRanker;
|
||||
}): Promise<{
|
||||
filteredTools: BoundaryValue[];
|
||||
selectedToolNames: string[];
|
||||
usedRanker: boolean;
|
||||
}> {
|
||||
const toolRanker = params.toolRanker ?? new ToolRanker(params.config);
|
||||
const startedAt = Date.now();
|
||||
const startedAtIso = new Date().toISOString();
|
||||
|
||||
params.streamMessage.setStatus(Environment.getSelectingToolsText());
|
||||
await params.streamMessage.flush();
|
||||
|
||||
try {
|
||||
const selection = await toolRanker.selectTools({
|
||||
provider: params.provider,
|
||||
userQuery: latestUserText(params.messages),
|
||||
availableTools: params.availableTools,
|
||||
round: params.round,
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
params.streamMessage.clearStatus();
|
||||
await params.streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage: params.streamMessage,
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
round: params.round,
|
||||
startedAt,
|
||||
startedAtIso,
|
||||
selectedTools: selection.toolNames,
|
||||
});
|
||||
|
||||
return {
|
||||
filteredTools: filterRankedTools(params.availableTools, selection.toolNames),
|
||||
selectedToolNames: selection.toolNames,
|
||||
usedRanker: selection.usedRanker,
|
||||
};
|
||||
} catch (error) {
|
||||
params.streamMessage.clearStatus();
|
||||
await params.streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage: params.streamMessage,
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
round: params.round,
|
||||
startedAt,
|
||||
startedAtIso,
|
||||
error,
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import {AiProvider} from "../model/ai-provider";
|
||||
import {Environment} from "../common/environment";
|
||||
import {ifTrue, logError} from "../util/utils";
|
||||
import {UserRequestPipeline, type UserRequestPipelineState, type UserRequestPipelineStage} from "./user-request-pipeline";
|
||||
import {getProviderAdapter} from "./provider-adapters";
|
||||
import type {AiDownloadedFile} from "./telegram-attachments";
|
||||
import type {TelegramStreamMessage} from "./telegram-stream-message";
|
||||
import type {PreparedUnifiedAiRequest} from "./unified-ai-request-pipeline";
|
||||
@@ -9,12 +10,14 @@ import type {OpenAIChatMessage} from "./openai-chat-message";
|
||||
import type {MistralChatMessage} from "./mistral-chat-message";
|
||||
import type {ChatMessage} from "./chat-messages-types";
|
||||
import {
|
||||
allToolSchemaNames,
|
||||
providerName,
|
||||
RuntimeConfigSnapshot,
|
||||
snapshotModel,
|
||||
TELEGRAM_LIMIT,
|
||||
UnifiedRunOptions,
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {runToolRankStage} from "./tool-rank-stage";
|
||||
import {runOpenAi} from "./unified-ai-runner.openai";
|
||||
import {runOllama} from "./unified-ai-runner.ollama";
|
||||
import {runMistral} from "./unified-ai-runner.mistral";
|
||||
@@ -159,6 +162,9 @@ export async function runUnifiedAiResponsePipeline(params: {
|
||||
}): Promise<void> {
|
||||
const {options, config, downloads, prepared, streamMessage, controller} = params;
|
||||
const state = createResponsePipelineState(options);
|
||||
const adapter = getProviderAdapter(options.provider);
|
||||
let selectedToolNames: string[] = [];
|
||||
let filteredTools: unknown[] = [];
|
||||
|
||||
const stages: UserRequestPipelineStage[] = [
|
||||
{
|
||||
@@ -177,6 +183,62 @@ export async function runUnifiedAiResponsePipeline(params: {
|
||||
};
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_rank",
|
||||
async run() {
|
||||
const availableTools = adapter.rankTools(config, {
|
||||
forCreator: options.msg.from?.id === Environment.CREATOR_ID,
|
||||
vectorStoreIds: prepared.preparedDocumentRag?.provider === AiProvider.OPENAI
|
||||
? prepared.preparedDocumentRag.vectorStoreIds
|
||||
: [],
|
||||
});
|
||||
|
||||
const rankResult = await runToolRankStage({
|
||||
provider: options.provider,
|
||||
model: snapshotModel(options.provider, config),
|
||||
round: state.toolRankDecisions.length,
|
||||
config,
|
||||
availableTools,
|
||||
messages: prepared.chatMessages,
|
||||
streamMessage,
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
selectedToolNames = rankResult.selectedToolNames;
|
||||
filteredTools = rankResult.filteredTools;
|
||||
state.toolRankDecisions.push({
|
||||
provider: options.provider,
|
||||
round: state.toolRankDecisions.length,
|
||||
availableTools: allToolSchemaNames(availableTools),
|
||||
selectedTools: selectedToolNames,
|
||||
usedRanker: rankResult.usedRanker,
|
||||
});
|
||||
|
||||
return {
|
||||
stage: "tool_rank",
|
||||
status: "succeeded",
|
||||
details: {
|
||||
selectedTools: selectedToolNames,
|
||||
usedRanker: rankResult.usedRanker,
|
||||
availableTools: allToolSchemaNames(availableTools),
|
||||
toolRankDecision: state.toolRankDecisions.at(-1),
|
||||
},
|
||||
};
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_tools",
|
||||
async run() {
|
||||
return {
|
||||
stage: "filter_tools",
|
||||
status: "succeeded",
|
||||
details: {
|
||||
selectedTools: selectedToolNames,
|
||||
filteredToolCount: filteredTools.length,
|
||||
},
|
||||
};
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model_call",
|
||||
async run() {
|
||||
@@ -312,6 +374,8 @@ export async function runUnifiedAiResponsePipeline(params: {
|
||||
stages,
|
||||
stageNames: [
|
||||
"audit_start",
|
||||
"tool_rank",
|
||||
"filter_tools",
|
||||
"model_call",
|
||||
"tool_loop",
|
||||
"output_size_gate",
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import {Environment} from "../common/environment";
|
||||
import {getMistralTools} from "./tool-mappers";
|
||||
import {TelegramStreamMessage} from "./telegram-stream-message";
|
||||
import {ToolRuntimeContext} from "./tools/runtime";
|
||||
import {MistralChatMessage} from "./mistral-chat-message";
|
||||
import {createMistralClient} from "./ai-runtime-target";
|
||||
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
||||
import {getProviderAdapter} from "./provider-adapters";
|
||||
import {runToolRankStage} from "./tool-rank-stage";
|
||||
|
||||
import {
|
||||
contentFromMistralDelta,
|
||||
executeToolBatch,
|
||||
MAX_TOOL_ROUNDS,
|
||||
MistralDeltaLike,
|
||||
MistralDocumentReference,
|
||||
mistralToolCalls,
|
||||
normalizeMistralToolCalls,
|
||||
roundStatus,
|
||||
RuntimeConfigSnapshot,
|
||||
StreamingToolCallAccumulator,
|
||||
@@ -23,8 +19,6 @@ import {
|
||||
ToolExecutionMemory
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {Message} from "typescript-telegram-bot-api";
|
||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
||||
|
||||
export async function runMistral(
|
||||
msg: Message,
|
||||
@@ -39,8 +33,9 @@ export async function runMistral(
|
||||
): Promise<void> {
|
||||
const runnerStartedAt = Date.now();
|
||||
const mistralAi = createMistralClient(config.mistralChatTarget);
|
||||
const toolRanker = new ToolRanker(config);
|
||||
const availableTools = getMistralTools(msg.from?.id === Environment.CREATOR_ID);
|
||||
const adapter = getProviderAdapter(AiProvider.MISTRAL);
|
||||
const availableTools = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID});
|
||||
const requestMessages = adapter.mapMessages([...messages]) as unknown as MistralChatMessage[];
|
||||
aiLog("info", "mistral.run.start", {
|
||||
stream,
|
||||
target: aiLogProviderTarget(config.mistralChatTarget),
|
||||
@@ -50,49 +45,23 @@ export async function runMistral(
|
||||
});
|
||||
|
||||
const toolMemory: ToolExecutionMemory = new Map();
|
||||
|
||||
try {
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
const roundStartedAt = Date.now();
|
||||
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
|
||||
if (signal.aborted) throw new Error("Aborted");
|
||||
|
||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
||||
await streamMessage.flush();
|
||||
const toolRankStartedAt = Date.now();
|
||||
const toolRankStartedAtIso = new Date().toISOString();
|
||||
const rankerSelection = await toolRanker.selectTools({
|
||||
const rankResult = await runToolRankStage({
|
||||
provider: AiProvider.MISTRAL,
|
||||
userQuery: latestUserTextFromMessages(messages),
|
||||
model: config.mistralChatTarget.model,
|
||||
round,
|
||||
config,
|
||||
availableTools,
|
||||
round,
|
||||
messages,
|
||||
streamMessage,
|
||||
signal,
|
||||
})
|
||||
.catch(async error => {
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.MISTRAL,
|
||||
model: config.mistralChatTarget.model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
error,
|
||||
});
|
||||
throw error;
|
||||
});
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.MISTRAL,
|
||||
model: config.mistralChatTarget.model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
selectedTools: rankerSelection.toolNames,
|
||||
});
|
||||
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
|
||||
const filteredTools = rankResult.filteredTools;
|
||||
const requestTools = filteredTools.length ? filteredTools : undefined;
|
||||
|
||||
streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? "");
|
||||
@@ -101,15 +70,15 @@ export async function runMistral(
|
||||
if (!stream) {
|
||||
const request = {
|
||||
model: config.mistralChatTarget.model,
|
||||
messages,
|
||||
messages: requestMessages,
|
||||
tools: requestTools,
|
||||
documents: documents
|
||||
} as Parameters<typeof mistralAi.chat.complete>[0];
|
||||
const response = await mistralAi.chat.complete(request, {signal});
|
||||
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal}));
|
||||
const message = response.choices?.[0]?.message;
|
||||
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
|
||||
streamMessage.append(text);
|
||||
const calls = normalizeMistralToolCalls(mistralToolCalls(message));
|
||||
const calls = adapter.extractToolCalls(message);
|
||||
aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", {
|
||||
round,
|
||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||
@@ -125,25 +94,27 @@ export async function runMistral(
|
||||
function: {name: call.name, arguments: call.argumentsText},
|
||||
})),
|
||||
});
|
||||
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
||||
for (const [index, call] of calls.entries()) {
|
||||
messages.push({
|
||||
role: "tool",
|
||||
name: call.name,
|
||||
toolCallId: call.id,
|
||||
content: toolResults[index] ?? "",
|
||||
requestMessages.push({
|
||||
role: "assistant",
|
||||
content: text,
|
||||
toolCalls: calls.map(call => ({
|
||||
id: call.id,
|
||||
function: {name: call.name, arguments: call.argumentsText},
|
||||
})),
|
||||
});
|
||||
}
|
||||
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
||||
adapter.appendToolResults(messages, calls, toolResults);
|
||||
adapter.appendToolResults(requestMessages, calls, toolResults);
|
||||
continue;
|
||||
}
|
||||
|
||||
const request = {
|
||||
model: config.mistralChatTarget.model,
|
||||
messages,
|
||||
messages: requestMessages,
|
||||
tools: requestTools,
|
||||
documents: documents
|
||||
} as Parameters<typeof mistralAi.chat.stream>[0];
|
||||
const streamResponse = await mistralAi.chat.stream(request, {signal});
|
||||
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal}));
|
||||
aiLog("debug", "mistral.stream.open", {round});
|
||||
let calls: ToolCallData[] = [];
|
||||
const roundTextStart = streamMessage.getText().length;
|
||||
@@ -154,11 +125,10 @@ export async function runMistral(
|
||||
|
||||
const choice = event.data?.choices?.[0];
|
||||
const delta = choice?.delta;
|
||||
const mistralDelta = delta as MistralDeltaLike;
|
||||
const mistralDelta = delta;
|
||||
streamMessage.append(adapter.extractTextDelta(mistralDelta));
|
||||
|
||||
streamMessage.append(contentFromMistralDelta(mistralDelta));
|
||||
|
||||
const rawDeltaCalls = mistralToolCalls(mistralDelta);
|
||||
const rawDeltaCalls = adapter.extractStreamingToolCalls(mistralDelta);
|
||||
if (rawDeltaCalls.length) {
|
||||
calls = toolCallAccumulator.add(rawDeltaCalls);
|
||||
streamMessage.setStatus(Environment.getUseToolText(calls));
|
||||
@@ -178,14 +148,16 @@ export async function runMistral(
|
||||
content: roundText,
|
||||
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
|
||||
});
|
||||
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
||||
for (const [index, call] of calls.entries()) {
|
||||
messages.push({
|
||||
role: "tool",
|
||||
name: call.name,
|
||||
toolCallId: call.id,
|
||||
content: toolResults[index] ?? "",
|
||||
requestMessages.push({
|
||||
role: "assistant",
|
||||
content: roundText,
|
||||
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
|
||||
});
|
||||
}
|
||||
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
||||
adapter.appendToolResults(messages, calls, toolResults);
|
||||
adapter.appendToolResults(requestMessages, calls, toolResults);
|
||||
}
|
||||
} finally {
|
||||
await adapter.finalize().catch(() => undefined);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import {Environment} from "../common/environment";
|
||||
import type {BoundaryValue} from "../common/boundary-types";
|
||||
import {bot, notesDir} from "../index";
|
||||
import {clamp, logError} from "../util/utils";
|
||||
import {getOllamaTools} from "./tool-mappers";
|
||||
import {TelegramStreamMessage} from "./telegram-stream-message";
|
||||
import {ChatMessage} from "./chat-messages-types";
|
||||
import {ChatRequest, Tool} from "ollama";
|
||||
@@ -14,10 +13,11 @@ import {enqueueTelegramApiCall} from "../util/telegram-api-queue";
|
||||
import {loadOllamaModel, unloadAllOllamaModels} from "./tools/utils";
|
||||
import {createOllamaClient} from "./ai-runtime-target";
|
||||
import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
||||
import {getProviderAdapter} from "./provider-adapters";
|
||||
import {runToolRankStage} from "./tool-rank-stage";
|
||||
|
||||
import {
|
||||
allToolSchemaNames,
|
||||
appendOllamaToolResults,
|
||||
dedupeToolCalls,
|
||||
DEFAULT_OLLAMA_CONTEXT_SIZE,
|
||||
executeToolBatch,
|
||||
@@ -26,8 +26,6 @@ import {
|
||||
MAX_OLLAMA_CONTEXT_SIZE,
|
||||
MAX_TOOL_ROUNDS,
|
||||
MIN_OLLAMA_CONTEXT_SIZE,
|
||||
normalizeOllamaToolCalls,
|
||||
OllamaToolCallLike,
|
||||
roundStatus,
|
||||
RuntimeConfigSnapshot,
|
||||
safeJsonParseObject,
|
||||
@@ -35,14 +33,11 @@ import {
|
||||
ToolCallData,
|
||||
ToolExecutionMemory
|
||||
} from "./unified-ai-runner.shared";
|
||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
||||
import {getToolPrompts} from "./tools/registry";
|
||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
||||
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
||||
import {getModelCapabilities} from "./provider-model-runtime";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {Message} from "typescript-telegram-bot-api";
|
||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
||||
|
||||
export async function runOllama(
|
||||
msg: Message,
|
||||
@@ -157,6 +152,7 @@ export async function runOllama(
|
||||
}
|
||||
|
||||
const toolMemory: ToolExecutionMemory = new Map();
|
||||
const adapter = getProviderAdapter(AiProvider.OLLAMA);
|
||||
|
||||
try {
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
@@ -183,7 +179,7 @@ export async function runOllama(
|
||||
|
||||
let activeToolNames: string[] = [];
|
||||
if ((await getModelCapabilities(AiProvider.OLLAMA, model, "tools"))?.tools?.supported) {
|
||||
const availableOllamaTools: Tool[] = getOllamaTools(msg.from?.id === Environment.CREATOR_ID) as Tool[];
|
||||
const availableOllamaTools: Tool[] = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID}) as Tool[];
|
||||
|
||||
aiLog("debug", "ollama.tools.available", {
|
||||
round,
|
||||
@@ -191,44 +187,18 @@ export async function runOllama(
|
||||
rankerEnabled: !!config.ollamaToolRankerTarget,
|
||||
});
|
||||
|
||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
||||
await streamMessage.flush();
|
||||
const toolRankStartedAt = Date.now();
|
||||
const toolRankStartedAtIso = new Date().toISOString();
|
||||
const rankerSelection = await new ToolRanker(config).selectTools({
|
||||
const rankResult = await runToolRankStage({
|
||||
provider: AiProvider.OLLAMA,
|
||||
userQuery: latestUserTextFromMessages(messages),
|
||||
model,
|
||||
round,
|
||||
config,
|
||||
availableTools: availableOllamaTools,
|
||||
round,
|
||||
messages,
|
||||
streamMessage,
|
||||
signal,
|
||||
})
|
||||
.catch(async error => {
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.OLLAMA,
|
||||
model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
error,
|
||||
});
|
||||
throw error;
|
||||
});
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.OLLAMA,
|
||||
model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
selectedTools: rankerSelection.toolNames,
|
||||
});
|
||||
|
||||
const filteredTools = [...new Set(filterRankedTools(availableOllamaTools, rankerSelection.toolNames))];
|
||||
const filteredTools = [...new Set(rankResult.filteredTools as Tool[])];
|
||||
activeToolNames = filteredTools.map(t => t.function.name ?? "");
|
||||
if (filteredTools.length > 0) {
|
||||
request.tools = [...filteredTools];
|
||||
@@ -256,24 +226,21 @@ export async function runOllama(
|
||||
round,
|
||||
tools: activeToolNames,
|
||||
count: activeToolNames.length,
|
||||
usedRanker: rankerSelection.usedRanker,
|
||||
usedRanker: rankResult.usedRanker,
|
||||
});
|
||||
}
|
||||
|
||||
if (!stream) {
|
||||
const response = await ollama.chat({
|
||||
const response = await adapter.callModel(request, () => ollama.chat({
|
||||
...request,
|
||||
stream: false
|
||||
});
|
||||
}));
|
||||
|
||||
const message = response.message;
|
||||
const rawContent = message?.content ?? "";
|
||||
|
||||
const nativeCalls = dedupeToolCalls(
|
||||
normalizeOllamaToolCalls(
|
||||
message?.tool_calls as readonly OllamaToolCallLike[] | undefined,
|
||||
round,
|
||||
),
|
||||
adapter.extractToolCalls(message),
|
||||
);
|
||||
|
||||
const responseText = rawContent;
|
||||
@@ -301,7 +268,7 @@ export async function runOllama(
|
||||
break;
|
||||
}
|
||||
|
||||
const calls = nativeCalls;
|
||||
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
|
||||
|
||||
aiLog("info", "ollama.tool_calls", {
|
||||
round,
|
||||
@@ -319,11 +286,7 @@ export async function runOllama(
|
||||
})),
|
||||
});
|
||||
|
||||
appendOllamaToolResults(
|
||||
messages,
|
||||
calls,
|
||||
await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory),
|
||||
);
|
||||
adapter.appendToolResults(messages, calls, await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory));
|
||||
|
||||
continue;
|
||||
}
|
||||
@@ -332,10 +295,10 @@ export async function runOllama(
|
||||
round,
|
||||
messageCount: request.messages?.length ?? 0,
|
||||
});
|
||||
const response = await ollama.chat({
|
||||
const response = await adapter.callModel(request, () => ollama.chat({
|
||||
...request,
|
||||
stream: true
|
||||
});
|
||||
}));
|
||||
|
||||
aiLog("debug", "ollama.stream.open", {round});
|
||||
const calls: ToolCallData[] = [];
|
||||
@@ -354,10 +317,7 @@ export async function runOllama(
|
||||
|
||||
const localToolCalls: ToolCallData[] = [];
|
||||
|
||||
localToolCalls.push(...normalizeOllamaToolCalls(
|
||||
chunk.message.tool_calls as readonly OllamaToolCallLike[] | undefined,
|
||||
round,
|
||||
));
|
||||
localToolCalls.push(...adapter.extractStreamingToolCalls(chunk.message));
|
||||
|
||||
const newStatus = roundStatus(round, firstRoundStatus, chunk.message.content, localToolCalls, !!chunk.message.thinking);
|
||||
const previousStatus = streamMessage.getStatus();
|
||||
@@ -377,13 +337,10 @@ export async function runOllama(
|
||||
}
|
||||
|
||||
if (!(chunk.message?.thinking && streamMessage.getStatus() !== Environment.reasoningText)) {
|
||||
streamMessage.append(chunk.message?.content ?? "");
|
||||
streamMessage.append(adapter.extractTextDelta(chunk));
|
||||
}
|
||||
|
||||
calls.push(...normalizeOllamaToolCalls(
|
||||
chunk.message?.tool_calls as readonly OllamaToolCallLike[] | undefined,
|
||||
round,
|
||||
));
|
||||
calls.push(...adapter.extractStreamingToolCalls(chunk.message));
|
||||
|
||||
if (chunk.done) {
|
||||
aiLog("debug", "ollama.stream.done", {
|
||||
@@ -471,9 +428,10 @@ export async function runOllama(
|
||||
}).catch(logError);
|
||||
}
|
||||
|
||||
appendOllamaToolResults(messages, calls, toolResults);
|
||||
adapter.appendToolResults(messages, calls, toolResults);
|
||||
}
|
||||
} finally {
|
||||
if (interval) clearInterval(interval);
|
||||
await adapter.finalize().catch(() => undefined);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,11 +17,9 @@ import {
|
||||
AsyncIterableStream,
|
||||
buildSystemInstruction,
|
||||
collectOpenAiResponseCodeInterpreterCalls,
|
||||
collectOpenAiResponseFunctionCalls,
|
||||
collectOpenAiResponseImages,
|
||||
collectOpenAiResponseText,
|
||||
executeToolBatch,
|
||||
getOpenAIResponsesToolsWithImage,
|
||||
MAX_TOOL_ROUNDS,
|
||||
OPENAI_IMAGE_PARTIALS,
|
||||
openAiResponseItemCallId,
|
||||
@@ -42,10 +40,9 @@ import {logError} from "../util/utils";
|
||||
import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files";
|
||||
import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings";
|
||||
import {AiDownloadedFile} from "./telegram-attachments";
|
||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
||||
import {getProviderAdapter} from "./provider-adapters";
|
||||
import {runToolRankStage} from "./tool-rank-stage";
|
||||
|
||||
export async function runOpenAi(
|
||||
msg: Message,
|
||||
@@ -60,16 +57,15 @@ export async function runOpenAi(
|
||||
documentRag?: OpenAiDocumentRagContext,
|
||||
): Promise<void> {
|
||||
const runnerStartedAt = Date.now();
|
||||
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = [...messages] as Array<ResponseInputItem | OpenAiResponseOutputItem>;
|
||||
const openAi = createOpenAiClient(config.openAiChatTarget);
|
||||
const ownsDocumentRag = !documentRag;
|
||||
const preparedDocumentRag = documentRag ?? await prepareOpenAiDocumentRag(openAi, downloads.filter(download => download.kind === "document"));
|
||||
const toolRanker = new ToolRanker(config);
|
||||
const availableTools = getOpenAIResponsesToolsWithImage(
|
||||
config,
|
||||
msg.from?.id === Environment.CREATOR_ID,
|
||||
preparedDocumentRag?.vectorStoreIds ?? [],
|
||||
);
|
||||
const adapter = getProviderAdapter(AiProvider.OPENAI);
|
||||
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = adapter.mapMessages(messages) as unknown as Array<ResponseInputItem | OpenAiResponseOutputItem>;
|
||||
const availableTools = adapter.rankTools(config, {
|
||||
forCreator: msg.from?.id === Environment.CREATOR_ID,
|
||||
vectorStoreIds: preparedDocumentRag?.vectorStoreIds ?? [],
|
||||
});
|
||||
|
||||
const systemPrompt = buildSystemInstruction(
|
||||
config,
|
||||
@@ -93,43 +89,17 @@ export async function runOpenAi(
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||
const roundStartedAt = Date.now();
|
||||
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
||||
await streamMessage.flush();
|
||||
const toolRankStartedAt = Date.now();
|
||||
const toolRankStartedAtIso = new Date().toISOString();
|
||||
const rankerSelection = await toolRanker.selectTools({
|
||||
const rankResult = await runToolRankStage({
|
||||
provider: AiProvider.OPENAI,
|
||||
userQuery: latestUserTextFromMessages(messages),
|
||||
model: config.openAiChatTarget.model,
|
||||
round,
|
||||
config,
|
||||
availableTools,
|
||||
round,
|
||||
messages,
|
||||
streamMessage,
|
||||
signal,
|
||||
})
|
||||
.catch(async error => {
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.OPENAI,
|
||||
model: config.openAiChatTarget.model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
error,
|
||||
});
|
||||
throw error;
|
||||
});
|
||||
streamMessage.clearStatus();
|
||||
await streamMessage.flush();
|
||||
await storeToolRankAudit({
|
||||
streamMessage,
|
||||
provider: AiProvider.OPENAI,
|
||||
model: config.openAiChatTarget.model,
|
||||
round,
|
||||
startedAt: toolRankStartedAt,
|
||||
startedAtIso: toolRankStartedAtIso,
|
||||
selectedTools: rankerSelection.toolNames,
|
||||
});
|
||||
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
|
||||
const filteredTools = rankResult.filteredTools;
|
||||
const requestTools = preparedDocumentRag?.vectorStoreIds.length
|
||||
? (() => {
|
||||
const tools = [...filteredTools];
|
||||
@@ -151,7 +121,7 @@ export async function runOpenAi(
|
||||
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
||||
instructions: systemPrompt,
|
||||
};
|
||||
const response = await openAi.responses.create(request, {signal}) as OpenAiResponseLike;
|
||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
|
||||
|
||||
const responseText = collectOpenAiResponseText(response);
|
||||
streamMessage.append(responseText);
|
||||
@@ -188,12 +158,12 @@ export async function runOpenAi(
|
||||
});
|
||||
}
|
||||
|
||||
const calls = collectOpenAiResponseFunctionCalls(response);
|
||||
const calls = adapter.extractToolCalls(response);
|
||||
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
||||
round,
|
||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||
calls: calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
arguments: safeJsonParseObject(call.argumentsText)
|
||||
})),
|
||||
@@ -201,16 +171,13 @@ export async function runOpenAi(
|
||||
if (!calls.length) return;
|
||||
|
||||
const toolCalls = calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
argumentsText: call.argumentsText,
|
||||
}));
|
||||
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
||||
const toolOutputs = calls.map((call, index) => ({
|
||||
type: "function_call_output" as const,
|
||||
call_id: call.callId,
|
||||
output: toolResults[index] ?? "",
|
||||
}));
|
||||
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
|
||||
adapter.appendToolResults(toolOutputs, calls, toolResults);
|
||||
|
||||
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
||||
if (uploadFilesResult.found) {
|
||||
@@ -243,7 +210,7 @@ export async function runOpenAi(
|
||||
parallel_tool_calls: true,
|
||||
instructions: systemPrompt
|
||||
};
|
||||
const response = await openAi.responses.create(request, {signal}) as AsyncIterableStream<ResponseStreamEvent>;
|
||||
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
|
||||
|
||||
aiLog("debug", "openai.stream.open", {round});
|
||||
|
||||
@@ -253,7 +220,7 @@ export async function runOpenAi(
|
||||
|
||||
switch (event.type) {
|
||||
case "response.output_text.delta":
|
||||
streamMessage.append(event.delta ?? "");
|
||||
streamMessage.append(adapter.extractTextDelta(event));
|
||||
break;
|
||||
case "response.image_generation_call.in_progress":
|
||||
streamMessage.setStatus(Environment.startingImageGenText);
|
||||
@@ -301,14 +268,11 @@ export async function runOpenAi(
|
||||
case "response.code_interpreter_call_code.done":
|
||||
break;
|
||||
case "response.output_item.added":
|
||||
if (event.item.type === "function_call" && event.item.name) {
|
||||
const item = event.item as OpenAiResponseOutputItem & { id?: string };
|
||||
localToolCalls.push({
|
||||
id: openAiResponseItemCallId(item),
|
||||
name: item.name ?? "",
|
||||
argumentsText: item.arguments ?? "{}",
|
||||
});
|
||||
|
||||
{
|
||||
const streamedCalls = adapter.extractStreamingToolCalls(event);
|
||||
if (streamedCalls.length) {
|
||||
localToolCalls.push(...streamedCalls);
|
||||
}
|
||||
aiLog("info", "openai.stream.tool_call.added", {
|
||||
round,
|
||||
toolCalls: localToolCalls.map(aiLogToolCall)
|
||||
@@ -383,12 +347,12 @@ export async function runOpenAi(
|
||||
});
|
||||
}
|
||||
|
||||
const calls = collectOpenAiResponseFunctionCalls(completedResponse);
|
||||
const calls = adapter.extractToolCalls(completedResponse);
|
||||
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
||||
round,
|
||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||
calls: calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
arguments: safeJsonParseObject(call.argumentsText)
|
||||
})),
|
||||
@@ -396,16 +360,13 @@ export async function runOpenAi(
|
||||
if (!calls.length) return;
|
||||
|
||||
const toolCalls = calls.map(call => ({
|
||||
id: call.callId,
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
argumentsText: call.argumentsText,
|
||||
}));
|
||||
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
||||
const toolOutputs = calls.map((call, index) => ({
|
||||
type: "function_call_output",
|
||||
call_id: call.callId,
|
||||
output: toolResults[index] ?? "",
|
||||
}));
|
||||
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
|
||||
adapter.appendToolResults(toolOutputs, calls, toolResults);
|
||||
|
||||
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
||||
if (uploadFilesResult.found) {
|
||||
@@ -431,6 +392,7 @@ export async function runOpenAi(
|
||||
if (ownsDocumentRag) {
|
||||
await preparedDocumentRag?.cleanup().catch(logError);
|
||||
}
|
||||
await adapter.finalize().catch(logError);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,40 +2,40 @@ import {Message} from "typescript-telegram-bot-api";
|
||||
import * as fs from "node:fs";
|
||||
import path from "node:path";
|
||||
import type {BoundaryValue} from "../common/boundary-types";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {ToolRankerFallbackPolicy} from "../common/policies";
|
||||
import {Environment} from "../common/environment";
|
||||
import {photoGenDir} from "../index";
|
||||
import {delay, logError, replyToMessage} from "../util/utils";
|
||||
import {MessageStore} from "../common/message-store";
|
||||
import type {OpenAiResponseTool} from "./tool-mappers";
|
||||
import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers";
|
||||
import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message";
|
||||
import {AiDownloadedFile} from "./telegram-attachments";
|
||||
import {getRuntimeCapabilities} from "./provider-model-runtime";
|
||||
import {StoredAttachment} from "../model/stored-attachment";
|
||||
import {AiChatMessage, ChatMessage} from "./chat-messages-types";
|
||||
import {AiProvider} from "../model/ai-provider.js";
|
||||
import {ToolRankerFallbackPolicy} from "../common/policies.js";
|
||||
import {Environment} from "../common/environment.js";
|
||||
import {photoGenDir} from "../index.js";
|
||||
import {delay, logError, replyToMessage} from "../util/utils.js";
|
||||
import {MessageStore} from "../common/message-store.js";
|
||||
import type {OpenAiResponseTool} from "./tool-mappers.js";
|
||||
import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers.js";
|
||||
import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message.js";
|
||||
import {AiDownloadedFile} from "./telegram-attachments.js";
|
||||
import {getRuntimeCapabilities} from "./provider-model-runtime.js";
|
||||
import {StoredAttachment} from "../model/stored-attachment.js";
|
||||
import {AiChatMessage, ChatMessage} from "./chat-messages-types.js";
|
||||
import {ListResponse, Ollama} from "ollama";
|
||||
import {executeToolCall, ToolRuntimeContext} from "./tools/runtime";
|
||||
import {MessageImagePart, MessagePart} from "../common/message-part";
|
||||
import {KeyedAsyncLock} from "../util/async-lock";
|
||||
import {type AiRequestQueueTarget} from "./provider-request-queue";
|
||||
import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator";
|
||||
import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings";
|
||||
import {executeToolCall, ToolRuntimeContext} from "./tools/runtime.js";
|
||||
import {MessageImagePart, MessagePart} from "../common/message-part.js";
|
||||
import {KeyedAsyncLock} from "../util/async-lock.js";
|
||||
import {type AiRequestQueueTarget} from "./provider-request-queue.js";
|
||||
import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator.js";
|
||||
import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings.js";
|
||||
import {
|
||||
isTranscribableAudioDownload,
|
||||
resolveSpeechToTextProviderForUser,
|
||||
transcribeSpeechDownloads
|
||||
} from "./speech-to-text";
|
||||
} from "./speech-to-text.js";
|
||||
import type {ChatCompletionMessageParam} from "openai/resources/chat/completions";
|
||||
import {MistralChatMessage} from "./mistral-chat-message";
|
||||
import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer";
|
||||
import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target";
|
||||
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
||||
import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline";
|
||||
import {MistralChatMessage} from "./mistral-chat-message.js";
|
||||
import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer.js";
|
||||
import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target.js";
|
||||
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger.js";
|
||||
import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline.js";
|
||||
import type {ResponseInputMessageContentList} from "openai/resources/responses/responses";
|
||||
import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store";
|
||||
import {filterUserVisibleStoredAttachments} from "../common/stored-attachment-utils";
|
||||
import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store.js";
|
||||
import {filterUserVisibleStoredAttachments} from "../common/attachment-visibility.js";
|
||||
|
||||
export type {Message} from "typescript-telegram-bot-api";
|
||||
export type {AiRuntimeTarget} from "./ai-runtime-target";
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
import type {StoredAttachment} from "../model/stored-attachment";
|
||||
|
||||
export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] {
|
||||
return attachments.filter(attachment => attachment.scope !== "internal_artifact");
|
||||
}
|
||||
+10
-10
@@ -3,18 +3,18 @@ import os from "node:os";
|
||||
import path from "node:path";
|
||||
import {parse as parseDotEnv} from "dotenv";
|
||||
import {z} from "zod";
|
||||
import {appLogger} from "../logging/logger";
|
||||
import {appLogger} from "../logging/logger.js";
|
||||
import type {BoundaryValue, ErrorLike} from "./boundary-types";
|
||||
|
||||
import {saveData} from "../db/database";
|
||||
import {Answers} from "../model/answers";
|
||||
import {ifTrue} from "../util/utils";
|
||||
import {AiProvider} from "../model/ai-provider";
|
||||
import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies";
|
||||
import {ToolRankerFallbackPolicy} from "./policies";
|
||||
import type {ToolCallData} from "../ai/unified-ai-runner";
|
||||
import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator";
|
||||
import {Localization, type LocalizationParams} from "./localization";
|
||||
import {saveData} from "../db/database.js";
|
||||
import {Answers} from "../model/answers.js";
|
||||
import {ifTrue} from "../util/utils.js";
|
||||
import {AiProvider} from "../model/ai-provider.js";
|
||||
import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies.js";
|
||||
import {ToolRankerFallbackPolicy} from "./policies.js";
|
||||
import type {ToolCallData} from "../ai/unified-ai-runner.js";
|
||||
import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator.js";
|
||||
import {Localization, type LocalizationParams} from "./localization.js";
|
||||
|
||||
type EnvRecord = Record<string, string>;
|
||||
type StringEnumLike = Record<string, string>;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import path from "node:path";
|
||||
import {Environment} from "./environment";
|
||||
import {StoredAttachment} from "../model/stored-attachment";
|
||||
export {filterUserVisibleStoredAttachments} from "./attachment-visibility";
|
||||
|
||||
export function photoCachePathForUniqueId(uniqueId: string): string {
|
||||
return path.join(Environment.DATA_PATH, "cache", "photo", `${uniqueId}.jpg`);
|
||||
@@ -44,7 +45,3 @@ export function uniqueStoredAttachments(attachments: StoredAttachment[]): Stored
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] {
|
||||
return attachments.filter(attachment => attachment.scope !== "internal_artifact");
|
||||
}
|
||||
|
||||
+4
-4
@@ -1,9 +1,9 @@
|
||||
import * as fs from "fs";
|
||||
import {Environment} from "../common/environment";
|
||||
import {logError} from "../util/utils";
|
||||
import {Answers} from "../model/answers";
|
||||
import {Environment} from "../common/environment.js";
|
||||
import {logError} from "../util/utils.js";
|
||||
import {Answers} from "../model/answers.js";
|
||||
import path from "node:path";
|
||||
import {KeyedAsyncLock} from "../util/async-lock";
|
||||
import {KeyedAsyncLock} from "../util/async-lock.js";
|
||||
|
||||
type DataJsonFile = {
|
||||
admins: number[]
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
|
||||
const {
|
||||
extractOpenAiToolCalls,
|
||||
extractOpenAiStreamingToolCalls,
|
||||
extractOpenAiTextDelta,
|
||||
extractMistralToolCalls,
|
||||
extractMistralTextDelta,
|
||||
extractOllamaToolCalls,
|
||||
extractOllamaTextDelta,
|
||||
} = await import("../dist/ai/provider-adapter-contract.js");
|
||||
|
||||
test("openai contract extracts text delta and function calls", () => {
|
||||
assert.equal(extractOpenAiTextDelta({type: "response.output_text.delta", delta: "hello"}), "hello");
|
||||
|
||||
const calls = extractOpenAiToolCalls({
|
||||
output: [{
|
||||
type: "function_call",
|
||||
call_id: "call-1",
|
||||
name: "read_file",
|
||||
arguments: "{\"path\":\"src/index.ts\"}",
|
||||
}],
|
||||
});
|
||||
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].id, "call-1");
|
||||
assert.equal(calls[0].name, "read_file");
|
||||
|
||||
const streamed = extractOpenAiStreamingToolCalls({
|
||||
type: "response.output_item.added",
|
||||
item: {
|
||||
type: "function_call",
|
||||
id: "call-2",
|
||||
name: "search_files",
|
||||
arguments: "{\"query\":\"sendMessage\"}",
|
||||
},
|
||||
});
|
||||
|
||||
assert.equal(streamed.length, 1);
|
||||
assert.equal(streamed[0].id, "call-2");
|
||||
assert.equal(streamed[0].name, "search_files");
|
||||
});
|
||||
|
||||
test("mistral contract extracts content and tool calls", () => {
|
||||
assert.equal(extractMistralTextDelta({
|
||||
content: [{text: "hello"}, {text: " world"}],
|
||||
}), "hello world");
|
||||
|
||||
const calls = extractMistralToolCalls({
|
||||
toolCalls: [{
|
||||
id: "m-1",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
arguments: {location: "Moscow"},
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].id, "m-1");
|
||||
assert.equal(calls[0].name, "get_weather");
|
||||
});
|
||||
|
||||
test("ollama contract extracts content and tool calls", () => {
|
||||
assert.equal(extractOllamaTextDelta({
|
||||
message: {content: "hello from ollama"},
|
||||
}), "hello from ollama");
|
||||
|
||||
const calls = extractOllamaToolCalls({
|
||||
tool_calls: [{
|
||||
id: "o-1",
|
||||
function: {
|
||||
name: "web_search",
|
||||
arguments: {query: "openai docs"},
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].id, "o-1");
|
||||
assert.equal(calls[0].name, "web_search");
|
||||
});
|
||||
+27
-94
@@ -1,32 +1,13 @@
|
||||
import test, {after} from "node:test";
|
||||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import fs from "node:fs";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
|
||||
const tempRoot = fs.mkdtempSync(path.join(os.tmpdir(), "tg-chat-bot-rag-"));
|
||||
process.env.BOT_TOKEN = process.env.BOT_TOKEN ?? "test-token";
|
||||
process.env.CREATOR_ID = process.env.CREATOR_ID ?? "1";
|
||||
process.env.DATA_PATH = tempRoot;
|
||||
process.env.DB_PATH = `file:${path.join(tempRoot, "test.sqlite")}`;
|
||||
process.env.TEST_ENVIRONMENT = "true";
|
||||
|
||||
const {Environment} = await import("../dist/common/environment.js");
|
||||
Environment.load();
|
||||
|
||||
const {DatabaseManager} = await import("../dist/db/database-manager.js");
|
||||
DatabaseManager.init();
|
||||
await DatabaseManager.ready;
|
||||
|
||||
const {ArtifactStore} = await import("../dist/common/artifact-store.js");
|
||||
const {filterUserVisibleStoredAttachments} = await import("../dist/common/stored-attachment-utils.js");
|
||||
const {
|
||||
buildRagArtifactPayload,
|
||||
} = await import("../dist/ai/rag-artifact-payload.js");
|
||||
const {
|
||||
filterUserVisibleStoredAttachments,
|
||||
} = await import("../dist/common/attachment-visibility.js");
|
||||
const {AiProvider} = await import("../dist/model/ai-provider.js");
|
||||
const {persistRagArtifactAttachment} = await import("../dist/ai/rag-artifact-store.js");
|
||||
|
||||
after(async () => {
|
||||
await DatabaseManager.close().catch(() => undefined);
|
||||
fs.rmSync(tempRoot, {recursive: true, force: true});
|
||||
});
|
||||
|
||||
test("internal artifacts are not treated as user-visible attachments", () => {
|
||||
const visible = filterUserVisibleStoredAttachments([
|
||||
@@ -50,65 +31,26 @@ test("internal artifacts are not treated as user-visible attachments", () => {
|
||||
assert.equal(visible[0].fileId, "visible");
|
||||
});
|
||||
|
||||
test("RAG artifacts persist structured ollama metadata", async () => {
|
||||
const chatId = 42;
|
||||
const messageId = 7;
|
||||
|
||||
const attachment = await persistRagArtifactAttachment({
|
||||
test("RAG artifact payload keeps ollama retrieval metadata", () => {
|
||||
const payload = buildRagArtifactPayload({
|
||||
provider: AiProvider.OLLAMA,
|
||||
prepared: {
|
||||
provider: AiProvider.OLLAMA,
|
||||
prepared: true,
|
||||
cleanup: async () => undefined,
|
||||
artifact: {
|
||||
query: "What is in the file?",
|
||||
extractedDocuments: [
|
||||
{documentIndex: 0, fileName: "report.txt", textChars: 120},
|
||||
],
|
||||
selectedChunks: [
|
||||
{
|
||||
sourceId: "doc1-1",
|
||||
documentIndex: 0,
|
||||
documentName: "report.txt",
|
||||
chunkIndex: 0,
|
||||
chunkCount: 1,
|
||||
textChars: 120,
|
||||
score: 0.91,
|
||||
},
|
||||
],
|
||||
skippedDocuments: [
|
||||
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
|
||||
],
|
||||
providerState: {
|
||||
embeddingModel: "nomic-embed-text:latest",
|
||||
topK: 8,
|
||||
chunkSize: 1400,
|
||||
chunkOverlap: 220,
|
||||
maxContextChars: 14000,
|
||||
minScore: 0.12,
|
||||
maxArchiveFiles: 200,
|
||||
maxArchiveBytes: 50 * 1024 * 1024,
|
||||
maxArchiveDepth: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
downloads: [{
|
||||
kind: "document",
|
||||
createdAt: "2026-01-01T00:00:00.000Z",
|
||||
sources: [{
|
||||
fileId: "file-1",
|
||||
fileName: "report.txt",
|
||||
buffer: Buffer.from("hello world"),
|
||||
path: path.join(tempRoot, "report.txt"),
|
||||
mimeType: "text/plain",
|
||||
sizeBytes: 12,
|
||||
sha256: "abc123",
|
||||
uploadedFileId: "uploaded-1",
|
||||
}],
|
||||
chatId,
|
||||
messageId,
|
||||
details: {
|
||||
providerState: {
|
||||
provider: AiProvider.OLLAMA,
|
||||
prepared: true,
|
||||
embeddingModel: "nomic-embed-text:latest",
|
||||
topK: 8,
|
||||
chunkSize: 1400,
|
||||
chunkOverlap: 220,
|
||||
maxContextChars: 14000,
|
||||
artifact: {
|
||||
query: "What is in the file?",
|
||||
extractedDocuments: [
|
||||
{documentIndex: 0, fileName: "report.txt", textChars: 120},
|
||||
],
|
||||
@@ -126,29 +68,20 @@ test("RAG artifacts persist structured ollama metadata", async () => {
|
||||
skippedDocuments: [
|
||||
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
|
||||
],
|
||||
providerState: {
|
||||
embeddingModel: "nomic-embed-text:latest",
|
||||
topK: 8,
|
||||
chunkSize: 1400,
|
||||
chunkOverlap: 220,
|
||||
maxContextChars: 14000,
|
||||
minScore: 0.12,
|
||||
maxArchiveFiles: 200,
|
||||
maxArchiveBytes: 50 * 1024 * 1024,
|
||||
maxArchiveDepth: 2,
|
||||
},
|
||||
},
|
||||
query: "What is in the file?",
|
||||
},
|
||||
});
|
||||
|
||||
assert.equal(attachment?.artifactKind, "rag");
|
||||
assert.equal(fs.existsSync(attachment.cachePath), true);
|
||||
|
||||
const stored = await ArtifactStore.getByMessage(chatId, messageId);
|
||||
assert.equal(stored.length, 1);
|
||||
assert.equal(stored[0].kind, "rag");
|
||||
assert.equal(stored[0].payload.providerState.query, "What is in the file?");
|
||||
assert.equal(stored[0].payload.providerState.selectedChunks[0].score, 0.91);
|
||||
assert.equal(stored[0].payload.providerState.skippedDocuments[0].reason, "unsupported format");
|
||||
assert.equal(stored[0].payload.providerState.ollama.embeddingModel, "nomic-embed-text:latest");
|
||||
assert.equal(payload.artifactKind, "rag");
|
||||
assert.equal(payload.provider, AiProvider.OLLAMA);
|
||||
assert.equal(payload.sources[0].uploadedFileId, "uploaded-1");
|
||||
assert.equal(payload.providerState.provider, AiProvider.OLLAMA);
|
||||
assert.equal(payload.providerState.query, "What is in the file?");
|
||||
assert.equal(payload.providerState.selectedChunks[0].score, 0.91);
|
||||
assert.equal(payload.providerState.skippedDocuments[0].reason, "unsupported format");
|
||||
assert.equal(payload.providerState.embeddingModel, "nomic-embed-text:latest");
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user