├── bot ├── __init__.py ├── ai │ ├── __init__.py │ ├── images.py │ └── chat.py ├── commands │ ├── __init__.py │ ├── retry.py │ ├── constants.py │ ├── start.py │ ├── message.py │ ├── error.py │ ├── help.py │ ├── imagine.py │ ├── model.py │ ├── version.py │ ├── prompt.py │ └── config.py ├── shortcuts.py ├── cli.py ├── markdown.py ├── filters.py ├── fetcher.py ├── askers.py ├── questions.py ├── models.py ├── bot.py └── config.py ├── tests ├── __init__.py ├── test_shortcuts.py ├── test_filters.py ├── test_markdown.py ├── mocks.py ├── test_ai.py ├── test_fetcher.py ├── test_askers.py ├── test_questions.py ├── test_models.py ├── test_config.py └── test_commands.py ├── requirements.txt ├── Dockerfile ├── compose.yml ├── Makefile ├── .github └── workflows │ └── build.yml ├── LICENSE ├── config.example.yml ├── .gitignore └── README.md /bot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bot/ai/__init__.py: -------------------------------------------------------------------------------- 1 | from . import chat 2 | from . import images 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | httpcore==1.0.2 2 | httpx==0.25.1 3 | beautifulsoup4==4.12.2 4 | python-telegram-bot==20.6 5 | PyYAML==6.0.1 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-alpine 2 | 3 | ARG PIP_DISABLE_PIP_VERSION_CHECK=1 4 | ARG PIP_NO_CACHE_DIR=1 5 | ENV PYTHONDONTWRITEBYTECODE=1 6 | ENV PYTHONUNBUFFERED=1 7 | 8 | WORKDIR /code 9 | COPY ./requirements.txt . 10 | RUN pip3 install -r requirements.txt 11 | 12 | COPY ./bot ./bot 13 | 14 | CMD ["python","-m","bot.bot"] 15 | -------------------------------------------------------------------------------- /compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | pokitoki: 3 | container_name: pokitoki 4 | restart: unless-stopped 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | volumes: 9 | - ./config.yml:/code/config.yml 10 | - data:/code/data 11 | volumes: 12 | data: 13 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: start stop test 2 | .SILENT: start stop test 3 | 4 | start: 5 | CONFIG=config.$(name).yml nohup env/bin/python -m bot.bot > $(name).log 2>&1 & echo $$! > $(name).pid 6 | echo "Started $(name) bot" 7 | 8 | stop: 9 | kill $(shell cat $(name).pid) 10 | rm -f $(name).pid 11 | echo "Stopped $(name) bot" 12 | 13 | test: 14 | env/bin/python -m unittest discover 15 | -------------------------------------------------------------------------------- /bot/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import BOT_COMMANDS 2 | from .config import ConfigCommand as Config 3 | from .error import ErrorCommand as Error 4 | from .help import HelpCommand as Help 5 | from .imagine import ImagineCommand as Imagine 6 | from .message import MessageCommand as Message 7 | from .model import ModelCommand as Model 8 | from .prompt import PromptCommand as Prompt 9 | from .retry import RetryCommand as Retry 10 | from .start import StartCommand as Start 11 | from .version import VersionCommand as Version 12 | -------------------------------------------------------------------------------- /bot/commands/retry.py: -------------------------------------------------------------------------------- 1 | """/retry command.""" 2 | 3 | from typing import Awaitable 4 | from telegram import Update 5 | from telegram.ext import CallbackContext 6 | from bot.models import UserData 7 | 8 | 9 | class RetryCommand: 10 | """Retries asking the last question (if any).""" 11 | 12 | def __init__(self, reply_func: Awaitable) -> None: 13 | self.reply_func = reply_func 14 | 15 | async def __call__(self, update: Update, context: CallbackContext) -> None: 16 | user = UserData(context.user_data) 17 | last_message = user.messages.pop() 18 | if not last_message: 19 | await update.message.reply_text("No message to retry 🤷♂️") 20 | return 21 | await self.reply_func( 22 | update=update, message=update.message, context=context, question=last_message.question 23 | ) 24 | -------------------------------------------------------------------------------- /bot/shortcuts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Working with shortcuts. 3 | A shortcut is an action that preprocesses a question before asking it of the AI. 4 | """ 5 | 6 | import re 7 | from bot.config import config 8 | 9 | shortcut_re = re.compile(r"^!(\w+)\b") 10 | 11 | 12 | def extract(question: str) -> tuple[str, str]: 13 | """Extracts a shortcut from the question.""" 14 | match = shortcut_re.match(question) 15 | if not match: 16 | raise ValueError("failed to extract shortcut") 17 | name = match.group(1) 18 | question = question.removeprefix(match.group(0)).strip() 19 | return name, question 20 | 21 | 22 | def apply(name: str, question: str) -> str: 23 | """Applies a given shortcut to a text.""" 24 | prompt = config.shortcuts.get(name) 25 | if not prompt: 26 | raise ValueError(f"unknown shortcut: {name}") 27 | return f"{prompt}\n\n{question}" 28 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | paths: 7 | - .github/** 8 | - bot/** 9 | - tests/** 10 | pull_request: 11 | branches: [main] 12 | workflow_dispatch: 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.11" 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -r requirements.txt 29 | 30 | - name: Run tests 31 | run: | 32 | cp config.example.yml config.yml 33 | python -m unittest discover 34 | -------------------------------------------------------------------------------- /tests/test_shortcuts.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from bot.config import config 4 | from bot import shortcuts 5 | 6 | 7 | class ExtractTest(unittest.TestCase): 8 | def test_extract(self): 9 | name, question = shortcuts.extract("!translate Ciao") 10 | self.assertEqual(name, "translate") 11 | self.assertEqual(question, "Ciao") 12 | 13 | def test_failed(self): 14 | config.shortcuts["translate"] = "Translate into English." 15 | with self.assertRaises(ValueError): 16 | shortcuts.extract("Ciao") 17 | 18 | 19 | class ApplyTest(unittest.TestCase): 20 | def test_apply(self): 21 | config.shortcuts["translate"] = "Translate into English." 22 | question = shortcuts.apply("translate", "Ciao") 23 | self.assertEqual(question, "Translate into English.\n\nCiao") 24 | 25 | def test_unknown_shortcut(self): 26 | with self.assertRaises(ValueError): 27 | shortcuts.apply("sing", "Ciao") 28 | -------------------------------------------------------------------------------- /bot/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command-line interface to the OpenAI API. 3 | 4 | Usage example: 5 | $ python -m bot.cli "What is your name?" 6 | """ 7 | 8 | import asyncio 9 | import os 10 | import sys 11 | import textwrap 12 | 13 | from bot.config import config 14 | from bot.fetcher import Fetcher 15 | import bot.ai.chat 16 | 17 | 18 | async def main(question): 19 | print(f"> {question}") 20 | fetcher = Fetcher() 21 | question = await fetcher.substitute_urls(question) 22 | ai = init_model() 23 | answer = await ai.ask(prompt=config.openai.prompt, question=question, history=[]) 24 | await fetcher.close() 25 | lines = textwrap.wrap(answer, width=60) 26 | for line in lines: 27 | print(line) 28 | 29 | 30 | def init_model(): 31 | name = os.getenv("OPENAI_MODEL") or config.openai.model 32 | return bot.ai.chat.Model(name) 33 | 34 | 35 | if __name__ == "__main__": 36 | if len(sys.argv) == 0: 37 | exit(1) 38 | asyncio.run(main(sys.argv[1])) 39 | -------------------------------------------------------------------------------- /bot/commands/constants.py: -------------------------------------------------------------------------------- 1 | """Bot command constants.""" 2 | 3 | HELP_MESSAGE = """Send me a question, and I will do my best to answer it. Please be specific, as I'm not very clever. 4 | 5 | I don't remember chat context by default. To ask follow-up questions, reply to my messages or start your questions with a '+' sign. 6 | 7 | Built-in commands: 8 | {commands}{admin_commands} 9 | 10 | AI shortcuts: 11 | {shortcuts} 12 | 13 | [More features →](https://github.com/nalgeon/pokitoki#readme) 14 | """ 15 | 16 | PRIVACY_MESSAGE = ( 17 | "⚠️ The bot does not have access to group messages, " 18 | "so it cannot reply in groups. Use @botfather " 19 | "to give the bot access (Bot Settings > Group Privacy > Turn off)" 20 | ) 21 | 22 | BOT_COMMANDS = [ 23 | ("retry", "retry the last question"), 24 | ("imagine", "generate described image"), 25 | ("version", "show debug information"), 26 | ("help", "show help"), 27 | ] 28 | 29 | ADMIN_COMMANDS = { 30 | "config": "view or edit the config", 31 | } 32 | -------------------------------------------------------------------------------- /bot/ai/images.py: -------------------------------------------------------------------------------- 1 | """OpenAI-compatible image generation model.""" 2 | 3 | import httpx 4 | from bot.config import config 5 | 6 | client = httpx.AsyncClient(timeout=60.0) 7 | 8 | 9 | class Model: 10 | """AI API wrapper.""" 11 | 12 | async def imagine(self, prompt: str, size: str) -> str: 13 | """Generates an image of the specified size according to the description.""" 14 | response = await client.post( 15 | f"{config.openai.url}/images/generations", 16 | headers={"Authorization": f"Bearer {config.openai.api_key}"}, 17 | json={ 18 | "model": config.openai.image_model, 19 | "prompt": prompt, 20 | "size": size, 21 | "n": 1, 22 | }, 23 | ) 24 | resp = response.json() 25 | if "data" not in resp: 26 | raise Exception(resp) 27 | if len(resp["data"]) == 0: 28 | raise Exception("received an empty answer") 29 | return resp["data"][0]["url"] 30 | -------------------------------------------------------------------------------- /bot/commands/start.py: -------------------------------------------------------------------------------- 1 | """/start command.""" 2 | 3 | from telegram import Update 4 | from telegram.ext import CallbackContext 5 | from telegram.constants import ParseMode 6 | 7 | from bot.config import config 8 | from . import constants 9 | from . import help 10 | 11 | 12 | class StartCommand: 13 | """Answers the `start` command.""" 14 | 15 | async def __call__(self, update: Update, context: CallbackContext) -> None: 16 | if update.effective_user.username not in config.telegram.usernames: 17 | text = ( 18 | "Sorry, I don't know you. To setup your own bot, " 19 | "visit https://github.com/nalgeon/pokitoki" 20 | ) 21 | await update.message.reply_text(text) 22 | return 23 | 24 | text = "Hi! I'm a humble AI-driven chat bot.\n\n" 25 | text += help.generate_message(update.effective_user.username) 26 | if not context.bot.can_read_all_group_messages: 27 | text += f"\n\n{constants.PRIVACY_MESSAGE}" 28 | await update.message.reply_text( 29 | text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Karim Iskakov, Anton Zhiyanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bot/commands/message.py: -------------------------------------------------------------------------------- 1 | """Text message handler.""" 2 | 3 | import logging 4 | from typing import Awaitable 5 | from telegram import Chat, Update 6 | from telegram.ext import CallbackContext 7 | from bot import questions 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class MessageCommand: 13 | """Answers a question from the user.""" 14 | 15 | def __init__(self, reply_func: Awaitable) -> None: 16 | self.reply_func = reply_func 17 | 18 | async def __call__(self, update: Update, context: CallbackContext) -> None: 19 | message = update.message or update.edited_message 20 | logger.debug(update) 21 | 22 | # the bot is meant to answer questions in private chats, 23 | # but it can also answer a specific question in a group when mentioned 24 | if message.chat.type == Chat.PRIVATE: 25 | question = await questions.extract_private(message, context) 26 | else: 27 | question, message = await questions.extract_group(message, context) 28 | 29 | if not question: 30 | # this is not a question to the bot, so ignore it 31 | return 32 | 33 | await self.reply_func(update=update, message=message, context=context, question=question) 34 | -------------------------------------------------------------------------------- /bot/markdown.py: -------------------------------------------------------------------------------- 1 | """Markdown/HTML text formatting.""" 2 | 3 | import re 4 | 5 | # Code blocks, e.g.: 6 | # ```sql 7 | # select count(*) from messages; 8 | # ``` 9 | pre_re = re.compile(r"^[ ]*```\w*$(.+?)^```$", re.MULTILINE | re.DOTALL) 10 | 11 | # Inline code, e.g.: 12 | # `print(message)` displays the message. 13 | code_re = re.compile(r"`([^`\n]+)`") 14 | 15 | # Bold text, e.g.: 16 | # **Note**. Cats **do** like boxes. 17 | bold_re = re.compile(r"\*\*([^<*]+?)\*\*") 18 | 19 | # Unordered list items, e.g.: 20 | # * Wake up. 21 | # * Have breakfast. 22 | # * Go to work. 23 | bullet_re = re.compile(r"^\*\s\s+(.+)$", re.MULTILINE) 24 | 25 | 26 | def to_html(text: str) -> str: 27 | """ 28 | Converts Markdown text to "Telegram HTML", which supports only some of the tags. 29 | See https://core.telegram.org/bots/api#html-style for details. 30 | Escapes certain entities and converts `code` and `pre`, 31 | but ignores all other formatting. 32 | """ 33 | text = text.replace("&", "&").replace("<", "<").replace(">", ">") 34 | text = pre_re.sub(r"
\1", text) 35 | text = code_re.sub(r"
\1", text)
36 | text = bold_re.sub(r"\1", text)
37 | text = bullet_re.sub(r"— \1", text)
38 | return text
39 |
--------------------------------------------------------------------------------
/bot/commands/error.py:
--------------------------------------------------------------------------------
1 | """Generic error handler."""
2 |
3 | import logging
4 | import textwrap
5 | from telegram import Chat, Update
6 | from telegram.ext import CallbackContext
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class ErrorCommand:
12 | """If the bot failed to answer, prints the error and the stack trace (if any)."""
13 |
14 | async def __call__(self, update: Update, context: CallbackContext) -> None:
15 | if not update:
16 | # telegram.error.NetworkError or a similar error, there is no chat to respond to.
17 | # Not sure if we should completely silence such errors.
18 | logger.warning("General exception: %s:", context.error)
19 | return
20 |
21 | class_name = f"{context.error.__class__.__module__}.{context.error.__class__.__qualname__}"
22 | error_text = f"{class_name}: {context.error}"
23 | logger.warning("Exception while handling an update %s: %s", update, error_text)
24 | text = textwrap.shorten(f"⚠️ {error_text}", width=255, placeholder="...")
25 |
26 | message = update.message
27 | reply_to_message_id = message.id if message and message.chat.type != Chat.PRIVATE else None
28 | await context.bot.send_message(
29 | update.effective_chat.id, text, reply_to_message_id=reply_to_message_id
30 | )
31 |
--------------------------------------------------------------------------------
/bot/commands/help.py:
--------------------------------------------------------------------------------
1 | """/help command."""
2 |
3 | from telegram import Update
4 | from telegram.ext import CallbackContext
5 | from telegram.constants import ParseMode
6 | from bot.config import config
7 | from . import constants
8 |
9 |
10 | class HelpCommand:
11 | """Answers the `help` command."""
12 |
13 | async def __call__(self, update: Update, context: CallbackContext) -> None:
14 | text = generate_message(update.effective_user.username)
15 | await update.message.reply_text(
16 | text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True
17 | )
18 |
19 |
20 | def generate_message(username: str) -> str:
21 | """Generates a help message, including a list of allowed commands."""
22 |
23 | # user commands
24 | commands = "\n".join(f"/{cmd} - {descr}" for cmd, descr in constants.BOT_COMMANDS)
25 |
26 | # admin commands
27 | admin_commands = ""
28 | if username in config.telegram.admins:
29 | admin_commands += "\n\nAdmin-only commads:\n"
30 | admin_commands += f"/config - {constants.ADMIN_COMMANDS['config']}\n"
31 | admin_commands = admin_commands.rstrip()
32 |
33 | # shortcuts
34 | if config.shortcuts:
35 | shortcuts = "\n".join(f"`!{shortcut}`" for shortcut in config.shortcuts)
36 | else:
37 | shortcuts = "none"
38 |
39 | return constants.HELP_MESSAGE.format(
40 | commands=commands, admin_commands=admin_commands, shortcuts=shortcuts
41 | )
42 |
--------------------------------------------------------------------------------
/bot/commands/imagine.py:
--------------------------------------------------------------------------------
1 | """/imagine command."""
2 |
3 | from typing import Awaitable
4 | from telegram import Update
5 | from telegram.ext import CallbackContext
6 | from telegram.constants import ParseMode
7 | from bot.config import config
8 |
9 |
10 | class ImagineCommand:
11 | """Generates an image according to the description."""
12 |
13 | def __init__(self, reply_func: Awaitable) -> None:
14 | self.reply_func = reply_func
15 |
16 | async def __call__(self, update: Update, context: CallbackContext) -> None:
17 | message = update.message or update.edited_message
18 |
19 | if config.imagine.enabled == "none":
20 | await message.reply_text(
21 | "The `imagine` command is disabled.", parse_mode=ParseMode.MARKDOWN
22 | )
23 | return
24 |
25 | if (
26 | config.imagine.enabled == "users_only"
27 | and update.effective_user.username not in config.telegram.usernames
28 | ):
29 | await message.reply_text(
30 | "The `imagine` command is disabled for group users.", parse_mode=ParseMode.MARKDOWN
31 | )
32 | return
33 |
34 | if not context.args:
35 | await message.reply_text(
36 | "Please describe an image. "
37 | "For example:\n/imagine a lazy cat on a sunny day",
38 | parse_mode=ParseMode.HTML,
39 | )
40 | return
41 | await self.reply_func(
42 | update=update, message=update.message, context=context, question=message.text
43 | )
44 |
--------------------------------------------------------------------------------
/bot/commands/model.py:
--------------------------------------------------------------------------------
1 | """/model command."""
2 |
3 | from telegram import Chat, Update
4 | from telegram.ext import CallbackContext
5 | from telegram.constants import ParseMode
6 |
7 | from bot.config import config
8 | from bot.models import ChatData
9 |
10 | HELP_MESSAGE = """Syntax:
11 | /model [AI model name]
12 |
13 | For example:
14 | /model gpt-4o
15 |
16 | To use the default model:
17 | /model reset"""
18 |
19 | RESET = "reset"
20 |
21 |
22 | class ModelCommand:
23 | """Sets an AI model."""
24 |
25 | async def __call__(self, update: Update, context: CallbackContext) -> None:
26 | message = update.message or update.edited_message
27 |
28 | if (
29 | message.chat.type != Chat.PRIVATE
30 | and update.effective_user.username not in config.telegram.admins
31 | ):
32 | # Only admins are allowed to change the model in group chats.
33 | return
34 |
35 | chat = ChatData(context.chat_data)
36 | _, _, model = message.text.partition(" ")
37 | if not model:
38 | # /model without arguments
39 | if chat.model:
40 | # the model is already set, show it
41 | await message.reply_text(
42 | f"Using model:\n{chat.model}",
43 | parse_mode=ParseMode.HTML,
44 | )
45 | return
46 | else:
47 | # the model is not set, show help message
48 | await message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML)
49 | return
50 |
51 | if model == RESET:
52 | # /model with "reset" argument
53 | chat.model = ""
54 | await message.reply_text(
55 | f"✓ Using default model:\n{config.openai.model}",
56 | parse_mode=ParseMode.HTML,
57 | )
58 | return
59 |
60 | # /model with a name
61 | chat.model = model
62 | await message.reply_text(
63 | f"✓ Set model:\n{model}",
64 | parse_mode=ParseMode.HTML,
65 | )
66 |
--------------------------------------------------------------------------------
/bot/filters.py:
--------------------------------------------------------------------------------
1 | """Bot message filters."""
2 |
3 | from typing import Union
4 | from dataclasses import dataclass
5 | from telegram.ext import filters
6 | from bot.config import config
7 |
8 |
9 | @dataclass
10 | class Filters:
11 | """Filters for the incoming Telegram messages."""
12 |
13 | users: Union[filters.MessageFilter, filters.User]
14 | admins: filters.User
15 | chats: Union[filters.MessageFilter, filters.Chat]
16 |
17 | users_or_chats: filters.BaseFilter
18 | admins_private: filters.BaseFilter
19 | messages: filters.BaseFilter
20 |
21 | def __init__(self) -> None:
22 | """Defines users and chats that are allowed to use the bot."""
23 | if config.telegram.usernames:
24 | self.users = filters.User(username=config.telegram.usernames)
25 | self.chats = filters.Chat(chat_id=config.telegram.chat_ids)
26 | else:
27 | self.users = filters.ALL
28 | self.chats = filters.ALL
29 |
30 | if config.telegram.admins:
31 | self.admins = filters.User(username=config.telegram.admins)
32 | else:
33 | self.admins = filters.User(username=[])
34 |
35 | self.users_or_chats = self.users | self.chats
36 | self.admins_private = self.admins & filters.ChatType.PRIVATE
37 | self.messages = (
38 | (filters.TEXT | filters.Document.TEXT) & ~filters.COMMAND & self.users_or_chats
39 | )
40 |
41 | def reload(self) -> None:
42 | """Reloads users and chats from config."""
43 | if self.users == filters.ALL and config.telegram.usernames:
44 | # cannot update the filter from ALL to specific usernames without a restart
45 | raise Exception("Restart the bot for changes to take effect")
46 | self.users.usernames = config.telegram.usernames
47 | self.chats.chat_ids = config.telegram.chat_ids
48 | self.admins.usernames = config.telegram.admins
49 |
50 | def is_known_user(self, username: str) -> bool:
51 | """Checks if the username is included in the `users` filter."""
52 | if self.users == filters.ALL:
53 | return False
54 | return username in self.users.usernames
55 |
--------------------------------------------------------------------------------
/bot/commands/version.py:
--------------------------------------------------------------------------------
1 | """/version command."""
2 |
3 | from urllib import parse
4 |
5 | from telegram import Update
6 | from telegram.ext import CallbackContext
7 | from telegram.constants import ParseMode
8 |
9 | from bot.config import config
10 | from . import constants
11 |
12 |
13 | class VersionCommand:
14 | """Answers the `version` command."""
15 |
16 | async def __call__(self, update: Update, context: CallbackContext) -> None:
17 | chat = update.message.chat
18 | # chat information
19 | text = (
20 | ""
21 | "Chat information:\n"
22 | f"- id: {chat.id}\n"
23 | f"- title: {chat.title}\n"
24 | f"- type: {chat.type}"
25 | ""
26 | )
27 | bot = await context.bot.get_me()
28 | usernames = (
29 | "all" if not config.telegram.usernames else f"{len(config.telegram.usernames)} users"
30 | )
31 | admins = "none" if not config.telegram.admins else f"{len(config.telegram.admins)} users"
32 |
33 | # bot information
34 | text += (
35 | "\n\n"
36 | "Bot information:\n"
37 | f"- id: {bot.id}\n"
38 | f"- name: {bot.name}\n"
39 | f"- version: {config.version}\n"
40 | f"- usernames: {usernames}\n"
41 | f"- admins: {admins}\n"
42 | f"- chat IDs: {config.telegram.chat_ids}\n"
43 | f"- access to messages: {bot.can_read_all_group_messages}"
44 | ""
45 | )
46 | if not bot.can_read_all_group_messages:
47 | text += f"\n\n{constants.PRIVACY_MESSAGE}"
48 |
49 | # AI information
50 | provider = parse.urlparse(config.openai.url).hostname
51 | text += (
52 | "\n\n"
53 | "AI information:\n"
54 | f"- provider: {provider}\n"
55 | f"- model: {config.openai.model}\n"
56 | f"- history depth: {config.conversation.depth}\n"
57 | f"- imagine: {config.imagine.enabled}\n"
58 | f"- shortcuts: {', '.join(config.shortcuts.keys())}"
59 | ""
60 | )
61 | await update.message.reply_text(text, parse_mode=ParseMode.HTML)
62 |
--------------------------------------------------------------------------------
/bot/commands/prompt.py:
--------------------------------------------------------------------------------
1 | """/prompt command."""
2 |
3 | from telegram import Chat, Update
4 | from telegram.ext import CallbackContext
5 | from telegram.constants import ParseMode
6 |
7 | from bot.config import config
8 | from bot.models import ChatData
9 |
10 | HELP_MESSAGE = """Syntax:
11 | /prompt [custom prompt]
12 |
13 | For example:
14 | /prompt You are an evil genius. Reply with an evil laugh.
15 |
16 | To use the default prompt:
17 | /prompt reset"""
18 |
19 | RESET = "reset"
20 |
21 |
22 | class PromptCommand:
23 | """Sets a custom chat prompt."""
24 |
25 | async def __call__(self, update: Update, context: CallbackContext) -> None:
26 | message = update.message or update.edited_message
27 |
28 | if (
29 | message.chat.type != Chat.PRIVATE
30 | and update.effective_user.username not in config.telegram.admins
31 | ):
32 | # Only admins are allowed to change the prompt in group chats.
33 | return
34 |
35 | chat = ChatData(context.chat_data)
36 | _, _, prompt = message.text.partition(" ")
37 | if not prompt:
38 | # /prompt without arguments
39 | if chat.prompt:
40 | # custom prompt is already set, show it
41 | await message.reply_text(
42 | f"Using custom prompt:\n{chat.prompt}",
43 | parse_mode=ParseMode.HTML,
44 | )
45 | return
46 | else:
47 | # custom prompt is not set, show help message
48 | await message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML)
49 | return
50 |
51 | if prompt == RESET:
52 | # /prompt with "reset" argument
53 | chat.prompt = ""
54 | await message.reply_text(
55 | f"✓ Using default prompt:\n{config.openai.prompt}",
56 | parse_mode=ParseMode.HTML,
57 | )
58 | return
59 |
60 | # /prompt with a custom prompt
61 | chat.prompt = prompt
62 | await message.reply_text(
63 | f"✓ Set custom prompt:\n{prompt}",
64 | parse_mode=ParseMode.HTML,
65 | )
66 |
--------------------------------------------------------------------------------
/tests/test_filters.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from telegram.ext import filters as tg_filters
4 | from bot.config import config
5 | from bot.filters import Filters
6 |
7 |
8 | class FiltersTest(unittest.TestCase):
9 | def test_init(self):
10 | config.telegram.usernames = ["alice", "bob"]
11 | config.telegram.chat_ids = [-100, -200]
12 | config.telegram.admins = ["admin"]
13 | filters = Filters()
14 | self.assertEqual(filters.users.usernames, set(["alice", "bob"]))
15 | self.assertEqual(filters.chats.chat_ids, set([-100, -200]))
16 | self.assertEqual(filters.admins.usernames, set(["admin"]))
17 |
18 | def test_reload(self):
19 | config.telegram.usernames = ["alice", "bob"]
20 | config.telegram.chat_ids = [-100, -200]
21 | config.telegram.admins = ["admin"]
22 | filters = Filters()
23 |
24 | config.telegram.usernames = ["alice", "bob", "cindy"]
25 | config.telegram.chat_ids = [-300]
26 | config.telegram.admins = ["zappa", "xanos"]
27 | filters.reload()
28 | self.assertEqual(filters.users.usernames, set(["alice", "bob", "cindy"]))
29 | self.assertEqual(filters.chats.chat_ids, set([-300]))
30 | self.assertEqual(filters.admins.usernames, set(["zappa", "xanos"]))
31 |
32 | def test_is_known_user(self):
33 | config.telegram.usernames = ["alice", "bob"]
34 | filters = Filters()
35 | self.assertTrue(filters.is_known_user("alice"))
36 | self.assertFalse(filters.is_known_user("cindy"))
37 |
38 |
39 | class EmptyTest(unittest.TestCase):
40 | def test_init(self):
41 | config.telegram.usernames = []
42 | config.telegram.chat_ids = [-100, -200]
43 | config.telegram.admins = ["admin"]
44 | filters = Filters()
45 | self.assertEqual(filters.users, tg_filters.ALL)
46 | self.assertEqual(filters.chats, tg_filters.ALL)
47 | self.assertEqual(filters.admins.usernames, set(["admin"]))
48 |
49 | def test_reload(self):
50 | config.telegram.usernames = []
51 | filters = Filters()
52 | config.telegram.usernames = ["alice", "bob"]
53 | with self.assertRaises(Exception):
54 | filters.reload()
55 |
56 | def test_is_known_user(self):
57 | config.telegram.usernames = []
58 | filters = Filters()
59 | self.assertFalse(filters.is_known_user("alice"))
60 | self.assertFalse(filters.is_known_user("cindy"))
61 |
--------------------------------------------------------------------------------
/bot/commands/config.py:
--------------------------------------------------------------------------------
1 | """/config command."""
2 |
3 | from telegram import Update
4 | from telegram.ext import CallbackContext
5 | from telegram.constants import ParseMode
6 |
7 | from bot.config import config, ConfigEditor
8 | from bot.filters import Filters
9 |
10 | HELP_MESSAGE = """Syntax:
11 | /config property [value]
12 |
13 | E.g. to view the property value:
14 | /config openai.prompt
15 |
16 | E.g. to change the property value:
17 | /config openai.prompt You are an AI assistant"""
18 |
19 | editor = ConfigEditor(config)
20 |
21 |
22 | class ConfigCommand:
23 | """Displays or changes config properties."""
24 |
25 | def __init__(self, filters: Filters) -> None:
26 | self.filters = filters
27 |
28 | async def __call__(self, update: Update, context: CallbackContext) -> None:
29 | message = update.message or update.edited_message
30 |
31 | parts = message.text.split()
32 | if len(parts) == 1:
33 | # /config without arguments
34 | await message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML)
35 | return
36 |
37 | property = parts[1]
38 | value = editor.get_value(property)
39 | value = value if value is not None else "(empty)"
40 |
41 | if len(parts) == 2:
42 | # view config property (`/config {property}`)
43 | await message.reply_text(f"`{value}`", parse_mode=ParseMode.MARKDOWN)
44 | return
45 |
46 | # change config property (`/config {property} {new_value}`)
47 | new_value = " ".join(parts[2:])
48 | has_changed, is_immediate, new_val = editor.set_value(property, new_value)
49 |
50 | if not has_changed:
51 | text = f"✗ The `{property}` property already equals to `{new_value}`"
52 | await message.reply_text(text, parse_mode=ParseMode.MARKDOWN)
53 | return
54 |
55 | editor.save()
56 | if self._should_reload_filters(property):
57 | self.filters.reload()
58 |
59 | text = f"✓ Changed the `{property}` property: `{value}` → `{new_val}`"
60 | if not is_immediate:
61 | text += "\n❗️Restart the bot for changes to take effect."
62 | await message.reply_text(text, parse_mode=ParseMode.MARKDOWN)
63 |
64 | def _should_reload_filters(self, property: str) -> bool:
65 | return property in (
66 | "telegram.usernames",
67 | "telegram.chat_ids",
68 | "telegram.admins",
69 | )
70 |
--------------------------------------------------------------------------------
/tests/test_markdown.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from bot import markdown
3 |
4 | TEXT_MD = r"""You can easily regular expressions them using the `sqlean-regexp` extension.
5 |
6 | > **Note**. Unlike other DBMS, adding extensions to SQLite is a breeze.
7 |
8 | With `sqlean-regexp`, matching a string against a pattern becomes as easy as:
9 |
10 | ```sql
11 | select count(*) from messages
12 | where msg_text regexp '\d+';
13 | ```
14 |
15 | `regexp_like(source, pattern)` checks if the source string matches the pattern:
16 |
17 | ```sql
18 | select regexp_like('Meet me at 10:30', '\d+:\d+');
19 | select 10 > 5 = true;
20 | ```
21 |
22 | See [Documentation](https://github.com/nalgeon/sqlean) for reference.
23 | """
24 |
25 | TEXT_HTML = r"""You can easily regular expressions them using the sqlean-regexp extension.
26 |
27 | > Note. Unlike other DBMS, adding extensions to SQLite is a breeze.
28 |
29 | With sqlean-regexp, matching a string against a pattern becomes as easy as:
30 |
31 | 32 | select count(*) from messages 33 | where msg_text regexp '\d+'; 34 |35 | 36 |
regexp_like(source, pattern) checks if the source string matches the pattern:
37 |
38 |
39 | select regexp_like('Meet me at 10:30', '\d+:\d+');
40 | select 10 > 5 = true;
41 |
42 |
43 | See [Documentation](https://github.com/nalgeon/sqlean) for reference.
44 | """
45 |
46 |
47 | class Test(unittest.TestCase):
48 | def test_to_html(self):
49 | text = markdown.to_html(TEXT_MD)
50 | self.assertEqual(text, TEXT_HTML)
51 |
52 | def test_ticks(self):
53 | text = markdown.to_html("one `two` three")
54 | self.assertEqual(text, "one two three")
55 | text = markdown.to_html("one `two three")
56 | self.assertEqual(text, "one `two three")
57 | text = markdown.to_html("one `two\n` three")
58 | self.assertEqual(text, "one `two\n` three")
59 |
60 | def test_bold(self):
61 | text = markdown.to_html("one **two** three")
62 | self.assertEqual(text, "one two three")
63 | text = markdown.to_html("one *two* three")
64 | self.assertEqual(text, "one *two* three")
65 | text = markdown.to_html("one **two three")
66 | self.assertEqual(text, "one **two three")
67 | text = markdown.to_html("operator `**` is so ** powerful")
68 | self.assertEqual(text, "operator ** is so ** powerful")
69 |
70 | def test_bullet(self):
71 | text = markdown.to_html("* one two three")
72 | self.assertEqual(text, "— one two three")
73 | text = markdown.to_html("* one two three")
74 | self.assertEqual(text, "* one two three")
75 | text = markdown.to_html(" * one two three")
76 | self.assertEqual(text, " * one two three")
77 |
--------------------------------------------------------------------------------
/tests/mocks.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from telegram import User
3 | from bot import askers
4 |
5 |
6 | class FakeGPT:
7 | def __init__(self, error: Optional[Exception] = None):
8 | self.error = error
9 | self.prompt = None
10 | self.question = None
11 | self.history = None
12 |
13 | async def ask(self, prompt: str, question: str, history: list) -> str:
14 | self.prompt = prompt
15 | self.question = question
16 | self.history = history
17 | if self.error:
18 | raise self.error
19 | return question
20 |
21 |
22 | class FakeDalle:
23 | def __init__(self, error: Optional[Exception] = None):
24 | self.error = error
25 | self.prompt = None
26 | self.size = None
27 |
28 | async def imagine(self, prompt: str, size: str) -> str:
29 | self.prompt = prompt
30 | self.size = size
31 | if self.error:
32 | raise self.error
33 | return "image"
34 |
35 |
36 | class FakeFile:
37 | def __init__(self, file_id: str) -> None:
38 | self.file_id = file_id
39 |
40 | async def download_as_bytearray(self, buf=None, **kwargs) -> bytearray:
41 | return bytearray(b"file content")
42 |
43 |
44 | class FakeBot:
45 | def __init__(self, username: str) -> None:
46 | self.user = User(
47 | id=42,
48 | first_name=username,
49 | is_bot=True,
50 | username=username,
51 | can_read_all_group_messages=True,
52 | )
53 | self.text = ""
54 |
55 | @property
56 | def username(self) -> str:
57 | return self.user.username
58 |
59 | @property
60 | def name(self) -> str:
61 | return f"@{self.username}"
62 |
63 | @property
64 | def can_read_all_group_messages(self) -> bool:
65 | return self.user.can_read_all_group_messages
66 |
67 | async def get_file(self, file_id, **kwargs):
68 | return FakeFile(file_id)
69 |
70 | async def send_chat_action(self, **kwargs) -> None:
71 | pass
72 |
73 | async def send_message(self, chat_id: int, text: str, **kwargs) -> None:
74 | self.text = text
75 |
76 | async def send_document(
77 | self, chat_id: int, document: object, caption: str, filename: str, **kwargs
78 | ) -> None:
79 | self.text = f"{caption}: {filename}"
80 |
81 | async def send_photo(self, chat_id: int, photo: str, caption: str = None, **kwargs) -> None:
82 | self.text = f"{caption}: {photo}"
83 |
84 | async def get_me(self, **kwargs) -> User:
85 | return self.user
86 |
87 |
88 | class FakeApplication:
89 | def __init__(self, bot: FakeBot) -> None:
90 | self.chat_data = {1: {}}
91 | self.user_data = {1: {}}
92 | self.bot = bot
93 |
94 |
95 | def mock_text_asker(ai: FakeGPT) -> None:
96 | mock_init = lambda asker, _: setattr(asker, "model", ai)
97 | askers.TextAsker.__init__ = mock_init
98 |
--------------------------------------------------------------------------------
/bot/fetcher.py:
--------------------------------------------------------------------------------
1 | """Retrieves remote content over HTTP."""
2 |
3 | import re
4 | import httpx
5 | from bs4 import BeautifulSoup
6 |
7 |
8 | class Fetcher:
9 | """Retrieves remote content over HTTP."""
10 |
11 | # Matches non-quoted URLs in text
12 | url_re = re.compile(r"(?:[^'\"]|^)\b(https?://\S+)\b(?:[^'\"]|$)")
13 | timeout = 3 # seconds
14 |
15 | def __init__(self):
16 | self.client = httpx.AsyncClient(follow_redirects=True, timeout=self.timeout)
17 |
18 | async def substitute_urls(self, text: str) -> str:
19 | """
20 | Extracts URLs from text, fetches their contents,
21 | and appends the contents to the text.
22 | """
23 | urls = self._extract_urls(text)
24 | for url in urls:
25 | content = await self._fetch_url(url)
26 | text += f"\n\n---\n{url} contents:\n\n{content}\n---"
27 | return text
28 |
29 | async def close(self) -> None:
30 | """Frees network connections."""
31 | await self.client.aclose()
32 |
33 | def _extract_urls(self, text: str) -> list[str]:
34 | """Extracts URLs from text."""
35 | urls = self.url_re.findall(text)
36 | return urls
37 |
38 | async def _fetch_url(self, url: str) -> str:
39 | """Retrieves URL content and returns it as text."""
40 | try:
41 | response = await self.client.get(url)
42 | response.raise_for_status()
43 | content = Content(response)
44 | return content.extract_text()
45 | except Exception as exc:
46 | class_name = f"{exc.__class__.__module__}.{exc.__class__.__qualname__}"
47 | return f"Failed to fetch ({class_name})"
48 |
49 |
50 | class Content:
51 | """Extracts resource content as human-readable text."""
52 |
53 | allowed_content_types = set(
54 | [
55 | "application/json",
56 | "application/sql",
57 | "application/xml",
58 | ]
59 | )
60 |
61 | def __init__(self, response: httpx.Response) -> None:
62 | self.response = response
63 | content_type, _, _ = response.headers.get("content-type").partition(";")
64 | self.content_type = content_type
65 |
66 | def extract_text(self) -> str:
67 | """Extracts resource content as human-readable text."""
68 | if not self.is_text():
69 | return "Unknown binary content"
70 | if self.content_type != "text/html":
71 | return self.response.text
72 | html = BeautifulSoup(self.response.text, "html.parser")
73 | article = html.find("main") or html.find("body")
74 | return article.get_text()
75 |
76 | def is_text(self) -> bool:
77 | """Checks if the content type is plain text."""
78 | if not self.content_type:
79 | return False
80 | if self.content_type.startswith("text/"):
81 | return True
82 | if self.content_type in self.allowed_content_types:
83 | return True
84 | return False
85 |
--------------------------------------------------------------------------------
/config.example.yml:
--------------------------------------------------------------------------------
1 | # Bot configuration parameters.
2 |
3 | # Config schema version. Increments for backward-incompatible changes.
4 | schema_version: 3
5 |
6 | # Telegram settings.
7 | telegram:
8 | # Telegram Bot API token.
9 | token: ""
10 |
11 | # The list of Telegram usernames allowed to chat with the bot.
12 | # If empty, the bot will be available to anyone.
13 | usernames: []
14 |
15 | # The list of Telegram usernames allowed to change bot settings.
16 | # If empty, then no one is allowed to change the settings.
17 | admins: []
18 |
19 | # The list of Telegram group ids, whose members are allowed to chat with the bot.
20 | # If empty, the bot will only be available to `telegram.usernames`.
21 | chat_ids: []
22 |
23 | # AI settings.
24 | openai:
25 | # AI API endpoint. You can use any OpenAI-compatible provider like
26 | # OpenRouter (https://openrouter.ai/api/v1) or Nebius (https://api.studio.nebius.ai/v1).
27 | url: "https://api.openai.com/v1"
28 |
29 | # AI API key.
30 | api_key: ""
31 |
32 | # Chat model name.
33 | # See https://platform.openai.com/docs/models for description.
34 | model: "gpt-4o-mini"
35 |
36 | # Image generation model name.
37 | image_model: "dall-e-3"
38 |
39 | # Context window size in tokens.
40 | # Applies only to custom models.
41 | window: 128000
42 |
43 | # Model prompt.
44 | prompt: "You are an AI assistant."
45 |
46 | # Model parameters.
47 | # See https://platform.openai.com/docs/api-reference/chat/create for description.
48 | params:
49 | temperature: 0.7
50 | max_tokens: 4096
51 |
52 | conversation:
53 | # The maximum number of previous messages
54 | # the bot will remember when talking to a user.
55 | depth: 3
56 |
57 | # The maximum number of messages a user can send during the specified time period.
58 | # Only applies to users who are not listed in `telegram.usernames`.
59 | # `count` = an integer number of messages (0 = unlimited)
60 | # `period` = minute | hour | day
61 | message_limit:
62 | count: 0
63 | period: hour
64 |
65 | # Image generation settings.
66 | imagine:
67 | # Enable/disable image generation:
68 | # - none = disabled for everyone
69 | # - users_only = enabled only for users listed in `telegram.usernames`
70 | # - users_and_groups = enabled for both users listed in `telegram.usernames`
71 | # and members of `telegrams.chat_ids`
72 | enabled: none
73 |
74 | # Where to store the chat context file.
75 | persistence_path: "./data/persistence.pkl"
76 |
77 | # Custom AI commands (additional prompts)
78 | shortcuts:
79 | bugfix: "Examine the following code. Rewrite it if necessary to fix bugs and various problems. Explain the changes you've made."
80 |
81 | proofread: "Proofread the following text. Correct grammar and punctuation errors. Rephrase if necessary. Make sure the resulting text is clear, concise, and easy to read. Explain the changes you've made."
82 |
83 | summarize: "Explain the following text in simple terms. Use no more than two paragraphs."
84 |
85 | translate: "Translate the following text into English."
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # Custom
156 | config.yml
157 | *.pkl
--------------------------------------------------------------------------------
/tests/test_ai.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from bot.config import config
3 | from bot.ai import chat
4 | from bot.models import UserMessage
5 |
6 |
7 | class ModelTest(unittest.TestCase):
8 | def setUp(self) -> None:
9 | self.model = chat.Model("gpt")
10 |
11 | def test_generate_messages(self):
12 | history = [UserMessage("Hello", "Hi"), UserMessage("Is it cold today?", "Yep!")]
13 | messages = self.model._generate_messages(
14 | prompt_role="system", prompt="", question="What's your name?", history=history
15 | )
16 | self.assertEqual(len(messages), 6)
17 |
18 | self.assertEqual(messages[0]["role"], "system")
19 | self.assertEqual(messages[0]["content"], config.openai.prompt)
20 |
21 | self.assertEqual(messages[1]["role"], "user")
22 | self.assertEqual(messages[1]["content"], "Hello")
23 |
24 | self.assertEqual(messages[2]["role"], "assistant")
25 | self.assertEqual(messages[2]["content"], "Hi")
26 |
27 | self.assertEqual(messages[3]["role"], "user")
28 | self.assertEqual(messages[3]["content"], "Is it cold today?")
29 |
30 | self.assertEqual(messages[4]["role"], "assistant")
31 | self.assertEqual(messages[4]["content"], "Yep!")
32 |
33 | self.assertEqual(messages[5]["role"], "user")
34 | self.assertEqual(messages[5]["content"], "What's your name?")
35 |
36 |
37 | class ShortenTest(unittest.TestCase):
38 | def test_do_not_shorten(self):
39 | messages = [
40 | {"role": "system", "content": "You are an AI assistant."},
41 | {"role": "user", "content": "Hello"},
42 | ]
43 | shortened = chat.shorten(messages, length=11)
44 | self.assertEqual(shortened, messages)
45 |
46 | def test_remove_messages_1(self):
47 | messages = [
48 | {"role": "system", "content": "You are an AI assistant."},
49 | {"role": "user", "content": "What is your name?"},
50 | {"role": "assistant", "content": "My name is Alice."},
51 | {"role": "user", "content": "Is it cold today?"},
52 | ]
53 | shortened = chat.shorten(messages, length=11)
54 | self.assertEqual(
55 | shortened,
56 | [
57 | {"role": "system", "content": "You are an AI assistant."},
58 | {"role": "user", "content": "Is it cold today?"},
59 | ],
60 | )
61 |
62 | def test_remove_messages_2(self):
63 | messages = [
64 | {"role": "system", "content": "You are an AI assistant."},
65 | {"role": "user", "content": "What is your name?"},
66 | {"role": "assistant", "content": "My name is Alice."},
67 | {"role": "user", "content": "Is it cold today?"},
68 | ]
69 | shortened = chat.shorten(messages, length=15)
70 | self.assertEqual(
71 | shortened,
72 | [
73 | {"role": "system", "content": "You are an AI assistant."},
74 | {"role": "assistant", "content": "My name is Alice."},
75 | {"role": "user", "content": "Is it cold today?"},
76 | ],
77 | )
78 |
79 | def test_shorten_question(self):
80 | messages = [
81 | {"role": "system", "content": "You are an AI assistant."},
82 | {"role": "user", "content": "Is it cold today? I think it's rather cold"},
83 | ]
84 | shortened = chat.shorten(messages, length=10)
85 | self.assertEqual(
86 | shortened,
87 | [
88 | {"role": "system", "content": "You are an AI assistant."},
89 | {"role": "user", "content": "Is it cold today?"},
90 | ],
91 | )
92 |
--------------------------------------------------------------------------------
/bot/askers.py:
--------------------------------------------------------------------------------
1 | """
2 | Asker is an abstraction that sends questions to the AI
3 | and responds to the user with answers provided by the AI.
4 | """
5 |
6 | import io
7 | import re
8 | import textwrap
9 |
10 | from telegram import Chat, Message
11 | from telegram.constants import MessageLimit, ParseMode
12 | from telegram.ext import CallbackContext
13 |
14 | from bot import ai
15 | from bot import markdown
16 |
17 |
18 | class Asker:
19 | """Asks AI questions and responds with answers."""
20 |
21 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str:
22 | """Asks AI a question."""
23 | pass
24 |
25 | async def reply(self, message: Message, context: CallbackContext, answer: str) -> None:
26 | """Replies with an answer from AI."""
27 | pass
28 |
29 |
30 | class TextAsker(Asker):
31 | """Works with chat completion AI."""
32 |
33 | def __init__(self, model_name: str) -> None:
34 | self.model = ai.chat.Model(model_name)
35 |
36 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str:
37 | """Asks AI a question."""
38 | return await self.model.ask(prompt, question, history)
39 |
40 | async def reply(self, message: Message, context: CallbackContext, answer: str) -> None:
41 | """Replies with an answer from AI."""
42 | html_answer = markdown.to_html(answer)
43 | if len(html_answer) <= MessageLimit.MAX_TEXT_LENGTH:
44 | await message.reply_text(html_answer, parse_mode=ParseMode.HTML)
45 | return
46 |
47 | doc = io.StringIO(answer)
48 | caption = (
49 | textwrap.shorten(answer, width=255, placeholder="...")
50 | + " (see attachment for the rest)"
51 | )
52 | reply_to_message_id = message.id if message.chat.type != Chat.PRIVATE else None
53 | await context.bot.send_document(
54 | chat_id=message.chat_id,
55 | caption=caption,
56 | filename=f"{message.id}.md",
57 | document=doc,
58 | reply_to_message_id=reply_to_message_id,
59 | )
60 |
61 |
62 | class ImagineAsker(Asker):
63 | """Works with image generation AI."""
64 |
65 | model = ai.images.Model()
66 | size_re = re.compile(r"(256|512|1024)(?:x\1)?\s?(?:px)?")
67 | sizes = {
68 | "256": "256x256",
69 | "512": "512x512",
70 | "1024": "1024x1024",
71 | "1792": "1792x1024",
72 | }
73 | default_size = "1024x1024"
74 |
75 | def __init__(self) -> None:
76 | self.caption = ""
77 |
78 | async def ask(self, prompt: str, question: str, history: list[tuple[str, str]]) -> str:
79 | """Asks AI a question."""
80 | size = self._extract_size(question)
81 | self.caption = self._extract_caption(question)
82 | return await self.model.imagine(prompt=self.caption, size=size)
83 |
84 | async def reply(self, message: Message, context: CallbackContext, answer: str) -> None:
85 | """Replies with an answer from AI."""
86 | await message.reply_photo(answer, caption=self.caption)
87 |
88 | def _extract_size(self, question: str) -> str:
89 | match = self.size_re.search(question)
90 | if not match:
91 | return self.default_size
92 | width = match.group(1)
93 | return self.sizes.get(width, width)
94 |
95 | def _extract_caption(self, question: str) -> str:
96 | caption = self.size_re.sub("", question).strip()
97 | return caption
98 |
99 |
100 | def create(model: str, question: str) -> Asker:
101 | """Creates a new asker based on the question asked."""
102 | if question.startswith("/imagine"):
103 | return ImagineAsker()
104 | return TextAsker(model)
105 |
--------------------------------------------------------------------------------
/tests/test_fetcher.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from httpx import Request, Response
3 |
4 | from bot.fetcher import Fetcher, Content
5 |
6 |
7 | class FakeClient:
8 | def __init__(self, responses: dict[str, Response | Exception]) -> None:
9 | self.responses = responses
10 |
11 | async def get(self, url: str) -> Response:
12 | request = Request(method="GET", url=url)
13 | response = self.responses[url]
14 | if isinstance(response, Exception):
15 | raise response
16 | return Response(
17 | status_code=response.status_code,
18 | headers=response.headers,
19 | text=response.text,
20 | request=request,
21 | )
22 |
23 |
24 | class FetcherTest(unittest.IsolatedAsyncioTestCase):
25 | def setUp(self) -> None:
26 | self.fetcher = Fetcher()
27 |
28 | async def test_substitute_urls(self):
29 | resp_1 = Response(status_code=200, headers={"content-type": "text/plain"}, text="first")
30 | resp_2 = Response(status_code=200, headers={"content-type": "text/plain"}, text="second")
31 | self.fetcher.client = FakeClient(
32 | {
33 | "https://example.org/first": resp_1,
34 | "https://example.org/second": resp_2,
35 | }
36 | )
37 | text = "Compare https://example.org/first and https://example.org/second"
38 | text = await self.fetcher.substitute_urls(text)
39 | self.assertEqual(
40 | text,
41 | """Compare https://example.org/first and https://example.org/second
42 |
43 | ---
44 | https://example.org/first contents:
45 |
46 | first
47 | ---
48 |
49 | ---
50 | https://example.org/second contents:
51 |
52 | second
53 | ---""",
54 | )
55 |
56 | async def test_fetch_url(self):
57 | resp = Response(status_code=200, headers={"content-type": "text/plain"}, text="hello")
58 | exc = ConnectionError("timeout")
59 | self.fetcher.client = FakeClient({"https://success.org": resp, "https://failure.org": exc})
60 | text = await self.fetcher._fetch_url("https://success.org")
61 | self.assertEqual(text, "hello")
62 | text = await self.fetcher._fetch_url("https://failure.org")
63 | self.assertEqual(text, "Failed to fetch (builtins.ConnectionError)")
64 |
65 | async def test_ignore_quoted(self):
66 | src = "What is 'https://example.org/first'?"
67 | text = await self.fetcher.substitute_urls(src)
68 | self.assertEqual(text, src)
69 |
70 | async def test_nothing_to_substitute(self):
71 | src = "How are you?"
72 | text = await self.fetcher.substitute_urls(src)
73 | self.assertEqual(text, src)
74 |
75 | def test_extract_urls(self):
76 | text = "Compare https://example.org/first and https://example.org/second"
77 | urls = self.fetcher._extract_urls(text)
78 | self.assertEqual(urls, ["https://example.org/first", "https://example.org/second"])
79 |
80 | text = "Extract https://example.org/first."
81 | urls = self.fetcher._extract_urls(text)
82 | self.assertEqual(urls, ["https://example.org/first"])
83 |
84 | text = 'Extract "https://example.org/first"'
85 | urls = self.fetcher._extract_urls(text)
86 | self.assertEqual(urls, [])
87 |
88 |
89 | class ContentTest(unittest.TestCase):
90 | def test_extract_as_is(self):
91 | resp = Response(
92 | status_code=200, headers={"content-type": "application/sql"}, text="select 42;"
93 | )
94 | content = Content(resp)
95 | text = content.extract_text()
96 | self.assertEqual(text, "select 42;")
97 |
98 | def test_extract_html(self):
99 | html = "Chat information:"))
110 | self.assertTrue("- name: @bot" in self.bot.text)
111 | self.assertTrue("- version: 101" in self.bot.text)
112 | self.assertTrue("- usernames: 2 users" in self.bot.text)
113 | self.assertTrue("- admins: 1 users" in self.bot.text)
114 | self.assertTrue("- chat IDs: [-100500]" in self.bot.text)
115 | self.assertTrue("- access to messages: True" in self.bot.text)
116 |
117 | self.assertTrue("AI information:" in self.bot.text)
118 | self.assertTrue("- provider: api.openai.com" in self.bot.text)
119 | self.assertTrue("- model: gpt-4" in self.bot.text)
120 | self.assertTrue("- history depth: 10" in self.bot.text)
121 | self.assertTrue("- imagine: none" in self.bot.text)
122 | self.assertTrue("- shortcuts: translate_en, translate_fr" in self.bot.text)
123 |
124 |
125 | class ConfigTest(unittest.IsolatedAsyncioTestCase, Helper):
126 | def setUp(self):
127 | self.bot = FakeBot("bot")
128 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
129 | self.chat.set_bot(self.bot)
130 | self.application = FakeApplication(self.bot)
131 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
132 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
133 | config.telegram.usernames = ["alice"]
134 | config.telegram.admins = ["alice"]
135 | self.filters = Filters()
136 | self.command = commands.Config(self.filters)
137 |
138 | async def test_help(self):
139 | update = self._create_update(11, "/config")
140 | await self.command(update, self.context)
141 | self.assertTrue(self.bot.text.startswith("Syntax:"))
142 |
143 | async def test_view(self):
144 | config.openai.model = "gpt-3.5-turbo"
145 | update = self._create_update(11, "/config openai.model")
146 | await self.command(update, self.context)
147 | self.assertEqual(self.bot.text, "`gpt-3.5-turbo`")
148 |
149 | async def test_change(self):
150 | commands.config.editor.save = lambda: None
151 | config.openai.model = "gpt-3.5-turbo"
152 | update = self._create_update(11, "/config openai.model gpt-4")
153 | await self.command(update, self.context)
154 | self.assertTrue(self.bot.text.startswith("✓ Changed the `openai.model` property"))
155 |
156 | async def test_conversation_depth(self):
157 | commands.config.editor.save = lambda: None
158 | config.conversation.depth = 3
159 | user = models.UserData(self.context.user_data)
160 | assert user.messages.messages.maxlen == 3
161 | update = self._create_update(11, "/config conversation.depth 5")
162 | await self.command(update, self.context)
163 | user = models.UserData(self.context.user_data)
164 | self.assertEqual(user.messages.messages.maxlen, 5)
165 |
166 | async def test_not_changed(self):
167 | commands.config.editor.save = lambda: None
168 | config.openai.model = "gpt-3.5-turbo"
169 | update = self._create_update(11, "/config openai.model gpt-3.5-turbo")
170 | await self.command(update, self.context)
171 | self.assertEqual(
172 | self.bot.text,
173 | "✗ The `openai.model` property already equals to `gpt-3.5-turbo`",
174 | )
175 |
176 | async def test_delayed(self):
177 | commands.config.editor.save = lambda: None
178 | config.persistence_path = "./data/persistence.pkl"
179 | update = self._create_update(11, "/config persistence_path /tmp/data.pkl")
180 | await self.command(update, self.context)
181 | self.assertTrue("Restart the bot" in self.bot.text)
182 |
183 | async def test_telegram_usernames(self):
184 | update = self._create_update(11, '/config telegram.usernames ["alice", "bob"]')
185 | await self.command(update, self.context)
186 | self.assertEqual(self.filters.users.usernames, frozenset(["alice", "bob"]))
187 |
188 | async def test_telegram_admins(self):
189 | update = self._create_update(11, '/config telegram.admins ["alice", "bob"]')
190 | await self.command(update, self.context)
191 | self.assertEqual(self.filters.admins.usernames, frozenset(["alice", "bob"]))
192 |
193 | async def test_telegram_chat_ids(self):
194 | update = self._create_update(11, "/config telegram.chat_ids [-100500]")
195 | await self.command(update, self.context)
196 | self.assertEqual(self.filters.chats.chat_ids, frozenset([-100500]))
197 |
198 |
199 | class ModelPrivateTest(unittest.IsolatedAsyncioTestCase, Helper):
200 | def setUp(self):
201 | self.bot = FakeBot("bot")
202 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
203 | self.chat.set_bot(self.bot)
204 | self.application = FakeApplication(self.bot)
205 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
206 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
207 | config.telegram.usernames = ["alice"]
208 | self.command = commands.Model()
209 |
210 | async def test_help(self):
211 | update = self._create_update(11, "/model")
212 | await self.command(update, self.context)
213 | self.assertTrue(self.bot.text.startswith("Syntax:"))
214 | self.assertEqual(self.application.chat_data[1], {})
215 |
216 | async def test_set(self):
217 | update = self._create_update(11, text="/model gpt-5")
218 | await self.command(update, self.context)
219 | self.assertTrue(self.bot.text.startswith("✓ Set model"))
220 | self.assertEqual(self.application.chat_data[1]["model"], "gpt-5")
221 |
222 | async def test_show(self):
223 | update = self._create_update(11, text="/model gpt-5")
224 | await self.command(update, self.context)
225 |
226 | update = self._create_update(11, text="/model")
227 | await self.command(update, self.context)
228 | self.assertTrue(self.bot.text.startswith("Using model"))
229 | self.assertEqual(self.application.chat_data[1]["model"], "gpt-5")
230 |
231 | async def test_reset(self):
232 | update = self._create_update(11, text="/model gpt-5")
233 | await self.command(update, self.context)
234 |
235 | update = self._create_update(11, text="/model reset")
236 | await self.command(update, self.context)
237 | self.assertTrue(self.bot.text.startswith("✓ Using default model"))
238 | self.assertEqual(self.application.chat_data[1]["model"], "")
239 |
240 |
241 | class ModelGroupTest(unittest.IsolatedAsyncioTestCase, Helper):
242 | def setUp(self):
243 | self.bot = FakeBot("bot")
244 | self.chat = Chat(id=1, type=ChatType.GROUP)
245 | self.chat.set_bot(self.bot)
246 | self.application = FakeApplication(self.bot)
247 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
248 | self.alice = User(id=1, first_name="Alice", is_bot=False, username="alice")
249 | self.bob = User(id=2, first_name="Bob", is_bot=False, username="bob")
250 | config.telegram.admins = ["alice"]
251 | self.command = commands.Model()
252 |
253 | async def test_allowed(self):
254 | update = self._create_update(11, "/model@bot gpt-5", user=self.alice)
255 | await self.command(update, self.context)
256 | self.assertTrue(self.bot.text.startswith("✓ Set model"))
257 | self.assertEqual(self.application.chat_data[1]["model"], "gpt-5")
258 |
259 | async def test_not_allowed(self):
260 | update = self._create_update(11, "/model@bot gpt-5", user=self.bob)
261 | await self.command(update, self.context)
262 | self.assertEqual(self.bot.text, "")
263 | self.assertEqual(self.application.chat_data[1], {})
264 |
265 |
266 | class PromptPrivateTest(unittest.IsolatedAsyncioTestCase, Helper):
267 | def setUp(self):
268 | self.bot = FakeBot("bot")
269 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
270 | self.chat.set_bot(self.bot)
271 | self.application = FakeApplication(self.bot)
272 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
273 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
274 | config.telegram.usernames = ["alice"]
275 | self.command = commands.Prompt()
276 |
277 | async def test_not_set(self):
278 | update = self._create_update(11, "/prompt")
279 | await self.command(update, self.context)
280 | self.assertTrue(self.bot.text.startswith("Syntax:"))
281 | self.assertEqual(self.application.chat_data[1], {})
282 |
283 | async def test_set_custom(self):
284 | update = self._create_update(11, text="/prompt Laugh")
285 | await self.command(update, self.context)
286 | self.assertTrue(self.bot.text.startswith("✓ Set custom prompt"))
287 | self.assertEqual(self.application.chat_data[1]["prompt"], "Laugh")
288 |
289 | async def test_show_custom(self):
290 | update = self._create_update(11, text="/prompt Laugh")
291 | await self.command(update, self.context)
292 |
293 | update = self._create_update(11, text="/prompt")
294 | await self.command(update, self.context)
295 | self.assertTrue(self.bot.text.startswith("Using custom prompt"))
296 | self.assertEqual(self.application.chat_data[1]["prompt"], "Laugh")
297 |
298 | async def test_reset(self):
299 | update = self._create_update(11, text="/prompt Laugh")
300 | await self.command(update, self.context)
301 |
302 | update = self._create_update(11, text="/prompt reset")
303 | await self.command(update, self.context)
304 | self.assertTrue(self.bot.text.startswith("✓ Using default prompt"))
305 | self.assertEqual(self.application.chat_data[1]["prompt"], "")
306 |
307 |
308 | class PromptGroupTest(unittest.IsolatedAsyncioTestCase, Helper):
309 | def setUp(self):
310 | self.bot = FakeBot("bot")
311 | self.chat = Chat(id=1, type=ChatType.GROUP)
312 | self.chat.set_bot(self.bot)
313 | self.application = FakeApplication(self.bot)
314 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
315 | self.alice = User(id=1, first_name="Alice", is_bot=False, username="alice")
316 | self.bob = User(id=2, first_name="Bob", is_bot=False, username="bob")
317 | config.telegram.admins = ["alice"]
318 | self.command = commands.Prompt()
319 |
320 | async def test_allowed(self):
321 | update = self._create_update(11, "/prompt@bot Laugh", user=self.alice)
322 | await self.command(update, self.context)
323 | self.assertTrue(self.bot.text.startswith("✓ Set custom prompt"))
324 | self.assertEqual(self.application.chat_data[1]["prompt"], "Laugh")
325 |
326 | async def test_not_allowed(self):
327 | update = self._create_update(11, "/prompt@bot Laugh", user=self.bob)
328 | await self.command(update, self.context)
329 | self.assertEqual(self.bot.text, "")
330 | self.assertEqual(self.application.chat_data[1], {})
331 |
332 |
333 | class RetryTest(unittest.IsolatedAsyncioTestCase, Helper):
334 | def setUp(self):
335 | mock_text_asker(FakeGPT())
336 | self.bot = FakeBot("bot")
337 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
338 | self.chat.set_bot(self.bot)
339 | self.application = FakeApplication(self.bot)
340 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
341 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
342 | self.command = commands.Retry(bot.reply_to)
343 | config.telegram.usernames = ["alice"]
344 |
345 | async def test_retry(self):
346 | user_data = models.UserData(self.context.user_data)
347 | user_data.messages.add("What is your name?", "My name is AI.")
348 | update = self._create_update(11)
349 | await self.command(update, self.context)
350 | self.assertEqual(self.bot.text, "What is your name?")
351 |
352 |
353 | class ImagineTest(unittest.IsolatedAsyncioTestCase, Helper):
354 | def setUp(self):
355 | askers.ImagineAsker.model = FakeDalle()
356 | self.bot = FakeBot("bot")
357 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
358 | self.chat.set_bot(self.bot)
359 | self.application = FakeApplication(self.bot)
360 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
361 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
362 | self.command = commands.Imagine(bot.reply_to)
363 | config.telegram.usernames = ["alice"]
364 |
365 | async def test_imagine(self):
366 | config.imagine.enabled = "users_only"
367 | update = self._create_update(11, "/imagine a cat")
368 | self.context.args = ["a", "cat"]
369 | await self.command(update, self.context)
370 | self.assertEqual(self.bot.text, "a cat: image")
371 |
372 | async def test_disabled(self):
373 | config.imagine.enabled = "none"
374 | update = self._create_update(11, "/imagine a cat")
375 | self.context.args = ["a", "cat"]
376 | await self.command(update, self.context)
377 | self.assertTrue("command is disabled" in self.bot.text)
378 |
379 | async def test_users_only(self):
380 | config.imagine.enabled = "users_only"
381 | user = User(id=2, first_name="Bob", is_bot=False, username="bob")
382 | update = self._create_update(11, "/imagine a cat", user=user)
383 | self.context.args = ["a", "cat"]
384 | await self.command(update, self.context)
385 | self.assertTrue("command is disabled" in self.bot.text)
386 |
387 | async def test_users_and_groups(self):
388 | config.imagine.enabled = "users_and_groups"
389 | user = User(id=2, first_name="Bob", is_bot=False, username="bob")
390 | update = self._create_update(11, "/imagine a cat", user=user)
391 | self.context.args = ["a", "cat"]
392 | await self.command(update, self.context)
393 | self.assertEqual(self.bot.text, "a cat: image")
394 |
395 |
396 | class MessageTest(unittest.IsolatedAsyncioTestCase, Helper):
397 | def setUp(self):
398 | self.ai = FakeGPT()
399 | mock_text_asker(self.ai)
400 | self.bot = FakeBot("bot")
401 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
402 | self.chat.set_bot(self.bot)
403 | self.application = FakeApplication(self.bot)
404 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
405 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
406 | self.command = commands.Message(bot.reply_to)
407 | config.telegram.usernames = ["alice"]
408 |
409 | async def test_message(self):
410 | update = self._create_update(11, text="What is your name?")
411 | await self.command(update, self.context)
412 | self.assertEqual(self.bot.text, "What is your name?")
413 | self.assertEqual(self.ai.question, "What is your name?")
414 | self.assertEqual(self.ai.history, [])
415 |
416 | async def test_follow_up(self):
417 | update = self._create_update(11, text="What is your name?")
418 | await self.command(update, self.context)
419 | self.assertEqual(self.ai.question, "What is your name?")
420 | self.assertEqual(self.ai.history, [])
421 |
422 | update = self._create_update(12, text="+ And why is that?")
423 | await self.command(update, self.context)
424 | self.assertEqual(self.ai.question, "And why is that?")
425 | self.assertEqual(self.ai.history, [("What is your name?", "What is your name?")])
426 |
427 | update = self._create_update(13, text="+ Where are you?")
428 | await self.command(update, self.context)
429 | self.assertEqual(self.ai.question, "Where are you?")
430 | self.assertEqual(
431 | self.ai.history,
432 | [
433 | ("What is your name?", "What is your name?"),
434 | ("+ And why is that?", "And why is that?"),
435 | ],
436 | )
437 |
438 | async def test_forward(self):
439 | update = self._create_update(11, text="What is your name?", forward_date=dt.datetime.now())
440 | await self.command(update, self.context)
441 | self.assertTrue(self.bot.text.startswith("This is a forwarded message"))
442 |
443 | async def test_document(self):
444 | update = self._create_update(11, text="I have so much to say" + "." * 5000)
445 | await self.command(update, self.context)
446 | self.assertEqual(self.bot.text, "I have so much to... (see attachment for the rest): 11.md")
447 |
448 | async def test_exception(self):
449 | ai = FakeGPT(error=Exception("connection timeout"))
450 | mock_text_asker(ai)
451 | update = self._create_update(11, text="What is your name?")
452 | await self.command(update, self.context)
453 | self.assertTrue(self.bot.text.startswith("⚠️ builtins.Exception:"))
454 | self.assertTrue("connection timeout" in self.bot.text)
455 |
456 |
457 | class MessageGroupTest(unittest.IsolatedAsyncioTestCase, Helper):
458 | def setUp(self):
459 | mock_text_asker(FakeGPT())
460 | self.bot = FakeBot("bot")
461 | self.chat = Chat(id=1, type=ChatType.GROUP)
462 | self.chat.set_bot(self.bot)
463 | self.application = FakeApplication(self.bot)
464 | self.application.user_data[1] = {}
465 | self.application.user_data[2] = {}
466 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
467 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
468 | self.user_erik = User(id=2, first_name="Erik", is_bot=False, username="erik")
469 | self.user_bot = User(id=42, first_name="Bot", is_bot=True, username="bot")
470 | self.command = commands.Message(bot.reply_to)
471 | config.telegram.usernames = ["alice"]
472 |
473 | async def test_message(self):
474 | mention = MessageEntity(type=MessageEntity.MENTION, offset=0, length=4)
475 | update = self._create_update(11, text="@bot What is your name?", entities=(mention,))
476 | await self.command(update, self.context)
477 | self.assertEqual(self.bot.text, "What is your name?")
478 |
479 | async def test_no_mention(self):
480 | update = self._create_update(11, text="What is your name?")
481 | await self.command(update, self.context)
482 | self.assertEqual(self.bot.text, "")
483 |
484 |
485 | class MessageLimitTest(unittest.IsolatedAsyncioTestCase, Helper):
486 | def setUp(self):
487 | self.ai = FakeGPT()
488 | mock_text_asker(self.ai)
489 | self.bot = FakeBot("bot")
490 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
491 | self.chat.set_bot(self.bot)
492 | self.application = FakeApplication(self.bot)
493 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
494 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
495 | self.command = commands.Message(bot.reply_to)
496 | config.telegram.usernames = ["alice"]
497 | config.conversation.message_limit.count = 1
498 | config.conversation.message_limit.period = "minute"
499 | # a hack for testing purposes only
500 | bot.filters.users = tg_filters.User(username=config.telegram.usernames)
501 |
502 | async def test_known_user(self):
503 | update = self._create_update(11, text="What is your name?")
504 | await self.command(update, self.context)
505 | self.assertEqual(self.bot.text, "What is your name?")
506 |
507 | update = self._create_update(12, text="Where are you from?")
508 | await self.command(update, self.context)
509 | self.assertEqual(self.bot.text, "Where are you from?")
510 |
511 | async def test_unknown_user(self):
512 | other_user = User(id=2, first_name="Bob", is_bot=False, username="bob")
513 |
514 | update = self._create_update(11, text="What is your name?", user=other_user)
515 | await self.command(update, self.context)
516 | self.assertEqual(self.bot.text, "What is your name?")
517 |
518 | update = self._create_update(12, text="Where are you from?", user=other_user)
519 | await self.command(update, self.context)
520 | self.assertTrue(self.bot.text.startswith("Please wait"))
521 |
522 | async def test_expired(self):
523 | config.conversation.message_limit.count = 3
524 |
525 | user = User(id=2, first_name="Bob", is_bot=False, username="bob")
526 | # the counter has reached the limit, but the value has expired
527 | user_data = {
528 | "message_counter": {
529 | "value": 3,
530 | "timestamp": dt.datetime.now() - dt.timedelta(hours=1),
531 | }
532 | }
533 | self.application.user_data[user.id] = user_data
534 | context = CallbackContext(self.application, chat_id=1, user_id=user.id)
535 |
536 | update = self._create_update(11, text="What is your name?", user=user)
537 | await self.command(update, context)
538 | self.assertEqual(self.bot.text, "What is your name?")
539 | self.assertEqual(user_data["message_counter"]["value"], 1)
540 |
541 | async def test_unlimited(self):
542 | config.conversation.message_limit.count = 0
543 | other_user = User(id=2, first_name="Bob", is_bot=False, username="bob")
544 |
545 | update = self._create_update(11, text="What is your name?", user=other_user)
546 | await self.command(update, self.context)
547 | self.assertEqual(self.bot.text, "What is your name?")
548 |
549 | update = self._create_update(12, text="Where are you from?", user=other_user)
550 | await self.command(update, self.context)
551 | self.assertEqual(self.bot.text, "Where are you from?")
552 |
553 |
554 | class ErrorTest(unittest.IsolatedAsyncioTestCase, Helper):
555 | def setUp(self):
556 | mock_text_asker(FakeGPT())
557 | self.bot = FakeBot("bot")
558 | self.chat = Chat(id=1, type=ChatType.PRIVATE)
559 | self.chat.set_bot(self.bot)
560 | self.application = FakeApplication(self.bot)
561 | self.context = CallbackContext(self.application, chat_id=1, user_id=1)
562 | self.user = User(id=1, first_name="Alice", is_bot=False, username="alice")
563 | config.telegram.usernames = ["alice"]
564 |
565 | async def test_error(self):
566 | self.context.error = Exception("Something went wrong")
567 | command = commands.Error()
568 | update = self._create_update(11, "Something went wrong")
569 | update._effective_chat = self.chat
570 | await command(update, self.context)
571 | self.assertEqual(self.bot.text, "⚠️ builtins.Exception: Something went wrong")
572 |
--------------------------------------------------------------------------------