├── bot ├── __init__.py ├── ai │ ├── __init__.py │ ├── images.py │ └── chat.py ├── commands │ ├── __init__.py │ ├── retry.py │ ├── constants.py │ ├── start.py │ ├── message.py │ ├── error.py │ ├── help.py │ ├── imagine.py │ ├── model.py │ ├── version.py │ ├── prompt.py │ └── config.py ├── shortcuts.py ├── cli.py ├── markdown.py ├── filters.py ├── fetcher.py ├── askers.py ├── questions.py ├── models.py ├── bot.py └── config.py ├── tests ├── __init__.py ├── test_shortcuts.py ├── test_filters.py ├── test_markdown.py ├── mocks.py ├── test_ai.py ├── test_fetcher.py ├── test_askers.py ├── test_questions.py ├── test_models.py ├── test_config.py └── test_commands.py ├── requirements.txt ├── Dockerfile ├── compose.yml ├── Makefile ├── .github └── workflows │ └── build.yml ├── LICENSE ├── config.example.yml ├── .gitignore └── README.md /bot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bot/ai/__init__.py: -------------------------------------------------------------------------------- 1 | from . import chat 2 | from . import images 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | httpcore==1.0.2 2 | httpx==0.25.1 3 | beautifulsoup4==4.12.2 4 | python-telegram-bot==20.6 5 | PyYAML==6.0.1 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | ARG PIP_DISABLE_PIP_VERSION_CHECK=1 4 | ARG PIP_NO_CACHE_DIR=1 5 | ENV PYTHONDONTWRITEBYTECODE=1 6 | ENV PYTHONUNBUFFERED=1 7 | 8 | WORKDIR /code 9 | COPY ./requirements.txt . 10 | RUN pip3 install -r requirements.txt 11 | 12 | COPY ./bot ./bot 13 | 14 | CMD ["python","-m","bot.bot"] 15 | -------------------------------------------------------------------------------- /compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | pokitoki: 3 | container_name: pokitoki 4 | restart: unless-stopped 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | volumes: 9 | - ./config.yml:/code/config.yml 10 | - data:/code/data 11 | volumes: 12 | data: 13 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: start stop test 2 | .SILENT: start stop test 3 | 4 | start: 5 | CONFIG=config.$(name).yml nohup env/bin/python -m bot.bot > $(name).log 2>&1 & echo $$! > $(name).pid 6 | echo "Started $(name) bot" 7 | 8 | stop: 9 | kill $(shell cat $(name).pid) 10 | rm -f $(name).pid 11 | echo "Stopped $(name) bot" 12 | 13 | test: 14 | env/bin/python -m unittest discover 15 | -------------------------------------------------------------------------------- /bot/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import BOT_COMMANDS 2 | from .config import ConfigCommand as Config 3 | from .error import ErrorCommand as Error 4 | from .help import HelpCommand as Help 5 | from .imagine import ImagineCommand as Imagine 6 | from .message import MessageCommand as Message 7 | from .model import ModelCommand as Model 8 | from .prompt import PromptCommand as Prompt 9 | from .retry import RetryCommand as Retry 10 | from .start import StartCommand as Start 11 | from .version import VersionCommand as Version 12 | -------------------------------------------------------------------------------- /bot/commands/retry.py: -------------------------------------------------------------------------------- 1 | """/retry command.""" 2 | 3 | from typing import Awaitable 4 | from telegram import Update 5 | from telegram.ext import CallbackContext 6 | from bot.models import UserData 7 | 8 | 9 | class RetryCommand: 10 | """Retries asking the last question (if any).""" 11 | 12 | def __init__(self, reply_func: Awaitable) -> None: 13 | self.reply_func = reply_func 14 | 15 | async def __call__(self, update: Update, context: CallbackContext) -> None: 16 | user = UserData(context.user_data) 17 | last_message = user.messages.pop() 18 | if not last_message: 19 | await update.message.reply_text("No message to retry 🤷‍♂️") 20 | return 21 | await self.reply_func( 22 | update=update, message=update.message, context=context, question=last_message.question 23 | ) 24 | -------------------------------------------------------------------------------- /bot/shortcuts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Working with shortcuts. 3 | A shortcut is an action that preprocesses a question before asking it of the AI. 4 | """ 5 | 6 | import re 7 | from bot.config import config 8 | 9 | shortcut_re = re.compile(r"^!(\w+)\b") 10 | 11 | 12 | def extract(question: str) -> tuple[str, str]: 13 | """Extracts a shortcut from the question.""" 14 | match = shortcut_re.match(question) 15 | if not match: 16 | raise ValueError("failed to extract shortcut") 17 | name = match.group(1) 18 | question = question.removeprefix(match.group(0)).strip() 19 | return name, question 20 | 21 | 22 | def apply(name: str, question: str) -> str: 23 | """Applies a given shortcut to a text.""" 24 | prompt = config.shortcuts.get(name) 25 | if not prompt: 26 | raise ValueError(f"unknown shortcut: {name}") 27 | return f"{prompt}\n\n{question}" 28 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | paths: 7 | - .github/** 8 | - bot/** 9 | - tests/** 10 | pull_request: 11 | branches: [main] 12 | workflow_dispatch: 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.11" 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -r requirements.txt 29 | 30 | - name: Run tests 31 | run: | 32 | cp config.example.yml config.yml 33 | python -m unittest discover 34 | -------------------------------------------------------------------------------- /tests/test_shortcuts.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from bot.config import config 4 | from bot import shortcuts 5 | 6 | 7 | class ExtractTest(unittest.TestCase): 8 | def test_extract(self): 9 | name, question = shortcuts.extract("!translate Ciao") 10 | self.assertEqual(name, "translate") 11 | self.assertEqual(question, "Ciao") 12 | 13 | def test_failed(self): 14 | config.shortcuts["translate"] = "Translate into English." 15 | with self.assertRaises(ValueError): 16 | shortcuts.extract("Ciao") 17 | 18 | 19 | class ApplyTest(unittest.TestCase): 20 | def test_apply(self): 21 | config.shortcuts["translate"] = "Translate into English." 22 | question = shortcuts.apply("translate", "Ciao") 23 | self.assertEqual(question, "Translate into English.\n\nCiao") 24 | 25 | def test_unknown_shortcut(self): 26 | with self.assertRaises(ValueError): 27 | shortcuts.apply("sing", "Ciao") 28 | -------------------------------------------------------------------------------- /bot/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command-line interface to the OpenAI API. 3 | 4 | Usage example: 5 | $ python -m bot.cli "What is your name?" 6 | """ 7 | 8 | import asyncio 9 | import os 10 | import sys 11 | import textwrap 12 | 13 | from bot.config import config 14 | from bot.fetcher import Fetcher 15 | import bot.ai.chat 16 | 17 | 18 | async def main(question): 19 | print(f"> {question}") 20 | fetcher = Fetcher() 21 | question = await fetcher.substitute_urls(question) 22 | ai = init_model() 23 | answer = await ai.ask(prompt=config.openai.prompt, question=question, history=[]) 24 | await fetcher.close() 25 | lines = textwrap.wrap(answer, width=60) 26 | for line in lines: 27 | print(line) 28 | 29 | 30 | def init_model(): 31 | name = os.getenv("OPENAI_MODEL") or config.openai.model 32 | return bot.ai.chat.Model(name) 33 | 34 | 35 | if __name__ == "__main__": 36 | if len(sys.argv) == 0: 37 | exit(1) 38 | asyncio.run(main(sys.argv[1])) 39 | -------------------------------------------------------------------------------- /bot/commands/constants.py: -------------------------------------------------------------------------------- 1 | """Bot command constants.""" 2 | 3 | HELP_MESSAGE = """Send me a question, and I will do my best to answer it. Please be specific, as I'm not very clever. 4 | 5 | I don't remember chat context by default. To ask follow-up questions, reply to my messages or start your questions with a '+' sign. 6 | 7 | Built-in commands: 8 | {commands}{admin_commands} 9 | 10 | AI shortcuts: 11 | {shortcuts} 12 | 13 | [More features →](https://github.com/nalgeon/pokitoki#readme) 14 | """ 15 | 16 | PRIVACY_MESSAGE = ( 17 | "⚠️ The bot does not have access to group messages, " 18 | "so it cannot reply in groups. Use @botfather " 19 | "to give the bot access (Bot Settings > Group Privacy > Turn off)" 20 | ) 21 | 22 | BOT_COMMANDS = [ 23 | ("retry", "retry the last question"), 24 | ("imagine", "generate described image"), 25 | ("version", "show debug information"), 26 | ("help", "show help"), 27 | ] 28 | 29 | ADMIN_COMMANDS = { 30 | "config": "view or edit the config", 31 | } 32 | -------------------------------------------------------------------------------- /bot/ai/images.py: -------------------------------------------------------------------------------- 1 | """OpenAI-compatible image generation model.""" 2 | 3 | import httpx 4 | from bot.config import config 5 | 6 | client = httpx.AsyncClient(timeout=60.0) 7 | 8 | 9 | class Model: 10 | """AI API wrapper.""" 11 | 12 | async def imagine(self, prompt: str, size: str) -> str: 13 | """Generates an image of the specified size according to the description.""" 14 | response = await client.post( 15 | f"{config.openai.url}/images/generations", 16 | headers={"Authorization": f"Bearer {config.openai.api_key}"}, 17 | json={ 18 | "model": config.openai.image_model, 19 | "prompt": prompt, 20 | "size": size, 21 | "n": 1, 22 | }, 23 | ) 24 | resp = response.json() 25 | if "data" not in resp: 26 | raise Exception(resp) 27 | if len(resp["data"]) == 0: 28 | raise Exception("received an empty answer") 29 | return resp["data"][0]["url"] 30 | -------------------------------------------------------------------------------- /bot/commands/start.py: -------------------------------------------------------------------------------- 1 | """/start command.""" 2 | 3 | from telegram import Update 4 | from telegram.ext import CallbackContext 5 | from telegram.constants import ParseMode 6 | 7 | from bot.config import config 8 | from . import constants 9 | from . import help 10 | 11 | 12 | class StartCommand: 13 | """Answers the `start` command.""" 14 | 15 | async def __call__(self, update: Update, context: CallbackContext) -> None: 16 | if update.effective_user.username not in config.telegram.usernames: 17 | text = ( 18 | "Sorry, I don't know you. To setup your own bot, " 19 | "visit https://github.com/nalgeon/pokitoki" 20 | ) 21 | await update.message.reply_text(text) 22 | return 23 | 24 | text = "Hi! I'm a humble AI-driven chat bot.\n\n" 25 | text += help.generate_message(update.effective_user.username) 26 | if not context.bot.can_read_all_group_messages: 27 | text += f"\n\n{constants.PRIVACY_MESSAGE}" 28 | await update.message.reply_text( 29 | text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Karim Iskakov, Anton Zhiyanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bot/commands/message.py: -------------------------------------------------------------------------------- 1 | """Text message handler.""" 2 | 3 | import logging 4 | from typing import Awaitable 5 | from telegram import Chat, Update 6 | from telegram.ext import CallbackContext 7 | from bot import questions 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class MessageCommand: 13 | """Answers a question from the user.""" 14 | 15 | def __init__(self, reply_func: Awaitable) -> None: 16 | self.reply_func = reply_func 17 | 18 | async def __call__(self, update: Update, context: CallbackContext) -> None: 19 | message = update.message or update.edited_message 20 | logger.debug(update) 21 | 22 | # the bot is meant to answer questions in private chats, 23 | # but it can also answer a specific question in a group when mentioned 24 | if message.chat.type == Chat.PRIVATE: 25 | question = await questions.extract_private(message, context) 26 | else: 27 | question, message = await questions.extract_group(message, context) 28 | 29 | if not question: 30 | # this is not a question to the bot, so ignore it 31 | return 32 | 33 | await self.reply_func(update=update, message=message, context=context, question=question) 34 | -------------------------------------------------------------------------------- /bot/markdown.py: -------------------------------------------------------------------------------- 1 | """Markdown/HTML text formatting.""" 2 | 3 | import re 4 | 5 | # Code blocks, e.g.: 6 | # ```sql 7 | # select count(*) from messages; 8 | # ``` 9 | pre_re = re.compile(r"^[ ]*```\w*$(.+?)^```$", re.MULTILINE | re.DOTALL) 10 | 11 | # Inline code, e.g.: 12 | # `print(message)` displays the message. 13 | code_re = re.compile(r"`([^`\n]+)`") 14 | 15 | # Bold text, e.g.: 16 | # **Note**. Cats **do** like boxes. 17 | bold_re = re.compile(r"\*\*([^<*]+?)\*\*") 18 | 19 | # Unordered list items, e.g.: 20 | # * Wake up. 21 | # * Have breakfast. 22 | # * Go to work. 23 | bullet_re = re.compile(r"^\*\s\s+(.+)$", re.MULTILINE) 24 | 25 | 26 | def to_html(text: str) -> str: 27 | """ 28 | Converts Markdown text to "Telegram HTML", which supports only some of the tags. 29 | See https://core.telegram.org/bots/api#html-style for details. 30 | Escapes certain entities and converts `code` and `pre`, 31 | but ignores all other formatting. 32 | """ 33 | text = text.replace("&", "&").replace("<", "<").replace(">", ">") 34 | text = pre_re.sub(r"
\1
", text) 35 | text = code_re.sub(r"\1", text) 36 | text = bold_re.sub(r"\1", text) 37 | text = bullet_re.sub(r"— \1", text) 38 | return text 39 | -------------------------------------------------------------------------------- /bot/commands/error.py: -------------------------------------------------------------------------------- 1 | """Generic error handler.""" 2 | 3 | import logging 4 | import textwrap 5 | from telegram import Chat, Update 6 | from telegram.ext import CallbackContext 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ErrorCommand: 12 | """If the bot failed to answer, prints the error and the stack trace (if any).""" 13 | 14 | async def __call__(self, update: Update, context: CallbackContext) -> None: 15 | if not update: 16 | # telegram.error.NetworkError or a similar error, there is no chat to respond to. 17 | # Not sure if we should completely silence such errors. 18 | logger.warning("General exception: %s:", context.error) 19 | return 20 | 21 | class_name = f"{context.error.__class__.__module__}.{context.error.__class__.__qualname__}" 22 | error_text = f"{class_name}: {context.error}" 23 | logger.warning("Exception while handling an update %s: %s", update, error_text) 24 | text = textwrap.shorten(f"⚠️ {error_text}", width=255, placeholder="...") 25 | 26 | message = update.message 27 | reply_to_message_id = message.id if message and message.chat.type != Chat.PRIVATE else None 28 | await context.bot.send_message( 29 | update.effective_chat.id, text, reply_to_message_id=reply_to_message_id 30 | ) 31 | -------------------------------------------------------------------------------- /bot/commands/help.py: -------------------------------------------------------------------------------- 1 | """/help command.""" 2 | 3 | from telegram import Update 4 | from telegram.ext import CallbackContext 5 | from telegram.constants import ParseMode 6 | from bot.config import config 7 | from . import constants 8 | 9 | 10 | class HelpCommand: 11 | """Answers the `help` command.""" 12 | 13 | async def __call__(self, update: Update, context: CallbackContext) -> None: 14 | text = generate_message(update.effective_user.username) 15 | await update.message.reply_text( 16 | text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True 17 | ) 18 | 19 | 20 | def generate_message(username: str) -> str: 21 | """Generates a help message, including a list of allowed commands.""" 22 | 23 | # user commands 24 | commands = "\n".join(f"/{cmd} - {descr}" for cmd, descr in constants.BOT_COMMANDS) 25 | 26 | # admin commands 27 | admin_commands = "" 28 | if username in config.telegram.admins: 29 | admin_commands += "\n\nAdmin-only commads:\n" 30 | admin_commands += f"/config - {constants.ADMIN_COMMANDS['config']}\n" 31 | admin_commands = admin_commands.rstrip() 32 | 33 | # shortcuts 34 | if config.shortcuts: 35 | shortcuts = "\n".join(f"`!{shortcut}`" for shortcut in config.shortcuts) 36 | else: 37 | shortcuts = "none" 38 | 39 | return constants.HELP_MESSAGE.format( 40 | commands=commands, admin_commands=admin_commands, shortcuts=shortcuts 41 | ) 42 | -------------------------------------------------------------------------------- /bot/commands/imagine.py: -------------------------------------------------------------------------------- 1 | """/imagine command.""" 2 | 3 | from typing import Awaitable 4 | from telegram import Update 5 | from telegram.ext import CallbackContext 6 | from telegram.constants import ParseMode 7 | from bot.config import config 8 | 9 | 10 | class ImagineCommand: 11 | """Generates an image according to the description.""" 12 | 13 | def __init__(self, reply_func: Awaitable) -> None: 14 | self.reply_func = reply_func 15 | 16 | async def __call__(self, update: Update, context: CallbackContext) -> None: 17 | message = update.message or update.edited_message 18 | 19 | if config.imagine.enabled == "none": 20 | await message.reply_text( 21 | "The `imagine` command is disabled.", parse_mode=ParseMode.MARKDOWN 22 | ) 23 | return 24 | 25 | if ( 26 | config.imagine.enabled == "users_only" 27 | and update.effective_user.username not in config.telegram.usernames 28 | ): 29 | await message.reply_text( 30 | "The `imagine` command is disabled for group users.", parse_mode=ParseMode.MARKDOWN 31 | ) 32 | return 33 | 34 | if not context.args: 35 | await message.reply_text( 36 | "Please describe an image. " 37 | "For example:\n/imagine a lazy cat on a sunny day", 38 | parse_mode=ParseMode.HTML, 39 | ) 40 | return 41 | await self.reply_func( 42 | update=update, message=update.message, context=context, question=message.text 43 | ) 44 | -------------------------------------------------------------------------------- /bot/commands/model.py: -------------------------------------------------------------------------------- 1 | """/model command.""" 2 | 3 | from telegram import Chat, Update 4 | from telegram.ext import CallbackContext 5 | from telegram.constants import ParseMode 6 | 7 | from bot.config import config 8 | from bot.models import ChatData 9 | 10 | HELP_MESSAGE = """Syntax: 11 | /model [AI model name] 12 | 13 | For example: 14 | /model gpt-4o 15 | 16 | To use the default model: 17 | /model reset""" 18 | 19 | RESET = "reset" 20 | 21 | 22 | class ModelCommand: 23 | """Sets an AI model.""" 24 | 25 | async def __call__(self, update: Update, context: CallbackContext) -> None: 26 | message = update.message or update.edited_message 27 | 28 | if ( 29 | message.chat.type != Chat.PRIVATE 30 | and update.effective_user.username not in config.telegram.admins 31 | ): 32 | # Only admins are allowed to change the model in group chats. 33 | return 34 | 35 | chat = ChatData(context.chat_data) 36 | _, _, model = message.text.partition(" ") 37 | if not model: 38 | # /model without arguments 39 | if chat.model: 40 | # the model is already set, show it 41 | await message.reply_text( 42 | f"Using model:\n{chat.model}", 43 | parse_mode=ParseMode.HTML, 44 | ) 45 | return 46 | else: 47 | # the model is not set, show help message 48 | await message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML) 49 | return 50 | 51 | if model == RESET: 52 | # /model with "reset" argument 53 | chat.model = "" 54 | await message.reply_text( 55 | f"✓ Using default model:\n{config.openai.model}", 56 | parse_mode=ParseMode.HTML, 57 | ) 58 | return 59 | 60 | # /model with a name 61 | chat.model = model 62 | await message.reply_text( 63 | f"✓ Set model:\n{model}", 64 | parse_mode=ParseMode.HTML, 65 | ) 66 | -------------------------------------------------------------------------------- /bot/filters.py: -------------------------------------------------------------------------------- 1 | """Bot message filters.""" 2 | 3 | from typing import Union 4 | from dataclasses import dataclass 5 | from telegram.ext import filters 6 | from bot.config import config 7 | 8 | 9 | @dataclass 10 | class Filters: 11 | """Filters for the incoming Telegram messages.""" 12 | 13 | users: Union[filters.MessageFilter, filters.User] 14 | admins: filters.User 15 | chats: Union[filters.MessageFilter, filters.Chat] 16 | 17 | users_or_chats: filters.BaseFilter 18 | admins_private: filters.BaseFilter 19 | messages: filters.BaseFilter 20 | 21 | def __init__(self) -> None: 22 | """Defines users and chats that are allowed to use the bot.""" 23 | if config.telegram.usernames: 24 | self.users = filters.User(username=config.telegram.usernames) 25 | self.chats = filters.Chat(chat_id=config.telegram.chat_ids) 26 | else: 27 | self.users = filters.ALL 28 | self.chats = filters.ALL 29 | 30 | if config.telegram.admins: 31 | self.admins = filters.User(username=config.telegram.admins) 32 | else: 33 | self.admins = filters.User(username=[]) 34 | 35 | self.users_or_chats = self.users | self.chats 36 | self.admins_private = self.admins & filters.ChatType.PRIVATE 37 | self.messages = ( 38 | (filters.TEXT | filters.Document.TEXT) & ~filters.COMMAND & self.users_or_chats 39 | ) 40 | 41 | def reload(self) -> None: 42 | """Reloads users and chats from config.""" 43 | if self.users == filters.ALL and config.telegram.usernames: 44 | # cannot update the filter from ALL to specific usernames without a restart 45 | raise Exception("Restart the bot for changes to take effect") 46 | self.users.usernames = config.telegram.usernames 47 | self.chats.chat_ids = config.telegram.chat_ids 48 | self.admins.usernames = config.telegram.admins 49 | 50 | def is_known_user(self, username: str) -> bool: 51 | """Checks if the username is included in the `users` filter.""" 52 | if self.users == filters.ALL: 53 | return False 54 | return username in self.users.usernames 55 | -------------------------------------------------------------------------------- /bot/commands/version.py: -------------------------------------------------------------------------------- 1 | """/version command.""" 2 | 3 | from urllib import parse 4 | 5 | from telegram import Update 6 | from telegram.ext import CallbackContext 7 | from telegram.constants import ParseMode 8 | 9 | from bot.config import config 10 | from . import constants 11 | 12 | 13 | class VersionCommand: 14 | """Answers the `version` command.""" 15 | 16 | async def __call__(self, update: Update, context: CallbackContext) -> None: 17 | chat = update.message.chat 18 | # chat information 19 | text = ( 20 | "
"
21 |             "Chat information:\n"
22 |             f"- id: {chat.id}\n"
23 |             f"- title: {chat.title}\n"
24 |             f"- type: {chat.type}"
25 |             "
" 26 | ) 27 | bot = await context.bot.get_me() 28 | usernames = ( 29 | "all" if not config.telegram.usernames else f"{len(config.telegram.usernames)} users" 30 | ) 31 | admins = "none" if not config.telegram.admins else f"{len(config.telegram.admins)} users" 32 | 33 | # bot information 34 | text += ( 35 | "\n\n
"
36 |             "Bot information:\n"
37 |             f"- id: {bot.id}\n"
38 |             f"- name: {bot.name}\n"
39 |             f"- version: {config.version}\n"
40 |             f"- usernames: {usernames}\n"
41 |             f"- admins: {admins}\n"
42 |             f"- chat IDs: {config.telegram.chat_ids}\n"
43 |             f"- access to messages: {bot.can_read_all_group_messages}"
44 |             "
" 45 | ) 46 | if not bot.can_read_all_group_messages: 47 | text += f"\n\n{constants.PRIVACY_MESSAGE}" 48 | 49 | # AI information 50 | provider = parse.urlparse(config.openai.url).hostname 51 | text += ( 52 | "\n\n
"
53 |             "AI information:\n"
54 |             f"- provider: {provider}\n"
55 |             f"- model: {config.openai.model}\n"
56 |             f"- history depth: {config.conversation.depth}\n"
57 |             f"- imagine: {config.imagine.enabled}\n"
58 |             f"- shortcuts: {', '.join(config.shortcuts.keys())}"
59 |             "
" 60 | ) 61 | await update.message.reply_text(text, parse_mode=ParseMode.HTML) 62 | -------------------------------------------------------------------------------- /bot/commands/prompt.py: -------------------------------------------------------------------------------- 1 | """/prompt command.""" 2 | 3 | from telegram import Chat, Update 4 | from telegram.ext import CallbackContext 5 | from telegram.constants import ParseMode 6 | 7 | from bot.config import config 8 | from bot.models import ChatData 9 | 10 | HELP_MESSAGE = """Syntax: 11 | /prompt [custom prompt] 12 | 13 | For example: 14 | /prompt You are an evil genius. Reply with an evil laugh. 15 | 16 | To use the default prompt: 17 | /prompt reset""" 18 | 19 | RESET = "reset" 20 | 21 | 22 | class PromptCommand: 23 | """Sets a custom chat prompt.""" 24 | 25 | async def __call__(self, update: Update, context: CallbackContext) -> None: 26 | message = update.message or update.edited_message 27 | 28 | if ( 29 | message.chat.type != Chat.PRIVATE 30 | and update.effective_user.username not in config.telegram.admins 31 | ): 32 | # Only admins are allowed to change the prompt in group chats. 33 | return 34 | 35 | chat = ChatData(context.chat_data) 36 | _, _, prompt = message.text.partition(" ") 37 | if not prompt: 38 | # /prompt without arguments 39 | if chat.prompt: 40 | # custom prompt is already set, show it 41 | await message.reply_text( 42 | f"Using custom prompt:\n{chat.prompt}", 43 | parse_mode=ParseMode.HTML, 44 | ) 45 | return 46 | else: 47 | # custom prompt is not set, show help message 48 | await message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML) 49 | return 50 | 51 | if prompt == RESET: 52 | # /prompt with "reset" argument 53 | chat.prompt = "" 54 | await message.reply_text( 55 | f"✓ Using default prompt:\n{config.openai.prompt}", 56 | parse_mode=ParseMode.HTML, 57 | ) 58 | return 59 | 60 | # /prompt with a custom prompt 61 | chat.prompt = prompt 62 | await message.reply_text( 63 | f"✓ Set custom prompt:\n{prompt}", 64 | parse_mode=ParseMode.HTML, 65 | ) 66 | -------------------------------------------------------------------------------- /tests/test_filters.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from telegram.ext import filters as tg_filters 4 | from bot.config import config 5 | from bot.filters import Filters 6 | 7 | 8 | class FiltersTest(unittest.TestCase): 9 | def test_init(self): 10 | config.telegram.usernames = ["alice", "bob"] 11 | config.telegram.chat_ids = [-100, -200] 12 | config.telegram.admins = ["admin"] 13 | filters = Filters() 14 | self.assertEqual(filters.users.usernames, set(["alice", "bob"])) 15 | self.assertEqual(filters.chats.chat_ids, set([-100, -200])) 16 | self.assertEqual(filters.admins.usernames, set(["admin"])) 17 | 18 | def test_reload(self): 19 | config.telegram.usernames = ["alice", "bob"] 20 | config.telegram.chat_ids = [-100, -200] 21 | config.telegram.admins = ["admin"] 22 | filters = Filters() 23 | 24 | config.telegram.usernames = ["alice", "bob", "cindy"] 25 | config.telegram.chat_ids = [-300] 26 | config.telegram.admins = ["zappa", "xanos"] 27 | filters.reload() 28 | self.assertEqual(filters.users.usernames, set(["alice", "bob", "cindy"])) 29 | self.assertEqual(filters.chats.chat_ids, set([-300])) 30 | self.assertEqual(filters.admins.usernames, set(["zappa", "xanos"])) 31 | 32 | def test_is_known_user(self): 33 | config.telegram.usernames = ["alice", "bob"] 34 | filters = Filters() 35 | self.assertTrue(filters.is_known_user("alice")) 36 | self.assertFalse(filters.is_known_user("cindy")) 37 | 38 | 39 | class EmptyTest(unittest.TestCase): 40 | def test_init(self): 41 | config.telegram.usernames = [] 42 | config.telegram.chat_ids = [-100, -200] 43 | config.telegram.admins = ["admin"] 44 | filters = Filters() 45 | self.assertEqual(filters.users, tg_filters.ALL) 46 | self.assertEqual(filters.chats, tg_filters.ALL) 47 | self.assertEqual(filters.admins.usernames, set(["admin"])) 48 | 49 | def test_reload(self): 50 | config.telegram.usernames = [] 51 | filters = Filters() 52 | config.telegram.usernames = ["alice", "bob"] 53 | with self.assertRaises(Exception): 54 | filters.reload() 55 | 56 | def test_is_known_user(self): 57 | config.telegram.usernames = [] 58 | filters = Filters() 59 | self.assertFalse(filters.is_known_user("alice")) 60 | self.assertFalse(filters.is_known_user("cindy")) 61 | -------------------------------------------------------------------------------- /bot/commands/config.py: -------------------------------------------------------------------------------- 1 | """/config command.""" 2 | 3 | from telegram import Update 4 | from telegram.ext import CallbackContext 5 | from telegram.constants import ParseMode 6 | 7 | from bot.config import config, ConfigEditor 8 | from bot.filters import Filters 9 | 10 | HELP_MESSAGE = """Syntax: 11 | /config property [value] 12 | 13 | E.g. to view the property value: 14 | /config openai.prompt 15 | 16 | E.g. to change the property value: 17 | /config openai.prompt You are an AI assistant""" 18 | 19 | editor = ConfigEditor(config) 20 | 21 | 22 | class ConfigCommand: 23 | """Displays or changes config properties.""" 24 | 25 | def __init__(self, filters: Filters) -> None: 26 | self.filters = filters 27 | 28 | async def __call__(self, update: Update, context: CallbackContext) -> None: 29 | message = update.message or update.edited_message 30 | 31 | parts = message.text.split() 32 | if len(parts) == 1: 33 | # /config without arguments 34 | await message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML) 35 | return 36 | 37 | property = parts[1] 38 | value = editor.get_value(property) 39 | value = value if value is not None else "(empty)" 40 | 41 | if len(parts) == 2: 42 | # view config property (`/config {property}`) 43 | await message.reply_text(f"`{value}`", parse_mode=ParseMode.MARKDOWN) 44 | return 45 | 46 | # change config property (`/config {property} {new_value}`) 47 | new_value = " ".join(parts[2:]) 48 | has_changed, is_immediate, new_val = editor.set_value(property, new_value) 49 | 50 | if not has_changed: 51 | text = f"✗ The `{property}` property already equals to `{new_value}`" 52 | await message.reply_text(text, parse_mode=ParseMode.MARKDOWN) 53 | return 54 | 55 | editor.save() 56 | if self._should_reload_filters(property): 57 | self.filters.reload() 58 | 59 | text = f"✓ Changed the `{property}` property: `{value}` → `{new_val}`" 60 | if not is_immediate: 61 | text += "\n❗️Restart the bot for changes to take effect." 62 | await message.reply_text(text, parse_mode=ParseMode.MARKDOWN) 63 | 64 | def _should_reload_filters(self, property: str) -> bool: 65 | return property in ( 66 | "telegram.usernames", 67 | "telegram.chat_ids", 68 | "telegram.admins", 69 | ) 70 | -------------------------------------------------------------------------------- /tests/test_markdown.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from bot import markdown 3 | 4 | TEXT_MD = r"""You can easily regular expressions them using the `sqlean-regexp` extension. 5 | 6 | > **Note**. Unlike other DBMS, adding extensions to SQLite is a breeze. 7 | 8 | With `sqlean-regexp`, matching a string against a pattern becomes as easy as: 9 | 10 | ```sql 11 | select count(*) from messages 12 | where msg_text regexp '\d+'; 13 | ``` 14 | 15 | `regexp_like(source, pattern)` checks if the source string matches the pattern: 16 | 17 | ```sql 18 | select regexp_like('Meet me at 10:30', '\d+:\d+'); 19 | select 10 > 5 = true; 20 | ``` 21 | 22 | See [Documentation](https://github.com/nalgeon/sqlean) for reference. 23 | """ 24 | 25 | TEXT_HTML = r"""You can easily regular expressions them using the sqlean-regexp extension. 26 | 27 | > Note. Unlike other DBMS, adding extensions to SQLite is a breeze. 28 | 29 | With sqlean-regexp, matching a string against a pattern becomes as easy as: 30 | 31 |
32 | select count(*) from messages
33 | where msg_text regexp '\d+';
34 | 
35 | 36 | regexp_like(source, pattern) checks if the source string matches the pattern: 37 | 38 |
39 | select regexp_like('Meet me at 10:30', '\d+:\d+');
40 | select 10 > 5 = true;
41 | 
42 | 43 | See [Documentation](https://github.com/nalgeon/sqlean) for reference. 44 | """ 45 | 46 | 47 | class Test(unittest.TestCase): 48 | def test_to_html(self): 49 | text = markdown.to_html(TEXT_MD) 50 | self.assertEqual(text, TEXT_HTML) 51 | 52 | def test_ticks(self): 53 | text = markdown.to_html("one `two` three") 54 | self.assertEqual(text, "one two three") 55 | text = markdown.to_html("one `two three") 56 | self.assertEqual(text, "one `two three") 57 | text = markdown.to_html("one `two\n` three") 58 | self.assertEqual(text, "one `two\n` three") 59 | 60 | def test_bold(self): 61 | text = markdown.to_html("one **two** three") 62 | self.assertEqual(text, "one two three") 63 | text = markdown.to_html("one *two* three") 64 | self.assertEqual(text, "one *two* three") 65 | text = markdown.to_html("one **two three") 66 | self.assertEqual(text, "one **two three") 67 | text = markdown.to_html("operator `**` is so ** powerful") 68 | self.assertEqual(text, "operator ** is so ** powerful") 69 | 70 | def test_bullet(self): 71 | text = markdown.to_html("* one two three") 72 | self.assertEqual(text, "— one two three") 73 | text = markdown.to_html("* one two three") 74 | self.assertEqual(text, "* one two three") 75 | text = markdown.to_html(" * one two three") 76 | self.assertEqual(text, " * one two three") 77 | -------------------------------------------------------------------------------- /tests/mocks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from telegram import User 3 | from bot import askers 4 | 5 | 6 | class FakeGPT: 7 | def __init__(self, error: Optional[Exception] = None): 8 | self.error = error 9 | self.prompt = None 10 | self.question = None 11 | self.history = None 12 | 13 | async def ask(self, prompt: str, question: str, history: list) -> str: 14 | self.prompt = prompt 15 | self.question = question 16 | self.history = history 17 | if self.error: 18 | raise self.error 19 | return question 20 | 21 | 22 | class FakeDalle: 23 | def __init__(self, error: Optional[Exception] = None): 24 | self.error = error 25 | self.prompt = None 26 | self.size = None 27 | 28 | async def imagine(self, prompt: str, size: str) -> str: 29 | self.prompt = prompt 30 | self.size = size 31 | if self.error: 32 | raise self.error 33 | return "image" 34 | 35 | 36 | class FakeFile: 37 | def __init__(self, file_id: str) -> None: 38 | self.file_id = file_id 39 | 40 | async def download_as_bytearray(self, buf=None, **kwargs) -> bytearray: 41 | return bytearray(b"file content") 42 | 43 | 44 | class FakeBot: 45 | def __init__(self, username: str) -> None: 46 | self.user = User( 47 | id=42, 48 | first_name=username, 49 | is_bot=True, 50 | username=username, 51 | can_read_all_group_messages=True, 52 | ) 53 | self.text = "" 54 | 55 | @property 56 | def username(self) -> str: 57 | return self.user.username 58 | 59 | @property 60 | def name(self) -> str: 61 | return f"@{self.username}" 62 | 63 | @property 64 | def can_read_all_group_messages(self) -> bool: 65 | return self.user.can_read_all_group_messages 66 | 67 | async def get_file(self, file_id, **kwargs): 68 | return FakeFile(file_id) 69 | 70 | async def send_chat_action(self, **kwargs) -> None: 71 | pass 72 | 73 | async def send_message(self, chat_id: int, text: str, **kwargs) -> None: 74 | self.text = text 75 | 76 | async def send_document( 77 | self, chat_id: int, document: object, caption: str, filename: str, **kwargs 78 | ) -> None: 79 | self.text = f"{caption}: {filename}" 80 | 81 | async def send_photo(self, chat_id: int, photo: str, caption: str = None, **kwargs) -> None: 82 | self.text = f"{caption}: {photo}" 83 | 84 | async def get_me(self, **kwargs) -> User: 85 | return self.user 86 | 87 | 88 | class FakeApplication: 89 | def __init__(self, bot: FakeBot) -> None: 90 | self.chat_data = {1: {}} 91 | self.user_data = {1: {}} 92 | self.bot = bot 93 | 94 | 95 | def mock_text_asker(ai: FakeGPT) -> None: 96 | mock_init = lambda asker, _: setattr(asker, "model", ai) 97 | askers.TextAsker.__init__ = mock_init 98 | -------------------------------------------------------------------------------- /bot/fetcher.py: -------------------------------------------------------------------------------- 1 | """Retrieves remote content over HTTP.""" 2 | 3 | import re 4 | import httpx 5 | from bs4 import BeautifulSoup 6 | 7 | 8 | class Fetcher: 9 | """Retrieves remote content over HTTP.""" 10 | 11 | # Matches non-quoted URLs in text 12 | url_re = re.compile(r"(?:[^'\"]|^)\b(https?://\S+)\b(?:[^'\"]|$)") 13 | timeout = 3 # seconds 14 | 15 | def __init__(self): 16 | self.client = httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) 17 | 18 | async def substitute_urls(self, text: str) -> str: 19 | """ 20 | Extracts URLs from text, fetches their contents, 21 | and appends the contents to the text. 22 | """ 23 | urls = self._extract_urls(text) 24 | for url in urls: 25 | content = await self._fetch_url(url) 26 | text += f"\n\n---\n{url} contents:\n\n{content}\n---" 27 | return text 28 | 29 | async def close(self) -> None: 30 | """Frees network connections.""" 31 | await self.client.aclose() 32 | 33 | def _extract_urls(self, text: str) -> list[str]: 34 | """Extracts URLs from text.""" 35 | urls = self.url_re.findall(text) 36 | return urls 37 | 38 | async def _fetch_url(self, url: str) -> str: 39 | """Retrieves URL content and returns it as text.""" 40 | try: 41 | response = await self.client.get(url) 42 | response.raise_for_status() 43 | content = Content(response) 44 | return content.extract_text() 45 | except Exception as exc: 46 | class_name = f"{exc.__class__.__module__}.{exc.__class__.__qualname__}" 47 | return f"Failed to fetch ({class_name})" 48 | 49 | 50 | class Content: 51 | """Extracts resource content as human-readable text.""" 52 | 53 | allowed_content_types = set( 54 | [ 55 | "application/json", 56 | "application/sql", 57 | "application/xml", 58 | ] 59 | ) 60 | 61 | def __init__(self, response: httpx.Response) -> None: 62 | self.response = response 63 | content_type, _, _ = response.headers.get("content-type").partition(";") 64 | self.content_type = content_type 65 | 66 | def extract_text(self) -> str: 67 | """Extracts resource content as human-readable text.""" 68 | if not self.is_text(): 69 | return "Unknown binary content" 70 | if self.content_type != "text/html": 71 | return self.response.text 72 | html = BeautifulSoup(self.response.text, "html.parser") 73 | article = html.find("main") or html.find("body") 74 | return article.get_text() 75 | 76 | def is_text(self) -> bool: 77 | """Checks if the content type is plain text.""" 78 | if not self.content_type: 79 | return False 80 | if self.content_type.startswith("text/"): 81 | return True 82 | if self.content_type in self.allowed_content_types: 83 | return True 84 | return False 85 | -------------------------------------------------------------------------------- /config.example.yml: -------------------------------------------------------------------------------- 1 | # Bot configuration parameters. 2 | 3 | # Config schema version. Increments for backward-incompatible changes. 4 | schema_version: 3 5 | 6 | # Telegram settings. 7 | telegram: 8 | # Telegram Bot API token. 9 | token: "" 10 | 11 | # The list of Telegram usernames allowed to chat with the bot. 12 | # If empty, the bot will be available to anyone. 13 | usernames: [] 14 | 15 | # The list of Telegram usernames allowed to change bot settings. 16 | # If empty, then no one is allowed to change the settings. 17 | admins: [] 18 | 19 | # The list of Telegram group ids, whose members are allowed to chat with the bot. 20 | # If empty, the bot will only be available to `telegram.usernames`. 21 | chat_ids: [] 22 | 23 | # AI settings. 24 | openai: 25 | # AI API endpoint. You can use any OpenAI-compatible provider like 26 | # OpenRouter (https://openrouter.ai/api/v1) or Nebius (https://api.studio.nebius.ai/v1). 27 | url: "https://api.openai.com/v1" 28 | 29 | # AI API key. 30 | api_key: "" 31 | 32 | # Chat model name. 33 | # See https://platform.openai.com/docs/models for description. 34 | model: "gpt-4o-mini" 35 | 36 | # Image generation model name. 37 | image_model: "dall-e-3" 38 | 39 | # Context window size in tokens. 40 | # Applies only to custom models. 41 | window: 128000 42 | 43 | # Model prompt. 44 | prompt: "You are an AI assistant." 45 | 46 | # Model parameters. 47 | # See https://platform.openai.com/docs/api-reference/chat/create for description. 48 | params: 49 | temperature: 0.7 50 | max_tokens: 4096 51 | 52 | conversation: 53 | # The maximum number of previous messages 54 | # the bot will remember when talking to a user. 55 | depth: 3 56 | 57 | # The maximum number of messages a user can send during the specified time period. 58 | # Only applies to users who are not listed in `telegram.usernames`. 59 | # `count` = an integer number of messages (0 = unlimited) 60 | # `period` = minute | hour | day 61 | message_limit: 62 | count: 0 63 | period: hour 64 | 65 | # Image generation settings. 66 | imagine: 67 | # Enable/disable image generation: 68 | # - none = disabled for everyone 69 | # - users_only = enabled only for users listed in `telegram.usernames` 70 | # - users_and_groups = enabled for both users listed in `telegram.usernames` 71 | # and members of `telegrams.chat_ids` 72 | enabled: none 73 | 74 | # Where to store the chat context file. 75 | persistence_path: "./data/persistence.pkl" 76 | 77 | # Custom AI commands (additional prompts) 78 | shortcuts: 79 | bugfix: "Examine the following code. Rewrite it if necessary to fix bugs and various problems. Explain the changes you've made." 80 | 81 | proofread: "Proofread the following text. Correct grammar and punctuation errors. Rephrase if necessary. Make sure the resulting text is clear, concise, and easy to read. Explain the changes you've made." 82 | 83 | summarize: "Explain the following text in simple terms. Use no more than two paragraphs." 84 | 85 | translate: "Translate the following text into English." 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Custom 156 | config.yml 157 | *.pkl -------------------------------------------------------------------------------- /tests/test_ai.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from bot.config import config 3 | from bot.ai import chat 4 | from bot.models import UserMessage 5 | 6 | 7 | class ModelTest(unittest.TestCase): 8 | def setUp(self) -> None: 9 | self.model = chat.Model("gpt") 10 | 11 | def test_generate_messages(self): 12 | history = [UserMessage("Hello", "Hi"), UserMessage("Is it cold today?", "Yep!")] 13 | messages = self.model._generate_messages( 14 | prompt_role="system", prompt="", question="What's your name?", history=history 15 | ) 16 | self.assertEqual(len(messages), 6) 17 | 18 | self.assertEqual(messages[0]["role"], "system") 19 | self.assertEqual(messages[0]["content"], config.openai.prompt) 20 | 21 | self.assertEqual(messages[1]["role"], "user") 22 | self.assertEqual(messages[1]["content"], "Hello") 23 | 24 | self.assertEqual(messages[2]["role"], "assistant") 25 | self.assertEqual(messages[2]["content"], "Hi") 26 | 27 | self.assertEqual(messages[3]["role"], "user") 28 | self.assertEqual(messages[3]["content"], "Is it cold today?") 29 | 30 | self.assertEqual(messages[4]["role"], "assistant") 31 | self.assertEqual(messages[4]["content"], "Yep!") 32 | 33 | self.assertEqual(messages[5]["role"], "user") 34 | self.assertEqual(messages[5]["content"], "What's your name?") 35 | 36 | 37 | class ShortenTest(unittest.TestCase): 38 | def test_do_not_shorten(self): 39 | messages = [ 40 | {"role": "system", "content": "You are an AI assistant."}, 41 | {"role": "user", "content": "Hello"}, 42 | ] 43 | shortened = chat.shorten(messages, length=11) 44 | self.assertEqual(shortened, messages) 45 | 46 | def test_remove_messages_1(self): 47 | messages = [ 48 | {"role": "system", "content": "You are an AI assistant."}, 49 | {"role": "user", "content": "What is your name?"}, 50 | {"role": "assistant", "content": "My name is Alice."}, 51 | {"role": "user", "content": "Is it cold today?"}, 52 | ] 53 | shortened = chat.shorten(messages, length=11) 54 | self.assertEqual( 55 | shortened, 56 | [ 57 | {"role": "system", "content": "You are an AI assistant."}, 58 | {"role": "user", "content": "Is it cold today?"}, 59 | ], 60 | ) 61 | 62 | def test_remove_messages_2(self): 63 | messages = [ 64 | {"role": "system", "content": "You are an AI assistant."}, 65 | {"role": "user", "content": "What is your name?"}, 66 | {"role": "assistant", "content": "My name is Alice."}, 67 | {"role": "user", "content": "Is it cold today?"}, 68 | ] 69 | shortened = chat.shorten(messages, length=15) 70 | self.assertEqual( 71 | shortened, 72 | [ 73 | {"role": "system", "content": "You are an AI assistant."}, 74 | {"role": "assistant", "content": "My name is Alice."}, 75 | {"role": "user", "content": "Is it cold today?"}, 76 | ], 77 | ) 78 | 79 | def test_shorten_question(self): 80 | messages = [ 81 | {"role": "system", "content": "You are an AI assistant."}, 82 | {"role": "user", "content": "Is it cold today? I think it's rather cold"}, 83 | ] 84 | shortened = chat.shorten(messages, length=10) 85 | self.assertEqual( 86 | shortened, 87 | [ 88 | {"role": "system", "content": "You are an AI assistant."}, 89 | {"role": "user", "content": "Is it cold today?"}, 90 | ], 91 | ) 92 | -------------------------------------------------------------------------------- /bot/askers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Asker is an abstraction that sends questions to the AI 3 | and responds to the user with answers provided by the AI. 4 | """ 5 | 6 | import io 7 | import re 8 | import textwrap 9 | 10 | from telegram import Chat, Message 11 | from telegram.constants import MessageLimit, ParseMode 12 | from telegram.ext import CallbackContext 13 | 14 | from bot import ai 15 | from bot import markdown 16 | 17 | 18 | class Asker: 19 | """Asks AI questions and responds with answers.""" 20 | 21 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str: 22 | """Asks AI a question.""" 23 | pass 24 | 25 | async def reply(self, message: Message, context: CallbackContext, answer: str) -> None: 26 | """Replies with an answer from AI.""" 27 | pass 28 | 29 | 30 | class TextAsker(Asker): 31 | """Works with chat completion AI.""" 32 | 33 | def __init__(self, model_name: str) -> None: 34 | self.model = ai.chat.Model(model_name) 35 | 36 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str: 37 | """Asks AI a question.""" 38 | return await self.model.ask(prompt, question, history) 39 | 40 | async def reply(self, message: Message, context: CallbackContext, answer: str) -> None: 41 | """Replies with an answer from AI.""" 42 | html_answer = markdown.to_html(answer) 43 | if len(html_answer) <= MessageLimit.MAX_TEXT_LENGTH: 44 | await message.reply_text(html_answer, parse_mode=ParseMode.HTML) 45 | return 46 | 47 | doc = io.StringIO(answer) 48 | caption = ( 49 | textwrap.shorten(answer, width=255, placeholder="...") 50 | + " (see attachment for the rest)" 51 | ) 52 | reply_to_message_id = message.id if message.chat.type != Chat.PRIVATE else None 53 | await context.bot.send_document( 54 | chat_id=message.chat_id, 55 | caption=caption, 56 | filename=f"{message.id}.md", 57 | document=doc, 58 | reply_to_message_id=reply_to_message_id, 59 | ) 60 | 61 | 62 | class ImagineAsker(Asker): 63 | """Works with image generation AI.""" 64 | 65 | model = ai.images.Model() 66 | size_re = re.compile(r"(256|512|1024)(?:x\1)?\s?(?:px)?") 67 | sizes = { 68 | "256": "256x256", 69 | "512": "512x512", 70 | "1024": "1024x1024", 71 | "1792": "1792x1024", 72 | } 73 | default_size = "1024x1024" 74 | 75 | def __init__(self) -> None: 76 | self.caption = "" 77 | 78 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str: 79 | """Asks AI a question.""" 80 | size = self._extract_size(question) 81 | self.caption = self._extract_caption(question) 82 | return await self.model.imagine(prompt=self.caption, size=size) 83 | 84 | async def reply(self, message: Message, context: CallbackContext, answer: str) -> None: 85 | """Replies with an answer from AI.""" 86 | await message.reply_photo(answer, caption=self.caption) 87 | 88 | def _extract_size(self, question: str) -> str: 89 | match = self.size_re.search(question) 90 | if not match: 91 | return self.default_size 92 | width = match.group(1) 93 | return self.sizes.get(width, width) 94 | 95 | def _extract_caption(self, question: str) -> str: 96 | caption = self.size_re.sub("", question).strip() 97 | return caption 98 | 99 | 100 | def create(model: str, question: str) -> Asker: 101 | """Creates a new asker based on the question asked.""" 102 | if question.startswith("/imagine"): 103 | return ImagineAsker() 104 | return TextAsker(model) 105 | -------------------------------------------------------------------------------- /tests/test_fetcher.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from httpx import Request, Response 3 | 4 | from bot.fetcher import Fetcher, Content 5 | 6 | 7 | class FakeClient: 8 | def __init__(self, responses: dict[str, Response | Exception]) -> None: 9 | self.responses = responses 10 | 11 | async def get(self, url: str) -> Response: 12 | request = Request(method="GET", url=url) 13 | response = self.responses[url] 14 | if isinstance(response, Exception): 15 | raise response 16 | return Response( 17 | status_code=response.status_code, 18 | headers=response.headers, 19 | text=response.text, 20 | request=request, 21 | ) 22 | 23 | 24 | class FetcherTest(unittest.IsolatedAsyncioTestCase): 25 | def setUp(self) -> None: 26 | self.fetcher = Fetcher() 27 | 28 | async def test_substitute_urls(self): 29 | resp_1 = Response(status_code=200, headers={"content-type": "text/plain"}, text="first") 30 | resp_2 = Response(status_code=200, headers={"content-type": "text/plain"}, text="second") 31 | self.fetcher.client = FakeClient( 32 | { 33 | "https://example.org/first": resp_1, 34 | "https://example.org/second": resp_2, 35 | } 36 | ) 37 | text = "Compare https://example.org/first and https://example.org/second" 38 | text = await self.fetcher.substitute_urls(text) 39 | self.assertEqual( 40 | text, 41 | """Compare https://example.org/first and https://example.org/second 42 | 43 | --- 44 | https://example.org/first contents: 45 | 46 | first 47 | --- 48 | 49 | --- 50 | https://example.org/second contents: 51 | 52 | second 53 | ---""", 54 | ) 55 | 56 | async def test_fetch_url(self): 57 | resp = Response(status_code=200, headers={"content-type": "text/plain"}, text="hello") 58 | exc = ConnectionError("timeout") 59 | self.fetcher.client = FakeClient({"https://success.org": resp, "https://failure.org": exc}) 60 | text = await self.fetcher._fetch_url("https://success.org") 61 | self.assertEqual(text, "hello") 62 | text = await self.fetcher._fetch_url("https://failure.org") 63 | self.assertEqual(text, "Failed to fetch (builtins.ConnectionError)") 64 | 65 | async def test_ignore_quoted(self): 66 | src = "What is 'https://example.org/first'?" 67 | text = await self.fetcher.substitute_urls(src) 68 | self.assertEqual(text, src) 69 | 70 | async def test_nothing_to_substitute(self): 71 | src = "How are you?" 72 | text = await self.fetcher.substitute_urls(src) 73 | self.assertEqual(text, src) 74 | 75 | def test_extract_urls(self): 76 | text = "Compare https://example.org/first and https://example.org/second" 77 | urls = self.fetcher._extract_urls(text) 78 | self.assertEqual(urls, ["https://example.org/first", "https://example.org/second"]) 79 | 80 | text = "Extract https://example.org/first." 81 | urls = self.fetcher._extract_urls(text) 82 | self.assertEqual(urls, ["https://example.org/first"]) 83 | 84 | text = 'Extract "https://example.org/first"' 85 | urls = self.fetcher._extract_urls(text) 86 | self.assertEqual(urls, []) 87 | 88 | 89 | class ContentTest(unittest.TestCase): 90 | def test_extract_as_is(self): 91 | resp = Response( 92 | status_code=200, headers={"content-type": "application/sql"}, text="select 42;" 93 | ) 94 | content = Content(resp) 95 | text = content.extract_text() 96 | self.assertEqual(text, "select 42;") 97 | 98 | def test_extract_html(self): 99 | html = "
hello
" 100 | resp = Response(status_code=200, headers={"content-type": "text/html"}, text=html) 101 | content = Content(resp) 102 | text = content.extract_text() 103 | self.assertEqual(text, "hello") 104 | 105 | def test_extract_unknown(self): 106 | resp = Response(status_code=200, headers={"content-type": "application/pdf"}, text="...") 107 | content = Content(resp) 108 | text = content.extract_text() 109 | self.assertEqual(text, "Unknown binary content") 110 | -------------------------------------------------------------------------------- /tests/test_askers.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | from telegram import Chat, Message, User 4 | from telegram.ext import CallbackContext 5 | 6 | from bot import askers 7 | from bot.askers import ImagineAsker, TextAsker 8 | from tests.mocks import FakeApplication, FakeBot, FakeDalle, FakeGPT, mock_text_asker 9 | 10 | 11 | class TextAskerTest(unittest.IsolatedAsyncioTestCase): 12 | def setUp(self) -> None: 13 | self.ai = FakeGPT() 14 | mock_text_asker(self.ai) 15 | 16 | async def test_ask(self): 17 | asker = TextAsker("gpt") 18 | await asker.ask( 19 | prompt="Answer me", question="What is your name?", history=[("Hello", "Hi")] 20 | ) 21 | self.assertEqual(self.ai.prompt, "Answer me") 22 | self.assertEqual(self.ai.question, "What is your name?") 23 | self.assertEqual(self.ai.history, [("Hello", "Hi")]) 24 | 25 | async def test_reply(self): 26 | message, context = _create_message() 27 | asker = TextAsker("gpt") 28 | await asker.reply(message, context, answer="My name is ChatGPT.") 29 | self.assertEqual(context.bot.text, "My name is ChatGPT.") 30 | 31 | 32 | class ImagineAskerTest(unittest.IsolatedAsyncioTestCase): 33 | def setUp(self) -> None: 34 | self.ai = FakeDalle() 35 | ImagineAsker.model = self.ai 36 | 37 | async def test_ask(self): 38 | asker = ImagineAsker() 39 | await asker.ask(prompt="answer me", question="a cat 256x256", history=[]) 40 | self.assertEqual(self.ai.prompt, "a cat") 41 | self.assertEqual(self.ai.size, "256x256") 42 | 43 | async def test_reply(self): 44 | asker = ImagineAsker() 45 | await asker.ask(prompt="answer me", question="a cat 256x256", history=[]) 46 | message, context = _create_message() 47 | await asker.reply(message, context, answer="https://image.url") 48 | self.assertEqual(context.bot.text, "a cat: https://image.url") 49 | 50 | def test_extract_size(self): 51 | asker = ImagineAsker() 52 | size = asker._extract_size(question="a cat 256x256") 53 | self.assertEqual(size, "256x256") 54 | size = asker._extract_size(question="a cat 512x512") 55 | self.assertEqual(size, "512x512") 56 | size = asker._extract_size(question="a cat 1024x1024") 57 | self.assertEqual(size, "1024x1024") 58 | size = asker._extract_size(question="a cat 256") 59 | self.assertEqual(size, "256x256") 60 | size = asker._extract_size(question="a cat 256px") 61 | self.assertEqual(size, "256x256") 62 | size = asker._extract_size(question="a cat 384") 63 | self.assertEqual(size, "1024x1024") 64 | 65 | def test_extract_caption(self): 66 | asker = ImagineAsker() 67 | caption = asker._extract_caption(question="a cat 256x256") 68 | self.assertEqual(caption, "a cat") 69 | caption = asker._extract_caption(question="a cat 512x512") 70 | self.assertEqual(caption, "a cat") 71 | caption = asker._extract_caption(question="a cat 1024x1024") 72 | self.assertEqual(caption, "a cat") 73 | caption = asker._extract_caption(question="a cat 256") 74 | self.assertEqual(caption, "a cat") 75 | caption = asker._extract_caption(question="a cat 256px") 76 | self.assertEqual(caption, "a cat") 77 | caption = asker._extract_caption(question="a cat 384") 78 | self.assertEqual(caption, "a cat 384") 79 | 80 | 81 | class CreateTest(unittest.TestCase): 82 | def test_text_asker(self): 83 | asker = askers.create(model="gpt", question="What is your name?") 84 | self.assertIsInstance(asker, TextAsker) 85 | 86 | def test_imagine_asker(self): 87 | asker = askers.create(model="dalle", question="/imagine a cat") 88 | self.assertIsInstance(asker, ImagineAsker) 89 | 90 | 91 | def _create_message() -> tuple[Message, CallbackContext]: 92 | bot = FakeBot("bot") 93 | chat = Chat(id=1, type=Chat.PRIVATE) 94 | chat.set_bot(bot) 95 | application = FakeApplication(bot) 96 | context = CallbackContext(application, chat_id=1, user_id=1) 97 | user = User(id=1, first_name="Alice", is_bot=False, username="alice") 98 | message = Message( 99 | message_id=11, 100 | date=dt.datetime.now(), 101 | chat=chat, 102 | from_user=user, 103 | ) 104 | message.set_bot(bot) 105 | return message, context 106 | -------------------------------------------------------------------------------- /bot/questions.py: -------------------------------------------------------------------------------- 1 | """Extracts questions from chat messages.""" 2 | 3 | from telegram import Message, MessageEntity 4 | from telegram.ext import CallbackContext 5 | from bot import shortcuts 6 | 7 | 8 | async def extract_private(message: Message, context: CallbackContext) -> str: 9 | """Extracts a question from a message in a private chat.""" 10 | # allow any messages in a private chat 11 | question = await _extract_text(message, context) 12 | if message.reply_to_message: 13 | # it's a follow-up question 14 | question = f"+ {question}" 15 | return question 16 | 17 | 18 | async def extract_group(message: Message, context: CallbackContext) -> tuple[str, Message]: 19 | """Extracts a question from a message in a group chat.""" 20 | text = await _extract_text(message, context) 21 | if not text: 22 | # ignore messages without text 23 | return "", message 24 | 25 | if ( 26 | message.reply_to_message 27 | and message.reply_to_message.from_user.username == context.bot.username 28 | ): 29 | # treat a reply to the bot as a follow-up question 30 | question = f"+ {text}" 31 | return question, message 32 | 33 | entities = message.entities or message.caption_entities 34 | mention = entities[0] if entities and entities[0].type == MessageEntity.MENTION else None 35 | if not mention: 36 | # the message is not a reply to the bot, 37 | # so ignore it unless it's mentioning the bot 38 | return "", message 39 | 40 | mention_text = text[mention.offset : mention.offset + mention.length] 41 | if mention_text.lower() != context.bot.name.lower(): 42 | # the message mentions someone else 43 | return "", message 44 | 45 | # the message is mentioning the bot, 46 | # so remove the mention to get the question 47 | question = text[: mention.offset] + text[mention.offset + mention.length :] 48 | question = question.strip() 49 | 50 | # messages in topics are technically replies to the 'topic created' message 51 | # so we should ignore such replies 52 | if message.reply_to_message and not message.reply_to_message.forum_topic_created: 53 | # the real question is in the original message 54 | reply_text = await _extract_text(message.reply_to_message, context) 55 | question = f"{question}: {reply_text}" if question else reply_text 56 | return question, message.reply_to_message 57 | 58 | return question, message 59 | 60 | 61 | def extract_prev(message: Message, context: CallbackContext) -> str: 62 | """Extracts the previous message by the bot, if any.""" 63 | if ( 64 | message.reply_to_message 65 | and message.reply_to_message.from_user.username == context.bot.username 66 | ): 67 | # treat a reply to the bot as a follow-up question 68 | return message.reply_to_message.text 69 | 70 | # otherwise, ignore previous messages 71 | return "" 72 | 73 | 74 | def prepare(question: str) -> tuple[str, bool]: 75 | """ 76 | Returns the question without the special commands 77 | and indicates whether it is a follow-up. 78 | """ 79 | 80 | if question[0] == "+": 81 | question = question.strip("+ ") 82 | is_follow_up = True 83 | else: 84 | is_follow_up = False 85 | 86 | if question[0] == "!": 87 | # this is a shortcut, so the bot should 88 | # process the question before asking it 89 | shortcut, question = shortcuts.extract(question) 90 | question = shortcuts.apply(shortcut, question) 91 | 92 | elif question[0] == "/": 93 | # this is a command, so the bot should 94 | # strip it from the question before asking 95 | _, _, question = question.partition(" ") 96 | question = question.strip() 97 | 98 | return question, is_follow_up 99 | 100 | 101 | async def _extract_text(message: Message, context: CallbackContext) -> str: 102 | """Extracts text from a text message or a document message.""" 103 | if message.text: 104 | return message.text 105 | if message.document: 106 | return await _extract_document_text(message, context) 107 | return "" 108 | 109 | 110 | async def _extract_document_text(message: Message, context: CallbackContext) -> str: 111 | """Extracts text from a document message.""" 112 | file = await context.bot.get_file(message.document.file_id) 113 | bytes = await file.download_as_bytearray() 114 | text = bytes.decode("utf-8").strip() 115 | caption = f"{message.caption}\n\n" if message.caption else "" 116 | return f"{caption}{message.document.file_name}:\n```\n{text}\n```" 117 | -------------------------------------------------------------------------------- /bot/models.py: -------------------------------------------------------------------------------- 1 | """Bot data models.""" 2 | 3 | from collections import deque 4 | import datetime as dt 5 | from typing import Generic, Mapping, NamedTuple, Optional, TypeVar 6 | from bot.config import config 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | class ChatData: 12 | """Represents data associated with a specific chat.""" 13 | 14 | def __init__(self, data: Mapping): 15 | # data should be a 'chat data' mapping from the chat context 16 | self.data = data 17 | 18 | @property 19 | def model(self) -> str: 20 | return self.data.get("model") or "" 21 | 22 | @model.setter 23 | def model(self, value: str) -> str: 24 | self.data["model"] = value 25 | 26 | @property 27 | def prompt(self) -> str: 28 | return self.data.get("prompt") or "" 29 | 30 | @prompt.setter 31 | def prompt(self, value: str) -> str: 32 | self.data["prompt"] = value 33 | 34 | 35 | class UserData: 36 | """Represents data associated with a specific user.""" 37 | 38 | def __init__(self, data: Mapping): 39 | # data should be a 'user data' mapping from the chat context 40 | self.data = data 41 | self.messages = UserMessages(data, maxlen=config.conversation.depth) 42 | period = parse_period(value=1, period=config.conversation.message_limit.period) 43 | message_count = TimestampedValue(data, name="message_counter", initial=0) 44 | self.message_counter = ExpiringCounter(message_count, period=period) 45 | 46 | 47 | class UserMessage(NamedTuple): 48 | """Represents a question and an answer to it.""" 49 | 50 | question: str 51 | answer: str 52 | 53 | 54 | class UserMessages: 55 | """Represents user message history.""" 56 | 57 | def __init__(self, data: Mapping, maxlen: int) -> None: 58 | messages = data.get("messages") or [] 59 | data["messages"] = deque(messages, maxlen) 60 | self.data = data 61 | self.messages = data["messages"] 62 | 63 | @property 64 | def last(self) -> Optional[UserMessage]: 65 | """The latest chat message (if any).""" 66 | if not self.messages: 67 | return None 68 | return self.messages[-1] 69 | 70 | def add(self, question: str, answer: str): 71 | """Adds a message to the message history.""" 72 | self.messages.append(UserMessage(question, answer)) 73 | 74 | def pop(self) -> Optional[UserMessage]: 75 | """Removes the last message from the message history and returns it.""" 76 | if not self.messages: 77 | return None 78 | return self.messages.pop() 79 | 80 | def clear(self): 81 | """Cleares messages history.""" 82 | self.messages.clear() 83 | 84 | def as_list(self): 85 | return list(self.messages) 86 | 87 | def __str__(self) -> str: 88 | return str(self.messages) 89 | 90 | def __repr__(self) -> str: 91 | return repr(self.messages) 92 | 93 | 94 | class TimestampedValue(Generic[T]): 95 | """A value with a 'last modified' timestamp.""" 96 | 97 | def __init__(self, data: Mapping, name: str, initial: Optional[T] = None) -> None: 98 | if name not in data: 99 | data[name] = {"value": initial, "timestamp": dt.datetime.now()} 100 | self._data = data[name] 101 | 102 | @property 103 | def value(self) -> T: 104 | """Returns the value.""" 105 | return self._data["value"] 106 | 107 | @value.setter 108 | def value(self, value: T) -> None: 109 | """Sets the value.""" 110 | self._data["value"] = value 111 | self._data["timestamp"] = dt.datetime.now() 112 | 113 | @property 114 | def timestamp(self) -> dt.datetime: 115 | """Returns the date and time of the last modification.""" 116 | return self._data["timestamp"] 117 | 118 | 119 | class ExpiringCounter: 120 | """A counter that expires after a given period of time.""" 121 | 122 | def __init__(self, data: TimestampedValue, period: dt.timedelta) -> None: 123 | self._data = data 124 | self.period = period 125 | 126 | @property 127 | def value(self) -> int: 128 | """Counter value.""" 129 | return self._data.value 130 | 131 | def is_expired(self) -> bool: 132 | """Checks if the counter value has expired.""" 133 | return dt.datetime.now() > self._data.timestamp + self.period 134 | 135 | def expires_after(self) -> dt.timedelta: 136 | """ 137 | Returns the timedelta after which the counter will expire 138 | (with respect to the current time). 139 | If the counter has already expired, returns zero timedelta. 140 | """ 141 | if self.is_expired(): 142 | return dt.timedelta(0) 143 | return self._data.timestamp + self.period - dt.datetime.now() 144 | 145 | def increment(self) -> int: 146 | """Increments and returns the counter value.""" 147 | if self.is_expired(): 148 | self._data.value = 0 149 | self._data.value += 1 150 | return self._data.value 151 | 152 | 153 | def parse_period(value: int, period: str) -> dt.timedelta: 154 | """Creates a timedelta from a time period description.""" 155 | if value < 0: 156 | raise ValueError(f"Invalid value: {value}") 157 | if period not in ("second", "minute", "hour", "day", "week"): 158 | raise ValueError(f"Invalid period: {period}") 159 | kwargs = {} 160 | kwargs[f"{period}s"] = value 161 | return dt.timedelta(**kwargs) 162 | 163 | 164 | def format_timedelta(delta: dt.timedelta) -> str: 165 | """Returns a string representation of a timedelta.""" 166 | if delta == dt.timedelta(0): 167 | return "now" 168 | seconds = int(delta.total_seconds()) 169 | if seconds < 60: 170 | return f"{seconds} seconds" 171 | if seconds < 3600: 172 | return f"{seconds // 60} minutes" 173 | if seconds < 2 * 3600: 174 | hours = round(seconds / 3600, 1) 175 | return f"{hours} hours" 176 | return f"{seconds // 3600} hours" 177 | -------------------------------------------------------------------------------- /bot/ai/chat.py: -------------------------------------------------------------------------------- 1 | """OpenAI-compatible language model.""" 2 | 3 | import logging 4 | import httpx 5 | from bot.config import config 6 | 7 | client = httpx.AsyncClient(timeout=60.0) 8 | logger = logging.getLogger(__name__) 9 | 10 | # Known models and their context windows 11 | MODELS = { 12 | # Gemini 13 | "gemini-2.5-pro": 1_048_576, 14 | "gemini-2.5-flash": 1_048_576, 15 | "gemini-2.5-flash-lite": 1_048_576, 16 | "gemini-2.0-flash": 1_048_576, 17 | "gemini-1.5-flash": 1_048_576, 18 | "gemini-1.5-flash-8b": 1_048_576, 19 | "gemini-1.5-pro": 2_097_152, 20 | # OpenAI 21 | "o1": 200000, 22 | "o1-mini": 128000, 23 | "o1-pro": 200000, 24 | "o3": 200000, 25 | "o3-mini": 200000, 26 | "o4": 200000, 27 | "o4-mini": 200000, 28 | "gpt-5.1": 128000, 29 | "gpt-5": 128000, 30 | "gpt-5-mini": 128000, 31 | "gpt-5-nano": 128000, 32 | "gpt-4.1": 1_047_576, 33 | "gpt-4.1-mini": 1_047_576, 34 | "gpt-4.1-nano": 1_047_576, 35 | "gpt-4o": 128000, 36 | "gpt-4o-mini": 128000, 37 | "gpt-4-turbo": 128000, 38 | "gpt-4-turbo-preview": 128000, 39 | "gpt-4-vision-preview": 128000, 40 | "gpt-4": 8192, 41 | "gpt-4-32k": 32768, 42 | "gpt-3.5-turbo": 16385, 43 | } 44 | 45 | # Prompt role name overrides. 46 | ROLE_OVERRIDES = { 47 | "o1": "user", 48 | "o1-mini": "user", 49 | "o1-pro": "user", 50 | "o3": "user", 51 | "o3-mini": "user", 52 | "o4": "user", 53 | "o4-mini": "user", 54 | } 55 | 56 | # Model parameter overrides. 57 | PARAM_OVERRIDES = { 58 | "gpt-5.1": lambda params: {}, 59 | "gpt-5": lambda params: {}, 60 | "gpt-5-mini": lambda params: {}, 61 | "gpt-5-nano": lambda params: {}, 62 | "o1": lambda params: {}, 63 | "o1-mini": lambda params: {}, 64 | "o1-pro": lambda params: {}, 65 | "o3": lambda params: {}, 66 | "o3-mini": lambda params: {}, 67 | "o4": lambda params: {}, 68 | "o4-mini": lambda params: {}, 69 | } 70 | 71 | 72 | class Model: 73 | """AI API wrapper.""" 74 | 75 | def __init__(self, name: str) -> None: 76 | """Creates a wrapper for a given OpenAI large language model.""" 77 | self.name = name 78 | 79 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str: 80 | """Asks the language model a question and returns an answer.""" 81 | model = self.name 82 | prompt_role = ROLE_OVERRIDES.get(model) or "system" 83 | params_func = PARAM_OVERRIDES.get(model) or (lambda params: params) 84 | 85 | n_input = _calc_n_input(model, n_output=config.openai.params["max_tokens"]) 86 | messages = self._generate_messages(prompt_role, prompt, question, history) 87 | messages = shorten(messages, length=n_input) 88 | 89 | params = params_func(config.openai.params) 90 | logger.debug( 91 | "> chat request: model=%s, params=%s, messages=%s", 92 | model, 93 | params, 94 | messages, 95 | ) 96 | response = await client.post( 97 | f"{config.openai.url}/chat/completions", 98 | headers={"Authorization": f"Bearer {config.openai.api_key}"}, 99 | json={ 100 | "model": model, 101 | "messages": messages, 102 | **params, 103 | }, 104 | ) 105 | resp = response.json() 106 | if "usage" not in resp: 107 | raise Exception(resp) 108 | logger.debug( 109 | "< chat response: prompt_tokens=%s, completion_tokens=%s, total_tokens=%s", 110 | resp["usage"]["prompt_tokens"], 111 | resp["usage"]["completion_tokens"], 112 | resp["usage"]["total_tokens"], 113 | ) 114 | answer = self._prepare_answer(resp) 115 | return answer 116 | 117 | def _generate_messages( 118 | self, 119 | prompt_role: str, 120 | prompt: str, 121 | question: str, 122 | history: list[tuple[str, str]], 123 | ) -> list[dict]: 124 | """Builds message history to provide context for the language model.""" 125 | messages = [{"role": prompt_role, "content": prompt or config.openai.prompt}] 126 | for prev_question, prev_answer in history: 127 | messages.append({"role": "user", "content": prev_question}) 128 | messages.append({"role": "assistant", "content": prev_answer}) 129 | messages.append({"role": "user", "content": question}) 130 | return messages 131 | 132 | def _prepare_answer(self, resp) -> str: 133 | """Post-processes an answer from the language model.""" 134 | if len(resp["choices"]) == 0: 135 | raise ValueError("received an empty answer") 136 | 137 | answer = resp["choices"][0]["message"]["content"] 138 | answer = answer.strip() 139 | return answer 140 | 141 | 142 | def shorten(messages: list[dict], length: int) -> list[dict]: 143 | """ 144 | Truncates messages so that the total number or tokens 145 | does not exceed the specified length. 146 | """ 147 | lengths = [_calc_tokens(m["content"]) for m in messages] 148 | total_len = sum(lengths) 149 | if total_len <= length: 150 | return messages 151 | 152 | # exclude older messages to fit into the desired length 153 | # can't exclude the prompt though 154 | prompt_msg, messages = messages[0], messages[1:] 155 | prompt_len, lengths = lengths[0], lengths[1:] 156 | while len(messages) > 1 and total_len > length: 157 | messages = messages[1:] 158 | first_len, lengths = lengths[0], lengths[1:] 159 | total_len -= first_len 160 | messages = [prompt_msg] + messages 161 | if total_len <= length: 162 | return messages 163 | 164 | # there is only one message left, and it's still longer than allowed 165 | # so we have to shorten it 166 | maxlen = length - prompt_len 167 | tokens = messages[1]["content"].split()[:maxlen] 168 | messages[1]["content"] = " ".join(tokens) 169 | return messages 170 | 171 | 172 | def _calc_tokens(s: str) -> int: 173 | """Calculates the number of tokens in a string.""" 174 | return int(len(s.split()) * 1.2) 175 | 176 | 177 | def _calc_n_input(name: str, n_output: int) -> int: 178 | """ 179 | Calculates the maximum number of input tokens 180 | according to the model and the maximum number of output tokens. 181 | """ 182 | # OpenAI counts length in tokens, not charactes. 183 | # We need to leave some tokens reserved for the output. 184 | n_total = MODELS.get(name) or config.openai.window 185 | logger.debug("model=%s, n_total=%s, n_output=%s", name, n_total, n_output) 186 | return n_total - n_output 187 | -------------------------------------------------------------------------------- /tests/test_questions.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | from telegram import Chat, Document, Message, MessageEntity, User 4 | from telegram.constants import ChatType 5 | from telegram.ext import CallbackContext 6 | 7 | from bot import questions 8 | from bot.config import config 9 | from tests.mocks import FakeApplication, FakeBot 10 | 11 | 12 | class ExtractPrivateTest(unittest.IsolatedAsyncioTestCase): 13 | def setUp(self): 14 | self.chat = Chat(id=1, type=ChatType.PRIVATE) 15 | bot = FakeBot("bot") 16 | self.context = CallbackContext(FakeApplication(bot)) 17 | 18 | async def test_text(self): 19 | message = Message( 20 | message_id=123, 21 | date=dt.datetime.now(), 22 | chat=self.chat, 23 | text="What is the capital of France?", 24 | ) 25 | result = await questions.extract_private(message, self.context) 26 | self.assertEqual(result, "What is the capital of France?") 27 | 28 | async def test_document(self): 29 | message = Message( 30 | message_id=123, 31 | date=dt.datetime.now(), 32 | chat=self.chat, 33 | caption="What is this?", 34 | document=Document( 35 | file_id="f1234", file_unique_id="f1234", file_name="file.txt", file_size=1234 36 | ), 37 | ) 38 | result = await questions.extract_private(message, self.context) 39 | self.assertEqual(result, "What is this?\n\nfile.txt:\n```\nfile content\n```") 40 | 41 | async def test_reply(self): 42 | reply_message = Message( 43 | message_id=124, date=dt.datetime.now(), chat=self.chat, text="It is Paris." 44 | ) 45 | message = Message( 46 | message_id=123, 47 | date=dt.datetime.now(), 48 | chat=self.chat, 49 | text="Isn't it London?", 50 | reply_to_message=reply_message, 51 | ) 52 | result = await questions.extract_private(message, self.context) 53 | self.assertEqual(result, "+ Isn't it London?") 54 | 55 | 56 | class ExtractGroupTest(unittest.IsolatedAsyncioTestCase): 57 | def setUp(self): 58 | self.chat = Chat(id=1, type=ChatType.GROUP) 59 | self.bot = FakeBot("bot") 60 | self.context = CallbackContext(FakeApplication(self.bot)) 61 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice") 62 | 63 | async def test_message(self): 64 | message = Message( 65 | message_id=11, 66 | date=dt.datetime.now(), 67 | chat=self.chat, 68 | text="How are you?", 69 | reply_to_message=None, 70 | ) 71 | result = await questions.extract_group(message, self.context) 72 | self.assertEqual(result, ("", message)) 73 | 74 | async def test_reply_to_bot(self): 75 | bot_user = User(id=2, first_name="Bot", is_bot=True, username=self.bot.username) 76 | bot_message = Message( 77 | message_id=11, 78 | date=dt.datetime.now(), 79 | chat=self.chat, 80 | text="It's cold today.", 81 | from_user=bot_user, 82 | ) 83 | message = Message( 84 | message_id=12, 85 | date=dt.datetime.now(), 86 | chat=self.chat, 87 | text="Is it?", 88 | reply_to_message=bot_message, 89 | ) 90 | result = await questions.extract_group(message, self.context) 91 | self.assertEqual(result, ("+ Is it?", message)) 92 | 93 | async def test_reply_to_other_user(self): 94 | other_message = Message( 95 | message_id=11, 96 | date=dt.datetime.now(), 97 | chat=self.chat, 98 | text="It's cold today.", 99 | from_user=self.user, 100 | ) 101 | message = Message( 102 | message_id=12, 103 | date=dt.datetime.now(), 104 | chat=self.chat, 105 | text="Is it?", 106 | reply_to_message=other_message, 107 | ) 108 | result = await questions.extract_group(message, self.context) 109 | self.assertEqual(result, ("", message)) 110 | 111 | async def test_mention(self): 112 | message = Message( 113 | message_id=11, 114 | date=dt.datetime.now(), 115 | chat=self.chat, 116 | text="@bot How are you?", 117 | entities=(MessageEntity(type=MessageEntity.MENTION, offset=0, length=4),), 118 | reply_to_message=None, 119 | ) 120 | result = await questions.extract_group(message, self.context) 121 | self.assertEqual(result, ("How are you?", message)) 122 | 123 | async def test_mention_case_insensitive(self): 124 | message = Message( 125 | message_id=11, 126 | date=dt.datetime.now(), 127 | chat=self.chat, 128 | text="@Bot How are you?", 129 | entities=(MessageEntity(type=MessageEntity.MENTION, offset=0, length=4),), 130 | reply_to_message=None, 131 | ) 132 | result = await questions.extract_group(message, self.context) 133 | self.assertEqual(result, ("How are you?", message)) 134 | 135 | async def test_mention_in_the_middle(self): 136 | message = Message( 137 | message_id=11, 138 | date=dt.datetime.now(), 139 | chat=self.chat, 140 | text="How are you @bot?", 141 | entities=(MessageEntity(type=MessageEntity.MENTION, offset=12, length=4),), 142 | reply_to_message=None, 143 | ) 144 | result = await questions.extract_group(message, self.context) 145 | self.assertEqual(result, ("How are you ?", message)) 146 | 147 | async def test_mention_other_user(self): 148 | message = Message( 149 | message_id=11, 150 | date=dt.datetime.now(), 151 | chat=self.chat, 152 | text="@bob How are you?", 153 | entities=(MessageEntity(type=MessageEntity.MENTION, offset=0, length=4),), 154 | reply_to_message=None, 155 | ) 156 | result = await questions.extract_group(message, self.context) 157 | self.assertEqual(result, ("", message)) 158 | 159 | async def test_mention_in_reply(self): 160 | original = Message( 161 | message_id=11, 162 | date=dt.datetime.now(), 163 | chat=self.chat, 164 | text="What time is it now?", 165 | from_user=self.user, 166 | ) 167 | message = Message( 168 | message_id=11, 169 | date=dt.datetime.now(), 170 | chat=self.chat, 171 | text="@bot help", 172 | entities=(MessageEntity(type=MessageEntity.MENTION, offset=0, length=4),), 173 | reply_to_message=original, 174 | ) 175 | result = await questions.extract_group(message, self.context) 176 | self.assertEqual(result, ("help: What time is it now?", original)) 177 | 178 | async def test_mention_document(self): 179 | message = Message( 180 | message_id=11, 181 | date=dt.datetime.now(), 182 | chat=self.chat, 183 | caption_entities=(MessageEntity(type=MessageEntity.MENTION, offset=0, length=4),), 184 | caption="@bot What is this?", 185 | document=Document( 186 | file_id="f1234", file_unique_id="f1234", file_name="file.txt", file_size=1234 187 | ), 188 | ) 189 | result, _ = await questions.extract_group(message, self.context) 190 | self.assertEqual(result, "What is this?\n\nfile.txt:\n```\nfile content\n```") 191 | 192 | 193 | class TestPrepare(unittest.TestCase): 194 | def test_ordinary(self): 195 | question, is_follow_up = questions.prepare("How are you?") 196 | self.assertEqual(question, "How are you?") 197 | self.assertFalse(is_follow_up) 198 | 199 | def test_follow_up(self): 200 | question, is_follow_up = questions.prepare("+ How are you?") 201 | self.assertEqual(question, "How are you?") 202 | self.assertTrue(is_follow_up) 203 | 204 | def test_shortcut(self): 205 | config.shortcuts["translate"] = "Translate into English." 206 | question, is_follow_up = questions.prepare("!translate Ciao") 207 | self.assertEqual(question, "Translate into English.\n\nCiao") 208 | self.assertFalse(is_follow_up) 209 | -------------------------------------------------------------------------------- /bot/bot.py: -------------------------------------------------------------------------------- 1 | """Telegram chat bot built using the language model from OpenAI.""" 2 | 3 | import logging 4 | import sys 5 | import textwrap 6 | import time 7 | 8 | from telegram import Chat, Message, Update 9 | from telegram.ext import ( 10 | Application, 11 | ApplicationBuilder, 12 | CallbackContext, 13 | CommandHandler, 14 | MessageHandler, 15 | PicklePersistence, 16 | ) 17 | from bot import askers 18 | from bot import commands 19 | from bot import questions 20 | from bot import models 21 | from bot.config import config 22 | from bot.fetcher import Fetcher 23 | from bot.filters import Filters 24 | from bot.models import ChatData, UserData 25 | 26 | 27 | logging.basicConfig( 28 | stream=sys.stdout, 29 | level=logging.INFO, 30 | format="%(asctime)s %(levelname)s %(name)s %(message)s", 31 | ) 32 | logging.getLogger("httpx").setLevel(logging.WARNING) 33 | logging.getLogger("openai").setLevel(logging.WARNING) 34 | logging.getLogger("bot").setLevel(logging.INFO) 35 | logging.getLogger("bot.ai.chat").setLevel(logging.INFO) 36 | logging.getLogger("bot.commands").setLevel(logging.INFO) 37 | logging.getLogger("bot.questions").setLevel(logging.INFO) 38 | logging.getLogger("__main__").setLevel(logging.INFO) 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | # retrieves remote content 43 | fetcher = Fetcher() 44 | 45 | # telegram message filters 46 | filters = Filters() 47 | 48 | 49 | def main(): 50 | persistence = PicklePersistence(filepath=config.persistence_path) 51 | application = ( 52 | ApplicationBuilder() 53 | .token(config.telegram.token) 54 | .post_init(post_init) 55 | .post_shutdown(post_shutdown) 56 | .persistence(persistence) 57 | .concurrent_updates(True) 58 | .get_updates_http_version("1.1") 59 | .http_version("1.1") 60 | .build() 61 | ) 62 | add_handlers(application) 63 | application.run_polling() 64 | 65 | 66 | def add_handlers(application: Application): 67 | """Adds command handlers.""" 68 | 69 | # info commands 70 | application.add_handler(CommandHandler("start", commands.Start())) 71 | application.add_handler(CommandHandler("help", commands.Help(), filters=filters.users)) 72 | application.add_handler(CommandHandler("version", commands.Version(), filters=filters.users)) 73 | 74 | # admin commands 75 | application.add_handler( 76 | CommandHandler("config", commands.Config(filters), filters=filters.admins_private) 77 | ) 78 | 79 | # message-related commands 80 | application.add_handler( 81 | CommandHandler("imagine", commands.Imagine(reply_to), filters=filters.users_or_chats) 82 | ) 83 | application.add_handler(CommandHandler("model", commands.Model(), filters=filters.users)) 84 | application.add_handler(CommandHandler("prompt", commands.Prompt(), filters=filters.users)) 85 | application.add_handler( 86 | CommandHandler("retry", commands.Retry(reply_to), filters=filters.users_or_chats) 87 | ) 88 | 89 | # non-command handler: the default action is to reply to a message 90 | application.add_handler(MessageHandler(filters.messages, commands.Message(reply_to))) 91 | 92 | # generic error handler 93 | application.add_error_handler(commands.Error()) 94 | 95 | 96 | async def post_init(application: Application) -> None: 97 | """Defines bot settings.""" 98 | bot = application.bot 99 | logging.info(f"config: file={config.filename}, version={config.version}") 100 | logging.info(f"allowed users: {config.telegram.usernames}") 101 | logging.info(f"allowed chats: {config.telegram.chat_ids}") 102 | logging.info(f"admins: {config.telegram.admins}") 103 | logging.info(f"api url: {config.openai.url}") 104 | logging.info(f"model name: {config.openai.model}") 105 | logging.info(f"bot: username={bot.username}, id={bot.id}") 106 | await bot.set_my_commands(commands.BOT_COMMANDS) 107 | 108 | 109 | async def post_shutdown(application: Application) -> None: 110 | """Frees acquired resources.""" 111 | await fetcher.close() 112 | 113 | 114 | def with_message_limit(func): 115 | """Refuses to reply if the user has exceeded the message limit.""" 116 | 117 | async def wrapper( 118 | update: Update, message: Message, context: CallbackContext, question: str 119 | ) -> None: 120 | username = update.effective_user.username 121 | user = UserData(context.user_data) 122 | 123 | # check if the message counter exceeds the message limit 124 | if ( 125 | not filters.is_known_user(username) 126 | and user.message_counter.value >= config.conversation.message_limit.count > 0 127 | and not user.message_counter.is_expired() 128 | ): 129 | # this is a group user and they have exceeded the message limit 130 | wait_for = models.format_timedelta(user.message_counter.expires_after()) 131 | await message.reply_text(f"Please wait {wait_for} before asking a new question.") 132 | return 133 | 134 | # this is a known user or they have not exceeded the message limit, 135 | # so proceed to the actual message handler 136 | await func(update=update, message=message, context=context, question=question) 137 | 138 | # increment the message counter 139 | message_count = user.message_counter.increment() 140 | logger.debug(f"user={username}, n_messages={message_count}") 141 | 142 | return wrapper 143 | 144 | 145 | @with_message_limit 146 | async def reply_to( 147 | update: Update, message: Message, context: CallbackContext, question: str 148 | ) -> None: 149 | """Replies to a specific question.""" 150 | await message.chat.send_action(action="typing", message_thread_id=message.message_thread_id) 151 | 152 | try: 153 | chat = ChatData(context.chat_data) 154 | model = chat.model or config.openai.model 155 | asker = askers.create(model=model, question=question) 156 | if message.chat.type == Chat.PRIVATE and message.forward_date: 157 | # this is a forwarded message, don't answer yet 158 | answer = "This is a forwarded message. What should I do with it?" 159 | else: 160 | answer = await _ask_question(message, context, question, asker) 161 | 162 | user = UserData(context.user_data) 163 | user.messages.add(question, answer) 164 | logger.debug(user.messages) 165 | await asker.reply(message, context, answer) 166 | 167 | except Exception as exc: 168 | class_name = f"{exc.__class__.__module__}.{exc.__class__.__qualname__}" 169 | error_text = f"{class_name}: {exc}" 170 | logger.error("Failed to answer: %s", error_text) 171 | text = textwrap.shorten(f"⚠️ {error_text}", width=255, placeholder="...") 172 | await message.reply_text(text) 173 | 174 | 175 | async def _ask_question( 176 | message: Message, context: CallbackContext, question: str, asker: askers.Asker 177 | ) -> str: 178 | """Answers a question using the OpenAI model.""" 179 | user_id = message.from_user.username or message.from_user.id 180 | logger.info(f"-> question id={message.id}, user={user_id}, n_chars={len(question)}") 181 | 182 | question, is_follow_up = questions.prepare(question) 183 | question = await fetcher.substitute_urls(question) 184 | logger.debug(f"Prepared question: {question}") 185 | 186 | user = UserData(context.user_data) 187 | if message.chat.type == Chat.PRIVATE: 188 | # in private chats the bot remembers previous messages 189 | if is_follow_up: 190 | # this is a follow-up question, 191 | # so the bot should retain the previous history 192 | history = user.messages.as_list() 193 | else: 194 | # user is asking a question 'from scratch', 195 | # so the bot should forget the previous history 196 | user.messages.clear() 197 | history = [] 198 | else: 199 | # in group chats the bot only answers direct questions 200 | # or follow-up questions to the bot messages 201 | prev_message = questions.extract_prev(message, context) 202 | history = [("", prev_message)] if prev_message else [] 203 | 204 | chat = ChatData(context.chat_data) 205 | start = time.perf_counter_ns() 206 | answer = await asker.ask(prompt=chat.prompt, question=question, history=history) 207 | elapsed = int((time.perf_counter_ns() - start) / 1e6) 208 | 209 | logger.info( 210 | f"<- answer id={message.id}, user={user_id}, " 211 | f"n_chars={len(answer)}, len_history={len(history)}, took={elapsed}ms" 212 | ) 213 | return answer 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import datetime as dt 3 | import unittest 4 | 5 | from bot import models 6 | from bot.config import config 7 | from bot.models import ExpiringCounter, TimestampedValue, UserData, UserMessage, UserMessages 8 | 9 | 10 | class UserDataTest(unittest.TestCase): 11 | def test_init(self): 12 | data = {} 13 | user = UserData(data) 14 | self.assertEqual(user.messages.as_list(), []) 15 | self.assertEqual(data["messages"], deque([], maxlen=config.conversation.depth)) 16 | self.assertEqual(user.message_counter.value, 0) 17 | self.assertEqual(data["message_counter"]["value"], 0) 18 | 19 | def test_messages(self): 20 | data = {} 21 | user = UserData(data) 22 | user.messages.add("question", "answer") 23 | user.messages.add("question", "answer") 24 | self.assertEqual(len(user.messages.as_list()), 2) 25 | self.assertEqual(len(data["messages"]), 2) 26 | 27 | def test_message_counter(self): 28 | data = {} 29 | user = UserData(data) 30 | user.message_counter.increment() 31 | user.message_counter.increment() 32 | self.assertEqual(user.message_counter.value, 2) 33 | self.assertEqual(data["message_counter"]["value"], 2) 34 | 35 | 36 | class UserMessagesTest(unittest.TestCase): 37 | def test_init(self): 38 | um = UserMessages({}, maxlen=config.conversation.depth) 39 | self.assertIsInstance(um.messages, deque) 40 | self.assertEqual(um.messages.maxlen, config.conversation.depth) 41 | 42 | data = {"messages": deque([UserMessage("Hello", "Hi")])} 43 | um = UserMessages(data, maxlen=3) 44 | self.assertEqual(um.messages, data["messages"]) 45 | 46 | def test_last(self): 47 | um = UserMessages({}, maxlen=3) 48 | self.assertIsNone(um.last) 49 | 50 | data = { 51 | "messages": deque( 52 | [UserMessage("Hello", "Hi"), UserMessage("Is it cold today?", "Yep!")] 53 | ) 54 | } 55 | um = UserMessages(data, maxlen=3) 56 | self.assertEqual(um.last, ("Is it cold today?", "Yep!")) 57 | 58 | def test_add(self): 59 | data = {"messages": deque([UserMessage("Hello", "Hi")])} 60 | um = UserMessages(data, maxlen=3) 61 | um.add("Is it cold today?", "Yep!") 62 | self.assertEqual( 63 | um.messages, 64 | deque([UserMessage("Hello", "Hi"), UserMessage("Is it cold today?", "Yep!")]), 65 | ) 66 | 67 | def test_pop(self): 68 | data = {"messages": deque([UserMessage("Hello", "Hi")])} 69 | um = UserMessages(data, maxlen=3) 70 | message = um.pop() 71 | self.assertEqual(message.question, "Hello") 72 | self.assertEqual(message.answer, "Hi") 73 | 74 | message = um.pop() 75 | self.assertIsNone(message) 76 | 77 | def test_clear(self): 78 | data = {"messages": deque([UserMessage("Hello", "Hi")])} 79 | um = UserMessages(data, maxlen=3) 80 | um.clear() 81 | self.assertEqual(len(um.messages), 0) 82 | 83 | def test_as_list(self): 84 | data = { 85 | "messages": deque( 86 | [UserMessage("Hello", "Hi"), UserMessage("Is it cold today?", "Yep!")] 87 | ) 88 | } 89 | um = UserMessages(data, maxlen=3) 90 | self.assertEqual(um.as_list(), [("Hello", "Hi"), ("Is it cold today?", "Yep!")]) 91 | 92 | 93 | class TimestampedValueTest(unittest.TestCase): 94 | def test_init(self): 95 | data = {} 96 | now = dt.datetime.now() 97 | counter = TimestampedValue(data, name="counter") 98 | self.assertEqual(data["counter"]["value"], None) 99 | self.assertGreaterEqual(data["counter"]["timestamp"], now) 100 | self.assertIsNone(counter.value) 101 | self.assertGreaterEqual(counter.timestamp, now) 102 | 103 | def test_init_initial(self): 104 | data = {} 105 | now = dt.datetime.now() 106 | counter = TimestampedValue(data, name="counter", initial=42) 107 | self.assertEqual(data["counter"]["value"], 42) 108 | self.assertGreaterEqual(data["counter"]["timestamp"], now) 109 | self.assertEqual(counter.value, 42) 110 | self.assertGreaterEqual(counter.timestamp, now) 111 | 112 | def test_value(self): 113 | data = {} 114 | counter = TimestampedValue(data, name="counter") 115 | 116 | counter.value = 11 117 | self.assertEqual(data["counter"]["value"], 11) 118 | self.assertEqual(counter.value, 11) 119 | 120 | counter.value = 21 121 | self.assertEqual(data["counter"]["value"], 21) 122 | self.assertEqual(counter.value, 21) 123 | 124 | def test_timestamp(self): 125 | data = {} 126 | counter = TimestampedValue(data, name="counter") 127 | 128 | now = dt.datetime.now() 129 | counter.value = 11 130 | self.assertGreaterEqual(data["counter"]["timestamp"], now) 131 | self.assertGreaterEqual(counter.timestamp, now) 132 | 133 | now = dt.datetime.now() 134 | counter.value = 21 135 | self.assertGreaterEqual(data["counter"]["timestamp"], now) 136 | self.assertGreaterEqual(counter.timestamp, now) 137 | 138 | 139 | class ExpiringCounterTest(unittest.TestCase): 140 | def setUp(self) -> None: 141 | self.data = TimestampedValue(data={}, name="counter", initial=0) 142 | self.counter = ExpiringCounter(self.data, period=dt.timedelta(hours=1)) 143 | 144 | def test_increment(self): 145 | self.assertEqual(self.data.value, 0) 146 | self.counter.increment() 147 | self.assertEqual(self.data.value, 1) 148 | self.counter.increment() 149 | self.assertEqual(self.data.value, 2) 150 | 151 | def test_is_expired(self): 152 | self.assertFalse(self.counter.is_expired()) 153 | self.data._data["timestamp"] = dt.datetime.now() - dt.timedelta(hours=2) 154 | self.assertTrue(self.counter.is_expired()) 155 | 156 | def test_expires_after(self): 157 | self.assertGreater(self.counter.expires_after(), dt.timedelta(minutes=59)) 158 | self.assertLessEqual(self.counter.expires_after(), dt.timedelta(minutes=60)) 159 | 160 | self.data._data["timestamp"] = self.data._data["timestamp"] - dt.timedelta(minutes=30) 161 | self.assertGreater(self.counter.expires_after(), dt.timedelta(minutes=29)) 162 | self.assertLessEqual(self.counter.expires_after(), dt.timedelta(minutes=30)) 163 | 164 | self.data._data["timestamp"] = dt.datetime.now() - dt.timedelta(hours=2) 165 | self.assertEqual(self.counter.expires_after(), dt.timedelta(0)) 166 | 167 | def test_increment_expired(self): 168 | self.counter.increment() 169 | self.counter.increment() 170 | self.counter.increment() 171 | self.assertEqual(self.data.value, 3) 172 | 173 | self.data._data["timestamp"] = dt.datetime.now() - dt.timedelta(hours=2) 174 | self.counter.increment() 175 | self.assertEqual(self.data.value, 1) 176 | 177 | 178 | class ParsePeriodTest(unittest.TestCase): 179 | def test_parse(self): 180 | delta = models.parse_period(5, "minute") 181 | self.assertEqual(delta, dt.timedelta(minutes=5)) 182 | delta = models.parse_period(3, "hour") 183 | self.assertEqual(delta, dt.timedelta(hours=3)) 184 | delta = models.parse_period(1, "day") 185 | self.assertEqual(delta, dt.timedelta(days=1)) 186 | 187 | def test_zero(self): 188 | delta = models.parse_period(0, "minute") 189 | self.assertEqual(delta, dt.timedelta(0)) 190 | delta = models.parse_period(0, "hour") 191 | self.assertEqual(delta, dt.timedelta(0)) 192 | delta = models.parse_period(0, "day") 193 | self.assertEqual(delta, dt.timedelta(0)) 194 | 195 | def test_invalid_value(self): 196 | with self.assertRaises(ValueError): 197 | models.parse_period(-5, "minute") 198 | 199 | def test_invalid_period(self): 200 | with self.assertRaises(ValueError): 201 | models.parse_period(1, "month") 202 | 203 | 204 | class FormatTimedeltaTest(unittest.TestCase): 205 | def test_zero(self): 206 | val = models.format_timedelta(dt.timedelta(0)) 207 | self.assertEqual(val, "now") 208 | 209 | def test_seconds(self): 210 | val = models.format_timedelta(dt.timedelta(seconds=30)) 211 | self.assertEqual(val, "30 seconds") 212 | 213 | def test_minutes(self): 214 | val = models.format_timedelta(dt.timedelta(minutes=42)) 215 | self.assertEqual(val, "42 minutes") 216 | 217 | def test_one_hour(self): 218 | val = models.format_timedelta(dt.timedelta(minutes=90)) 219 | self.assertEqual(val, "1.5 hours") 220 | 221 | def test_hours(self): 222 | val = models.format_timedelta(dt.timedelta(hours=5)) 223 | self.assertEqual(val, "5 hours") 224 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Humble AI Telegram Bot 2 | 3 | This is a Telegram chat bot (AI assistant) that uses large language models from OpenAI-compatible providers (like [OpenAI](https://platform.openai.com/), [OpenRouter](https://openrouter.ai/docs/), [Nebius](https://docs.nebius.com/) or [Gemini](https://ai.google.dev/gemini-api/docs/openai)). 4 | 5 | Notable features: 6 | 7 | - Both one-on-one and group chats. 8 | - Direct questions, mentions, follow-ups. 9 | - Access external links (articles, code, data). 10 | - Shortcuts (custom AI commands). 11 | - Custom per-chat prompts and models. 12 | - On-the-fly configuration. 13 | 14 | Read on for a detailed feature description or jump to [setup](#setup). 15 | 16 | ## Personal chats 17 | 18 | The bot acts as your personal assistant: 19 | 20 | > 🧑 Explain Apache Kafka to a three year old 21 | > 22 | > 🤖 Apache Kafka is like a big mailbox that helps different toys (computers) send messages to each other. When one toy has something to say, it puts a message in the mailbox, and other toys can take that message out and read it. This way, they can talk and share things quickly! 23 | 24 | To allow other users to use the bot, list them in the `telegram.usernames` config property. 25 | 26 | The bot has a terrible memory, so don't expect it to remember any chat context by default. You can, however, reply with a follow-up question (`Ctrl/Cmd + ↑`). Alternatively, use a plus sign to follow up: 27 | 28 | > 🧑 Yesterday, all my troubles seemed so far away 29 | > 30 | > 🤖 It sounds like you're quoting "Yesterday" by The Beatles. It's a classic song about longing and nostalgia. 31 | > 32 | > 🧑 + who wrote it? 33 | > 34 | > 🤖 "Yesterday" was written by Paul McCartney of The Beatles. 35 | 36 | Available commands: 37 | 38 | - `/retry` - retry answering the last question 39 | - `/help` - show help 40 | - `/version` - show bot info 41 | 42 | To rephrase or add to the last question, simply edit it. The bot will then answer the updated question. 43 | 44 | ## Groups 45 | 46 | To get an answer from the bot in a group, mention it in a reply to a question: 47 | 48 | > 🧑 I forgot who played Ramsy in the Game of Thrones 😕
49 | > 🧑 ↳ @pokitokibot help 50 | > 51 | > 🤖 Ramsay Bolton in Game of Thrones was played by Iwan Rheon. 52 | 53 | Or ask a question directly: 54 | 55 | > 🧑 @pokitokibot Who played Ramsy in the Game of Thrones? 56 | > 57 | > 🤖 Iwan Rheon played Ramsay Bolton in Game of Thrones. 58 | 59 | To make the bot reply to group members, list the group id in the `telegram.chat_ids` config property. Otherwise, the bot will ignore questions from group members unless they are listed in the `telegram.usernames` config property. 60 | 61 | If you don't know the group id, run the `/version` bot command in a group to find it: 62 | 63 | ``` 64 | Chat information: 65 | - id: -1001405001234 66 | - title: My Favorite Group 67 | - type: supergroup 68 | ... 69 | ``` 70 | 71 | ## External links 72 | 73 | If you ask "vanilla" ChatGPT about external resources, it will either hallucinate or admit that it doesn't have access to remote content: 74 | 75 | > 🧑 What is the content of https://sqlime.org/employees.sql? Make no assumptions. 76 | > 77 | > 🤖 As an AI language model, I cannot access external URLs on the internet. 78 | 79 | The bot solves the problem by fetching the remote content and feeding it to the model: 80 | 81 | > 🧑 Name 1 major feature of the Go 1.23 release https://tip.golang.org/doc/go1.23 Be brief! 82 | > 83 | > 🤖 A major feature of Go 1.23 is the inclusion of the "range-over-func" experiment as a standard language feature, allowing the "range" clause in "for-range" loops to accept iterator functions. 84 | 85 | Currently only supports text content (articles, code, data), not PDFs, images or audio. 86 | 87 | If you _don't want_ the bot to access the URL, quote it: 88 | 89 | > 🧑 Exact contents of "https://antonz.org/robots.txt" 90 | > 91 | > 🤖 I can't access external websites directly. You can check the contents of a robots.txt file by visiting the URL in your browser. 92 | 93 | ## Shortcuts 94 | 95 | Use short commands to save time and ask the bot to do something specific with your questions. For example, ask it to proofread your writing with a `!proofread` command: 96 | 97 | > 🧑 !proofread I can has write java programz 98 | > 99 | > 🤖 Revised text: "I can write Java programs." 100 | > 101 | > Changes made: 102 | > 103 | > 1. Removed "has" for grammatical correctness; "can" does not require "has." 104 | > 2. Corrected "java" to "Java" for proper capitalization as it is a proper noun. 105 | > 3. Changed "programz" to "programs" for correct spelling and clarity. 106 | 107 | There are several built-in shortcuts: 108 | 109 | - `bugfix` fixes your code. 110 | - `proofread` fixes your writing. 111 | - `translate` translates your text into English. 112 | - `summarize` gives a two paragraph summary of a text. 113 | 114 | You can add your own shortcuts. See `config.example.yml` for details. 115 | 116 | ## Per-chat settings 117 | 118 | You can customize some settings on a per-chat basis. For example, to have one prompt for your private chat with a bot and another prompt for a group chat that you manage. 119 | 120 | ### Prompt 121 | 122 | Set a custom prompt for the current chat with the `/prompt` command: 123 | 124 | > 🧑 /prompt You are an evil genius. Reply with an evil laugh. 125 | 126 | To return to the default prompt, use `/prompt reset`. 127 | 128 | The `/prompt` command in group chats is only available to admins - users listed in the `telegram.admins` property. 129 | 130 | ### Model 131 | 132 | Set a custom AI model for the current chat with the `/model` command: 133 | 134 | > 🧑 /model o3-mini 135 | 136 | To return to the default model, use `/model reset`. 137 | 138 | The `/model` command in group chats is only available to admins - users listed in the `telegram.admins` property. 139 | 140 | ## Other useful features 141 | 142 | The convenience of working with a bot is made up of small details. Here are some situations where it can save you time and effort. 143 | 144 | ### Forwarding 145 | 146 | Say you received a message from a colleague or read a post on a channel and want to ask a question. Simply forward the message to the bot and answer the clarifying question it asks. 147 | 148 | ### Ask with file 149 | 150 | To ask a question about a document, send it as a file and write the question in the caption. The bot will read the file contents and answer. Currently only supports text content (plain text, code, data), not PDFs, images or audio. Sending multiple files is also not supported. 151 | 152 | ### Reply with attachment 153 | 154 | Sometimes the AI's reply exceeds the maximum message length set by Telegram. In this case, the bot will not fail or spam you with messages. Instead, it will send the answer as an attached markdown file. 155 | 156 | ### Edited question 157 | 158 | To rephrase or add to the last question, edit it (`↑` shortcut). The bot will notice this and respond to the clarified question. 159 | 160 | ## Bot information 161 | 162 | Use the `/version` command to print detailed information about the current chat, bot, and AI model: 163 | 164 | ``` 165 | Chat information: 166 | - id: -1001405001234 167 | - title: My Favorite Group 168 | - type: supergroup 169 | 170 | Bot information: 171 | - id: 5930739038 172 | - name: @pokitokibot 173 | - version: 129 174 | - usernames: 6 users 175 | - admins: 1 users 176 | - chat IDs: [] 177 | - access to messages: True 178 | 179 | AI information: 180 | - provider: api.openai.com 181 | - model: gpt-4o-mini 182 | - history depth: 5 183 | - shortcuts: proofread, summarize 184 | ``` 185 | 186 | ## Configuration 187 | 188 | Use the `/config` command to change almost any setting on the fly, without restarting the bot. 189 | 190 | - Add or remove users and chats allowed to interact with the bot (`telegram.usernames` and `telegram.chat_ids`). 191 | - Adjust the AI provider (`openai.url`), API key (`openai.api_key`), model (`openai.model`), prompt (`openai.prompt`) and params (`openai.params`). 192 | - Add or change AI shortcuts (`shortcuts`). 193 | - Change any other config property. 194 | 195 | To view a specific config property, put its name after `/config`: 196 | 197 | ``` 198 | /config telegram.usernames 199 | /config openai.prompt 200 | ``` 201 | 202 | To change a specific config property, put its name and value after `/config`: 203 | 204 | ``` 205 | /config telegram.usernames ["alice", "bob", "cindy"] 206 | /config openai.url https://api.studio.nebius.ai/v1 207 | /config openai.model meta-llama/Meta-Llama-3.1-70B-Instruct 208 | /config openai.prompt "You are an evil AI bot" 209 | ``` 210 | 211 | When working with list properties like `telegram.usernames`, you can add or remove individual items instead of redefining the whole list: 212 | 213 | ``` 214 | /config telegram.usernames +cindy 215 | /config telegram.usernames -bob 216 | ``` 217 | 218 | The `/config` command is only available to admins - users listed in the `telegram.admins` property. 219 | 220 | ## Message limits 221 | 222 | Heated discussions with the bot in large groups can lead to high usage of the AI API. To avoid spending your entire budget, set message limits for groups with the `conversation.message_limit` config property. 223 | 224 | You can limit the number of messages from a user over a period of time. For example: 225 | 226 | - 10 messages per user per hour: 227 | 228 | ```yaml 229 | message_limit: 230 | count: 10 231 | period: hour 232 | ``` 233 | 234 | - 1 message per user per minute: 235 | 236 | ```yaml 237 | message_limit: 238 | count: 1 239 | period: minute 240 | ``` 241 | 242 | - 30 messages per user per day: 243 | 244 | ```yaml 245 | message_limit: 246 | count: 30 247 | period: day 248 | ``` 249 | 250 | ## Setup 251 | 252 | 1. Get your AI API key (from [OpenAI](https://openai.com/api/) or other provider) 253 | 254 | 2. Get your Telegram bot token from [@BotFather](https://t.me/BotFather) 255 | 256 | 3. Clone the repo to the server: 257 | 258 | ```bash 259 | git clone https://github.com/nalgeon/pokitoki.git 260 | ``` 261 | 262 | 4. Copy `config.example.yml` to `config.yml` and specify your tokens there. Add yourself to the `telegram.usernames` property, otherwise the bot will be available to everyone. 263 | 264 | 5. Start the bot: 265 | 266 | ```bash 267 | docker compose up --build --detach 268 | ``` 269 | 270 | To stop the bot: 271 | 272 | ```bash 273 | docker compose stop 274 | ``` 275 | 276 | To update the bot to a new version: 277 | 278 | ```bash 279 | docker compose down 280 | git pull 281 | docker compose up --build --detach 282 | ``` 283 | 284 | ## Development setup 285 | 286 | Prepare the environment: 287 | 288 | ``` 289 | python3 -m venv env 290 | . env/bin/activate 291 | pip install -r requirements.txt 292 | ``` 293 | 294 | Prepare the config and data directory: 295 | 296 | ``` 297 | cp config.example.yml config.yml 298 | mkdir ./data 299 | ``` 300 | 301 | Set tokens and other settings in `config.yml`. 302 | 303 | Run tests: 304 | 305 | ``` 306 | python -m unittest discover 307 | ``` 308 | 309 | Start the bot: 310 | 311 | ``` 312 | python -m bot.bot 313 | ``` 314 | 315 | ## Contributing 316 | 317 | Contributions are welcome. For anything other than bugfixes, please first open an issue to discuss what you want to change. 318 | 319 | Use [Black](https://black.readthedocs.io/en/stable/) for code formatting. Be sure to add or update tests as appropriate. 320 | 321 | ## Changelog 322 | 323 | For feature history, see [releases](https://github.com/nalgeon/pokitoki/releases). 324 | 325 | ## Credits 326 | 327 | Originally based on [chatgpt_telegram_bot](https://github.com/karfly/chatgpt_telegram_bot), although the projects have diverged quite a bit. 328 | -------------------------------------------------------------------------------- /bot/config.py: -------------------------------------------------------------------------------- 1 | """Bot configuration parameters.""" 2 | 3 | import os 4 | from typing import Any 5 | import yaml 6 | import dataclasses 7 | from dataclasses import dataclass 8 | 9 | 10 | @dataclass 11 | class Telegram: 12 | token: str 13 | usernames: list 14 | admins: list 15 | chat_ids: list 16 | 17 | 18 | @dataclass 19 | class OpenAI: 20 | url: str 21 | api_key: str 22 | model: str 23 | image_model: str 24 | window: int 25 | prompt: str 26 | params: dict 27 | 28 | default_url = "https://api.openai.com/v1" 29 | default_model = "gpt-4o-mini" 30 | default_image_model = "dall-e-3" 31 | default_window = 128000 32 | default_prompt = "You are an AI assistant." 33 | default_params = { 34 | "temperature": 0.7, 35 | "max_tokens": 4096, 36 | } 37 | 38 | def __init__( 39 | self, 40 | url: str, 41 | api_key: str, 42 | model: str, 43 | image_model: str, 44 | window: int, 45 | prompt: str, 46 | params: dict, 47 | ) -> None: 48 | self.url = url or self.default_url 49 | self.api_key = api_key 50 | self.model = model or self.default_model 51 | self.image_model = image_model or self.default_image_model 52 | self.window = window or self.default_window 53 | self.prompt = prompt or self.default_prompt 54 | self.params = self.default_params.copy() 55 | self.params.update(params) 56 | 57 | 58 | @dataclass 59 | class RateLimit: 60 | count: int 61 | period: str 62 | 63 | allowed_periods = ("minute", "hour", "day") 64 | default_period = "hour" 65 | 66 | def __init__(self, count: int = 0, period: str = default_period) -> None: 67 | self.count = count 68 | if period not in self.allowed_periods: 69 | period = self.default_period 70 | self.period = period 71 | 72 | def __bool__(self) -> bool: 73 | return self.count > 0 74 | 75 | 76 | @dataclass 77 | class Conversation: 78 | depth: int 79 | message_limit: RateLimit 80 | 81 | default_depth = 3 82 | 83 | def __init__(self, depth: int, message_limit: dict) -> None: 84 | self.depth = depth or self.default_depth 85 | self.message_limit = RateLimit(**message_limit) 86 | 87 | 88 | @dataclass 89 | class Imagine: 90 | enabled: str 91 | 92 | def __init__(self, enabled: str) -> None: 93 | self.enabled = enabled if enabled in ("none", "users_only", "users_and_groups") else "none" 94 | 95 | 96 | class Config: 97 | """Config properties.""" 98 | 99 | # Config schema version. Increments for backward-incompatible changes. 100 | schema_version = 4 101 | # Bot version. 102 | version = 239 103 | 104 | def __init__(self, filename: str, src: dict) -> None: 105 | # Config filename. 106 | self.filename = filename 107 | 108 | # Telegram settings. 109 | self.telegram = Telegram( 110 | token=src["telegram"]["token"], 111 | usernames=src["telegram"].get("usernames") or [], 112 | admins=src["telegram"].get("admins") or [], 113 | chat_ids=src["telegram"].get("chat_ids") or [], 114 | ) 115 | 116 | # OpenAI settings. 117 | self.openai = OpenAI( 118 | url=src["openai"].get("url"), 119 | api_key=src["openai"]["api_key"], 120 | model=src["openai"].get("model"), 121 | image_model=src["openai"].get("image_model"), 122 | window=src["openai"].get("window"), 123 | prompt=src["openai"].get("prompt"), 124 | params=src["openai"].get("params") or {}, 125 | ) 126 | 127 | # Conversation settings. 128 | self.conversation = Conversation( 129 | depth=src["conversation"].get("depth"), 130 | message_limit=src["conversation"].get("message_limit") or {}, 131 | ) 132 | 133 | # Image generation settings. 134 | self.imagine = Imagine(enabled=src["imagine"].get("enabled") or "") 135 | 136 | # Where to store the chat context file. 137 | self.persistence_path = src.get("persistence_path") or "./data/persistence.pkl" 138 | 139 | # Custom AI commands (additional prompts). 140 | self.shortcuts = src.get("shortcuts") or {} 141 | 142 | def as_dict(self) -> dict: 143 | """Converts the config into a dictionary.""" 144 | return { 145 | "schema_version": self.schema_version, 146 | "telegram": dataclasses.asdict(self.telegram), 147 | "openai": dataclasses.asdict(self.openai), 148 | "conversation": dataclasses.asdict(self.conversation), 149 | "imagine": dataclasses.asdict(self.imagine), 150 | "persistence_path": self.persistence_path, 151 | "shortcuts": self.shortcuts, 152 | } 153 | 154 | 155 | class ConfigEditor: 156 | """ 157 | Config properties editor. 158 | Gets/sets config properties by their 'path', 159 | e.g. 'openai.params.temperature' or 'conversation.depth'. 160 | """ 161 | 162 | # These properties cannot be changed at all. 163 | readonly = [ 164 | "schema_version", 165 | "version", 166 | "filename", 167 | ] 168 | # Changes made to these properties take effect immediately. 169 | immediate = [ 170 | "telegram", 171 | "openai", 172 | "conversation", 173 | "imagine", 174 | "shortcuts", 175 | ] 176 | # Changes made to these properties take effect after a restart. 177 | delayed = [ 178 | "telegram.token", 179 | "persistence_path", 180 | ] 181 | # All editable properties. 182 | editable = immediate + delayed 183 | # All known properties. 184 | known = readonly + immediate + delayed 185 | 186 | def __init__(self, config: Config) -> None: 187 | self.config = config 188 | 189 | def get_value(self, property: str) -> Any: 190 | """Returns a config property value.""" 191 | names = property.split(".") 192 | if names[0] not in self.known: 193 | raise ValueError(f"No such property: {property}") 194 | 195 | obj = self.config 196 | for name in names[:-1]: 197 | if not hasattr(obj, name): 198 | raise ValueError(f"No such property: {property}") 199 | obj = getattr(obj, name) 200 | 201 | name = names[-1] 202 | if isinstance(obj, dict): 203 | return obj.get(name) 204 | 205 | if isinstance(obj, object): 206 | if not hasattr(obj, name): 207 | raise ValueError(f"No such property: {property}") 208 | val = getattr(obj, name) 209 | if dataclasses.is_dataclass(val): 210 | return dataclasses.asdict(val) 211 | return val 212 | 213 | raise ValueError(f"Failed to get property: {property}") 214 | 215 | def set_value(self, property: str, value: str) -> tuple[bool, bool]: 216 | """ 217 | Changes a config property value. 218 | Returns a tuple `(has_changed, is_immediate, new_val)` 219 | - `has_changed` = True if the value has actually changed, False otherwise. 220 | - `is_immediate` = True if the change takes effect immediately, False otherwise. 221 | - `new_val` is the new value 222 | """ 223 | try: 224 | val = yaml.safe_load(value) 225 | except Exception: 226 | raise ValueError(f"Invalid value: {value}") 227 | 228 | old_val = self.get_value(property) 229 | if val == old_val: 230 | return False, False, old_val 231 | 232 | if isinstance(old_val, list) and isinstance(val, str): 233 | # allow changing list properties by adding or removing individual items 234 | # e.g. /config telegram.usernames +bob 235 | # or /config telegram.usernames -alice 236 | if val[0] == "+": 237 | item = yaml.safe_load(val[1:]) 238 | val = old_val.copy() 239 | val.append(item) 240 | elif val[0] == "-": 241 | item = yaml.safe_load(val[1:]) 242 | val = old_val.copy() 243 | val.remove(item) 244 | 245 | old_cls = old_val.__class__ 246 | val_cls = val.__class__ 247 | if old_val is not None and old_cls != val_cls: 248 | raise ValueError( 249 | f"Property {property} should be of type {old_cls.__name__}, not {val_cls.__name__}" 250 | ) 251 | 252 | if not isinstance(val, (list, str, int, float, bool)): 253 | raise ValueError(f"Cannot set composite value for property: {property}") 254 | 255 | names = property.split(".") 256 | if names[0] not in self.editable: 257 | raise ValueError(f"Property {property} is not editable") 258 | 259 | is_immediate = property not in self.delayed 260 | 261 | obj = self.config 262 | for name in names[:-1]: 263 | obj = getattr(obj, name, val) 264 | 265 | name = names[-1] 266 | if isinstance(obj, dict): 267 | obj[name] = val 268 | return True, is_immediate, val 269 | 270 | if isinstance(obj, object): 271 | if not hasattr(obj, name): 272 | raise ValueError(f"No such property: {property}") 273 | setattr(obj, name, val) 274 | return True, is_immediate, val 275 | 276 | raise ValueError(f"Failed to set property: {property}") 277 | 278 | def save(self) -> None: 279 | """Saves the config to disk.""" 280 | data = self.config.as_dict() 281 | with open(self.config.filename, "w") as file: 282 | yaml.safe_dump(data, file, indent=4, allow_unicode=True) 283 | 284 | 285 | class SchemaMigrator: 286 | """Migrates the configuration data dictionary according to the schema version.""" 287 | 288 | @classmethod 289 | def migrate(cls, data: dict) -> tuple[dict, bool]: 290 | """Migrates the configuration to the latest schema version.""" 291 | has_changed = False 292 | if data.get("schema_version", 1) == 1: 293 | data = cls._migrate_v1(data) 294 | has_changed = True 295 | if data["schema_version"] == 2: 296 | data = cls._migrate_v2(data) 297 | has_changed = True 298 | if data["schema_version"] == 3: 299 | data = cls._migrate_v3(data) 300 | has_changed = True 301 | return data, has_changed 302 | 303 | @classmethod 304 | def _migrate_v1(cls, old: dict) -> dict: 305 | data = { 306 | "schema_version": 2, 307 | "telegram": None, 308 | "openai": None, 309 | "max_history_depth": old.get("max_history_depth"), 310 | "imagine": old.get("imagine"), 311 | "persistence_path": old.get("persistence_path"), 312 | "shortcuts": old.get("shortcuts"), 313 | } 314 | data["telegram"] = { 315 | "token": old["telegram_token"], 316 | "usernames": old.get("telegram_usernames"), 317 | "chat_ids": old.get("telegram_chat_ids"), 318 | } 319 | data["openai"] = { 320 | "api_key": old["openai_api_key"], 321 | "model": old.get("openai_model"), 322 | } 323 | return data 324 | 325 | @classmethod 326 | def _migrate_v2(cls, old: dict) -> dict: 327 | data = { 328 | "schema_version": 3, 329 | "telegram": old["telegram"], 330 | "openai": old["openai"], 331 | "imagine": old.get("imagine"), 332 | "persistence_path": old.get("persistence_path"), 333 | "shortcuts": old.get("shortcuts"), 334 | } 335 | data["conversation"] = {"depth": old.get("max_history_depth") or Conversation.default_depth} 336 | return data 337 | 338 | def _migrate_v3(old: dict) -> dict: 339 | data = { 340 | "schema_version": 4, 341 | "telegram": old["telegram"], 342 | "openai": old["openai"], 343 | "conversation": old["conversation"], 344 | "persistence_path": old.get("persistence_path"), 345 | "shortcuts": old.get("shortcuts"), 346 | } 347 | imagine_enabled = old.get("imagine") 348 | imagine_enabled = True if imagine_enabled is None else imagine_enabled 349 | data["imagine"] = {"enabled": "users_only" if imagine_enabled else "none"} 350 | return data 351 | 352 | 353 | def load(filename) -> dict: 354 | """Loads the configuration data dictionary from a file.""" 355 | with open(filename, "r") as f: 356 | data = yaml.safe_load(f) 357 | 358 | data, has_changed = SchemaMigrator.migrate(data) 359 | if has_changed: 360 | with open(filename, "w") as f: 361 | yaml.safe_dump(data, f, indent=4, allow_unicode=True) 362 | return data 363 | 364 | 365 | filename = os.getenv("CONFIG", "config.yml") 366 | _config = load(filename) 367 | config = Config(filename, _config) 368 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from bot.config import Config, ConfigEditor, SchemaMigrator 3 | 4 | 5 | class ConfigTest(unittest.TestCase): 6 | def test_init(self): 7 | src = { 8 | "telegram": {"token": "tg-1234", "usernames": ["nalgeon"]}, 9 | "openai": {"api_key": "oa-1234", "model": "gpt-4"}, 10 | "conversation": {"depth": 5}, 11 | "imagine": {"enabled": "none"}, 12 | } 13 | config = Config("config.test.yml", src) 14 | 15 | self.assertEqual(config.telegram.token, "tg-1234") 16 | self.assertEqual(config.telegram.usernames, ["nalgeon"]) 17 | self.assertEqual(config.telegram.admins, []) 18 | self.assertEqual(config.telegram.chat_ids, []) 19 | 20 | self.assertEqual(config.openai.api_key, "oa-1234") 21 | self.assertEqual(config.openai.model, "gpt-4") 22 | self.assertEqual(config.openai.window, 128000) 23 | self.assertTrue(config.openai.prompt, "You are an AI assistant.") 24 | self.assertEqual(config.openai.params["temperature"], 0.7) 25 | self.assertEqual(config.openai.params["max_tokens"], 4096) 26 | 27 | self.assertEqual(config.conversation.depth, 5) 28 | self.assertEqual(config.imagine.enabled, "none") 29 | self.assertEqual(config.persistence_path, "./data/persistence.pkl") 30 | self.assertEqual(config.shortcuts, {}) 31 | 32 | def test_as_dict(self): 33 | src = { 34 | "telegram": {"token": "tg-1234", "usernames": ["nalgeon"]}, 35 | "openai": {"api_key": "oa-1234", "model": "gpt-4"}, 36 | "conversation": {"depth": 5}, 37 | "imagine": {"enabled": "none"}, 38 | } 39 | config = Config("config.test.yml", src) 40 | data = config.as_dict() 41 | self.assertEqual(data["telegram"]["token"], src["telegram"]["token"]) 42 | self.assertEqual(data["telegram"]["usernames"], src["telegram"]["usernames"]) 43 | self.assertEqual(data["telegram"]["admins"], []) 44 | self.assertEqual(data["telegram"]["chat_ids"], []) 45 | self.assertEqual(data["openai"]["api_key"], src["openai"]["api_key"]) 46 | self.assertEqual(data["openai"]["model"], src["openai"]["model"]) 47 | self.assertEqual(data["conversation"]["depth"], src["conversation"]["depth"]) 48 | self.assertEqual(data["imagine"]["enabled"], src["imagine"]["enabled"]) 49 | 50 | 51 | class GetValueTest(unittest.TestCase): 52 | def setUp(self) -> None: 53 | src = { 54 | "telegram": {"token": "tg-1234", "usernames": ["nalgeon"]}, 55 | "openai": {"api_key": "oa-1234", "model": "gpt-4"}, 56 | "conversation": {"depth": 5}, 57 | "imagine": {"enabled": "none"}, 58 | "shortcuts": {"translate": "Translate into English"}, 59 | } 60 | self.editor = ConfigEditor(Config("config.test.yml", src)) 61 | 62 | def test_object(self): 63 | value = self.editor.get_value("telegram") 64 | self.assertEqual( 65 | value, {"token": "tg-1234", "usernames": ["nalgeon"], "admins": [], "chat_ids": []} 66 | ) 67 | 68 | def test_object_attr(self): 69 | value = self.editor.get_value("telegram.token") 70 | self.assertEqual(value, "tg-1234") 71 | 72 | def test_list(self): 73 | value = self.editor.get_value("telegram.usernames") 74 | self.assertEqual(value, ["nalgeon"]) 75 | 76 | def test_dict(self): 77 | value = self.editor.get_value("shortcuts") 78 | self.assertEqual(value, {"translate": "Translate into English"}) 79 | 80 | def test_dict_value(self): 81 | value = self.editor.get_value("shortcuts.translate") 82 | self.assertEqual(value, "Translate into English") 83 | 84 | def test_str(self): 85 | value = self.editor.get_value("persistence_path") 86 | self.assertEqual(value, "./data/persistence.pkl") 87 | 88 | def test_int(self): 89 | value = self.editor.get_value("conversation.depth") 90 | self.assertEqual(value, 5) 91 | 92 | def test_float(self): 93 | value = self.editor.get_value("openai.params.temperature") 94 | self.assertEqual(value, 0.7) 95 | 96 | def test_not_allowed(self): 97 | with self.assertRaises(ValueError): 98 | self.editor.get_value("__class__") 99 | 100 | def test_does_not_exist(self): 101 | with self.assertRaises(ValueError): 102 | self.editor.get_value("quack") 103 | 104 | def test_object_attr_not_exist(self): 105 | with self.assertRaises(ValueError): 106 | self.editor.get_value("telegram.godmode") 107 | 108 | def test_dict_value_does_not_exist(self): 109 | value = self.editor.get_value("shortcuts.bugfix") 110 | self.assertIsNone(value) 111 | 112 | 113 | class SetValueTest(unittest.TestCase): 114 | def setUp(self) -> None: 115 | src = { 116 | "telegram": { 117 | "token": "tg-1234", 118 | "usernames": ["alice", "bob"], 119 | "admins": ["botfather"], 120 | }, 121 | "openai": {"api_key": "oa-1234", "model": "gpt-4"}, 122 | "conversation": {"depth": 5}, 123 | "imagine": {"enabled": "none"}, 124 | "shortcuts": {"translate": "Translate into English"}, 125 | } 126 | self.editor = ConfigEditor(Config("config.test.yml", src)) 127 | self.editor.save = lambda: None 128 | 129 | def test_object(self): 130 | with self.assertRaises(ValueError): 131 | self.editor.set_value("telegram", '{"token": "tg-1234", "usernames": ["alice"]}') 132 | 133 | def test_object_attr(self): 134 | self.editor.set_value("telegram.token", "tg-5678") 135 | value = self.editor.get_value("telegram.token") 136 | self.assertEqual(value, "tg-5678") 137 | 138 | def test_list(self): 139 | self.editor.set_value("telegram.usernames", '["cindy", "dave"]') 140 | value = self.editor.get_value("telegram.usernames") 141 | self.assertEqual(value, ["cindy", "dave"]) 142 | 143 | def test_list_add(self): 144 | self.editor.set_value("telegram.usernames", "+cindy") 145 | value = self.editor.get_value("telegram.usernames") 146 | self.assertEqual(value, ["alice", "bob", "cindy"]) 147 | 148 | def test_list_remove(self): 149 | self.editor.set_value("telegram.usernames", "-bob") 150 | value = self.editor.get_value("telegram.usernames") 151 | self.assertEqual(value, ["alice"]) 152 | 153 | def test_dict(self): 154 | with self.assertRaises(ValueError): 155 | self.editor.set_value("shortcuts.bugfix", '{"bugfix": "Fix bugs in the code"}') 156 | 157 | def test_dict_value(self): 158 | self.editor.set_value("shortcuts.translate", "Translate into Spanish") 159 | value = self.editor.get_value("shortcuts.translate") 160 | self.assertEqual(value, "Translate into Spanish") 161 | 162 | def test_int(self): 163 | self.editor.set_value("openai.params.max_tokens", "500") 164 | value = self.editor.get_value("openai.params.max_tokens") 165 | self.assertEqual(value, 500) 166 | 167 | def test_float(self): 168 | self.editor.set_value("openai.params.temperature", "0.5") 169 | value = self.editor.get_value("openai.params.temperature") 170 | self.assertEqual(value, 0.5) 171 | 172 | def test_invalid_type_1(self): 173 | with self.assertRaises(ValueError): 174 | self.editor.set_value("conversation.depth", "five") 175 | 176 | def test_invalid_type_2(self): 177 | with self.assertRaises(ValueError): 178 | self.editor.set_value("telegram.usernames", "alice") 179 | 180 | def test_invalid_type_3(self): 181 | with self.assertRaises(ValueError): 182 | self.editor.set_value("imagine", "no") 183 | 184 | def test_not_allowed(self): 185 | with self.assertRaises(ValueError): 186 | self.editor.set_value("__class__", "{}") 187 | 188 | def test_readonly(self): 189 | with self.assertRaises(ValueError): 190 | self.editor.set_value("version", "10") 191 | 192 | def test_does_not_exist(self): 193 | with self.assertRaises(ValueError): 194 | self.editor.set_value("quack", "yes") 195 | 196 | def test_object_attr_not_exist(self): 197 | with self.assertRaises(ValueError): 198 | self.editor.set_value("telegram.godmode", "on") 199 | 200 | def test_dict_value_does_not_exist(self): 201 | self.editor.set_value("shortcuts.bugfix", "Fix bugs in the code") 202 | value = self.editor.get_value("shortcuts.bugfix") 203 | self.assertEqual(value, "Fix bugs in the code") 204 | 205 | def test_invalid_value(self): 206 | with self.assertRaises(ValueError): 207 | self.editor.set_value("imagine.enabled", '"users_only') 208 | 209 | def test_has_changed(self): 210 | has_changed, _, new_val = self.editor.set_value("imagine.enabled", "users_only") 211 | self.assertTrue(has_changed) 212 | self.assertEqual(new_val, "users_only") 213 | 214 | def test_has_not_changed(self): 215 | has_changed, _, new_val = self.editor.set_value("imagine.enabled", "none") 216 | self.assertFalse(has_changed) 217 | self.assertEqual(new_val, "none") 218 | 219 | def test_is_immediate_1(self): 220 | _, is_immediate, _ = self.editor.set_value("imagine.enabled", "users_only") 221 | self.assertTrue(is_immediate) 222 | 223 | def test_is_immediate_2(self): 224 | _, is_immediate, _ = self.editor.set_value("telegram.usernames", '["cindy", "dave"]') 225 | self.assertTrue(is_immediate) 226 | 227 | def test_is_delayed(self): 228 | _, is_immediate, _ = self.editor.set_value("telegram.token", "tg-5678") 229 | self.assertFalse(is_immediate) 230 | 231 | 232 | class MigrateTest(unittest.TestCase): 233 | def test_migrate_v1(self): 234 | old = { 235 | "telegram_token": "tg-1234", 236 | "telegram_usernames": ["alice"], 237 | "telegram_chat_ids": [-100500], 238 | "openai_api_key": "oa-1234", 239 | "openai_model": "gpt-3.5-turbo", 240 | "persistence_path": "./data/persistence.pkl", 241 | } 242 | migrated, has_changed = SchemaMigrator.migrate(old) 243 | self.assertTrue(has_changed) 244 | self.assertEqual( 245 | migrated, 246 | { 247 | "schema_version": 4, 248 | "telegram": {"token": "tg-1234", "usernames": ["alice"], "chat_ids": [-100500]}, 249 | "openai": {"api_key": "oa-1234", "model": "gpt-3.5-turbo"}, 250 | "conversation": {"depth": 3}, 251 | "imagine": {"enabled": "users_only"}, 252 | "persistence_path": "./data/persistence.pkl", 253 | "shortcuts": None, 254 | }, 255 | ) 256 | 257 | def test_migrate_v2(self): 258 | old = { 259 | "schema_version": 2, 260 | "telegram": {"token": "tg-1234", "usernames": ["alice"], "chat_ids": [-100500]}, 261 | "openai": {"api_key": "oa-1234", "model": "gpt-3.5-turbo"}, 262 | "max_history_depth": 5, 263 | "imagine": False, 264 | "persistence_path": "./data/persistence.pkl", 265 | "shortcuts": {"bugfix": "Fix bugs"}, 266 | } 267 | migrated, has_changed = SchemaMigrator.migrate(old) 268 | self.assertTrue(has_changed) 269 | self.assertEqual( 270 | migrated, 271 | { 272 | "schema_version": 4, 273 | "telegram": {"token": "tg-1234", "usernames": ["alice"], "chat_ids": [-100500]}, 274 | "openai": {"api_key": "oa-1234", "model": "gpt-3.5-turbo"}, 275 | "conversation": {"depth": 5}, 276 | "imagine": {"enabled": "none"}, 277 | "persistence_path": "./data/persistence.pkl", 278 | "shortcuts": {"bugfix": "Fix bugs"}, 279 | }, 280 | ) 281 | 282 | def test_migrate_v3(self): 283 | old = { 284 | "schema_version": 3, 285 | "telegram": {"token": "tg-1234", "usernames": ["alice"], "chat_ids": [-100500]}, 286 | "openai": {"api_key": "oa-1234", "model": "gpt-3.5-turbo"}, 287 | "conversation": {"depth": 5}, 288 | "imagine": True, 289 | "persistence_path": "./data/persistence.pkl", 290 | "shortcuts": {"bugfix": "Fix bugs"}, 291 | } 292 | migrated, has_changed = SchemaMigrator.migrate(old) 293 | self.assertTrue(has_changed) 294 | self.assertEqual( 295 | migrated, 296 | { 297 | "schema_version": 4, 298 | "telegram": {"token": "tg-1234", "usernames": ["alice"], "chat_ids": [-100500]}, 299 | "openai": {"api_key": "oa-1234", "model": "gpt-3.5-turbo"}, 300 | "conversation": {"depth": 5}, 301 | "imagine": {"enabled": "users_only"}, 302 | "persistence_path": "./data/persistence.pkl", 303 | "shortcuts": {"bugfix": "Fix bugs"}, 304 | }, 305 | ) 306 | 307 | def test_not_changed(self): 308 | old = { 309 | "schema_version": 4, 310 | "telegram": {"token": "tg-1234", "usernames": ["alice"], "chat_ids": [-100500]}, 311 | "openai": {"api_key": "oa-1234", "model": "gpt-3.5-turbo"}, 312 | "conversation": {"depth": 5}, 313 | "imagine": {"enabled": "users_only"}, 314 | "persistence_path": "./data/persistence.pkl", 315 | "shortcuts": {"bugfix": "Fix bugs"}, 316 | } 317 | migrated, has_changed = SchemaMigrator.migrate(old) 318 | self.assertFalse(has_changed) 319 | self.assertEqual(migrated, old) 320 | -------------------------------------------------------------------------------- /tests/test_commands.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | from telegram import Chat, Message, MessageEntity, Update, User 4 | from telegram.constants import ChatType 5 | from telegram.ext import CallbackContext 6 | from telegram.ext import filters as tg_filters 7 | 8 | from bot import askers 9 | from bot import bot 10 | from bot import commands 11 | from bot import models 12 | from bot.config import config 13 | from bot.filters import Filters 14 | from tests.mocks import FakeGPT, FakeDalle, FakeApplication, FakeBot, mock_text_asker 15 | 16 | 17 | class Helper: 18 | def _create_update(self, update_id: int, text: str = None, **kwargs) -> Update: 19 | if "user" in kwargs: 20 | user = kwargs["user"] 21 | del kwargs["user"] 22 | else: 23 | user = self.user 24 | message = Message( 25 | message_id=update_id, 26 | date=dt.datetime.now(), 27 | chat=self.chat, 28 | text=text, 29 | from_user=user, 30 | **kwargs, 31 | ) 32 | message.set_bot(self.bot) 33 | return Update(update_id=update_id, message=message) 34 | 35 | 36 | class StartTest(unittest.IsolatedAsyncioTestCase, Helper): 37 | def setUp(self): 38 | mock_text_asker(FakeGPT()) 39 | self.bot = FakeBot("bot") 40 | self.chat = Chat(id=1, type=ChatType.PRIVATE) 41 | self.chat.set_bot(self.bot) 42 | self.application = FakeApplication(self.bot) 43 | self.context = CallbackContext(self.application, chat_id=1, user_id=1) 44 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice") 45 | config.telegram.usernames = ["alice"] 46 | self.command = commands.Start() 47 | 48 | async def test_start(self): 49 | update = self._create_update(11) 50 | await self.command(update, self.context) 51 | self.assertTrue(self.bot.text.startswith("Hi! I'm a humble AI-driven chat bot.")) 52 | 53 | async def test_start_unknown(self): 54 | user = User(id=2, first_name="Bob", is_bot=False, username="bob") 55 | message = Message( 56 | message_id=11, 57 | date=dt.datetime.now(), 58 | chat=self.chat, 59 | from_user=user, 60 | ) 61 | message.set_bot(self.bot) 62 | update = Update(update_id=11, message=message) 63 | await self.command(update, self.context) 64 | self.assertTrue(self.bot.text.startswith("Sorry, I don't know you")) 65 | 66 | 67 | class HelpTest(unittest.IsolatedAsyncioTestCase, Helper): 68 | def setUp(self): 69 | self.bot = FakeBot("bot") 70 | self.chat = Chat(id=1, type=ChatType.PRIVATE) 71 | self.chat.set_bot(self.bot) 72 | self.application = FakeApplication(self.bot) 73 | self.context = CallbackContext(self.application, chat_id=1, user_id=1) 74 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice") 75 | config.telegram.usernames = ["alice"] 76 | self.command = commands.Help() 77 | 78 | async def test_help(self): 79 | update = self._create_update(11) 80 | await self.command(update, self.context) 81 | self.assertTrue(self.bot.text.startswith("Send me a question")) 82 | 83 | 84 | class VersionTest(unittest.IsolatedAsyncioTestCase, Helper): 85 | def setUp(self): 86 | self.bot = FakeBot("bot") 87 | self.chat = Chat(id=1, type=ChatType.PRIVATE) 88 | self.chat.set_bot(self.bot) 89 | self.application = FakeApplication(self.bot) 90 | self.context = CallbackContext(self.application, chat_id=1, user_id=1) 91 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice") 92 | config.version = 101 93 | config.telegram.usernames = ["alice", "bob"] 94 | config.telegram.admins = ["alice"] 95 | config.telegram.chat_ids = [-100500] 96 | config.openai.url = "https://api.openai.com/v1" 97 | config.openai.model = "gpt-4" 98 | config.conversation.depth = 10 99 | config.imagine.enabled = "none" 100 | config.shortcuts = { 101 | "translate_en": "Translate into English", 102 | "translate_fr": "Translate into French", 103 | } 104 | self.command = commands.Version() 105 | 106 | async def test_version(self): 107 | update = self._create_update(11) 108 | await self.command(update, self.context) 109 | self.assertTrue(self.bot.text.startswith("
Chat information:"))
110 |         self.assertTrue("- name: @bot" in self.bot.text)
111 |         self.assertTrue("- version: 101" in self.bot.text)
112 |         self.assertTrue("- usernames: 2 users" in self.bot.text)
113 |         self.assertTrue("- admins: 1 users" in self.bot.text)
114 |         self.assertTrue("- chat IDs: [-100500]" in self.bot.text)
115 |         self.assertTrue("- access to messages: True" in self.bot.text)
116 | 
117 |         self.assertTrue("
AI information:" in self.bot.text)
118 |         self.assertTrue("- provider: api.openai.com" in self.bot.text)
119 |         self.assertTrue("- model: gpt-4" in self.bot.text)
120 |         self.assertTrue("- history depth: 10" in self.bot.text)
121 |         self.assertTrue("- imagine: none" in self.bot.text)
122 |         self.assertTrue("- shortcuts: translate_en, translate_fr" in self.bot.text)
123 | 
124 | 
125 | class ConfigTest(unittest.IsolatedAsyncioTestCase, Helper):
126 |     def setUp(self):
127 |         self.bot = FakeBot("bot")
128 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
129 |         self.chat.set_bot(self.bot)
130 |         self.application = FakeApplication(self.bot)
131 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
132 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
133 |         config.telegram.usernames = ["alice"]
134 |         config.telegram.admins = ["alice"]
135 |         self.filters = Filters()
136 |         self.command = commands.Config(self.filters)
137 | 
138 |     async def test_help(self):
139 |         update = self._create_update(11, "/config")
140 |         await self.command(update, self.context)
141 |         self.assertTrue(self.bot.text.startswith("Syntax:"))
142 | 
143 |     async def test_view(self):
144 |         config.openai.model = "gpt-3.5-turbo"
145 |         update = self._create_update(11, "/config openai.model")
146 |         await self.command(update, self.context)
147 |         self.assertEqual(self.bot.text, "`gpt-3.5-turbo`")
148 | 
149 |     async def test_change(self):
150 |         commands.config.editor.save = lambda: None
151 |         config.openai.model = "gpt-3.5-turbo"
152 |         update = self._create_update(11, "/config openai.model gpt-4")
153 |         await self.command(update, self.context)
154 |         self.assertTrue(self.bot.text.startswith("✓ Changed the `openai.model` property"))
155 | 
156 |     async def test_conversation_depth(self):
157 |         commands.config.editor.save = lambda: None
158 |         config.conversation.depth = 3
159 |         user = models.UserData(self.context.user_data)
160 |         assert user.messages.messages.maxlen == 3
161 |         update = self._create_update(11, "/config conversation.depth 5")
162 |         await self.command(update, self.context)
163 |         user = models.UserData(self.context.user_data)
164 |         self.assertEqual(user.messages.messages.maxlen, 5)
165 | 
166 |     async def test_not_changed(self):
167 |         commands.config.editor.save = lambda: None
168 |         config.openai.model = "gpt-3.5-turbo"
169 |         update = self._create_update(11, "/config openai.model gpt-3.5-turbo")
170 |         await self.command(update, self.context)
171 |         self.assertEqual(
172 |             self.bot.text,
173 |             "✗ The `openai.model` property already equals to `gpt-3.5-turbo`",
174 |         )
175 | 
176 |     async def test_delayed(self):
177 |         commands.config.editor.save = lambda: None
178 |         config.persistence_path = "./data/persistence.pkl"
179 |         update = self._create_update(11, "/config persistence_path /tmp/data.pkl")
180 |         await self.command(update, self.context)
181 |         self.assertTrue("Restart the bot" in self.bot.text)
182 | 
183 |     async def test_telegram_usernames(self):
184 |         update = self._create_update(11, '/config telegram.usernames ["alice", "bob"]')
185 |         await self.command(update, self.context)
186 |         self.assertEqual(self.filters.users.usernames, frozenset(["alice", "bob"]))
187 | 
188 |     async def test_telegram_admins(self):
189 |         update = self._create_update(11, '/config telegram.admins ["alice", "bob"]')
190 |         await self.command(update, self.context)
191 |         self.assertEqual(self.filters.admins.usernames, frozenset(["alice", "bob"]))
192 | 
193 |     async def test_telegram_chat_ids(self):
194 |         update = self._create_update(11, "/config telegram.chat_ids [-100500]")
195 |         await self.command(update, self.context)
196 |         self.assertEqual(self.filters.chats.chat_ids, frozenset([-100500]))
197 | 
198 | 
199 | class ModelPrivateTest(unittest.IsolatedAsyncioTestCase, Helper):
200 |     def setUp(self):
201 |         self.bot = FakeBot("bot")
202 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
203 |         self.chat.set_bot(self.bot)
204 |         self.application = FakeApplication(self.bot)
205 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
206 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
207 |         config.telegram.usernames = ["alice"]
208 |         self.command = commands.Model()
209 | 
210 |     async def test_help(self):
211 |         update = self._create_update(11, "/model")
212 |         await self.command(update, self.context)
213 |         self.assertTrue(self.bot.text.startswith("Syntax:"))
214 |         self.assertEqual(self.application.chat_data[1], {})
215 | 
216 |     async def test_set(self):
217 |         update = self._create_update(11, text="/model gpt-5")
218 |         await self.command(update, self.context)
219 |         self.assertTrue(self.bot.text.startswith("✓ Set model"))
220 |         self.assertEqual(self.application.chat_data[1]["model"], "gpt-5")
221 | 
222 |     async def test_show(self):
223 |         update = self._create_update(11, text="/model gpt-5")
224 |         await self.command(update, self.context)
225 | 
226 |         update = self._create_update(11, text="/model")
227 |         await self.command(update, self.context)
228 |         self.assertTrue(self.bot.text.startswith("Using model"))
229 |         self.assertEqual(self.application.chat_data[1]["model"], "gpt-5")
230 | 
231 |     async def test_reset(self):
232 |         update = self._create_update(11, text="/model gpt-5")
233 |         await self.command(update, self.context)
234 | 
235 |         update = self._create_update(11, text="/model reset")
236 |         await self.command(update, self.context)
237 |         self.assertTrue(self.bot.text.startswith("✓ Using default model"))
238 |         self.assertEqual(self.application.chat_data[1]["model"], "")
239 | 
240 | 
241 | class ModelGroupTest(unittest.IsolatedAsyncioTestCase, Helper):
242 |     def setUp(self):
243 |         self.bot = FakeBot("bot")
244 |         self.chat = Chat(id=1, type=ChatType.GROUP)
245 |         self.chat.set_bot(self.bot)
246 |         self.application = FakeApplication(self.bot)
247 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
248 |         self.alice = User(id=1, first_name="Alice", is_bot=False, username="alice")
249 |         self.bob = User(id=2, first_name="Bob", is_bot=False, username="bob")
250 |         config.telegram.admins = ["alice"]
251 |         self.command = commands.Model()
252 | 
253 |     async def test_allowed(self):
254 |         update = self._create_update(11, "/model@bot gpt-5", user=self.alice)
255 |         await self.command(update, self.context)
256 |         self.assertTrue(self.bot.text.startswith("✓ Set model"))
257 |         self.assertEqual(self.application.chat_data[1]["model"], "gpt-5")
258 | 
259 |     async def test_not_allowed(self):
260 |         update = self._create_update(11, "/model@bot gpt-5", user=self.bob)
261 |         await self.command(update, self.context)
262 |         self.assertEqual(self.bot.text, "")
263 |         self.assertEqual(self.application.chat_data[1], {})
264 | 
265 | 
266 | class PromptPrivateTest(unittest.IsolatedAsyncioTestCase, Helper):
267 |     def setUp(self):
268 |         self.bot = FakeBot("bot")
269 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
270 |         self.chat.set_bot(self.bot)
271 |         self.application = FakeApplication(self.bot)
272 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
273 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
274 |         config.telegram.usernames = ["alice"]
275 |         self.command = commands.Prompt()
276 | 
277 |     async def test_not_set(self):
278 |         update = self._create_update(11, "/prompt")
279 |         await self.command(update, self.context)
280 |         self.assertTrue(self.bot.text.startswith("Syntax:"))
281 |         self.assertEqual(self.application.chat_data[1], {})
282 | 
283 |     async def test_set_custom(self):
284 |         update = self._create_update(11, text="/prompt Laugh")
285 |         await self.command(update, self.context)
286 |         self.assertTrue(self.bot.text.startswith("✓ Set custom prompt"))
287 |         self.assertEqual(self.application.chat_data[1]["prompt"], "Laugh")
288 | 
289 |     async def test_show_custom(self):
290 |         update = self._create_update(11, text="/prompt Laugh")
291 |         await self.command(update, self.context)
292 | 
293 |         update = self._create_update(11, text="/prompt")
294 |         await self.command(update, self.context)
295 |         self.assertTrue(self.bot.text.startswith("Using custom prompt"))
296 |         self.assertEqual(self.application.chat_data[1]["prompt"], "Laugh")
297 | 
298 |     async def test_reset(self):
299 |         update = self._create_update(11, text="/prompt Laugh")
300 |         await self.command(update, self.context)
301 | 
302 |         update = self._create_update(11, text="/prompt reset")
303 |         await self.command(update, self.context)
304 |         self.assertTrue(self.bot.text.startswith("✓ Using default prompt"))
305 |         self.assertEqual(self.application.chat_data[1]["prompt"], "")
306 | 
307 | 
308 | class PromptGroupTest(unittest.IsolatedAsyncioTestCase, Helper):
309 |     def setUp(self):
310 |         self.bot = FakeBot("bot")
311 |         self.chat = Chat(id=1, type=ChatType.GROUP)
312 |         self.chat.set_bot(self.bot)
313 |         self.application = FakeApplication(self.bot)
314 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
315 |         self.alice = User(id=1, first_name="Alice", is_bot=False, username="alice")
316 |         self.bob = User(id=2, first_name="Bob", is_bot=False, username="bob")
317 |         config.telegram.admins = ["alice"]
318 |         self.command = commands.Prompt()
319 | 
320 |     async def test_allowed(self):
321 |         update = self._create_update(11, "/prompt@bot Laugh", user=self.alice)
322 |         await self.command(update, self.context)
323 |         self.assertTrue(self.bot.text.startswith("✓ Set custom prompt"))
324 |         self.assertEqual(self.application.chat_data[1]["prompt"], "Laugh")
325 | 
326 |     async def test_not_allowed(self):
327 |         update = self._create_update(11, "/prompt@bot Laugh", user=self.bob)
328 |         await self.command(update, self.context)
329 |         self.assertEqual(self.bot.text, "")
330 |         self.assertEqual(self.application.chat_data[1], {})
331 | 
332 | 
333 | class RetryTest(unittest.IsolatedAsyncioTestCase, Helper):
334 |     def setUp(self):
335 |         mock_text_asker(FakeGPT())
336 |         self.bot = FakeBot("bot")
337 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
338 |         self.chat.set_bot(self.bot)
339 |         self.application = FakeApplication(self.bot)
340 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
341 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
342 |         self.command = commands.Retry(bot.reply_to)
343 |         config.telegram.usernames = ["alice"]
344 | 
345 |     async def test_retry(self):
346 |         user_data = models.UserData(self.context.user_data)
347 |         user_data.messages.add("What is your name?", "My name is AI.")
348 |         update = self._create_update(11)
349 |         await self.command(update, self.context)
350 |         self.assertEqual(self.bot.text, "What is your name?")
351 | 
352 | 
353 | class ImagineTest(unittest.IsolatedAsyncioTestCase, Helper):
354 |     def setUp(self):
355 |         askers.ImagineAsker.model = FakeDalle()
356 |         self.bot = FakeBot("bot")
357 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
358 |         self.chat.set_bot(self.bot)
359 |         self.application = FakeApplication(self.bot)
360 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
361 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
362 |         self.command = commands.Imagine(bot.reply_to)
363 |         config.telegram.usernames = ["alice"]
364 | 
365 |     async def test_imagine(self):
366 |         config.imagine.enabled = "users_only"
367 |         update = self._create_update(11, "/imagine a cat")
368 |         self.context.args = ["a", "cat"]
369 |         await self.command(update, self.context)
370 |         self.assertEqual(self.bot.text, "a cat: image")
371 | 
372 |     async def test_disabled(self):
373 |         config.imagine.enabled = "none"
374 |         update = self._create_update(11, "/imagine a cat")
375 |         self.context.args = ["a", "cat"]
376 |         await self.command(update, self.context)
377 |         self.assertTrue("command is disabled" in self.bot.text)
378 | 
379 |     async def test_users_only(self):
380 |         config.imagine.enabled = "users_only"
381 |         user = User(id=2, first_name="Bob", is_bot=False, username="bob")
382 |         update = self._create_update(11, "/imagine a cat", user=user)
383 |         self.context.args = ["a", "cat"]
384 |         await self.command(update, self.context)
385 |         self.assertTrue("command is disabled" in self.bot.text)
386 | 
387 |     async def test_users_and_groups(self):
388 |         config.imagine.enabled = "users_and_groups"
389 |         user = User(id=2, first_name="Bob", is_bot=False, username="bob")
390 |         update = self._create_update(11, "/imagine a cat", user=user)
391 |         self.context.args = ["a", "cat"]
392 |         await self.command(update, self.context)
393 |         self.assertEqual(self.bot.text, "a cat: image")
394 | 
395 | 
396 | class MessageTest(unittest.IsolatedAsyncioTestCase, Helper):
397 |     def setUp(self):
398 |         self.ai = FakeGPT()
399 |         mock_text_asker(self.ai)
400 |         self.bot = FakeBot("bot")
401 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
402 |         self.chat.set_bot(self.bot)
403 |         self.application = FakeApplication(self.bot)
404 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
405 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
406 |         self.command = commands.Message(bot.reply_to)
407 |         config.telegram.usernames = ["alice"]
408 | 
409 |     async def test_message(self):
410 |         update = self._create_update(11, text="What is your name?")
411 |         await self.command(update, self.context)
412 |         self.assertEqual(self.bot.text, "What is your name?")
413 |         self.assertEqual(self.ai.question, "What is your name?")
414 |         self.assertEqual(self.ai.history, [])
415 | 
416 |     async def test_follow_up(self):
417 |         update = self._create_update(11, text="What is your name?")
418 |         await self.command(update, self.context)
419 |         self.assertEqual(self.ai.question, "What is your name?")
420 |         self.assertEqual(self.ai.history, [])
421 | 
422 |         update = self._create_update(12, text="+ And why is that?")
423 |         await self.command(update, self.context)
424 |         self.assertEqual(self.ai.question, "And why is that?")
425 |         self.assertEqual(self.ai.history, [("What is your name?", "What is your name?")])
426 | 
427 |         update = self._create_update(13, text="+ Where are you?")
428 |         await self.command(update, self.context)
429 |         self.assertEqual(self.ai.question, "Where are you?")
430 |         self.assertEqual(
431 |             self.ai.history,
432 |             [
433 |                 ("What is your name?", "What is your name?"),
434 |                 ("+ And why is that?", "And why is that?"),
435 |             ],
436 |         )
437 | 
438 |     async def test_forward(self):
439 |         update = self._create_update(11, text="What is your name?", forward_date=dt.datetime.now())
440 |         await self.command(update, self.context)
441 |         self.assertTrue(self.bot.text.startswith("This is a forwarded message"))
442 | 
443 |     async def test_document(self):
444 |         update = self._create_update(11, text="I have so much to say" + "." * 5000)
445 |         await self.command(update, self.context)
446 |         self.assertEqual(self.bot.text, "I have so much to... (see attachment for the rest): 11.md")
447 | 
448 |     async def test_exception(self):
449 |         ai = FakeGPT(error=Exception("connection timeout"))
450 |         mock_text_asker(ai)
451 |         update = self._create_update(11, text="What is your name?")
452 |         await self.command(update, self.context)
453 |         self.assertTrue(self.bot.text.startswith("⚠️ builtins.Exception:"))
454 |         self.assertTrue("connection timeout" in self.bot.text)
455 | 
456 | 
457 | class MessageGroupTest(unittest.IsolatedAsyncioTestCase, Helper):
458 |     def setUp(self):
459 |         mock_text_asker(FakeGPT())
460 |         self.bot = FakeBot("bot")
461 |         self.chat = Chat(id=1, type=ChatType.GROUP)
462 |         self.chat.set_bot(self.bot)
463 |         self.application = FakeApplication(self.bot)
464 |         self.application.user_data[1] = {}
465 |         self.application.user_data[2] = {}
466 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
467 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
468 |         self.user_erik = User(id=2, first_name="Erik", is_bot=False, username="erik")
469 |         self.user_bot = User(id=42, first_name="Bot", is_bot=True, username="bot")
470 |         self.command = commands.Message(bot.reply_to)
471 |         config.telegram.usernames = ["alice"]
472 | 
473 |     async def test_message(self):
474 |         mention = MessageEntity(type=MessageEntity.MENTION, offset=0, length=4)
475 |         update = self._create_update(11, text="@bot What is your name?", entities=(mention,))
476 |         await self.command(update, self.context)
477 |         self.assertEqual(self.bot.text, "What is your name?")
478 | 
479 |     async def test_no_mention(self):
480 |         update = self._create_update(11, text="What is your name?")
481 |         await self.command(update, self.context)
482 |         self.assertEqual(self.bot.text, "")
483 | 
484 | 
485 | class MessageLimitTest(unittest.IsolatedAsyncioTestCase, Helper):
486 |     def setUp(self):
487 |         self.ai = FakeGPT()
488 |         mock_text_asker(self.ai)
489 |         self.bot = FakeBot("bot")
490 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
491 |         self.chat.set_bot(self.bot)
492 |         self.application = FakeApplication(self.bot)
493 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
494 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
495 |         self.command = commands.Message(bot.reply_to)
496 |         config.telegram.usernames = ["alice"]
497 |         config.conversation.message_limit.count = 1
498 |         config.conversation.message_limit.period = "minute"
499 |         # a hack for testing purposes only
500 |         bot.filters.users = tg_filters.User(username=config.telegram.usernames)
501 | 
502 |     async def test_known_user(self):
503 |         update = self._create_update(11, text="What is your name?")
504 |         await self.command(update, self.context)
505 |         self.assertEqual(self.bot.text, "What is your name?")
506 | 
507 |         update = self._create_update(12, text="Where are you from?")
508 |         await self.command(update, self.context)
509 |         self.assertEqual(self.bot.text, "Where are you from?")
510 | 
511 |     async def test_unknown_user(self):
512 |         other_user = User(id=2, first_name="Bob", is_bot=False, username="bob")
513 | 
514 |         update = self._create_update(11, text="What is your name?", user=other_user)
515 |         await self.command(update, self.context)
516 |         self.assertEqual(self.bot.text, "What is your name?")
517 | 
518 |         update = self._create_update(12, text="Where are you from?", user=other_user)
519 |         await self.command(update, self.context)
520 |         self.assertTrue(self.bot.text.startswith("Please wait"))
521 | 
522 |     async def test_expired(self):
523 |         config.conversation.message_limit.count = 3
524 | 
525 |         user = User(id=2, first_name="Bob", is_bot=False, username="bob")
526 |         # the counter has reached the limit, but the value has expired
527 |         user_data = {
528 |             "message_counter": {
529 |                 "value": 3,
530 |                 "timestamp": dt.datetime.now() - dt.timedelta(hours=1),
531 |             }
532 |         }
533 |         self.application.user_data[user.id] = user_data
534 |         context = CallbackContext(self.application, chat_id=1, user_id=user.id)
535 | 
536 |         update = self._create_update(11, text="What is your name?", user=user)
537 |         await self.command(update, context)
538 |         self.assertEqual(self.bot.text, "What is your name?")
539 |         self.assertEqual(user_data["message_counter"]["value"], 1)
540 | 
541 |     async def test_unlimited(self):
542 |         config.conversation.message_limit.count = 0
543 |         other_user = User(id=2, first_name="Bob", is_bot=False, username="bob")
544 | 
545 |         update = self._create_update(11, text="What is your name?", user=other_user)
546 |         await self.command(update, self.context)
547 |         self.assertEqual(self.bot.text, "What is your name?")
548 | 
549 |         update = self._create_update(12, text="Where are you from?", user=other_user)
550 |         await self.command(update, self.context)
551 |         self.assertEqual(self.bot.text, "Where are you from?")
552 | 
553 | 
554 | class ErrorTest(unittest.IsolatedAsyncioTestCase, Helper):
555 |     def setUp(self):
556 |         mock_text_asker(FakeGPT())
557 |         self.bot = FakeBot("bot")
558 |         self.chat = Chat(id=1, type=ChatType.PRIVATE)
559 |         self.chat.set_bot(self.bot)
560 |         self.application = FakeApplication(self.bot)
561 |         self.context = CallbackContext(self.application, chat_id=1, user_id=1)
562 |         self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
563 |         config.telegram.usernames = ["alice"]
564 | 
565 |     async def test_error(self):
566 |         self.context.error = Exception("Something went wrong")
567 |         command = commands.Error()
568 |         update = self._create_update(11, "Something went wrong")
569 |         update._effective_chat = self.chat
570 |         await command(update, self.context)
571 |         self.assertEqual(self.bot.text, "⚠️ builtins.Exception: Something went wrong")
572 | 


--------------------------------------------------------------------------------