ability to cancel ollama generation

This commit is contained in:
2026-01-15 19:57:28 +03:00
parent f70b103282
commit 8743258474
6 changed files with 142 additions and 17 deletions
+22
View File
@@ -0,0 +1,22 @@
import {CallbackCommand} from "../base/callback-command";
export class Cancel extends CallbackCommand {
text = "❌ Отменить";
data = null;
constructor(text?: string, data?: string) {
super();
this.text = text ?? this.text;
this.data = data ?? this.data;
}
static withData(data?: string): Cancel {
return new Cancel(null, data);
}
async execute(): Promise<void> {
return Promise.resolve();
}
}
+32
View File
@@ -0,0 +1,32 @@
import {CallbackCommand} from "../base/callback-command";
import {CallbackQuery} from "typescript-telegram-bot-api";
import {abortOllamaRequest, bot, getOllamaRequest} from "../index";
import {logError} from "../util/utils";
export class OllamaCancel extends CallbackCommand {
data = "/cancel_ollama";
text = "Cancel Ollama generation";
async execute(query: CallbackQuery): Promise<void> {
const chatId = query.message.chat.id;
const fromId = query.from.id;
const messageId = query.message.message_id;
const uuid = query.data.split(" ")[1];
if (!uuid) return;
const request = getOllamaRequest(uuid);
if (!request) return;
if (request.fromId !== fromId) return;
const aborted = abortOllamaRequest(uuid);
console.log(`aborted request ${uuid}:`, aborted);
await bot.editMessageReplyMarkup({
chat_id: chatId,
message_id: messageId,
reply_markup: {inline_keyboard: []}
}).catch(logError);
}
}
+28 -5
View File
@@ -1,6 +1,6 @@
import {ChatCommand} from "../base/chat-command"; import {ChatCommand} from "../base/chat-command";
import {Message} from "typescript-telegram-bot-api"; import {Message} from "typescript-telegram-bot-api";
import {bot, ollama} from "../index"; import {abortOllamaRequest, bot, getOllamaRequest, ollama, ollamaRequests} from "../index";
import { import {
collectReplyChainText, collectReplyChainText,
editMessageText, editMessageText,
@@ -16,6 +16,8 @@ import {MessageStore} from "../common/message-store";
import axios from "axios"; import axios from "axios";
import * as fs from "node:fs"; import * as fs from "node:fs";
import path from "node:path"; import path from "node:path";
import {Cancel} from "../callback_commands/cancel";
import {OllamaCancel} from "../callback_commands/ollama-cancel";
export class OllamaChat extends ChatCommand { export class OllamaChat extends ChatCommand {
regexp = /^\/ollama\s([^]+)/; regexp = /^\/ollama\s([^]+)/;
@@ -73,13 +75,18 @@ export class OllamaChat extends ChatCommand {
const startTime = Date.now(); const startTime = Date.now();
try { try {
let isOver: boolean = false;
const uuid = crypto.randomUUID();
const cancelMarkup = {inline_keyboard: [[Cancel.withData(new OllamaCancel().data + " " + uuid).asButton()]]};
waitMessage = await bot.sendMessage({ waitMessage = await bot.sendMessage({
chat_id: chatId, chat_id: chatId,
text: maxSize !== null ? `🔍 Внимательно изучаю изображение...\n🤓 ${maxSize.width}x${maxSize.height}px` : Environment.waitText, text: maxSize !== null ? `🔍 Внимательно изучаю изображение...\n🤓 ${maxSize.width}x${maxSize.height}px` : Environment.waitText,
reply_parameters: { reply_parameters: {
chat_id: chatId, chat_id: chatId,
message_id: msg.message_id message_id: msg.message_id
} },
reply_markup: cancelMarkup
}); });
const stream = await ollama.chat({ const stream = await ollama.chat({
@@ -90,6 +97,8 @@ export class OllamaChat extends ChatCommand {
messages: chatMessages messages: chatMessages
}); });
ollamaRequests.push({uuid: uuid, stream: stream, done: false, fromId: msg.from.id, chatId: msg.chat.id});
let currentText = ""; let currentText = "";
let shouldBreak = false; let shouldBreak = false;
@@ -97,7 +106,13 @@ export class OllamaChat extends ChatCommand {
intervalMs: 4500, intervalMs: 4500,
getText: () => currentText, getText: () => currentText,
editFn: async (text) => { editFn: async (text) => {
await editMessageText(chatId, waitMessage.message_id, escapeMarkdownV2Text(text), "Markdown"); await editMessageText(
chatId,
waitMessage.message_id,
escapeMarkdownV2Text(text),
"Markdown",
isOver ? {inline_keyboard: []} : cancelMarkup
).catch(logError);
}, },
onStop: async () => { onStop: async () => {
} }
@@ -105,6 +120,7 @@ export class OllamaChat extends ChatCommand {
try { try {
for await (const chunk of stream) { for await (const chunk of stream) {
if (!getOllamaRequest(uuid).done) {
const content = chunk.message.content; const content = chunk.message.content;
currentText += content; currentText += content;
@@ -113,8 +129,13 @@ export class OllamaChat extends ChatCommand {
currentText = currentText.slice(0, 4093) + "..."; currentText = currentText.slice(0, 4093) + "...";
shouldBreak = true; shouldBreak = true;
} }
} else {
shouldBreak = true;
}
if (shouldBreak || chunk.done) { if (shouldBreak || chunk.done) {
isOver = true;
console.log("messageText", currentText); console.log("messageText", currentText);
console.log("length", length); console.log("length", length);
@@ -124,7 +145,7 @@ export class OllamaChat extends ChatCommand {
console.log("ended", true); console.log("ended", true);
} }
stream.abort(); console.log(`aborted request ${uuid}:`, abortOllamaRequest(uuid));
const diff = Math.abs(Date.now() - startTime) / 1000; const diff = Math.abs(Date.now() - startTime) / 1000;
@@ -140,13 +161,15 @@ export class OllamaChat extends ChatCommand {
} }
} }
} finally { } finally {
console.log(`aborted request ${uuid}:`, abortOllamaRequest(uuid));
await editor.tick(); await editor.tick();
await editor.stop(); await editor.stop();
} }
} catch (error) { } catch (error) {
if (error.message === "This operation was aborted") return;
console.error(error); console.error(error);
await replyToMessage(waitMessage, `Произошла ошибка!\n${error.toString()}`).catch(logError); await replyToMessage(waitMessage, `Произошла ошибка!\n${error.toString()}`).catch(logError);
} }
return Promise.resolve();
} }
} }
+40
View File
@@ -6,6 +6,7 @@ import {
checkRequirements, checkRequirements,
executeChatCommand, executeChatCommand,
extractTextMessage, extractTextMessage,
findAndExecuteCallbackCommand,
initSystemSpecs, initSystemSpecs,
logError, logError,
randomValue, randomValue,
@@ -53,6 +54,9 @@ import {MessageDao} from "./db/message-dao";
import {DatabaseManager} from "./db/database-manager"; import {DatabaseManager} from "./db/database-manager";
import {UserDao} from "./db/user-dao"; import {UserDao} from "./db/user-dao";
import {UserStore} from "./common/user-store"; import {UserStore} from "./common/user-store";
import {OllamaRequest} from "./model/ollama-request";
import {CallbackCommand} from "./base/callback-command";
import {OllamaCancel} from "./callback_commands/ollama-cancel";
process.setUncaughtExceptionCaptureCallback(console.error); process.setUncaughtExceptionCaptureCallback(console.error);
@@ -70,6 +74,33 @@ export const ollama = new Ollama({
headers: {"Authorization": `Bearer ${Environment.OLLAMA_API_KEY}`} headers: {"Authorization": `Bearer ${Environment.OLLAMA_API_KEY}`}
}); });
export const ollamaRequests: OllamaRequest[] = [];
export function getOllamaRequest(uuid: string): OllamaRequest | null {
return ollamaRequests.find(r => r.uuid === uuid);
}
export function updateOllamaRequest(uuid: string, request: OllamaRequest) {
const index = ollamaRequests.findIndex(r => r.uuid === uuid);
if (index >= 0) {
ollamaRequests[index] = request;
}
}
export function abortOllamaRequest(uuid: string): boolean {
const request = getOllamaRequest(uuid);
if (!request || request.done) return false;
try {
request.stream.abort();
updateOllamaRequest(uuid, {...request, done: true});
return true;
} catch (e) {
console.error(e);
return false;
}
}
export const googleAi = new GoogleGenAI({apiKey: Environment.GEMINI_API_KEY}); export const googleAi = new GoogleGenAI({apiKey: Environment.GEMINI_API_KEY});
export let systemInfoText: string = ""; export let systemInfoText: string = "";
@@ -113,6 +144,10 @@ export const chatCommands: ChatCommand[] = [
new Leave(), new Leave(),
]; ];
export const callbackCommands: CallbackCommand[] = [
new OllamaCancel()
];
if (Environment.OLLAMA_ADDRESS && Environment.OLLAMA_MODEL && Environment.SYSTEM_PROMPT) { if (Environment.OLLAMA_ADDRESS && Environment.OLLAMA_MODEL && Environment.SYSTEM_PROMPT) {
chatCommands.push(new OllamaChat(), new OllamaPrompt(), new OllamaKill()); chatCommands.push(new OllamaChat(), new OllamaPrompt(), new OllamaKill());
} }
@@ -257,4 +292,9 @@ bot.on("inline_query", async (query) => {
} }
}); });
bot.on("callback_query", async (query) => {
console.log(query);
await findAndExecuteCallbackCommand(callbackCommands, query);
});
main().catch(console.error); main().catch(console.error);
+7
View File
@@ -0,0 +1,7 @@
export type OllamaRequest = {
uuid: string;
stream: any;
done: boolean;
fromId: number;
chatId: number;
}
+8 -7
View File
@@ -127,20 +127,21 @@ export async function findAndExecuteCallbackCommand(commands: CallbackCommand[],
const fromId = query.from.id; const fromId = query.from.id;
const data = query.data || ""; const data = query.data || "";
const command = searchCallbackCommand(commands, data); const cmd = searchCallbackCommand(commands, data);
if (!command) return false; if (!cmd) return false;
const requirements = command.requirements; // TODO: 15/01/2026, Danil Nikolaev: reimplement
const requirements = cmd.requirements;
if (requirements) { if (requirements) {
if (requirements.isRequiresBotAdmin() && !Environment.ADMIN_IDS.has(fromId)) { if (requirements.isRequiresBotAdmin() && !Environment.ADMIN_IDS.has(fromId)) {
console.log(`${command.data}: adminId is bad: ${fromId}`); console.log(`${cmd.data}: adminId is bad: ${fromId}`);
return false; return false;
} }
} }
await command.execute(query); await cmd.execute(query);
await command.answerCallbackQuery(query); await cmd.answerCallbackQuery(query);
await command.afterExecute(query); await cmd.afterExecute(query);
return true; return true;
} }