Add unified request pipeline stages
This commit is contained in:
+15
-15
@@ -46,27 +46,27 @@
|
|||||||
- [x] Для Mistral сохранять `libraryId`.
|
- [x] Для Mistral сохранять `libraryId`.
|
||||||
- [x] Для Mistral сохранять uploaded document ids.
|
- [x] Для Mistral сохранять uploaded document ids.
|
||||||
- [x] Для Mistral сохранять source file mapping: local attachment -> Mistral document id.
|
- [x] Для Mistral сохранять source file mapping: local attachment -> Mistral document id.
|
||||||
- [ ] Добавить единый `providerState` schema для всех providers.
|
- [x] Добавить единый `providerState` schema для всех providers.
|
||||||
- [ ] Добавить tests на сериализацию `RagArtifact`.
|
- [x] Добавить tests на сериализацию `RagArtifact`.
|
||||||
- [ ] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context.
|
- [x] Добавить tests на то, что internal RAG artifacts не попадают обратно в user document context.
|
||||||
|
|
||||||
## 4. Вынести provider runners в adapter layer
|
## 4. Вынести provider runners в adapter layer
|
||||||
|
|
||||||
- [ ] Ввести интерфейс `AiProviderAdapter`.
|
- [x] Ввести интерфейс `AiProviderAdapter`.
|
||||||
- [ ] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`.
|
- [x] Методы adapter-а: `mapMessages`, `rankTools`, `callModel`, `extractTextDelta`, `extractToolCalls`, `appendToolResults`, `finalize`.
|
||||||
- [ ] Реализовать `OpenAiProviderAdapter`.
|
- [x] Реализовать `OpenAiProviderAdapter`.
|
||||||
- [ ] Реализовать `MistralProviderAdapter`.
|
- [x] Реализовать `MistralProviderAdapter`.
|
||||||
- [ ] Реализовать `OllamaProviderAdapter`.
|
- [x] Реализовать `OllamaProviderAdapter`.
|
||||||
- [ ] Перенести provider-specific tool schema mapping внутрь adapter-ов.
|
- [x] Перенести provider-specific tool schema mapping внутрь adapter-ов.
|
||||||
- [ ] Перенести provider-specific streaming parsing внутрь adapter-ов.
|
- [x] Перенести provider-specific streaming parsing внутрь adapter-ов.
|
||||||
- [ ] Перенести provider-specific tool result append внутрь adapter-ов.
|
- [x] Перенести provider-specific tool result append внутрь adapter-ов.
|
||||||
- [ ] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом.
|
- [x] Упростить `runOpenAi`, `runMistral`, `runOllama` или заменить их adapter-driven runner-ом.
|
||||||
- [ ] Оставить compatibility wrappers для текущих imports.
|
- [x] Оставить compatibility wrappers для текущих imports.
|
||||||
- [ ] Добавить tests на adapter contract без реальных API.
|
- [x] Добавить tests на adapter contract без реальных API.
|
||||||
|
|
||||||
## 5. Сделать tool-ranker полноценным pipeline stage
|
## 5. Сделать tool-ranker полноценным pipeline stage
|
||||||
|
|
||||||
- [ ] Вынести вызов `ToolRanker.selectTools(...)` из provider runners.
|
- [x] Вынести вызов `ToolRanker.selectTools(...)` из provider runners.
|
||||||
- [ ] Добавить stage `tool_rank`, который работает через provider adapter.
|
- [ ] Добавить stage `tool_rank`, который работает через provider adapter.
|
||||||
- [ ] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker.
|
- [ ] Добавить stage `filter_tools`, который фильтрует provider-specific tools по результату ranker.
|
||||||
- [ ] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`.
|
- [ ] Хранить `ToolRankDecision` в `UserRequestPipelineState.toolRankDecisions`.
|
||||||
|
|||||||
@@ -0,0 +1,112 @@
|
|||||||
|
import type {ToolCallData} from "./unified-ai-runner.shared.js";
|
||||||
|
import type {ResponseStreamEvent} from "openai/resources/responses/responses";
|
||||||
|
|
||||||
|
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||||
|
return !!value && typeof value === "object" && !Array.isArray(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeToolCallId(value: unknown, fallback: string): string {
|
||||||
|
return typeof value === "string" && value.trim().length > 0 ? value : fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeToolArguments(value: unknown): string {
|
||||||
|
if (typeof value === "string") return value;
|
||||||
|
return JSON.stringify(value ?? {});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractOpenAiToolCalls(response: unknown): ToolCallData[] {
|
||||||
|
const output = isRecord(response) && Array.isArray(response.output) ? response.output : [];
|
||||||
|
|
||||||
|
return output
|
||||||
|
.filter(item => isRecord(item) && item.type === "function_call" && (typeof item.call_id === "string" || typeof item.name === "string"))
|
||||||
|
.map((item, index) => ({
|
||||||
|
id: normalizeToolCallId(item.call_id, `openai_${index}`),
|
||||||
|
name: typeof item.name === "string" ? item.name : "",
|
||||||
|
argumentsText: normalizeToolArguments(item.arguments),
|
||||||
|
}))
|
||||||
|
.filter(call => call.name.length > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractOpenAiTextDelta(input: unknown): string {
|
||||||
|
const event = input as ResponseStreamEvent | undefined;
|
||||||
|
return event?.type === "response.output_text.delta" ? event.delta ?? "" : "";
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractOpenAiStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
const event = input as ResponseStreamEvent | undefined;
|
||||||
|
if (event?.type === "response.output_item.added" && isRecord(event.item) && event.item.type === "function_call") {
|
||||||
|
return extractOpenAiToolCalls({
|
||||||
|
output: [{
|
||||||
|
type: "function_call",
|
||||||
|
call_id: event.item.call_id ?? event.item.id,
|
||||||
|
name: event.item.name,
|
||||||
|
arguments: event.item.arguments,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractMistralToolCalls(calls: unknown): ToolCallData[] {
|
||||||
|
const normalized = Array.isArray(calls)
|
||||||
|
? calls
|
||||||
|
: isRecord(calls) && (Array.isArray(calls.toolCalls) || Array.isArray(calls.tool_calls))
|
||||||
|
? (calls.toolCalls ?? calls.tool_calls)
|
||||||
|
: [];
|
||||||
|
|
||||||
|
if (!Array.isArray(normalized)) return [];
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
.map((item, index) => {
|
||||||
|
const call = isRecord(item) ? item : {};
|
||||||
|
const fn = isRecord(call.function) ? call.function : undefined;
|
||||||
|
const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : "";
|
||||||
|
return {
|
||||||
|
id: normalizeToolCallId(call.id, `mistral_${index}`),
|
||||||
|
name,
|
||||||
|
argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments),
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.filter(call => call.name.length > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractMistralTextDelta(input: unknown): string {
|
||||||
|
const delta = isRecord(input) ? input : {};
|
||||||
|
const content = delta.content;
|
||||||
|
if (typeof content === "string") return content;
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
return content
|
||||||
|
.map(part => isRecord(part) && typeof part.text === "string" ? part.text : "")
|
||||||
|
.join("");
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractOllamaToolCalls(calls: unknown): ToolCallData[] {
|
||||||
|
const normalized = Array.isArray(calls)
|
||||||
|
? calls
|
||||||
|
: isRecord(calls) && Array.isArray(calls.tool_calls)
|
||||||
|
? calls.tool_calls
|
||||||
|
: [];
|
||||||
|
|
||||||
|
if (!Array.isArray(normalized)) return [];
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
.map((item, index) => {
|
||||||
|
const call = isRecord(item) ? item : {};
|
||||||
|
const fn = isRecord(call.function) ? call.function : undefined;
|
||||||
|
const name = typeof fn?.name === "string" ? fn.name : typeof call.name === "string" ? call.name : "";
|
||||||
|
return {
|
||||||
|
id: normalizeToolCallId(call.id, `ollama_${index}`),
|
||||||
|
name,
|
||||||
|
argumentsText: normalizeToolArguments(fn?.arguments ?? call.arguments),
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.filter(call => call.name.length > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractOllamaTextDelta(input: unknown): string {
|
||||||
|
const chunk = isRecord(input) ? input.message : undefined;
|
||||||
|
return isRecord(chunk) && typeof chunk.content === "string" ? chunk.content : "";
|
||||||
|
}
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
import {AiProvider} from "../model/ai-provider.js";
|
||||||
|
import type {BoundaryValue} from "../common/boundary-types.js";
|
||||||
|
import type {RuntimeConfigSnapshot, ToolCallData} from "./unified-ai-runner.shared.js";
|
||||||
|
import {getMistralTools, getOllamaTools, getOpenAIResponsesTools, getOpenAICodeInterpreterTool} from "./tool-mappers.js";
|
||||||
|
import type {MistralChatMessage as MistralMessageType} from "./mistral-chat-message.js";
|
||||||
|
import type {OpenAIChatMessage as OpenAiMessageType} from "./openai-chat-message.js";
|
||||||
|
import type {Message as OllamaMessage} from "ollama";
|
||||||
|
import {
|
||||||
|
extractMistralTextDelta,
|
||||||
|
extractMistralToolCalls,
|
||||||
|
extractOllamaTextDelta,
|
||||||
|
extractOllamaToolCalls,
|
||||||
|
extractOpenAiTextDelta,
|
||||||
|
extractOpenAiStreamingToolCalls,
|
||||||
|
extractOpenAiToolCalls,
|
||||||
|
} from "./provider-adapter-contract.js";
|
||||||
|
|
||||||
|
export type ProviderRankToolOptions = {
|
||||||
|
forCreator?: boolean;
|
||||||
|
vectorStoreIds?: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface AiProviderAdapter {
|
||||||
|
readonly provider: AiProvider;
|
||||||
|
mapMessages(messages: readonly unknown[]): unknown[];
|
||||||
|
rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[];
|
||||||
|
callModel<T>(request: unknown, execute: () => Promise<T>): Promise<T>;
|
||||||
|
extractTextDelta(input: unknown): string;
|
||||||
|
extractToolCalls(input: unknown): ToolCallData[];
|
||||||
|
extractStreamingToolCalls(input: unknown): ToolCallData[];
|
||||||
|
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void;
|
||||||
|
finalize(): Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
function appendOllamaToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||||
|
for (const [index, call] of calls.entries()) {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
content: results[index] ?? "",
|
||||||
|
tool_name: call.name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class OpenAiProviderAdapter implements AiProviderAdapter {
|
||||||
|
readonly provider = AiProvider.OPENAI;
|
||||||
|
|
||||||
|
mapMessages(messages: readonly unknown[]): unknown[] {
|
||||||
|
return messages as OpenAiMessageType[];
|
||||||
|
}
|
||||||
|
|
||||||
|
rankTools(config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
|
||||||
|
const tools: BoundaryValue[] = [
|
||||||
|
...getOpenAIResponsesTools(options?.forCreator) as BoundaryValue[],
|
||||||
|
getOpenAICodeInterpreterTool() as BoundaryValue,
|
||||||
|
{
|
||||||
|
type: "image_generation",
|
||||||
|
model: config.openAiImageTarget.model,
|
||||||
|
size: "auto",
|
||||||
|
moderation: "low",
|
||||||
|
output_format: "png",
|
||||||
|
partial_images: 3,
|
||||||
|
},
|
||||||
|
{type: "web_search"},
|
||||||
|
];
|
||||||
|
|
||||||
|
if (options?.vectorStoreIds?.length) {
|
||||||
|
tools.unshift({
|
||||||
|
type: "file_search",
|
||||||
|
vector_store_ids: options.vectorStoreIds,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return tools;
|
||||||
|
}
|
||||||
|
|
||||||
|
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
|
||||||
|
return execute();
|
||||||
|
}
|
||||||
|
|
||||||
|
extractTextDelta(input: unknown): string {
|
||||||
|
return extractOpenAiTextDelta(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
extractToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
return extractOpenAiToolCalls(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
extractStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
return extractOpenAiStreamingToolCalls(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||||
|
for (const [index, call] of calls.entries()) {
|
||||||
|
messages.push({
|
||||||
|
type: "function_call_output",
|
||||||
|
call_id: call.id,
|
||||||
|
output: results[index] ?? "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async finalize(): Promise<void> {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MistralProviderAdapter implements AiProviderAdapter {
|
||||||
|
readonly provider = AiProvider.MISTRAL;
|
||||||
|
|
||||||
|
mapMessages(messages: readonly unknown[]): unknown[] {
|
||||||
|
return messages as MistralMessageType[];
|
||||||
|
}
|
||||||
|
|
||||||
|
rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
|
||||||
|
return getMistralTools(options?.forCreator) as BoundaryValue[];
|
||||||
|
}
|
||||||
|
|
||||||
|
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
|
||||||
|
return execute();
|
||||||
|
}
|
||||||
|
|
||||||
|
extractTextDelta(input: unknown): string {
|
||||||
|
return extractMistralTextDelta(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
extractToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
return extractMistralToolCalls(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
extractStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
return this.extractToolCalls(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||||
|
for (const [index, call] of calls.entries()) {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
name: call.name,
|
||||||
|
toolCallId: call.id,
|
||||||
|
content: results[index] ?? "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async finalize(): Promise<void> {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class OllamaProviderAdapter implements AiProviderAdapter {
|
||||||
|
readonly provider = AiProvider.OLLAMA;
|
||||||
|
|
||||||
|
mapMessages(messages: readonly unknown[]): unknown[] {
|
||||||
|
return messages as OllamaMessage[];
|
||||||
|
}
|
||||||
|
|
||||||
|
rankTools(_config: RuntimeConfigSnapshot, options?: ProviderRankToolOptions): readonly BoundaryValue[] {
|
||||||
|
return getOllamaTools(options?.forCreator) as BoundaryValue[];
|
||||||
|
}
|
||||||
|
|
||||||
|
async callModel<T>(_request: unknown, execute: () => Promise<T>): Promise<T> {
|
||||||
|
return execute();
|
||||||
|
}
|
||||||
|
|
||||||
|
extractTextDelta(input: unknown): string {
|
||||||
|
return extractOllamaTextDelta(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
extractToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
return extractOllamaToolCalls(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
extractStreamingToolCalls(input: unknown): ToolCallData[] {
|
||||||
|
return this.extractToolCalls(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToolResults(messages: unknown[], calls: ToolCallData[], results: string[]): void {
|
||||||
|
appendOllamaToolResults(messages, calls, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
async finalize(): Promise<void> {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getProviderAdapter(provider: AiProvider): AiProviderAdapter {
|
||||||
|
switch (provider) {
|
||||||
|
case AiProvider.OPENAI:
|
||||||
|
return new OpenAiProviderAdapter();
|
||||||
|
case AiProvider.MISTRAL:
|
||||||
|
return new MistralProviderAdapter();
|
||||||
|
case AiProvider.OLLAMA:
|
||||||
|
return new OllamaProviderAdapter();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
import type {AiProvider} from "../model/ai-provider";
|
||||||
|
|
||||||
|
export type RagArtifactSource = {
|
||||||
|
fileId: string;
|
||||||
|
fileName: string;
|
||||||
|
mimeType?: string;
|
||||||
|
sizeBytes?: number;
|
||||||
|
sha256?: string;
|
||||||
|
uploadedFileId?: string;
|
||||||
|
documentId?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type RagArtifactPayload = {
|
||||||
|
artifactKind: "rag";
|
||||||
|
provider: AiProvider;
|
||||||
|
createdAt: string;
|
||||||
|
sources: RagArtifactSource[];
|
||||||
|
providerState:
|
||||||
|
| {
|
||||||
|
provider: AiProvider.OPENAI;
|
||||||
|
vectorStoreIds: string[];
|
||||||
|
uploadedFileIds: string[];
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
provider: AiProvider.MISTRAL;
|
||||||
|
libraryId?: string;
|
||||||
|
documentCount: number;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
provider: AiProvider.OLLAMA;
|
||||||
|
prepared: boolean;
|
||||||
|
embeddingModel?: string;
|
||||||
|
topK?: number;
|
||||||
|
chunkSize?: number;
|
||||||
|
chunkOverlap?: number;
|
||||||
|
maxContextChars?: number;
|
||||||
|
extractedDocuments: Array<{
|
||||||
|
documentIndex: number;
|
||||||
|
fileName: string;
|
||||||
|
textChars: number;
|
||||||
|
}>;
|
||||||
|
selectedChunks: Array<{
|
||||||
|
sourceId: string;
|
||||||
|
documentIndex: number;
|
||||||
|
documentName: string;
|
||||||
|
chunkIndex: number;
|
||||||
|
chunkCount: number;
|
||||||
|
textChars: number;
|
||||||
|
score?: number;
|
||||||
|
}>;
|
||||||
|
skippedDocuments: Array<{
|
||||||
|
documentIndex: number;
|
||||||
|
fileName: string;
|
||||||
|
reason: string;
|
||||||
|
}>;
|
||||||
|
query: string;
|
||||||
|
minScore: number;
|
||||||
|
maxArchiveFiles: number;
|
||||||
|
maxArchiveBytes: number;
|
||||||
|
maxArchiveDepth: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export function buildRagArtifactPayload(params: {
|
||||||
|
provider: AiProvider;
|
||||||
|
createdAt?: string;
|
||||||
|
sources: RagArtifactSource[];
|
||||||
|
providerState: RagArtifactPayload["providerState"];
|
||||||
|
}): RagArtifactPayload {
|
||||||
|
return {
|
||||||
|
artifactKind: "rag",
|
||||||
|
provider: params.provider,
|
||||||
|
createdAt: params.createdAt ?? new Date().toISOString(),
|
||||||
|
sources: params.sources,
|
||||||
|
providerState: params.providerState,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -4,75 +4,39 @@ import type {AiDownloadedFile} from "./telegram-attachments";
|
|||||||
import type {PreparedDocumentRag} from "./document-rag-pipeline";
|
import type {PreparedDocumentRag} from "./document-rag-pipeline";
|
||||||
import type {OllamaRagArtifactDetails} from "./ollama-rag";
|
import type {OllamaRagArtifactDetails} from "./ollama-rag";
|
||||||
import {persistInternalJsonArtifactAttachment} from "./internal-artifact-store";
|
import {persistInternalJsonArtifactAttachment} from "./internal-artifact-store";
|
||||||
|
import {buildRagArtifactPayload, type RagArtifactPayload} from "./rag-artifact-payload";
|
||||||
type RagArtifactPayload = {
|
|
||||||
artifactKind: "rag";
|
|
||||||
provider: AiProvider;
|
|
||||||
createdAt: string;
|
|
||||||
sources: Array<{
|
|
||||||
fileId: string;
|
|
||||||
fileName: string;
|
|
||||||
mimeType?: string;
|
|
||||||
sizeBytes?: number;
|
|
||||||
sha256?: string;
|
|
||||||
uploadedFileId?: string;
|
|
||||||
documentId?: string;
|
|
||||||
}>;
|
|
||||||
providerState: {
|
|
||||||
vectorStoreIds?: string[];
|
|
||||||
libraryId?: string;
|
|
||||||
documentCount?: number;
|
|
||||||
prepared?: boolean;
|
|
||||||
uploadedFileIds?: string[];
|
|
||||||
embeddingModel?: string;
|
|
||||||
topK?: number;
|
|
||||||
chunkSize?: number;
|
|
||||||
chunkOverlap?: number;
|
|
||||||
maxContextChars?: number;
|
|
||||||
extractedDocuments?: Array<{
|
|
||||||
documentIndex: number;
|
|
||||||
fileName: string;
|
|
||||||
textChars: number;
|
|
||||||
}>;
|
|
||||||
selectedChunks?: Array<{
|
|
||||||
sourceId: string;
|
|
||||||
documentIndex: number;
|
|
||||||
documentName: string;
|
|
||||||
chunkIndex: number;
|
|
||||||
chunkCount: number;
|
|
||||||
textChars: number;
|
|
||||||
score?: number;
|
|
||||||
}>;
|
|
||||||
skippedDocuments?: Array<{
|
|
||||||
documentIndex: number;
|
|
||||||
fileName: string;
|
|
||||||
reason: string;
|
|
||||||
}>;
|
|
||||||
query?: string;
|
|
||||||
ollama?: OllamaRagArtifactDetails["providerState"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
function providerState(prepared: PreparedDocumentRag, details?: NonNullable<Parameters<typeof persistRagArtifactAttachment>[0]["details"]>): RagArtifactPayload["providerState"] {
|
function providerState(prepared: PreparedDocumentRag, details?: NonNullable<Parameters<typeof persistRagArtifactAttachment>[0]["details"]>): RagArtifactPayload["providerState"] {
|
||||||
switch (prepared.provider) {
|
switch (prepared.provider) {
|
||||||
case AiProvider.OPENAI:
|
case AiProvider.OPENAI:
|
||||||
return {
|
return {
|
||||||
|
provider: AiProvider.OPENAI,
|
||||||
vectorStoreIds: prepared.vectorStoreIds,
|
vectorStoreIds: prepared.vectorStoreIds,
|
||||||
uploadedFileIds: prepared.uploadedFileIds,
|
uploadedFileIds: prepared.uploadedFileIds,
|
||||||
};
|
};
|
||||||
case AiProvider.MISTRAL:
|
case AiProvider.MISTRAL:
|
||||||
return {
|
return {
|
||||||
|
provider: AiProvider.MISTRAL,
|
||||||
libraryId: prepared.libraryId,
|
libraryId: prepared.libraryId,
|
||||||
documentCount: prepared.documents.length,
|
documentCount: prepared.documents.length,
|
||||||
};
|
};
|
||||||
case AiProvider.OLLAMA:
|
case AiProvider.OLLAMA:
|
||||||
return {
|
return {
|
||||||
|
provider: AiProvider.OLLAMA,
|
||||||
prepared: prepared.prepared,
|
prepared: prepared.prepared,
|
||||||
embeddingModel: details?.embeddingModel,
|
embeddingModel: details?.embeddingModel,
|
||||||
topK: details?.topK,
|
topK: details?.topK,
|
||||||
chunkSize: details?.chunkSize,
|
chunkSize: details?.chunkSize,
|
||||||
chunkOverlap: details?.chunkOverlap,
|
chunkOverlap: details?.chunkOverlap,
|
||||||
maxContextChars: details?.maxContextChars,
|
maxContextChars: details?.maxContextChars,
|
||||||
|
extractedDocuments: details?.artifact?.extractedDocuments ?? [],
|
||||||
|
selectedChunks: details?.artifact?.selectedChunks ?? [],
|
||||||
|
skippedDocuments: details?.artifact?.skippedDocuments ?? [],
|
||||||
|
query: details?.artifact?.query ?? "",
|
||||||
|
minScore: details?.artifact?.providerState?.minScore ?? 0,
|
||||||
|
maxArchiveFiles: details?.artifact?.providerState?.maxArchiveFiles ?? 0,
|
||||||
|
maxArchiveBytes: details?.artifact?.providerState?.maxArchiveBytes ?? 0,
|
||||||
|
maxArchiveDepth: details?.artifact?.providerState?.maxArchiveDepth ?? 0,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,22 +81,11 @@ export async function persistRagArtifactAttachment(params: {
|
|||||||
|
|
||||||
if (!sources.length) return Promise.resolve(undefined);
|
if (!sources.length) return Promise.resolve(undefined);
|
||||||
|
|
||||||
const payload: RagArtifactPayload = {
|
const payload = buildRagArtifactPayload({
|
||||||
artifactKind: "rag",
|
|
||||||
provider: params.provider,
|
provider: params.provider,
|
||||||
createdAt: new Date().toISOString(),
|
|
||||||
sources,
|
sources,
|
||||||
providerState: {
|
providerState: providerState(params.prepared, params.details),
|
||||||
...providerState(params.prepared, params.details),
|
});
|
||||||
...(params.details?.artifact ? {
|
|
||||||
extractedDocuments: params.details.artifact.extractedDocuments,
|
|
||||||
selectedChunks: params.details.artifact.selectedChunks,
|
|
||||||
skippedDocuments: params.details.artifact.skippedDocuments,
|
|
||||||
query: params.details.artifact.query,
|
|
||||||
ollama: params.details.artifact.providerState,
|
|
||||||
} : {}),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
return await persistInternalJsonArtifactAttachment({
|
return await persistInternalJsonArtifactAttachment({
|
||||||
artifactKind: "rag",
|
artifactKind: "rag",
|
||||||
fileNamePrefix: "rag",
|
fileNamePrefix: "rag",
|
||||||
@@ -140,14 +93,8 @@ export async function persistRagArtifactAttachment(params: {
|
|||||||
messageId: params.messageId,
|
messageId: params.messageId,
|
||||||
payload,
|
payload,
|
||||||
metadata: {
|
metadata: {
|
||||||
provider: params.provider,
|
|
||||||
sourceFileNames: sources.map(source => source.fileName),
|
sourceFileNames: sources.map(source => source.fileName),
|
||||||
...payload.providerState,
|
...payload.providerState,
|
||||||
embeddingModel: params.details?.embeddingModel,
|
|
||||||
topK: params.details?.topK,
|
|
||||||
chunkSize: params.details?.chunkSize,
|
|
||||||
chunkOverlap: params.details?.chunkOverlap,
|
|
||||||
maxContextChars: params.details?.maxContextChars,
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import {AiTool} from "./tool-types";
|
import {AiTool} from "./tool-types";
|
||||||
import {AiProvider} from "../model/ai-provider";
|
import {AiProvider} from "../model/ai-provider.js";
|
||||||
import {getTools} from "./tools/registry";
|
import {getTools} from "./tools/registry.js";
|
||||||
import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search";
|
import {WEB_SEARCH_TOOL_NAME} from "./tools/web-search.js";
|
||||||
import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator";
|
import {PYTHON_INTERPRETER_TOOL_NAME} from "./tools/python-interpretator.js";
|
||||||
|
|
||||||
export type AiProviderName = "ollama" | "openai" | "mistral";
|
export type AiProviderName = "ollama" | "openai" | "mistral";
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
import {Environment} from "../common/environment.js";
|
||||||
|
import {AiProvider} from "../model/ai-provider.js";
|
||||||
|
import type {BoundaryValue} from "../common/boundary-types.js";
|
||||||
|
import type {TelegramStreamMessage} from "./telegram-stream-message.js";
|
||||||
|
import type {RuntimeConfigSnapshot} from "./unified-ai-runner.shared.js";
|
||||||
|
import {filterRankedTools} from "./tool-ranker-pipeline.js";
|
||||||
|
import {ToolRanker} from "./unified-ai-runner.tool-ranker.js";
|
||||||
|
import {storeToolRankAudit} from "./tool-rank-audit.js";
|
||||||
|
|
||||||
|
function latestUserText(messages: readonly { role?: string; content?: unknown }[]): string {
|
||||||
|
for (let i = messages.length - 1; i >= 0; i--) {
|
||||||
|
const message = messages[i];
|
||||||
|
if (message?.role !== "user") continue;
|
||||||
|
if (typeof message.content === "string") return message.content;
|
||||||
|
if (Array.isArray(message.content)) {
|
||||||
|
return message.content
|
||||||
|
.map(part => typeof part === "object" && part !== null && "text" in part && typeof (part as { text?: unknown }).text === "string"
|
||||||
|
? (part as { text: string }).text
|
||||||
|
: "")
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function runToolRankStage(params: {
|
||||||
|
provider: AiProvider;
|
||||||
|
model: string;
|
||||||
|
round: number;
|
||||||
|
config: RuntimeConfigSnapshot;
|
||||||
|
availableTools: readonly BoundaryValue[];
|
||||||
|
messages: readonly { role?: string; content?: unknown }[];
|
||||||
|
streamMessage: TelegramStreamMessage;
|
||||||
|
signal: AbortSignal;
|
||||||
|
toolRanker?: ToolRanker;
|
||||||
|
}): Promise<{
|
||||||
|
filteredTools: BoundaryValue[];
|
||||||
|
selectedToolNames: string[];
|
||||||
|
usedRanker: boolean;
|
||||||
|
}> {
|
||||||
|
const toolRanker = params.toolRanker ?? new ToolRanker(params.config);
|
||||||
|
const startedAt = Date.now();
|
||||||
|
const startedAtIso = new Date().toISOString();
|
||||||
|
|
||||||
|
params.streamMessage.setStatus(Environment.getSelectingToolsText());
|
||||||
|
await params.streamMessage.flush();
|
||||||
|
|
||||||
|
try {
|
||||||
|
const selection = await toolRanker.selectTools({
|
||||||
|
provider: params.provider,
|
||||||
|
userQuery: latestUserText(params.messages),
|
||||||
|
availableTools: params.availableTools,
|
||||||
|
round: params.round,
|
||||||
|
signal: params.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
params.streamMessage.clearStatus();
|
||||||
|
await params.streamMessage.flush();
|
||||||
|
await storeToolRankAudit({
|
||||||
|
streamMessage: params.streamMessage,
|
||||||
|
provider: params.provider,
|
||||||
|
model: params.model,
|
||||||
|
round: params.round,
|
||||||
|
startedAt,
|
||||||
|
startedAtIso,
|
||||||
|
selectedTools: selection.toolNames,
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
filteredTools: filterRankedTools(params.availableTools, selection.toolNames),
|
||||||
|
selectedToolNames: selection.toolNames,
|
||||||
|
usedRanker: selection.usedRanker,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
params.streamMessage.clearStatus();
|
||||||
|
await params.streamMessage.flush();
|
||||||
|
await storeToolRankAudit({
|
||||||
|
streamMessage: params.streamMessage,
|
||||||
|
provider: params.provider,
|
||||||
|
model: params.model,
|
||||||
|
round: params.round,
|
||||||
|
startedAt,
|
||||||
|
startedAtIso,
|
||||||
|
error,
|
||||||
|
});
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ import {AiProvider} from "../model/ai-provider";
|
|||||||
import {Environment} from "../common/environment";
|
import {Environment} from "../common/environment";
|
||||||
import {ifTrue, logError} from "../util/utils";
|
import {ifTrue, logError} from "../util/utils";
|
||||||
import {UserRequestPipeline, type UserRequestPipelineState, type UserRequestPipelineStage} from "./user-request-pipeline";
|
import {UserRequestPipeline, type UserRequestPipelineState, type UserRequestPipelineStage} from "./user-request-pipeline";
|
||||||
|
import {getProviderAdapter} from "./provider-adapters";
|
||||||
import type {AiDownloadedFile} from "./telegram-attachments";
|
import type {AiDownloadedFile} from "./telegram-attachments";
|
||||||
import type {TelegramStreamMessage} from "./telegram-stream-message";
|
import type {TelegramStreamMessage} from "./telegram-stream-message";
|
||||||
import type {PreparedUnifiedAiRequest} from "./unified-ai-request-pipeline";
|
import type {PreparedUnifiedAiRequest} from "./unified-ai-request-pipeline";
|
||||||
@@ -9,12 +10,14 @@ import type {OpenAIChatMessage} from "./openai-chat-message";
|
|||||||
import type {MistralChatMessage} from "./mistral-chat-message";
|
import type {MistralChatMessage} from "./mistral-chat-message";
|
||||||
import type {ChatMessage} from "./chat-messages-types";
|
import type {ChatMessage} from "./chat-messages-types";
|
||||||
import {
|
import {
|
||||||
|
allToolSchemaNames,
|
||||||
providerName,
|
providerName,
|
||||||
RuntimeConfigSnapshot,
|
RuntimeConfigSnapshot,
|
||||||
snapshotModel,
|
snapshotModel,
|
||||||
TELEGRAM_LIMIT,
|
TELEGRAM_LIMIT,
|
||||||
UnifiedRunOptions,
|
UnifiedRunOptions,
|
||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
|
import {runToolRankStage} from "./tool-rank-stage";
|
||||||
import {runOpenAi} from "./unified-ai-runner.openai";
|
import {runOpenAi} from "./unified-ai-runner.openai";
|
||||||
import {runOllama} from "./unified-ai-runner.ollama";
|
import {runOllama} from "./unified-ai-runner.ollama";
|
||||||
import {runMistral} from "./unified-ai-runner.mistral";
|
import {runMistral} from "./unified-ai-runner.mistral";
|
||||||
@@ -159,6 +162,9 @@ export async function runUnifiedAiResponsePipeline(params: {
|
|||||||
}): Promise<void> {
|
}): Promise<void> {
|
||||||
const {options, config, downloads, prepared, streamMessage, controller} = params;
|
const {options, config, downloads, prepared, streamMessage, controller} = params;
|
||||||
const state = createResponsePipelineState(options);
|
const state = createResponsePipelineState(options);
|
||||||
|
const adapter = getProviderAdapter(options.provider);
|
||||||
|
let selectedToolNames: string[] = [];
|
||||||
|
let filteredTools: unknown[] = [];
|
||||||
|
|
||||||
const stages: UserRequestPipelineStage[] = [
|
const stages: UserRequestPipelineStage[] = [
|
||||||
{
|
{
|
||||||
@@ -177,6 +183,62 @@ export async function runUnifiedAiResponsePipeline(params: {
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "tool_rank",
|
||||||
|
async run() {
|
||||||
|
const availableTools = adapter.rankTools(config, {
|
||||||
|
forCreator: options.msg.from?.id === Environment.CREATOR_ID,
|
||||||
|
vectorStoreIds: prepared.preparedDocumentRag?.provider === AiProvider.OPENAI
|
||||||
|
? prepared.preparedDocumentRag.vectorStoreIds
|
||||||
|
: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const rankResult = await runToolRankStage({
|
||||||
|
provider: options.provider,
|
||||||
|
model: snapshotModel(options.provider, config),
|
||||||
|
round: state.toolRankDecisions.length,
|
||||||
|
config,
|
||||||
|
availableTools,
|
||||||
|
messages: prepared.chatMessages,
|
||||||
|
streamMessage,
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
selectedToolNames = rankResult.selectedToolNames;
|
||||||
|
filteredTools = rankResult.filteredTools;
|
||||||
|
state.toolRankDecisions.push({
|
||||||
|
provider: options.provider,
|
||||||
|
round: state.toolRankDecisions.length,
|
||||||
|
availableTools: allToolSchemaNames(availableTools),
|
||||||
|
selectedTools: selectedToolNames,
|
||||||
|
usedRanker: rankResult.usedRanker,
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
stage: "tool_rank",
|
||||||
|
status: "succeeded",
|
||||||
|
details: {
|
||||||
|
selectedTools: selectedToolNames,
|
||||||
|
usedRanker: rankResult.usedRanker,
|
||||||
|
availableTools: allToolSchemaNames(availableTools),
|
||||||
|
toolRankDecision: state.toolRankDecisions.at(-1),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter_tools",
|
||||||
|
async run() {
|
||||||
|
return {
|
||||||
|
stage: "filter_tools",
|
||||||
|
status: "succeeded",
|
||||||
|
details: {
|
||||||
|
selectedTools: selectedToolNames,
|
||||||
|
filteredToolCount: filteredTools.length,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "model_call",
|
name: "model_call",
|
||||||
async run() {
|
async run() {
|
||||||
@@ -312,6 +374,8 @@ export async function runUnifiedAiResponsePipeline(params: {
|
|||||||
stages,
|
stages,
|
||||||
stageNames: [
|
stageNames: [
|
||||||
"audit_start",
|
"audit_start",
|
||||||
|
"tool_rank",
|
||||||
|
"filter_tools",
|
||||||
"model_call",
|
"model_call",
|
||||||
"tool_loop",
|
"tool_loop",
|
||||||
"output_size_gate",
|
"output_size_gate",
|
||||||
|
|||||||
@@ -1,21 +1,17 @@
|
|||||||
import {Environment} from "../common/environment";
|
import {Environment} from "../common/environment";
|
||||||
import {getMistralTools} from "./tool-mappers";
|
|
||||||
import {TelegramStreamMessage} from "./telegram-stream-message";
|
import {TelegramStreamMessage} from "./telegram-stream-message";
|
||||||
import {ToolRuntimeContext} from "./tools/runtime";
|
import {ToolRuntimeContext} from "./tools/runtime";
|
||||||
import {MistralChatMessage} from "./mistral-chat-message";
|
import {MistralChatMessage} from "./mistral-chat-message";
|
||||||
import {createMistralClient} from "./ai-runtime-target";
|
import {createMistralClient} from "./ai-runtime-target";
|
||||||
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
||||||
import {AiProvider} from "../model/ai-provider";
|
import {AiProvider} from "../model/ai-provider";
|
||||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
import {getProviderAdapter} from "./provider-adapters";
|
||||||
|
import {runToolRankStage} from "./tool-rank-stage";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
contentFromMistralDelta,
|
|
||||||
executeToolBatch,
|
executeToolBatch,
|
||||||
MAX_TOOL_ROUNDS,
|
MAX_TOOL_ROUNDS,
|
||||||
MistralDeltaLike,
|
|
||||||
MistralDocumentReference,
|
MistralDocumentReference,
|
||||||
mistralToolCalls,
|
|
||||||
normalizeMistralToolCalls,
|
|
||||||
roundStatus,
|
roundStatus,
|
||||||
RuntimeConfigSnapshot,
|
RuntimeConfigSnapshot,
|
||||||
StreamingToolCallAccumulator,
|
StreamingToolCallAccumulator,
|
||||||
@@ -23,8 +19,6 @@ import {
|
|||||||
ToolExecutionMemory
|
ToolExecutionMemory
|
||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {Message} from "typescript-telegram-bot-api";
|
import {Message} from "typescript-telegram-bot-api";
|
||||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
|
||||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
|
||||||
|
|
||||||
export async function runMistral(
|
export async function runMistral(
|
||||||
msg: Message,
|
msg: Message,
|
||||||
@@ -39,8 +33,9 @@ export async function runMistral(
|
|||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const runnerStartedAt = Date.now();
|
const runnerStartedAt = Date.now();
|
||||||
const mistralAi = createMistralClient(config.mistralChatTarget);
|
const mistralAi = createMistralClient(config.mistralChatTarget);
|
||||||
const toolRanker = new ToolRanker(config);
|
const adapter = getProviderAdapter(AiProvider.MISTRAL);
|
||||||
const availableTools = getMistralTools(msg.from?.id === Environment.CREATOR_ID);
|
const availableTools = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID});
|
||||||
|
const requestMessages = adapter.mapMessages([...messages]) as unknown as MistralChatMessage[];
|
||||||
aiLog("info", "mistral.run.start", {
|
aiLog("info", "mistral.run.start", {
|
||||||
stream,
|
stream,
|
||||||
target: aiLogProviderTarget(config.mistralChatTarget),
|
target: aiLogProviderTarget(config.mistralChatTarget),
|
||||||
@@ -50,49 +45,23 @@ export async function runMistral(
|
|||||||
});
|
});
|
||||||
|
|
||||||
const toolMemory: ToolExecutionMemory = new Map();
|
const toolMemory: ToolExecutionMemory = new Map();
|
||||||
|
try {
|
||||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||||
const roundStartedAt = Date.now();
|
const roundStartedAt = Date.now();
|
||||||
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
|
aiLog("debug", "mistral.round.start", {round, messages: messages.length, stream});
|
||||||
if (signal.aborted) throw new Error("Aborted");
|
if (signal.aborted) throw new Error("Aborted");
|
||||||
|
|
||||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
const rankResult = await runToolRankStage({
|
||||||
await streamMessage.flush();
|
|
||||||
const toolRankStartedAt = Date.now();
|
|
||||||
const toolRankStartedAtIso = new Date().toISOString();
|
|
||||||
const rankerSelection = await toolRanker.selectTools({
|
|
||||||
provider: AiProvider.MISTRAL,
|
provider: AiProvider.MISTRAL,
|
||||||
userQuery: latestUserTextFromMessages(messages),
|
model: config.mistralChatTarget.model,
|
||||||
|
round,
|
||||||
|
config,
|
||||||
availableTools,
|
availableTools,
|
||||||
round,
|
messages,
|
||||||
|
streamMessage,
|
||||||
signal,
|
signal,
|
||||||
})
|
|
||||||
.catch(async error => {
|
|
||||||
streamMessage.clearStatus();
|
|
||||||
await streamMessage.flush();
|
|
||||||
await storeToolRankAudit({
|
|
||||||
streamMessage,
|
|
||||||
provider: AiProvider.MISTRAL,
|
|
||||||
model: config.mistralChatTarget.model,
|
|
||||||
round,
|
|
||||||
startedAt: toolRankStartedAt,
|
|
||||||
startedAtIso: toolRankStartedAtIso,
|
|
||||||
error,
|
|
||||||
});
|
});
|
||||||
throw error;
|
const filteredTools = rankResult.filteredTools;
|
||||||
});
|
|
||||||
streamMessage.clearStatus();
|
|
||||||
await streamMessage.flush();
|
|
||||||
await storeToolRankAudit({
|
|
||||||
streamMessage,
|
|
||||||
provider: AiProvider.MISTRAL,
|
|
||||||
model: config.mistralChatTarget.model,
|
|
||||||
round,
|
|
||||||
startedAt: toolRankStartedAt,
|
|
||||||
startedAtIso: toolRankStartedAtIso,
|
|
||||||
selectedTools: rankerSelection.toolNames,
|
|
||||||
});
|
|
||||||
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
|
|
||||||
const requestTools = filteredTools.length ? filteredTools : undefined;
|
const requestTools = filteredTools.length ? filteredTools : undefined;
|
||||||
|
|
||||||
streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? "");
|
streamMessage.setStatus(roundStatus(round, firstRoundStatus) ?? "");
|
||||||
@@ -101,15 +70,15 @@ export async function runMistral(
|
|||||||
if (!stream) {
|
if (!stream) {
|
||||||
const request = {
|
const request = {
|
||||||
model: config.mistralChatTarget.model,
|
model: config.mistralChatTarget.model,
|
||||||
messages,
|
messages: requestMessages,
|
||||||
tools: requestTools,
|
tools: requestTools,
|
||||||
documents: documents
|
documents: documents
|
||||||
} as Parameters<typeof mistralAi.chat.complete>[0];
|
} as Parameters<typeof mistralAi.chat.complete>[0];
|
||||||
const response = await mistralAi.chat.complete(request, {signal});
|
const response = await adapter.callModel(request, () => mistralAi.chat.complete(request, {signal}));
|
||||||
const message = response.choices?.[0]?.message;
|
const message = response.choices?.[0]?.message;
|
||||||
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
|
const text = typeof message?.content === "string" ? message.content : JSON.stringify(message?.content ?? "");
|
||||||
streamMessage.append(text);
|
streamMessage.append(text);
|
||||||
const calls = normalizeMistralToolCalls(mistralToolCalls(message));
|
const calls = adapter.extractToolCalls(message);
|
||||||
aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", {
|
aiLog(calls.length ? "info" : "success", calls.length ? "mistral.tool_calls" : "mistral.run.done", {
|
||||||
round,
|
round,
|
||||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||||
@@ -125,25 +94,27 @@ export async function runMistral(
|
|||||||
function: {name: call.name, arguments: call.argumentsText},
|
function: {name: call.name, arguments: call.argumentsText},
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
requestMessages.push({
|
||||||
for (const [index, call] of calls.entries()) {
|
role: "assistant",
|
||||||
messages.push({
|
content: text,
|
||||||
role: "tool",
|
toolCalls: calls.map(call => ({
|
||||||
name: call.name,
|
id: call.id,
|
||||||
toolCallId: call.id,
|
function: {name: call.name, arguments: call.argumentsText},
|
||||||
content: toolResults[index] ?? "",
|
})),
|
||||||
});
|
});
|
||||||
}
|
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
||||||
|
adapter.appendToolResults(messages, calls, toolResults);
|
||||||
|
adapter.appendToolResults(requestMessages, calls, toolResults);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const request = {
|
const request = {
|
||||||
model: config.mistralChatTarget.model,
|
model: config.mistralChatTarget.model,
|
||||||
messages,
|
messages: requestMessages,
|
||||||
tools: requestTools,
|
tools: requestTools,
|
||||||
documents: documents
|
documents: documents
|
||||||
} as Parameters<typeof mistralAi.chat.stream>[0];
|
} as Parameters<typeof mistralAi.chat.stream>[0];
|
||||||
const streamResponse = await mistralAi.chat.stream(request, {signal});
|
const streamResponse = await adapter.callModel(request, () => mistralAi.chat.stream(request, {signal}));
|
||||||
aiLog("debug", "mistral.stream.open", {round});
|
aiLog("debug", "mistral.stream.open", {round});
|
||||||
let calls: ToolCallData[] = [];
|
let calls: ToolCallData[] = [];
|
||||||
const roundTextStart = streamMessage.getText().length;
|
const roundTextStart = streamMessage.getText().length;
|
||||||
@@ -154,11 +125,10 @@ export async function runMistral(
|
|||||||
|
|
||||||
const choice = event.data?.choices?.[0];
|
const choice = event.data?.choices?.[0];
|
||||||
const delta = choice?.delta;
|
const delta = choice?.delta;
|
||||||
const mistralDelta = delta as MistralDeltaLike;
|
const mistralDelta = delta;
|
||||||
|
streamMessage.append(adapter.extractTextDelta(mistralDelta));
|
||||||
|
|
||||||
streamMessage.append(contentFromMistralDelta(mistralDelta));
|
const rawDeltaCalls = adapter.extractStreamingToolCalls(mistralDelta);
|
||||||
|
|
||||||
const rawDeltaCalls = mistralToolCalls(mistralDelta);
|
|
||||||
if (rawDeltaCalls.length) {
|
if (rawDeltaCalls.length) {
|
||||||
calls = toolCallAccumulator.add(rawDeltaCalls);
|
calls = toolCallAccumulator.add(rawDeltaCalls);
|
||||||
streamMessage.setStatus(Environment.getUseToolText(calls));
|
streamMessage.setStatus(Environment.getUseToolText(calls));
|
||||||
@@ -178,14 +148,16 @@ export async function runMistral(
|
|||||||
content: roundText,
|
content: roundText,
|
||||||
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
|
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
|
||||||
});
|
});
|
||||||
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
requestMessages.push({
|
||||||
for (const [index, call] of calls.entries()) {
|
role: "assistant",
|
||||||
messages.push({
|
content: roundText,
|
||||||
role: "tool",
|
toolCalls: calls.map(c => ({id: c.id, function: {name: c.name, arguments: c.argumentsText}}))
|
||||||
name: call.name,
|
|
||||||
toolCallId: call.id,
|
|
||||||
content: toolResults[index] ?? "",
|
|
||||||
});
|
});
|
||||||
}
|
const toolResults = await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory);
|
||||||
|
adapter.appendToolResults(messages, calls, toolResults);
|
||||||
|
adapter.appendToolResults(requestMessages, calls, toolResults);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
await adapter.finalize().catch(() => undefined);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import {Environment} from "../common/environment";
|
|||||||
import type {BoundaryValue} from "../common/boundary-types";
|
import type {BoundaryValue} from "../common/boundary-types";
|
||||||
import {bot, notesDir} from "../index";
|
import {bot, notesDir} from "../index";
|
||||||
import {clamp, logError} from "../util/utils";
|
import {clamp, logError} from "../util/utils";
|
||||||
import {getOllamaTools} from "./tool-mappers";
|
|
||||||
import {TelegramStreamMessage} from "./telegram-stream-message";
|
import {TelegramStreamMessage} from "./telegram-stream-message";
|
||||||
import {ChatMessage} from "./chat-messages-types";
|
import {ChatMessage} from "./chat-messages-types";
|
||||||
import {ChatRequest, Tool} from "ollama";
|
import {ChatRequest, Tool} from "ollama";
|
||||||
@@ -14,10 +13,11 @@ import {enqueueTelegramApiCall} from "../util/telegram-api-queue";
|
|||||||
import {loadOllamaModel, unloadAllOllamaModels} from "./tools/utils";
|
import {loadOllamaModel, unloadAllOllamaModels} from "./tools/utils";
|
||||||
import {createOllamaClient} from "./ai-runtime-target";
|
import {createOllamaClient} from "./ai-runtime-target";
|
||||||
import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
import {aiLog, aiLogDuration, aiLogMessageIdentity, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
||||||
|
import {getProviderAdapter} from "./provider-adapters";
|
||||||
|
import {runToolRankStage} from "./tool-rank-stage";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
allToolSchemaNames,
|
allToolSchemaNames,
|
||||||
appendOllamaToolResults,
|
|
||||||
dedupeToolCalls,
|
dedupeToolCalls,
|
||||||
DEFAULT_OLLAMA_CONTEXT_SIZE,
|
DEFAULT_OLLAMA_CONTEXT_SIZE,
|
||||||
executeToolBatch,
|
executeToolBatch,
|
||||||
@@ -26,8 +26,6 @@ import {
|
|||||||
MAX_OLLAMA_CONTEXT_SIZE,
|
MAX_OLLAMA_CONTEXT_SIZE,
|
||||||
MAX_TOOL_ROUNDS,
|
MAX_TOOL_ROUNDS,
|
||||||
MIN_OLLAMA_CONTEXT_SIZE,
|
MIN_OLLAMA_CONTEXT_SIZE,
|
||||||
normalizeOllamaToolCalls,
|
|
||||||
OllamaToolCallLike,
|
|
||||||
roundStatus,
|
roundStatus,
|
||||||
RuntimeConfigSnapshot,
|
RuntimeConfigSnapshot,
|
||||||
safeJsonParseObject,
|
safeJsonParseObject,
|
||||||
@@ -35,14 +33,11 @@ import {
|
|||||||
ToolCallData,
|
ToolCallData,
|
||||||
ToolExecutionMemory
|
ToolExecutionMemory
|
||||||
} from "./unified-ai-runner.shared";
|
} from "./unified-ai-runner.shared";
|
||||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
|
||||||
import {getToolPrompts} from "./tools/registry";
|
import {getToolPrompts} from "./tools/registry";
|
||||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
|
||||||
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
import {GetNoteFileResult, GetNoteFileResultSchema} from "./tools/notes";
|
||||||
import {getModelCapabilities} from "./provider-model-runtime";
|
import {getModelCapabilities} from "./provider-model-runtime";
|
||||||
import {AiProvider} from "../model/ai-provider";
|
import {AiProvider} from "../model/ai-provider";
|
||||||
import {Message} from "typescript-telegram-bot-api";
|
import {Message} from "typescript-telegram-bot-api";
|
||||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
|
||||||
|
|
||||||
export async function runOllama(
|
export async function runOllama(
|
||||||
msg: Message,
|
msg: Message,
|
||||||
@@ -157,6 +152,7 @@ export async function runOllama(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const toolMemory: ToolExecutionMemory = new Map();
|
const toolMemory: ToolExecutionMemory = new Map();
|
||||||
|
const adapter = getProviderAdapter(AiProvider.OLLAMA);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||||
@@ -183,7 +179,7 @@ export async function runOllama(
|
|||||||
|
|
||||||
let activeToolNames: string[] = [];
|
let activeToolNames: string[] = [];
|
||||||
if ((await getModelCapabilities(AiProvider.OLLAMA, model, "tools"))?.tools?.supported) {
|
if ((await getModelCapabilities(AiProvider.OLLAMA, model, "tools"))?.tools?.supported) {
|
||||||
const availableOllamaTools: Tool[] = getOllamaTools(msg.from?.id === Environment.CREATOR_ID) as Tool[];
|
const availableOllamaTools: Tool[] = adapter.rankTools(config, {forCreator: msg.from?.id === Environment.CREATOR_ID}) as Tool[];
|
||||||
|
|
||||||
aiLog("debug", "ollama.tools.available", {
|
aiLog("debug", "ollama.tools.available", {
|
||||||
round,
|
round,
|
||||||
@@ -191,44 +187,18 @@ export async function runOllama(
|
|||||||
rankerEnabled: !!config.ollamaToolRankerTarget,
|
rankerEnabled: !!config.ollamaToolRankerTarget,
|
||||||
});
|
});
|
||||||
|
|
||||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
const rankResult = await runToolRankStage({
|
||||||
await streamMessage.flush();
|
|
||||||
const toolRankStartedAt = Date.now();
|
|
||||||
const toolRankStartedAtIso = new Date().toISOString();
|
|
||||||
const rankerSelection = await new ToolRanker(config).selectTools({
|
|
||||||
provider: AiProvider.OLLAMA,
|
provider: AiProvider.OLLAMA,
|
||||||
userQuery: latestUserTextFromMessages(messages),
|
model,
|
||||||
|
round,
|
||||||
|
config,
|
||||||
availableTools: availableOllamaTools,
|
availableTools: availableOllamaTools,
|
||||||
round,
|
messages,
|
||||||
|
streamMessage,
|
||||||
signal,
|
signal,
|
||||||
})
|
|
||||||
.catch(async error => {
|
|
||||||
streamMessage.clearStatus();
|
|
||||||
await streamMessage.flush();
|
|
||||||
await storeToolRankAudit({
|
|
||||||
streamMessage,
|
|
||||||
provider: AiProvider.OLLAMA,
|
|
||||||
model,
|
|
||||||
round,
|
|
||||||
startedAt: toolRankStartedAt,
|
|
||||||
startedAtIso: toolRankStartedAtIso,
|
|
||||||
error,
|
|
||||||
});
|
|
||||||
throw error;
|
|
||||||
});
|
|
||||||
streamMessage.clearStatus();
|
|
||||||
await streamMessage.flush();
|
|
||||||
await storeToolRankAudit({
|
|
||||||
streamMessage,
|
|
||||||
provider: AiProvider.OLLAMA,
|
|
||||||
model,
|
|
||||||
round,
|
|
||||||
startedAt: toolRankStartedAt,
|
|
||||||
startedAtIso: toolRankStartedAtIso,
|
|
||||||
selectedTools: rankerSelection.toolNames,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const filteredTools = [...new Set(filterRankedTools(availableOllamaTools, rankerSelection.toolNames))];
|
const filteredTools = [...new Set(rankResult.filteredTools as Tool[])];
|
||||||
activeToolNames = filteredTools.map(t => t.function.name ?? "");
|
activeToolNames = filteredTools.map(t => t.function.name ?? "");
|
||||||
if (filteredTools.length > 0) {
|
if (filteredTools.length > 0) {
|
||||||
request.tools = [...filteredTools];
|
request.tools = [...filteredTools];
|
||||||
@@ -256,24 +226,21 @@ export async function runOllama(
|
|||||||
round,
|
round,
|
||||||
tools: activeToolNames,
|
tools: activeToolNames,
|
||||||
count: activeToolNames.length,
|
count: activeToolNames.length,
|
||||||
usedRanker: rankerSelection.usedRanker,
|
usedRanker: rankResult.usedRanker,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
const response = await ollama.chat({
|
const response = await adapter.callModel(request, () => ollama.chat({
|
||||||
...request,
|
...request,
|
||||||
stream: false
|
stream: false
|
||||||
});
|
}));
|
||||||
|
|
||||||
const message = response.message;
|
const message = response.message;
|
||||||
const rawContent = message?.content ?? "";
|
const rawContent = message?.content ?? "";
|
||||||
|
|
||||||
const nativeCalls = dedupeToolCalls(
|
const nativeCalls = dedupeToolCalls(
|
||||||
normalizeOllamaToolCalls(
|
adapter.extractToolCalls(message),
|
||||||
message?.tool_calls as readonly OllamaToolCallLike[] | undefined,
|
|
||||||
round,
|
|
||||||
),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const responseText = rawContent;
|
const responseText = rawContent;
|
||||||
@@ -301,7 +268,7 @@ export async function runOllama(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const calls = nativeCalls;
|
const calls = adapter.extractToolCalls(message).length ? adapter.extractToolCalls(message) : nativeCalls;
|
||||||
|
|
||||||
aiLog("info", "ollama.tool_calls", {
|
aiLog("info", "ollama.tool_calls", {
|
||||||
round,
|
round,
|
||||||
@@ -319,11 +286,7 @@ export async function runOllama(
|
|||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
|
|
||||||
appendOllamaToolResults(
|
adapter.appendToolResults(messages, calls, await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory));
|
||||||
messages,
|
|
||||||
calls,
|
|
||||||
await executeToolBatch(msg.from?.id, calls, streamMessage, toolContext, toolMemory),
|
|
||||||
);
|
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -332,10 +295,10 @@ export async function runOllama(
|
|||||||
round,
|
round,
|
||||||
messageCount: request.messages?.length ?? 0,
|
messageCount: request.messages?.length ?? 0,
|
||||||
});
|
});
|
||||||
const response = await ollama.chat({
|
const response = await adapter.callModel(request, () => ollama.chat({
|
||||||
...request,
|
...request,
|
||||||
stream: true
|
stream: true
|
||||||
});
|
}));
|
||||||
|
|
||||||
aiLog("debug", "ollama.stream.open", {round});
|
aiLog("debug", "ollama.stream.open", {round});
|
||||||
const calls: ToolCallData[] = [];
|
const calls: ToolCallData[] = [];
|
||||||
@@ -354,10 +317,7 @@ export async function runOllama(
|
|||||||
|
|
||||||
const localToolCalls: ToolCallData[] = [];
|
const localToolCalls: ToolCallData[] = [];
|
||||||
|
|
||||||
localToolCalls.push(...normalizeOllamaToolCalls(
|
localToolCalls.push(...adapter.extractStreamingToolCalls(chunk.message));
|
||||||
chunk.message.tool_calls as readonly OllamaToolCallLike[] | undefined,
|
|
||||||
round,
|
|
||||||
));
|
|
||||||
|
|
||||||
const newStatus = roundStatus(round, firstRoundStatus, chunk.message.content, localToolCalls, !!chunk.message.thinking);
|
const newStatus = roundStatus(round, firstRoundStatus, chunk.message.content, localToolCalls, !!chunk.message.thinking);
|
||||||
const previousStatus = streamMessage.getStatus();
|
const previousStatus = streamMessage.getStatus();
|
||||||
@@ -377,13 +337,10 @@ export async function runOllama(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!(chunk.message?.thinking && streamMessage.getStatus() !== Environment.reasoningText)) {
|
if (!(chunk.message?.thinking && streamMessage.getStatus() !== Environment.reasoningText)) {
|
||||||
streamMessage.append(chunk.message?.content ?? "");
|
streamMessage.append(adapter.extractTextDelta(chunk));
|
||||||
}
|
}
|
||||||
|
|
||||||
calls.push(...normalizeOllamaToolCalls(
|
calls.push(...adapter.extractStreamingToolCalls(chunk.message));
|
||||||
chunk.message?.tool_calls as readonly OllamaToolCallLike[] | undefined,
|
|
||||||
round,
|
|
||||||
));
|
|
||||||
|
|
||||||
if (chunk.done) {
|
if (chunk.done) {
|
||||||
aiLog("debug", "ollama.stream.done", {
|
aiLog("debug", "ollama.stream.done", {
|
||||||
@@ -471,9 +428,10 @@ export async function runOllama(
|
|||||||
}).catch(logError);
|
}).catch(logError);
|
||||||
}
|
}
|
||||||
|
|
||||||
appendOllamaToolResults(messages, calls, toolResults);
|
adapter.appendToolResults(messages, calls, toolResults);
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (interval) clearInterval(interval);
|
if (interval) clearInterval(interval);
|
||||||
|
await adapter.finalize().catch(() => undefined);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,11 +17,9 @@ import {
|
|||||||
AsyncIterableStream,
|
AsyncIterableStream,
|
||||||
buildSystemInstruction,
|
buildSystemInstruction,
|
||||||
collectOpenAiResponseCodeInterpreterCalls,
|
collectOpenAiResponseCodeInterpreterCalls,
|
||||||
collectOpenAiResponseFunctionCalls,
|
|
||||||
collectOpenAiResponseImages,
|
collectOpenAiResponseImages,
|
||||||
collectOpenAiResponseText,
|
collectOpenAiResponseText,
|
||||||
executeToolBatch,
|
executeToolBatch,
|
||||||
getOpenAIResponsesToolsWithImage,
|
|
||||||
MAX_TOOL_ROUNDS,
|
MAX_TOOL_ROUNDS,
|
||||||
OPENAI_IMAGE_PARTIALS,
|
OPENAI_IMAGE_PARTIALS,
|
||||||
openAiResponseItemCallId,
|
openAiResponseItemCallId,
|
||||||
@@ -42,10 +40,9 @@ import {logError} from "../util/utils";
|
|||||||
import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files";
|
import {SendFileAttachmentResult, SendFileAttachmentResultSchema} from "./tools/files";
|
||||||
import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings";
|
import {DEFAULT_AI_RESPONSE_LANGUAGE} from "../common/user-ai-settings";
|
||||||
import {AiDownloadedFile} from "./telegram-attachments";
|
import {AiDownloadedFile} from "./telegram-attachments";
|
||||||
import {ToolRanker} from "./unified-ai-runner.tool-ranker";
|
|
||||||
import {AiProvider} from "../model/ai-provider";
|
import {AiProvider} from "../model/ai-provider";
|
||||||
import {filterRankedTools, latestUserTextFromMessages} from "./tool-ranker-pipeline";
|
import {getProviderAdapter} from "./provider-adapters";
|
||||||
import {storeToolRankAudit} from "./tool-rank-audit";
|
import {runToolRankStage} from "./tool-rank-stage";
|
||||||
|
|
||||||
export async function runOpenAi(
|
export async function runOpenAi(
|
||||||
msg: Message,
|
msg: Message,
|
||||||
@@ -60,16 +57,15 @@ export async function runOpenAi(
|
|||||||
documentRag?: OpenAiDocumentRagContext,
|
documentRag?: OpenAiDocumentRagContext,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const runnerStartedAt = Date.now();
|
const runnerStartedAt = Date.now();
|
||||||
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = [...messages] as Array<ResponseInputItem | OpenAiResponseOutputItem>;
|
|
||||||
const openAi = createOpenAiClient(config.openAiChatTarget);
|
const openAi = createOpenAiClient(config.openAiChatTarget);
|
||||||
const ownsDocumentRag = !documentRag;
|
const ownsDocumentRag = !documentRag;
|
||||||
const preparedDocumentRag = documentRag ?? await prepareOpenAiDocumentRag(openAi, downloads.filter(download => download.kind === "document"));
|
const preparedDocumentRag = documentRag ?? await prepareOpenAiDocumentRag(openAi, downloads.filter(download => download.kind === "document"));
|
||||||
const toolRanker = new ToolRanker(config);
|
const adapter = getProviderAdapter(AiProvider.OPENAI);
|
||||||
const availableTools = getOpenAIResponsesToolsWithImage(
|
let responseInput: Array<ResponseInputItem | OpenAiResponseOutputItem> = adapter.mapMessages(messages) as unknown as Array<ResponseInputItem | OpenAiResponseOutputItem>;
|
||||||
config,
|
const availableTools = adapter.rankTools(config, {
|
||||||
msg.from?.id === Environment.CREATOR_ID,
|
forCreator: msg.from?.id === Environment.CREATOR_ID,
|
||||||
preparedDocumentRag?.vectorStoreIds ?? [],
|
vectorStoreIds: preparedDocumentRag?.vectorStoreIds ?? [],
|
||||||
);
|
});
|
||||||
|
|
||||||
const systemPrompt = buildSystemInstruction(
|
const systemPrompt = buildSystemInstruction(
|
||||||
config,
|
config,
|
||||||
@@ -93,43 +89,17 @@ export async function runOpenAi(
|
|||||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
||||||
const roundStartedAt = Date.now();
|
const roundStartedAt = Date.now();
|
||||||
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
aiLog("debug", "openai.round.start", {round, inputItems: responseInput.length, stream});
|
||||||
streamMessage.setStatus(Environment.getSelectingToolsText());
|
const rankResult = await runToolRankStage({
|
||||||
await streamMessage.flush();
|
|
||||||
const toolRankStartedAt = Date.now();
|
|
||||||
const toolRankStartedAtIso = new Date().toISOString();
|
|
||||||
const rankerSelection = await toolRanker.selectTools({
|
|
||||||
provider: AiProvider.OPENAI,
|
provider: AiProvider.OPENAI,
|
||||||
userQuery: latestUserTextFromMessages(messages),
|
model: config.openAiChatTarget.model,
|
||||||
|
round,
|
||||||
|
config,
|
||||||
availableTools,
|
availableTools,
|
||||||
round,
|
messages,
|
||||||
|
streamMessage,
|
||||||
signal,
|
signal,
|
||||||
})
|
|
||||||
.catch(async error => {
|
|
||||||
streamMessage.clearStatus();
|
|
||||||
await streamMessage.flush();
|
|
||||||
await storeToolRankAudit({
|
|
||||||
streamMessage,
|
|
||||||
provider: AiProvider.OPENAI,
|
|
||||||
model: config.openAiChatTarget.model,
|
|
||||||
round,
|
|
||||||
startedAt: toolRankStartedAt,
|
|
||||||
startedAtIso: toolRankStartedAtIso,
|
|
||||||
error,
|
|
||||||
});
|
});
|
||||||
throw error;
|
const filteredTools = rankResult.filteredTools;
|
||||||
});
|
|
||||||
streamMessage.clearStatus();
|
|
||||||
await streamMessage.flush();
|
|
||||||
await storeToolRankAudit({
|
|
||||||
streamMessage,
|
|
||||||
provider: AiProvider.OPENAI,
|
|
||||||
model: config.openAiChatTarget.model,
|
|
||||||
round,
|
|
||||||
startedAt: toolRankStartedAt,
|
|
||||||
startedAtIso: toolRankStartedAtIso,
|
|
||||||
selectedTools: rankerSelection.toolNames,
|
|
||||||
});
|
|
||||||
const filteredTools = filterRankedTools(availableTools, rankerSelection.toolNames);
|
|
||||||
const requestTools = preparedDocumentRag?.vectorStoreIds.length
|
const requestTools = preparedDocumentRag?.vectorStoreIds.length
|
||||||
? (() => {
|
? (() => {
|
||||||
const tools = [...filteredTools];
|
const tools = [...filteredTools];
|
||||||
@@ -151,7 +121,7 @@ export async function runOpenAi(
|
|||||||
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
tools: requestTools as ResponseCreateParamsNonStreaming["tools"],
|
||||||
instructions: systemPrompt,
|
instructions: systemPrompt,
|
||||||
};
|
};
|
||||||
const response = await openAi.responses.create(request, {signal}) as OpenAiResponseLike;
|
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as OpenAiResponseLike;
|
||||||
|
|
||||||
const responseText = collectOpenAiResponseText(response);
|
const responseText = collectOpenAiResponseText(response);
|
||||||
streamMessage.append(responseText);
|
streamMessage.append(responseText);
|
||||||
@@ -188,12 +158,12 @@ export async function runOpenAi(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const calls = collectOpenAiResponseFunctionCalls(response);
|
const calls = adapter.extractToolCalls(response);
|
||||||
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
||||||
round,
|
round,
|
||||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||||
calls: calls.map(call => ({
|
calls: calls.map(call => ({
|
||||||
id: call.callId,
|
id: call.id,
|
||||||
name: call.name,
|
name: call.name,
|
||||||
arguments: safeJsonParseObject(call.argumentsText)
|
arguments: safeJsonParseObject(call.argumentsText)
|
||||||
})),
|
})),
|
||||||
@@ -201,16 +171,13 @@ export async function runOpenAi(
|
|||||||
if (!calls.length) return;
|
if (!calls.length) return;
|
||||||
|
|
||||||
const toolCalls = calls.map(call => ({
|
const toolCalls = calls.map(call => ({
|
||||||
id: call.callId,
|
id: call.id,
|
||||||
name: call.name,
|
name: call.name,
|
||||||
argumentsText: call.argumentsText,
|
argumentsText: call.argumentsText,
|
||||||
}));
|
}));
|
||||||
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
||||||
const toolOutputs = calls.map((call, index) => ({
|
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
|
||||||
type: "function_call_output" as const,
|
adapter.appendToolResults(toolOutputs, calls, toolResults);
|
||||||
call_id: call.callId,
|
|
||||||
output: toolResults[index] ?? "",
|
|
||||||
}));
|
|
||||||
|
|
||||||
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
||||||
if (uploadFilesResult.found) {
|
if (uploadFilesResult.found) {
|
||||||
@@ -243,7 +210,7 @@ export async function runOpenAi(
|
|||||||
parallel_tool_calls: true,
|
parallel_tool_calls: true,
|
||||||
instructions: systemPrompt
|
instructions: systemPrompt
|
||||||
};
|
};
|
||||||
const response = await openAi.responses.create(request, {signal}) as AsyncIterableStream<ResponseStreamEvent>;
|
const response = await adapter.callModel(request, () => openAi.responses.create(request, {signal})) as AsyncIterableStream<ResponseStreamEvent>;
|
||||||
|
|
||||||
aiLog("debug", "openai.stream.open", {round});
|
aiLog("debug", "openai.stream.open", {round});
|
||||||
|
|
||||||
@@ -253,7 +220,7 @@ export async function runOpenAi(
|
|||||||
|
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case "response.output_text.delta":
|
case "response.output_text.delta":
|
||||||
streamMessage.append(event.delta ?? "");
|
streamMessage.append(adapter.extractTextDelta(event));
|
||||||
break;
|
break;
|
||||||
case "response.image_generation_call.in_progress":
|
case "response.image_generation_call.in_progress":
|
||||||
streamMessage.setStatus(Environment.startingImageGenText);
|
streamMessage.setStatus(Environment.startingImageGenText);
|
||||||
@@ -301,14 +268,11 @@ export async function runOpenAi(
|
|||||||
case "response.code_interpreter_call_code.done":
|
case "response.code_interpreter_call_code.done":
|
||||||
break;
|
break;
|
||||||
case "response.output_item.added":
|
case "response.output_item.added":
|
||||||
if (event.item.type === "function_call" && event.item.name) {
|
{
|
||||||
const item = event.item as OpenAiResponseOutputItem & { id?: string };
|
const streamedCalls = adapter.extractStreamingToolCalls(event);
|
||||||
localToolCalls.push({
|
if (streamedCalls.length) {
|
||||||
id: openAiResponseItemCallId(item),
|
localToolCalls.push(...streamedCalls);
|
||||||
name: item.name ?? "",
|
}
|
||||||
argumentsText: item.arguments ?? "{}",
|
|
||||||
});
|
|
||||||
|
|
||||||
aiLog("info", "openai.stream.tool_call.added", {
|
aiLog("info", "openai.stream.tool_call.added", {
|
||||||
round,
|
round,
|
||||||
toolCalls: localToolCalls.map(aiLogToolCall)
|
toolCalls: localToolCalls.map(aiLogToolCall)
|
||||||
@@ -383,12 +347,12 @@ export async function runOpenAi(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const calls = collectOpenAiResponseFunctionCalls(completedResponse);
|
const calls = adapter.extractToolCalls(completedResponse);
|
||||||
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
aiLog(calls.length ? "info" : "success", calls.length ? "openai.tool_calls" : "openai.run.done", {
|
||||||
round,
|
round,
|
||||||
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
duration: calls.length ? aiLogDuration(roundStartedAt) : aiLogDuration(runnerStartedAt),
|
||||||
calls: calls.map(call => ({
|
calls: calls.map(call => ({
|
||||||
id: call.callId,
|
id: call.id,
|
||||||
name: call.name,
|
name: call.name,
|
||||||
arguments: safeJsonParseObject(call.argumentsText)
|
arguments: safeJsonParseObject(call.argumentsText)
|
||||||
})),
|
})),
|
||||||
@@ -396,16 +360,13 @@ export async function runOpenAi(
|
|||||||
if (!calls.length) return;
|
if (!calls.length) return;
|
||||||
|
|
||||||
const toolCalls = calls.map(call => ({
|
const toolCalls = calls.map(call => ({
|
||||||
id: call.callId,
|
id: call.id,
|
||||||
name: call.name,
|
name: call.name,
|
||||||
argumentsText: call.argumentsText,
|
argumentsText: call.argumentsText,
|
||||||
}));
|
}));
|
||||||
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
const toolResults = await executeToolBatch(msg.from?.id, toolCalls, streamMessage, toolContext, toolMemory);
|
||||||
const toolOutputs = calls.map((call, index) => ({
|
const toolOutputs: Array<{type: "function_call_output"; call_id: string; output: string}> = [];
|
||||||
type: "function_call_output",
|
adapter.appendToolResults(toolOutputs, calls, toolResults);
|
||||||
call_id: call.callId,
|
|
||||||
output: toolResults[index] ?? "",
|
|
||||||
}));
|
|
||||||
|
|
||||||
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
const uploadFilesResult = await tryToUploadFiles(msg, toolResults);
|
||||||
if (uploadFilesResult.found) {
|
if (uploadFilesResult.found) {
|
||||||
@@ -431,6 +392,7 @@ export async function runOpenAi(
|
|||||||
if (ownsDocumentRag) {
|
if (ownsDocumentRag) {
|
||||||
await preparedDocumentRag?.cleanup().catch(logError);
|
await preparedDocumentRag?.cleanup().catch(logError);
|
||||||
}
|
}
|
||||||
|
await adapter.finalize().catch(logError);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,40 +2,40 @@ import {Message} from "typescript-telegram-bot-api";
|
|||||||
import * as fs from "node:fs";
|
import * as fs from "node:fs";
|
||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
import type {BoundaryValue} from "../common/boundary-types";
|
import type {BoundaryValue} from "../common/boundary-types";
|
||||||
import {AiProvider} from "../model/ai-provider";
|
import {AiProvider} from "../model/ai-provider.js";
|
||||||
import {ToolRankerFallbackPolicy} from "../common/policies";
|
import {ToolRankerFallbackPolicy} from "../common/policies.js";
|
||||||
import {Environment} from "../common/environment";
|
import {Environment} from "../common/environment.js";
|
||||||
import {photoGenDir} from "../index";
|
import {photoGenDir} from "../index.js";
|
||||||
import {delay, logError, replyToMessage} from "../util/utils";
|
import {delay, logError, replyToMessage} from "../util/utils.js";
|
||||||
import {MessageStore} from "../common/message-store";
|
import {MessageStore} from "../common/message-store.js";
|
||||||
import type {OpenAiResponseTool} from "./tool-mappers";
|
import type {OpenAiResponseTool} from "./tool-mappers.js";
|
||||||
import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers";
|
import {AiProviderName, getOpenAICodeInterpreterTool, getOpenAIResponsesTools} from "./tool-mappers.js";
|
||||||
import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message";
|
import {TelegramArtifactFile, TelegramStreamMessage} from "./telegram-stream-message.js";
|
||||||
import {AiDownloadedFile} from "./telegram-attachments";
|
import {AiDownloadedFile} from "./telegram-attachments.js";
|
||||||
import {getRuntimeCapabilities} from "./provider-model-runtime";
|
import {getRuntimeCapabilities} from "./provider-model-runtime.js";
|
||||||
import {StoredAttachment} from "../model/stored-attachment";
|
import {StoredAttachment} from "../model/stored-attachment.js";
|
||||||
import {AiChatMessage, ChatMessage} from "./chat-messages-types";
|
import {AiChatMessage, ChatMessage} from "./chat-messages-types.js";
|
||||||
import {ListResponse, Ollama} from "ollama";
|
import {ListResponse, Ollama} from "ollama";
|
||||||
import {executeToolCall, ToolRuntimeContext} from "./tools/runtime";
|
import {executeToolCall, ToolRuntimeContext} from "./tools/runtime.js";
|
||||||
import {MessageImagePart, MessagePart} from "../common/message-part";
|
import {MessageImagePart, MessagePart} from "../common/message-part.js";
|
||||||
import {KeyedAsyncLock} from "../util/async-lock";
|
import {KeyedAsyncLock} from "../util/async-lock.js";
|
||||||
import {type AiRequestQueueTarget} from "./provider-request-queue";
|
import {type AiRequestQueueTarget} from "./provider-request-queue.js";
|
||||||
import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator";
|
import {PYTHON_INTERPRETER_TOOL_NAME, pythonInterpreterToolPrompt} from "./tools/python-interpretator.js";
|
||||||
import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings";
|
import {getResponseLanguageInstruction, UserAiResponseLanguage, UserAiVoiceMode} from "../common/user-ai-settings.js";
|
||||||
import {
|
import {
|
||||||
isTranscribableAudioDownload,
|
isTranscribableAudioDownload,
|
||||||
resolveSpeechToTextProviderForUser,
|
resolveSpeechToTextProviderForUser,
|
||||||
transcribeSpeechDownloads
|
transcribeSpeechDownloads
|
||||||
} from "./speech-to-text";
|
} from "./speech-to-text.js";
|
||||||
import type {ChatCompletionMessageParam} from "openai/resources/chat/completions";
|
import type {ChatCompletionMessageParam} from "openai/resources/chat/completions";
|
||||||
import {MistralChatMessage} from "./mistral-chat-message";
|
import {MistralChatMessage} from "./mistral-chat-message.js";
|
||||||
import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer";
|
import {prepareTelegramMarkdownV2} from "../util/markdown-v2-renderer.js";
|
||||||
import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target";
|
import {AiRuntimeTarget, createMistralClient, resolveAiRuntimeTarget} from "./ai-runtime-target.js";
|
||||||
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger";
|
import {aiLog, aiLogDuration, aiLogProviderTarget, aiLogToolCall} from "../logging/ai-logger.js";
|
||||||
import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline";
|
import {buildConversationSnapshot, serializeConversationSnapshot} from "./conversation-pipeline.js";
|
||||||
import type {ResponseInputMessageContentList} from "openai/resources/responses/responses";
|
import type {ResponseInputMessageContentList} from "openai/resources/responses/responses";
|
||||||
import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store";
|
import {persistToolResultArtifactAttachment} from "./tool-result-artifact-store.js";
|
||||||
import {filterUserVisibleStoredAttachments} from "../common/stored-attachment-utils";
|
import {filterUserVisibleStoredAttachments} from "../common/attachment-visibility.js";
|
||||||
|
|
||||||
export type {Message} from "typescript-telegram-bot-api";
|
export type {Message} from "typescript-telegram-bot-api";
|
||||||
export type {AiRuntimeTarget} from "./ai-runtime-target";
|
export type {AiRuntimeTarget} from "./ai-runtime-target";
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
import type {StoredAttachment} from "../model/stored-attachment";
|
||||||
|
|
||||||
|
export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] {
|
||||||
|
return attachments.filter(attachment => attachment.scope !== "internal_artifact");
|
||||||
|
}
|
||||||
+10
-10
@@ -3,18 +3,18 @@ import os from "node:os";
|
|||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
import {parse as parseDotEnv} from "dotenv";
|
import {parse as parseDotEnv} from "dotenv";
|
||||||
import {z} from "zod";
|
import {z} from "zod";
|
||||||
import {appLogger} from "../logging/logger";
|
import {appLogger} from "../logging/logger.js";
|
||||||
import type {BoundaryValue, ErrorLike} from "./boundary-types";
|
import type {BoundaryValue, ErrorLike} from "./boundary-types";
|
||||||
|
|
||||||
import {saveData} from "../db/database";
|
import {saveData} from "../db/database.js";
|
||||||
import {Answers} from "../model/answers";
|
import {Answers} from "../model/answers.js";
|
||||||
import {ifTrue} from "../util/utils";
|
import {ifTrue} from "../util/utils.js";
|
||||||
import {AiProvider} from "../model/ai-provider";
|
import {AiProvider} from "../model/ai-provider.js";
|
||||||
import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies";
|
import {ImageHandleFallbackPolicy, ImageHandlePolicy, RateLimitFallbackPolicy} from "./policies.js";
|
||||||
import {ToolRankerFallbackPolicy} from "./policies";
|
import {ToolRankerFallbackPolicy} from "./policies.js";
|
||||||
import type {ToolCallData} from "../ai/unified-ai-runner";
|
import type {ToolCallData} from "../ai/unified-ai-runner.js";
|
||||||
import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator";
|
import {PYTHON_INTERPRETER_TOOL_NAME} from "../ai/tools/python-interpretator.js";
|
||||||
import {Localization, type LocalizationParams} from "./localization";
|
import {Localization, type LocalizationParams} from "./localization.js";
|
||||||
|
|
||||||
type EnvRecord = Record<string, string>;
|
type EnvRecord = Record<string, string>;
|
||||||
type StringEnumLike = Record<string, string>;
|
type StringEnumLike = Record<string, string>;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
import {Environment} from "./environment";
|
import {Environment} from "./environment";
|
||||||
import {StoredAttachment} from "../model/stored-attachment";
|
import {StoredAttachment} from "../model/stored-attachment";
|
||||||
|
export {filterUserVisibleStoredAttachments} from "./attachment-visibility";
|
||||||
|
|
||||||
export function photoCachePathForUniqueId(uniqueId: string): string {
|
export function photoCachePathForUniqueId(uniqueId: string): string {
|
||||||
return path.join(Environment.DATA_PATH, "cache", "photo", `${uniqueId}.jpg`);
|
return path.join(Environment.DATA_PATH, "cache", "photo", `${uniqueId}.jpg`);
|
||||||
@@ -44,7 +45,3 @@ export function uniqueStoredAttachments(attachments: StoredAttachment[]): Stored
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function filterUserVisibleStoredAttachments(attachments: StoredAttachment[]): StoredAttachment[] {
|
|
||||||
return attachments.filter(attachment => attachment.scope !== "internal_artifact");
|
|
||||||
}
|
|
||||||
|
|||||||
+4
-4
@@ -1,9 +1,9 @@
|
|||||||
import * as fs from "fs";
|
import * as fs from "fs";
|
||||||
import {Environment} from "../common/environment";
|
import {Environment} from "../common/environment.js";
|
||||||
import {logError} from "../util/utils";
|
import {logError} from "../util/utils.js";
|
||||||
import {Answers} from "../model/answers";
|
import {Answers} from "../model/answers.js";
|
||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
import {KeyedAsyncLock} from "../util/async-lock";
|
import {KeyedAsyncLock} from "../util/async-lock.js";
|
||||||
|
|
||||||
type DataJsonFile = {
|
type DataJsonFile = {
|
||||||
admins: number[]
|
admins: number[]
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
import test from "node:test";
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
|
||||||
|
const {
|
||||||
|
extractOpenAiToolCalls,
|
||||||
|
extractOpenAiStreamingToolCalls,
|
||||||
|
extractOpenAiTextDelta,
|
||||||
|
extractMistralToolCalls,
|
||||||
|
extractMistralTextDelta,
|
||||||
|
extractOllamaToolCalls,
|
||||||
|
extractOllamaTextDelta,
|
||||||
|
} = await import("../dist/ai/provider-adapter-contract.js");
|
||||||
|
|
||||||
|
test("openai contract extracts text delta and function calls", () => {
|
||||||
|
assert.equal(extractOpenAiTextDelta({type: "response.output_text.delta", delta: "hello"}), "hello");
|
||||||
|
|
||||||
|
const calls = extractOpenAiToolCalls({
|
||||||
|
output: [{
|
||||||
|
type: "function_call",
|
||||||
|
call_id: "call-1",
|
||||||
|
name: "read_file",
|
||||||
|
arguments: "{\"path\":\"src/index.ts\"}",
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(calls.length, 1);
|
||||||
|
assert.equal(calls[0].id, "call-1");
|
||||||
|
assert.equal(calls[0].name, "read_file");
|
||||||
|
|
||||||
|
const streamed = extractOpenAiStreamingToolCalls({
|
||||||
|
type: "response.output_item.added",
|
||||||
|
item: {
|
||||||
|
type: "function_call",
|
||||||
|
id: "call-2",
|
||||||
|
name: "search_files",
|
||||||
|
arguments: "{\"query\":\"sendMessage\"}",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(streamed.length, 1);
|
||||||
|
assert.equal(streamed[0].id, "call-2");
|
||||||
|
assert.equal(streamed[0].name, "search_files");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("mistral contract extracts content and tool calls", () => {
|
||||||
|
assert.equal(extractMistralTextDelta({
|
||||||
|
content: [{text: "hello"}, {text: " world"}],
|
||||||
|
}), "hello world");
|
||||||
|
|
||||||
|
const calls = extractMistralToolCalls({
|
||||||
|
toolCalls: [{
|
||||||
|
id: "m-1",
|
||||||
|
function: {
|
||||||
|
name: "get_weather",
|
||||||
|
arguments: {location: "Moscow"},
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(calls.length, 1);
|
||||||
|
assert.equal(calls[0].id, "m-1");
|
||||||
|
assert.equal(calls[0].name, "get_weather");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("ollama contract extracts content and tool calls", () => {
|
||||||
|
assert.equal(extractOllamaTextDelta({
|
||||||
|
message: {content: "hello from ollama"},
|
||||||
|
}), "hello from ollama");
|
||||||
|
|
||||||
|
const calls = extractOllamaToolCalls({
|
||||||
|
tool_calls: [{
|
||||||
|
id: "o-1",
|
||||||
|
function: {
|
||||||
|
name: "web_search",
|
||||||
|
arguments: {query: "openai docs"},
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
assert.equal(calls.length, 1);
|
||||||
|
assert.equal(calls[0].id, "o-1");
|
||||||
|
assert.equal(calls[0].name, "web_search");
|
||||||
|
});
|
||||||
+27
-94
@@ -1,32 +1,13 @@
|
|||||||
import test, {after} from "node:test";
|
import test from "node:test";
|
||||||
import assert from "node:assert/strict";
|
import assert from "node:assert/strict";
|
||||||
import fs from "node:fs";
|
|
||||||
import os from "node:os";
|
|
||||||
import path from "node:path";
|
|
||||||
|
|
||||||
const tempRoot = fs.mkdtempSync(path.join(os.tmpdir(), "tg-chat-bot-rag-"));
|
const {
|
||||||
process.env.BOT_TOKEN = process.env.BOT_TOKEN ?? "test-token";
|
buildRagArtifactPayload,
|
||||||
process.env.CREATOR_ID = process.env.CREATOR_ID ?? "1";
|
} = await import("../dist/ai/rag-artifact-payload.js");
|
||||||
process.env.DATA_PATH = tempRoot;
|
const {
|
||||||
process.env.DB_PATH = `file:${path.join(tempRoot, "test.sqlite")}`;
|
filterUserVisibleStoredAttachments,
|
||||||
process.env.TEST_ENVIRONMENT = "true";
|
} = await import("../dist/common/attachment-visibility.js");
|
||||||
|
|
||||||
const {Environment} = await import("../dist/common/environment.js");
|
|
||||||
Environment.load();
|
|
||||||
|
|
||||||
const {DatabaseManager} = await import("../dist/db/database-manager.js");
|
|
||||||
DatabaseManager.init();
|
|
||||||
await DatabaseManager.ready;
|
|
||||||
|
|
||||||
const {ArtifactStore} = await import("../dist/common/artifact-store.js");
|
|
||||||
const {filterUserVisibleStoredAttachments} = await import("../dist/common/stored-attachment-utils.js");
|
|
||||||
const {AiProvider} = await import("../dist/model/ai-provider.js");
|
const {AiProvider} = await import("../dist/model/ai-provider.js");
|
||||||
const {persistRagArtifactAttachment} = await import("../dist/ai/rag-artifact-store.js");
|
|
||||||
|
|
||||||
after(async () => {
|
|
||||||
await DatabaseManager.close().catch(() => undefined);
|
|
||||||
fs.rmSync(tempRoot, {recursive: true, force: true});
|
|
||||||
});
|
|
||||||
|
|
||||||
test("internal artifacts are not treated as user-visible attachments", () => {
|
test("internal artifacts are not treated as user-visible attachments", () => {
|
||||||
const visible = filterUserVisibleStoredAttachments([
|
const visible = filterUserVisibleStoredAttachments([
|
||||||
@@ -50,65 +31,26 @@ test("internal artifacts are not treated as user-visible attachments", () => {
|
|||||||
assert.equal(visible[0].fileId, "visible");
|
assert.equal(visible[0].fileId, "visible");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("RAG artifacts persist structured ollama metadata", async () => {
|
test("RAG artifact payload keeps ollama retrieval metadata", () => {
|
||||||
const chatId = 42;
|
const payload = buildRagArtifactPayload({
|
||||||
const messageId = 7;
|
|
||||||
|
|
||||||
const attachment = await persistRagArtifactAttachment({
|
|
||||||
provider: AiProvider.OLLAMA,
|
provider: AiProvider.OLLAMA,
|
||||||
prepared: {
|
createdAt: "2026-01-01T00:00:00.000Z",
|
||||||
provider: AiProvider.OLLAMA,
|
sources: [{
|
||||||
prepared: true,
|
|
||||||
cleanup: async () => undefined,
|
|
||||||
artifact: {
|
|
||||||
query: "What is in the file?",
|
|
||||||
extractedDocuments: [
|
|
||||||
{documentIndex: 0, fileName: "report.txt", textChars: 120},
|
|
||||||
],
|
|
||||||
selectedChunks: [
|
|
||||||
{
|
|
||||||
sourceId: "doc1-1",
|
|
||||||
documentIndex: 0,
|
|
||||||
documentName: "report.txt",
|
|
||||||
chunkIndex: 0,
|
|
||||||
chunkCount: 1,
|
|
||||||
textChars: 120,
|
|
||||||
score: 0.91,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
skippedDocuments: [
|
|
||||||
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
|
|
||||||
],
|
|
||||||
providerState: {
|
|
||||||
embeddingModel: "nomic-embed-text:latest",
|
|
||||||
topK: 8,
|
|
||||||
chunkSize: 1400,
|
|
||||||
chunkOverlap: 220,
|
|
||||||
maxContextChars: 14000,
|
|
||||||
minScore: 0.12,
|
|
||||||
maxArchiveFiles: 200,
|
|
||||||
maxArchiveBytes: 50 * 1024 * 1024,
|
|
||||||
maxArchiveDepth: 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
downloads: [{
|
|
||||||
kind: "document",
|
|
||||||
fileId: "file-1",
|
fileId: "file-1",
|
||||||
fileName: "report.txt",
|
fileName: "report.txt",
|
||||||
buffer: Buffer.from("hello world"),
|
mimeType: "text/plain",
|
||||||
path: path.join(tempRoot, "report.txt"),
|
sizeBytes: 12,
|
||||||
|
sha256: "abc123",
|
||||||
|
uploadedFileId: "uploaded-1",
|
||||||
}],
|
}],
|
||||||
chatId,
|
providerState: {
|
||||||
messageId,
|
provider: AiProvider.OLLAMA,
|
||||||
details: {
|
prepared: true,
|
||||||
embeddingModel: "nomic-embed-text:latest",
|
embeddingModel: "nomic-embed-text:latest",
|
||||||
topK: 8,
|
topK: 8,
|
||||||
chunkSize: 1400,
|
chunkSize: 1400,
|
||||||
chunkOverlap: 220,
|
chunkOverlap: 220,
|
||||||
maxContextChars: 14000,
|
maxContextChars: 14000,
|
||||||
artifact: {
|
|
||||||
query: "What is in the file?",
|
|
||||||
extractedDocuments: [
|
extractedDocuments: [
|
||||||
{documentIndex: 0, fileName: "report.txt", textChars: 120},
|
{documentIndex: 0, fileName: "report.txt", textChars: 120},
|
||||||
],
|
],
|
||||||
@@ -126,29 +68,20 @@ test("RAG artifacts persist structured ollama metadata", async () => {
|
|||||||
skippedDocuments: [
|
skippedDocuments: [
|
||||||
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
|
{documentIndex: 1, fileName: "ignored.bin", reason: "unsupported format"},
|
||||||
],
|
],
|
||||||
providerState: {
|
|
||||||
embeddingModel: "nomic-embed-text:latest",
|
|
||||||
topK: 8,
|
|
||||||
chunkSize: 1400,
|
|
||||||
chunkOverlap: 220,
|
|
||||||
maxContextChars: 14000,
|
|
||||||
minScore: 0.12,
|
minScore: 0.12,
|
||||||
maxArchiveFiles: 200,
|
maxArchiveFiles: 200,
|
||||||
maxArchiveBytes: 50 * 1024 * 1024,
|
maxArchiveBytes: 50 * 1024 * 1024,
|
||||||
maxArchiveDepth: 2,
|
maxArchiveDepth: 2,
|
||||||
},
|
query: "What is in the file?",
|
||||||
},
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
assert.equal(attachment?.artifactKind, "rag");
|
assert.equal(payload.artifactKind, "rag");
|
||||||
assert.equal(fs.existsSync(attachment.cachePath), true);
|
assert.equal(payload.provider, AiProvider.OLLAMA);
|
||||||
|
assert.equal(payload.sources[0].uploadedFileId, "uploaded-1");
|
||||||
const stored = await ArtifactStore.getByMessage(chatId, messageId);
|
assert.equal(payload.providerState.provider, AiProvider.OLLAMA);
|
||||||
assert.equal(stored.length, 1);
|
assert.equal(payload.providerState.query, "What is in the file?");
|
||||||
assert.equal(stored[0].kind, "rag");
|
assert.equal(payload.providerState.selectedChunks[0].score, 0.91);
|
||||||
assert.equal(stored[0].payload.providerState.query, "What is in the file?");
|
assert.equal(payload.providerState.skippedDocuments[0].reason, "unsupported format");
|
||||||
assert.equal(stored[0].payload.providerState.selectedChunks[0].score, 0.91);
|
assert.equal(payload.providerState.embeddingModel, "nomic-embed-text:latest");
|
||||||
assert.equal(stored[0].payload.providerState.skippedDocuments[0].reason, "unsupported format");
|
|
||||||
assert.equal(stored[0].payload.providerState.ollama.embeddingModel, "nomic-embed-text:latest");
|
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user