From 5b67e2306002cc715f9f004d39d0a4d7651773bd Mon Sep 17 00:00:00 2001 From: Danil Nikolaev Date: Tue, 19 May 2026 01:46:12 +0300 Subject: [PATCH] shitton --- .env.example | 11 + README.md | 12 + src/ai/mcp/mcp-client.ts | 421 +++++++++++++++++++++++++++++++++ src/ai/mcp/mcp-config.ts | 106 +++++++++ src/ai/mcp/mcp-json-schema.ts | 123 ++++++++++ src/ai/mcp/mcp-registry.ts | 165 +++++++++++++ src/ai/tool-ranker-metadata.ts | 39 ++- src/ai/tools/registry.ts | 47 ++-- src/common/environment.ts | 6 + src/index.ts | 12 +- 10 files changed, 912 insertions(+), 30 deletions(-) create mode 100644 src/ai/mcp/mcp-client.ts create mode 100644 src/ai/mcp/mcp-config.ts create mode 100644 src/ai/mcp/mcp-json-schema.ts create mode 100644 src/ai/mcp/mcp-registry.ts diff --git a/.env.example b/.env.example index c2501c4..5599b82 100644 --- a/.env.example +++ b/.env.example @@ -43,6 +43,9 @@ ONLY_FOR_CREATOR_MODE=false # Use user names in AI prompts USE_NAMES_IN_PROMPT=true +# Disable all built-in local tools and keep only MCP tools +DISABLE_LOCAL_TOOLS=false + # Custom system prompt for AI (or put it into data/SYSTEM_PROMPT.md) SYSTEM_PROMPT= @@ -99,6 +102,14 @@ OPENAI_TTS_VOICE=alloy OPENAI_TTS_INSTRUCTIONS= OPENAI_MAX_CONCURRENT_REQUESTS=3 +# MCP servers +# JSON array or {"mcpServers": {"name": {...}}} +# Stdio example: +# MCP_SERVERS=[{"name":"local-tools","transport":"stdio","command":"node","args":["./mcp-server.js"]}] +# HTTP example: +# MCP_SERVERS=[{"name":"remote-tools","transport":"http","url":"https://example.com/mcp"}] +MCP_SERVERS= + # Per-capability AI endpoint overrides # Pattern: # __MODEL= diff --git a/README.md b/README.md index 36c4f10..4ee6fec 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,18 @@ The bot initializes and migrates its database schema automatically on startup. `/exportdb` sends the SQLite file when available, plus a `.sql` dump and a JSON backup. `/importdb` restores the database from the JSON backup format. +MCP tool servers can be configured through `MCP_SERVERS` in `.env`. Use a JSON array with `stdio` or `http` transports. Example: + +```bash +MCP_SERVERS=[{"name":"local-tools","transport":"stdio","command":"node","args":["./mcp-server.js"]}] +``` + +If you want to disable all built-in local tools and use only MCP tools, set: + +```bash +DISABLE_LOCAL_TOOLS=true +``` + For local Ollama document RAG, install an embedding model locally and set it in `.env`: ```bash diff --git a/src/ai/mcp/mcp-client.ts b/src/ai/mcp/mcp-client.ts new file mode 100644 index 0000000..74a4bce --- /dev/null +++ b/src/ai/mcp/mcp-client.ts @@ -0,0 +1,421 @@ +import {spawn, type ChildProcessWithoutNullStreams} from "node:child_process"; +import type {BoundaryValue} from "../../common/boundary-types.js"; +import {toolsLogger} from "../tools/tool-logger.js"; +import type {McpServerConfig} from "./mcp-config.js"; + +const logger = toolsLogger.child("mcp"); +const MCP_PROTOCOL_VERSION = "2025-06-18"; +const DEFAULT_REQUEST_TIMEOUT_MS = 30_000; + +export type McpToolDefinition = { + name: string; + description?: string; + inputSchema?: BoundaryValue; +}; + +type JsonRpcRequest = { + jsonrpc: "2.0"; + id: number; + method: string; + params?: BoundaryValue; +}; + +type JsonRpcNotification = { + jsonrpc: "2.0"; + method: string; + params?: BoundaryValue; +}; + +type JsonRpcResponse = { + jsonrpc?: "2.0"; + id?: BoundaryValue; + result?: BoundaryValue; + error?: { + code?: number; + message?: string; + data?: BoundaryValue; + }; +}; + +interface JsonRpcTransport { + request(method: string, params?: BoundaryValue): Promise; + notify(method: string, params?: BoundaryValue): Promise; + close(): Promise; +} + +function isRecord(value: BoundaryValue): value is Record { + return value !== null && typeof value === "object" && !Array.isArray(value); +} + +function toJsonRpcResponse(value: BoundaryValue): JsonRpcResponse | undefined { + if (!isRecord(value)) return undefined; + if (value.jsonrpc !== undefined && value.jsonrpc !== "2.0") return undefined; + return value as JsonRpcResponse; +} + +function extractJsonRpcResult(response: BoundaryValue, expectedId?: number): BoundaryValue { + const parsed = toJsonRpcResponse(response); + if (!parsed) { + throw new Error("Invalid JSON-RPC response from MCP server."); + } + + if (parsed.error) { + throw new Error(parsed.error.message || "MCP server returned an error."); + } + + if (expectedId !== undefined && parsed.id !== undefined && parsed.id !== expectedId) { + throw new Error(`Unexpected JSON-RPC response id from MCP server. Expected ${expectedId}, got ${String(parsed.id)}.`); + } + + return parsed.result ?? {}; +} + +function parseSsePayload(text: string): BoundaryValue[] { + const events: string[] = []; + let current: string[] = []; + + for (const rawLine of text.split(/\r?\n/)) { + const line = rawLine.trimEnd(); + + if (!line) { + if (current.length) { + events.push(current.join("\n")); + current = []; + } + continue; + } + + if (line.startsWith("data:")) { + current.push(line.slice(5).replace(/^ /, "")); + } + } + + if (current.length) { + events.push(current.join("\n")); + } + + return events.map(event => { + try { + return JSON.parse(event) as BoundaryValue; + } catch { + return undefined; + } + }).filter((event): event is BoundaryValue => event !== undefined); +} + +function timeoutPromise(promise: Promise, timeoutMs: number, label: string): Promise { + if (!Number.isFinite(timeoutMs) || timeoutMs <= 0) return promise; + + let timeoutId: NodeJS.Timeout | undefined; + const timeout = new Promise((_, reject) => { + timeoutId = setTimeout(() => { + reject(new Error(`${label} timed out after ${timeoutMs}ms`)); + }, timeoutMs); + }); + + return Promise.race([promise, timeout]).finally(() => { + if (timeoutId) clearTimeout(timeoutId); + }); +} + +class StdioJsonRpcTransport implements JsonRpcTransport { + private readonly process: ChildProcessWithoutNullStreams; + private readonly pending = new Map void; reject: (error: Error) => void;}>(); + private buffer = ""; + private nextId = 1; + + constructor(private readonly config: McpServerConfig) { + if (!config.command) { + throw new Error(`MCP stdio server '${config.name}' is missing command.`); + } + + this.process = spawn(config.command, config.args ?? [], { + cwd: config.cwd, + env: { + ...process.env, + ...config.env, + }, + stdio: ["pipe", "pipe", "pipe"], + windowsHide: true, + }); + + this.process.stdout.on("data", chunk => this.handleStdout(chunk)); + this.process.stderr.on("data", chunk => { + const text = chunk.toString("utf8").trim(); + if (text) logger.debug("stdio.stderr", {server: config.name, text}); + }); + this.process.on("error", error => this.failAll(error)); + this.process.on("exit", code => this.failAll(new Error(`MCP stdio server '${config.name}' exited with code ${code ?? "unknown"}.`))); + } + + private handleStdout(chunk: Buffer): void { + this.buffer += chunk.toString("utf8"); + + let newlineIndex = this.buffer.indexOf("\n"); + while (newlineIndex !== -1) { + const line = this.buffer.slice(0, newlineIndex).trim(); + this.buffer = this.buffer.slice(newlineIndex + 1); + newlineIndex = this.buffer.indexOf("\n"); + + if (!line) continue; + + try { + const message = JSON.parse(line) as JsonRpcResponse | JsonRpcNotification; + if ("id" in message && message.id !== undefined) { + const pending = this.pending.get(Number(message.id)); + if (pending) { + this.pending.delete(Number(message.id)); + if ("error" in message && message.error) { + pending.reject(new Error(message.error.message || "MCP stdio request failed.")); + } else { + pending.resolve((message as JsonRpcResponse).result ?? {}); + } + } + continue; + } + + if ("method" in message) { + logger.debug("stdio.notification", {server: this.config.name, method: message.method}); + } + } catch (error) { + logger.warn("stdio.parse_failed", { + server: this.config.name, + line: line.slice(0, 500), + error: error instanceof Error ? error.message : String(error), + }); + } + } + } + + private failAll(error: Error): void { + for (const pending of this.pending.values()) { + pending.reject(error); + } + this.pending.clear(); + } + + async request(method: string, params?: BoundaryValue): Promise { + if (this.process.exitCode !== null) { + throw new Error(`MCP stdio server '${this.config.name}' is not running.`); + } + + const id = this.nextId++; + const request: JsonRpcRequest = { + jsonrpc: "2.0", + id, + method, + params, + }; + + const result = new Promise((resolve, reject) => { + this.pending.set(id, {resolve, reject}); + }); + + this.process.stdin.write(`${JSON.stringify(request)}\n`); + return timeoutPromise(result, this.config.timeoutMs ?? DEFAULT_REQUEST_TIMEOUT_MS, `${this.config.name}.${method}`); + } + + async notify(method: string, params?: BoundaryValue): Promise { + if (this.process.exitCode !== null) { + throw new Error(`MCP stdio server '${this.config.name}' is not running.`); + } + + const notification: JsonRpcNotification = { + jsonrpc: "2.0", + method, + params, + }; + + this.process.stdin.write(`${JSON.stringify(notification)}\n`); + } + + async close(): Promise { + this.failAll(new Error(`MCP stdio server '${this.config.name}' closed.`)); + if (!this.process.killed) { + this.process.kill(); + } + } +} + +class HttpJsonRpcTransport implements JsonRpcTransport { + private nextId = 1; + private sessionId?: string; + + constructor(private readonly config: McpServerConfig) { + if (!config.url) { + throw new Error(`MCP HTTP server '${config.name}' is missing url.`); + } + } + + private async post(body: BoundaryValue): Promise { + const controller = new AbortController(); + const timeoutMs = this.config.timeoutMs ?? DEFAULT_REQUEST_TIMEOUT_MS; + const timeoutId = setTimeout(() => controller.abort(), timeoutMs); + + try { + return await fetch(this.config.url!, { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + ...(this.sessionId ? {"Mcp-Session-Id": this.sessionId} : {}), + ...(this.config.headers ?? {}), + }, + body: JSON.stringify(body), + signal: controller.signal, + }).finally(() => clearTimeout(timeoutId)); + } catch (error) { + clearTimeout(timeoutId); + throw error; + } + } + + async request(method: string, params?: BoundaryValue): Promise { + const id = this.nextId++; + const request: JsonRpcRequest = { + jsonrpc: "2.0", + id, + method, + params, + }; + + const response = await this.post(request); + const sessionId = response.headers.get("Mcp-Session-Id"); + if (sessionId) { + this.sessionId = sessionId; + } + + if (!response.ok) { + const errorText = await response.text().catch(() => ""); + throw new Error(`MCP HTTP server '${this.config.name}' returned ${response.status}: ${errorText || response.statusText}`); + } + + const contentType = response.headers.get("content-type")?.toLowerCase() ?? ""; + let payload: BoundaryValue; + + if (contentType.includes("text/event-stream")) { + const text = await response.text(); + const messages = parseSsePayload(text); + const responseMessage = messages.map(toJsonRpcResponse).find(message => message?.id === id && (message.result !== undefined || message.error)); + payload = extractJsonRpcResult(responseMessage ?? messages[0] ?? {}, id); + } else { + payload = extractJsonRpcResult(await response.json() as BoundaryValue, id); + } + + return payload; + } + + async notify(method: string, params?: BoundaryValue): Promise { + const response = await this.post({ + jsonrpc: "2.0", + method, + params, + }); + + const sessionId = response.headers.get("Mcp-Session-Id"); + if (sessionId) { + this.sessionId = sessionId; + } + + if (!response.ok && response.status !== 202) { + const errorText = await response.text().catch(() => ""); + throw new Error(`MCP HTTP notification failed for '${this.config.name}' with ${response.status}: ${errorText || response.statusText}`); + } + } + + async close(): Promise { + return; + } +} + +function createTransport(config: McpServerConfig): JsonRpcTransport { + return config.transport === "stdio" + ? new StdioJsonRpcTransport(config) + : new HttpJsonRpcTransport(config); +} + +function normalizeToolResultContent(content: BoundaryValue): string { + if (content === undefined || content === null) return ""; + if (typeof content === "string") return content; + if (typeof content === "number" || typeof content === "boolean") return String(content); + if (Array.isArray(content)) return content.map(item => normalizeToolResultContent(item)).filter(Boolean).join("\n"); + if (!isRecord(content)) return JSON.stringify(content); + + if (content.type === "text" && typeof content.text === "string") return content.text; + if (content.type === "image") { + return `[image ${typeof content.mimeType === "string" ? content.mimeType : "unknown"}]`; + } + if (content.type === "resource" && isRecord(content.resource)) { + if (typeof content.resource.text === "string") return content.resource.text; + return JSON.stringify(content.resource); + } + + return JSON.stringify(content); +} + +export class McpClient { + private readonly transport: JsonRpcTransport; + private initialized = false; + + constructor(readonly config: McpServerConfig) { + this.transport = createTransport(config); + } + + async initialize(): Promise { + if (this.initialized) return; + + await this.transport.request("initialize", { + protocolVersion: MCP_PROTOCOL_VERSION, + clientInfo: { + name: "tg-chat-bot", + version: "1.0.0", + }, + capabilities: {}, + }); + + await this.transport.notify("notifications/initialized"); + this.initialized = true; + } + + async listTools(): Promise { + await this.initialize(); + const result = await this.transport.request("tools/list"); + + if (!isRecord(result)) return []; + + const tools = Array.isArray(result.tools) ? result.tools : []; + return tools.flatMap(tool => { + if (!isRecord(tool) || typeof tool.name !== "string") return []; + return [{ + name: tool.name, + description: typeof tool.description === "string" ? tool.description : undefined, + inputSchema: tool.inputSchema, + }]; + }); + } + + async callTool(name: string, args?: BoundaryValue): Promise { + await this.initialize(); + const result = await this.transport.request("tools/call", { + name, + arguments: args ?? {}, + }); + + if (!isRecord(result)) { + return normalizeToolResultContent(result); + } + + const content = Array.isArray(result.content) ? result.content : []; + const text = content.map(item => normalizeToolResultContent(item)).filter(Boolean).join("\n"); + + if (result.isError) { + return text ? `[MCP error] ${text}` : "[MCP error]"; + } + + return text || JSON.stringify(result); + } + + async close(): Promise { + await this.transport.close(); + } +} diff --git a/src/ai/mcp/mcp-config.ts b/src/ai/mcp/mcp-config.ts new file mode 100644 index 0000000..8c413ab --- /dev/null +++ b/src/ai/mcp/mcp-config.ts @@ -0,0 +1,106 @@ +import type {BoundaryValue} from "../../common/boundary-types.js"; + +export type McpTransport = "stdio" | "http"; + +export type McpServerConfig = { + name: string; + transport: McpTransport; + command?: string; + args?: string[]; + cwd?: string; + env?: Record; + url?: string; + headers?: Record; + timeoutMs?: number; +}; + +function isRecord(value: BoundaryValue): value is Record { + return value !== null && typeof value === "object" && !Array.isArray(value); +} + +function asString(value: BoundaryValue): string | undefined { + return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined; +} + +function toStringRecord(value: BoundaryValue): Record | undefined { + if (!isRecord(value)) return undefined; + + const result: Record = {}; + for (const [key, entry] of Object.entries(value)) { + if (typeof entry === "string" || typeof entry === "number" || typeof entry === "boolean") { + result[key] = String(entry); + } + } + + return Object.keys(result).length ? result : undefined; +} + +function toStringArray(value: BoundaryValue): string[] | undefined { + if (!Array.isArray(value)) return undefined; + const items = value.filter((item): item is string => typeof item === "string" && item.trim().length > 0) + .map(item => item.trim()); + return items.length ? items : undefined; +} + +function toPositiveInt(value: BoundaryValue): number | undefined { + const n = typeof value === "number" + ? value + : typeof value === "string" + ? Number(value) + : NaN; + + if (!Number.isFinite(n) || n <= 0) return undefined; + return Math.floor(n); +} + +function normalizeServerConfig(value: BoundaryValue, fallbackName?: string): McpServerConfig | undefined { + if (!isRecord(value)) return undefined; + + const name = asString(value.name) ?? fallbackName; + const transportRaw = asString(value.transport); + const transport = transportRaw === "http" || transportRaw === "stdio" ? transportRaw : undefined; + + if (!name || !transport) return undefined; + + return { + name, + transport, + command: asString(value.command), + args: toStringArray(value.args), + cwd: asString(value.cwd), + env: toStringRecord(value.env), + url: asString(value.url), + headers: toStringRecord(value.headers), + timeoutMs: toPositiveInt(value.timeoutMs), + }; +} + +export function parseMcpServerConfigs(raw: string | undefined): McpServerConfig[] { + if (!raw?.trim()) return []; + + let parsed: BoundaryValue; + try { + parsed = JSON.parse(raw) as BoundaryValue; + } catch (error) { + throw new Error(`Invalid MCP_SERVERS JSON: ${error instanceof Error ? error.message : String(error)}`); + } + + if (Array.isArray(parsed)) { + return parsed.flatMap((item, index) => normalizeServerConfig(item, `server-${index + 1}`) ? [normalizeServerConfig(item, `server-${index + 1}`)!] : []); + } + + if (!isRecord(parsed)) { + return []; + } + + if (Array.isArray(parsed.servers)) { + return parsed.servers.flatMap((item, index) => normalizeServerConfig(item, `server-${index + 1}`) ? [normalizeServerConfig(item, `server-${index + 1}`)!] : []); + } + + if (isRecord(parsed.mcpServers)) { + return Object.entries(parsed.mcpServers).flatMap(([name, item]) => normalizeServerConfig(item, name) ? [normalizeServerConfig(item, name)!] : []); + } + + const single = normalizeServerConfig(parsed); + return single ? [single] : []; +} diff --git a/src/ai/mcp/mcp-json-schema.ts b/src/ai/mcp/mcp-json-schema.ts new file mode 100644 index 0000000..d916e16 --- /dev/null +++ b/src/ai/mcp/mcp-json-schema.ts @@ -0,0 +1,123 @@ +import type {AiJsonValue, AiToolParameters} from "../tool-types.js"; +import type {BoundaryValue} from "../../common/boundary-types.js"; + +type JsonSchemaRecord = Record; + +function isRecord(value: BoundaryValue): value is JsonSchemaRecord { + return value !== null && typeof value === "object" && !Array.isArray(value); +} + +function toAiJsonValue(value: BoundaryValue): AiJsonValue | undefined { + if (value === undefined) return undefined; + if (value === null) return null; + if (typeof value === "string" || typeof value === "number" || typeof value === "boolean") return value; + + if (Array.isArray(value)) { + return value.map(item => toAiJsonValue(item) ?? null); + } + + if (!isRecord(value)) return undefined; + + const result: Record = {}; + for (const [key, entry] of Object.entries(value)) { + const normalized = toAiJsonValue(entry); + if (normalized !== undefined) { + result[key] = normalized; + } + } + + return result; +} + +function normalizeType(value: BoundaryValue): AiToolParameters["type"] | undefined { + const candidates = Array.isArray(value) + ? value.filter((item): item is string => typeof item === "string") + : typeof value === "string" + ? [value] + : []; + + const prioritized = candidates.find(item => item !== "null") ?? candidates[0]; + if (!prioritized) return undefined; + + switch (prioritized) { + case "object": + case "string": + case "number": + case "integer": + case "boolean": + case "array": + return prioritized; + default: + return undefined; + } +} + +export function convertJsonSchemaToToolParameters(schema: BoundaryValue): AiToolParameters | undefined { + if (!isRecord(schema)) return undefined; + + const declaredType = normalizeType(schema.type); + const inferredType = declaredType + ?? (schema.properties !== undefined || schema.additionalProperties !== undefined ? "object" : undefined) + ?? (schema.items !== undefined ? "array" : undefined) + ?? "object"; + + const result: AiToolParameters = { + type: inferredType, + }; + + const description = typeof schema.description === "string" && schema.description.trim().length > 0 + ? schema.description.trim() + : undefined; + if (description) result.description = description; + + const defaultValue = toAiJsonValue(schema.default); + if (defaultValue !== undefined) result.default = defaultValue; + + if (Array.isArray(schema.enum)) { + const enumValues = schema.enum + .filter((item): item is string => typeof item === "string" && item.length > 0); + if (enumValues.length) result.enum = enumValues; + } + + if (typeof schema.minItems === "number") result.minItems = schema.minItems; + if (typeof schema.maxItems === "number") result.maxItems = schema.maxItems; + if (typeof schema.minimum === "number") result.minimum = schema.minimum; + if (typeof schema.maximum === "number") result.maximum = schema.maximum; + + if (Array.isArray(schema.required)) { + const required = schema.required.filter((item): item is string => typeof item === "string" && item.trim().length > 0); + if (required.length) result.required = required; + } + + if (inferredType === "object" || schema.properties !== undefined || schema.additionalProperties !== undefined) { + if (isRecord(schema.properties)) { + const properties: Record = {}; + for (const [key, value] of Object.entries(schema.properties)) { + const converted = convertJsonSchemaToToolParameters(value); + if (converted) properties[key] = converted; + } + if (Object.keys(properties).length) result.properties = properties; + } + + if (schema.additionalProperties !== undefined) { + result.additionalProperties = typeof schema.additionalProperties === "boolean" + ? schema.additionalProperties + : convertJsonSchemaToToolParameters(schema.additionalProperties); + } + } + + if (inferredType === "array" || schema.items !== undefined) { + if (Array.isArray(schema.items)) { + const firstItem = schema.items[0]; + if (firstItem !== undefined) { + const converted = convertJsonSchemaToToolParameters(firstItem); + if (converted) result.items = converted; + } + } else { + const converted = convertJsonSchemaToToolParameters(schema.items); + if (converted) result.items = converted; + } + } + + return result; +} diff --git a/src/ai/mcp/mcp-registry.ts b/src/ai/mcp/mcp-registry.ts new file mode 100644 index 0000000..25b00f0 --- /dev/null +++ b/src/ai/mcp/mcp-registry.ts @@ -0,0 +1,165 @@ +import {Environment} from "../../common/environment.js"; +import type {AiTool} from "../tool-types.js"; +import type {ToolHandler} from "../tools/types.js"; +import {normalizeToolArguments} from "../tools/utils.js"; +import {toolsLogger} from "../tools/tool-logger.js"; +import {convertJsonSchemaToToolParameters} from "./mcp-json-schema.js"; +import {McpClient, type McpToolDefinition} from "./mcp-client.js"; +import {parseMcpServerConfigs, type McpServerConfig} from "./mcp-config.js"; + +const logger = toolsLogger.child("mcp-registry"); + +type McpToolBinding = { + server: McpServerConfig; + client: McpClient; + remoteToolName: string; + localToolName: string; + tool: AiTool; +}; + +type McpInitSummary = { + servers: number; + loadedServers: number; + tools: number; + failedServers: string[]; +}; + +const toolBindings = new Map(); +const clients = new Map(); +let initPromise: Promise | undefined; + +function sanitizeSegment(value: string): string { + return value + .trim() + .replace(/[^a-zA-Z0-9_]+/g, "_") + .replace(/_+/g, "_") + .replace(/^_+|_+$/g, "") || "tool"; +} + +function buildLocalToolName(serverName: string, toolName: string): string { + return `mcp__${sanitizeSegment(serverName)}__${sanitizeSegment(toolName)}`; +} + +function buildTool(serverName: string, tool: McpToolDefinition): AiTool { + const localName = buildLocalToolName(serverName, tool.name); + const description = tool.description?.trim() + ? `[MCP ${serverName}] ${tool.description.trim()}` + : `[MCP ${serverName}] ${tool.name}`; + + return { + type: "function", + function: { + name: localName, + description, + parameters: convertJsonSchemaToToolParameters(tool.inputSchema), + }, + }; +} + +async function loadServer(config: McpServerConfig): Promise<{loaded: boolean; tools: number}> { + const client = new McpClient(config); + clients.set(config.name, client); + + try { + const remoteTools = await client.listTools(); + let loaded = 0; + + for (const remoteTool of remoteTools) { + const localName = buildLocalToolName(config.name, remoteTool.name); + if (toolBindings.has(localName)) { + logger.warn("tool.duplicate", { + server: config.name, + tool: remoteTool.name, + localName, + }); + continue; + } + + const binding: McpToolBinding = { + server: config, + client, + remoteToolName: remoteTool.name, + localToolName: localName, + tool: buildTool(config.name, remoteTool), + }; + + toolBindings.set(localName, binding); + loaded += 1; + } + + logger.info("server.loaded", { + server: config.name, + transport: config.transport, + tools: loaded, + }); + return {loaded: true, tools: loaded}; + } catch (error) { + logger.error("server.failed", { + server: config.name, + transport: config.transport, + error: error instanceof Error ? error.message : String(error), + }); + await client.close().catch(() => undefined); + clients.delete(config.name); + return {loaded: false, tools: 0}; + } +} + +export async function initializeMcpTools(): Promise { + if (initPromise) return initPromise; + + initPromise = (async () => { + toolBindings.clear(); + await Promise.all([...clients.values()].map(client => client.close().catch(() => undefined))); + clients.clear(); + + const configs = parseMcpServerConfigs(Environment.MCP_SERVERS); + const results = await Promise.all(configs.map(config => loadServer(config))); + + return { + servers: configs.length, + loadedServers: results.filter(result => result.loaded).length, + tools: [...results].reduce((sum, result) => sum + result.tools, 0), + failedServers: configs.filter((_, index) => !results[index]?.loaded).map(config => config.name), + }; + })(); + + try { + const summary = await initPromise; + logger.info("init.done", summary); + return summary; + } catch (error) { + initPromise = undefined; + logger.error("init.failed", {error: error instanceof Error ? error.message : String(error)}); + throw error; + } +} + +export function getMcpTools(): AiTool[] { + return [...toolBindings.values()].map(binding => binding.tool); +} + +export function getMcpToolHandlers(): Record { + const handlers: Record = {}; + + for (const binding of toolBindings.values()) { + handlers[binding.localToolName] = async args => { + const normalized = normalizeToolArguments(args, undefined); + return binding.client.callTool(binding.remoteToolName, normalized); + }; + } + + return handlers; +} + +export function getMcpToolPrompts(_toolNames: string[]): string[] { + return []; +} + +export async function shutdownMcpTools(): Promise { + initPromise = undefined; + toolBindings.clear(); + + await Promise.all([...clients.values()].map(client => client.close().catch(() => undefined))); + clients.clear(); +} diff --git a/src/ai/tool-ranker-metadata.ts b/src/ai/tool-ranker-metadata.ts index 1d45989..7cf47f0 100644 --- a/src/ai/tool-ranker-metadata.ts +++ b/src/ai/tool-ranker-metadata.ts @@ -352,6 +352,20 @@ function toolNamesFromTool(tool: BoundaryValue): string[] { return name ? [name] : []; } +function fallbackToolInfoFromTool(toolValue: BoundaryValue, name: string): ToolRankerToolInfo | undefined { + if (!isRecord(toolValue)) return undefined; + + const fn = isRecord(toolValue.function) ? toolValue.function : undefined; + const description = asOptionalString(fn?.description ?? toolValue.description) + ?? `Tool ${name}.`; + + return tool( + name, + description, + "Use when the tool description matches the user's request.", + ); +} + export function getToolRankerToolInfo(name: string): ToolRankerToolInfo | undefined { return TOOL_RANKER_TOOL_INFOS[name as ToolRankerToolName]; } @@ -363,10 +377,25 @@ export function getToolRankerToolInfos(names: readonly string[]): ToolRankerTool } export function getToolRankerAvailableToolInfos(availableTools: readonly BoundaryValue[]): ToolRankerToolInfo[] { - return getToolRankerToolInfos([ - "no_tool", - ...availableTools.flatMap(toolNamesFromTool), - ]); + const infos = new Map(); + + infos.set("no_tool", TOOL_RANKER_TOOL_INFOS.no_tool); + + for (const tool of availableTools) { + for (const name of toolNamesFromTool(tool)) { + if (infos.has(name)) continue; + + const known = getToolRankerToolInfo(name); + const fallback = fallbackToolInfoFromTool(tool, name); + if (known) { + infos.set(name, known); + } else if (fallback) { + infos.set(name, fallback); + } + } + } + + return [...infos.values()]; } function renderToolLine(tool: ToolRankerToolInfo, compact: boolean): string { @@ -471,7 +500,7 @@ export function buildToolRankerSystemPrompt(params: { const includeExamples = params.includeExamples ?? false; const maxExamplesPerTool = Math.max(0, params.maxExamplesPerTool ?? 1); const compact = params.compact ?? true; - const availableTools = getToolRankerToolInfos(params.availableTools.map(tool => tool.name)); + const availableTools = params.availableTools; const availableToolNames = availableTools.map(tool => tool.name); const sections: string[] = [ diff --git a/src/ai/tools/registry.ts b/src/ai/tools/registry.ts index 04c4c71..dce51c5 100644 --- a/src/ai/tools/registry.ts +++ b/src/ai/tools/registry.ts @@ -45,6 +45,7 @@ import { writeFileChunk, writeFileChunkTool } from "./files"; +import {getMcpToolHandlers, getMcpToolPrompts, getMcpTools} from "../mcp/mcp-registry.js"; export const defaultTools: AiTool[] = [ getCurrentDateTimeTool, @@ -72,22 +73,16 @@ export const fileTools = [ deletePathTool, ] satisfies AiTool[]; -// export const notesFileTools: AiTool[] = [ -// createNoteTool, -// listNotesTool, -// getNoteContentTool, -// updateNoteContentTool, -// deleteNoteTool, -// sendNoteAsFileTool, -// searchNotesTool -// ] - export const getTools = (forCreator?: boolean) => { - const tools: AiTool[] = [ + const tools: AiTool[] = Environment.DISABLE_LOCAL_TOOLS ? [] : [ ...defaultTools, - // ...notesFileTools ]; + if (Environment.DISABLE_LOCAL_TOOLS) { + tools.push(...getMcpTools()); + return tools; + } + if (Environment.BRAVE_SEARCH_API_KEY) { tools.push(webSearchTool); } @@ -110,6 +105,8 @@ export const getTools = (forCreator?: boolean) => { } } + tools.push(...getMcpTools()); + return tools; }; @@ -136,20 +133,20 @@ export const fileToolHandlers = { export const getToolHandlers = () => { let handlers: Record = { + ...getMcpToolHandlers(), + }; + + if (Environment.DISABLE_LOCAL_TOOLS) { + return handlers; + } + + handlers = { + ...handlers, get_datetime: getCurrentDateTime, get_financial_market_data: getMarketRates, - // create_note: createNote, - // list_notes: listNotes, - // get_note_content: getNoteContent, - // update_note_content: updateNoteContent, - // delete_note: deleteNote, - // send_note_as_file: sendNoteAsFile, - // search_notes: searchNotes, - ...fileToolHandlers, - python_interpreter: runPythonInterpreter, shell_execute: shellExecute, @@ -157,13 +154,16 @@ export const getToolHandlers = () => { web_search: webSearch, get_weather: getWeather, - }; return handlers; }; export function getToolPrompts(toolNames: string[]): string[] { + if (Environment.DISABLE_LOCAL_TOOLS) { + return getMcpToolPrompts(toolNames); + } + const prompts: string[] = []; for (const toolName of toolNames) { @@ -185,5 +185,6 @@ export function getToolPrompts(toolNames: string[]): string[] { } } + prompts.push(...getMcpToolPrompts(toolNames)); return prompts; -} \ No newline at end of file +} diff --git a/src/common/environment.ts b/src/common/environment.ts index f19e4f6..03d82f6 100644 --- a/src/common/environment.ts +++ b/src/common/environment.ts @@ -214,6 +214,8 @@ const RuntimeEnvSchema = z.object({ SEND_TIME_TOOK: optionalBooleanSchema, ENABLE_PYTHON_INTERPRETER: optionalBooleanSchema, + DISABLE_LOCAL_TOOLS: optionalBooleanSchema, + MCP_SERVERS: optionalStringSchema, OLLAMA_API_KEY: optionalStringSchema, OLLAMA_ADDRESS: optionalStringSchema, @@ -308,6 +310,8 @@ export class Environment { static SEND_TIME_TOOK: boolean = false; static ENABLE_PYTHON_INTERPRETER: boolean = false; + static DISABLE_LOCAL_TOOLS: boolean = false; + static MCP_SERVERS?: string; static OLLAMA_API_KEY?: string; static OLLAMA_ADDRESS?: string; @@ -1842,6 +1846,8 @@ export class Environment { Environment.SEND_TIME_TOOK = env.SEND_TIME_TOOK ?? false; Environment.ENABLE_PYTHON_INTERPRETER = env.ENABLE_PYTHON_INTERPRETER ?? false; + Environment.DISABLE_LOCAL_TOOLS = env.DISABLE_LOCAL_TOOLS ?? false; + Environment.MCP_SERVERS = env.MCP_SERVERS; Environment.OLLAMA_API_KEY = env.OLLAMA_API_KEY; Environment.OLLAMA_ADDRESS = env.OLLAMA_ADDRESS; diff --git a/src/index.ts b/src/index.ts index bce4378..7473b84 100644 --- a/src/index.ts +++ b/src/index.ts @@ -79,6 +79,7 @@ import {AIAudit} from "./commands/ai-audit.js"; import {AIMetrics} from "./commands/ai-metrics.js"; import {AIRequests} from "./commands/ai-requests.js"; import {cleanupStaleRagProviderState} from "./ai/rag-retention.js"; +import {initializeMcpTools, shutdownMcpTools} from "./ai/mcp/mcp-registry.js"; process.setUncaughtExceptionCaptureCallback(logError); @@ -236,11 +237,17 @@ export async function shutdown(signal: NodeJS.Signals | "manual") { logError(error instanceof Error ? error : String(error)); } finally { try { - await DatabaseManager.close(); + await shutdownMcpTools(); } catch (error) { logError(error instanceof Error ? error : String(error)); + } finally { + try { + await DatabaseManager.close(); + } catch (error) { + logError(error instanceof Error ? error : String(error)); + } + process.exit(0); } - process.exit(0); } } @@ -280,6 +287,7 @@ async function main() { await measureStartupStep("cleanup_internal_artifacts", () => cleanupInternalArtifactCache(), () => ({retentionDays: 14})); await measureStartupStep("cleanup_stale_rag_provider_state", () => cleanupStaleRagProviderState(), () => ({retentionDays: 14})); + await measureStartupStep("mcp.initialize", () => initializeMcpTools()); await measureStartupStep("observability.snapshot", async () => { const [aiRequests, attachments, artifacts, requestAudits] = await Promise.all([ DatabaseManager.getAllAiRequests(),