├── .dockerignore ├── .env.example ├── .gitignore ├── Dockerfile ├── LICENSE.md ├── README.md ├── clio.jpg ├── cliobot ├── bots │ ├── __init__.py │ ├── command_handler.py │ └── telegram_bot.py ├── commands │ ├── __init__.py │ ├── audio.py │ ├── help.py │ ├── images.py │ ├── session.py │ └── text.py ├── config.py ├── db │ ├── __init__.py │ ├── inmemory.py │ ├── sqlite.py │ └── utils.py ├── errors │ ├── __init__.py │ └── sentry.py ├── metrics.py ├── metrics │ ├── __init__.py │ └── mixpanel_metrics.py ├── ollama │ └── client.py ├── openai │ └── client.py ├── replicate │ └── client.py ├── storage │ ├── __init__.py │ └── s3.py ├── translator │ ├── __init__.py │ └── google.py ├── utils.py └── webui │ └── client.py ├── config.example.yml ├── config.full.yml ├── fullbot.py ├── i18n └── en.yml ├── poetry.lock ├── pyproject.toml ├── schema.sql ├── test ├── res │ ├── hello.mp3 │ └── sandwich.jpg ├── test_command_handler.py ├── test_ollama_client.py ├── test_openai_client.py ├── test_replicate_client.py └── test_webui_client.py └── working.jpg /.dockerignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | cache/ 3 | tmp/ 4 | config.yml 5 | .env 6 | *.pyc 7 | *.pyo 8 | *.pyd 9 | .Python 10 | .idea 11 | .git 12 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_TOKEN=xxx 2 | 3 | REPLICATE_API_TOKEN=xxx 4 | 5 | OPENAI_AZURE_API_KEY_GPT35=xxx 6 | OPENAI_AZURE_BASE_URL_GPT35=https://xxx.openai.azure.com 7 | 8 | OPENAI_AZURE_API_KEY_GPT4=xxx 9 | OPENAI_AZURE_BASE_URL_GPT4=https://xxx.openai.azure.com 10 | 11 | OPENAI_AZURE_BASE_URL_EMBEDDINGS=https://xxx.openai.azure.com 12 | OPENAI_AZURE_API_KEY_EMBEDDINGS=xxx 13 | 14 | OPENAI_AZURE_BASE_URL_DALLE3=https://xxx.openai.azure.com 15 | OPENAI_AZURE_API_KEY_DALLE3=xxx 16 | 17 | OPENAI_AZURE_BASE_URL_WHISPER=https://xxx.openai.azure.com 18 | OPENAI_AZURE_API_KEY_WHISPER=xxx 19 | 20 | TELEGRAM_TOKEN=xxx 21 | 22 | HF_HOME=data/huggingface -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | .idea 3 | config.yml 4 | /.env.development 5 | **.pyc 6 | /clibot.db 7 | /tmp 8 | cache/ 9 | .env 10 | data/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 2 | 3 | # Set the working directory in the container 4 | WORKDIR /content 5 | 6 | RUN apt-get update -y && \ 7 | apt-get upgrade -y && \ 8 | apt-get install -y \ 9 | libgl1 \ 10 | libglib2.0-0 \ 11 | python3-pip \ 12 | python-is-python3 \ 13 | python3.10-dev \ 14 | bash \ 15 | gcc \ 16 | build-essential \ 17 | # wget \ 18 | # git \ 19 | # git-lfs \ 20 | # curl \ 21 | libffi-dev \ 22 | libssl-dev \ 23 | openssl \ 24 | # tcl-dev \ 25 | # tk-dev \ 26 | ca-certificates && \ 27 | rm -rf /var/lib/apt/lists/* 28 | 29 | COPY requirements.txt /content/requirements.txt 30 | RUN pip3 install -r /content/requirements.txt 31 | 32 | COPY . /content 33 | 34 | CMD ["python", "app.py"] -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2022 Herval Freire 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cliobot - multimodal generative AI bot for chat platforms 2 | 3 | ![Clio Bot](clio.jpg) 4 | 5 | Cliobot is a modular bot platform for generative AI agents. It's goal is to provide a simple, easy to use and extendable 6 | platform for running generative AI agents that can handle audio, video, text and images, on any chat platform. 7 | 8 | It can be easily extend it to use multiple APIs and services, from Stable Diffusion to OpenAI, and you can run it on 9 | your own device or deploy it online. 10 | 11 | It comes with Telegram support and multiuser handling out of the box, and minimal dependencies. 12 | 13 | Important: This repo is a work in progress - I'm porting over code from a startup I was working on, so it's still a bit 14 | rough and subject to multiple rewrites. 15 | 16 | ## The Basic 17 | 18 | Cliobot has two main working modes: __command mode__ or __LLM mode__ 19 | 20 | In __command mode__, you interact by using __slash commands__ (messages starting with a /). It comes with a set default 21 | of slash commands and you can easily create yor own. 22 | 23 | [WIP] In __LLM mode__, the bot works like chatgpt & other multimodal chatbots out there: it follows a configurable 24 | system prompt that defines its core behavior and can use functions to perform actions (including executing code or 25 | browsing the web). 26 | 27 | Notice both modes use the same command definitions, so the only difference between them is a tradeoff between more 28 | natural language interpretation versus cost (since running GPT4 & other models can get expensive quickly). 29 | 30 | ### Running the bot 31 | 32 | 1. Install all dependencies with: 33 | 34 | ``` 35 | poetry install 36 | ``` 37 | 38 | 2. Rename `config.example.yml` to `config.yml` and set the appropriate variables you want. 39 | 40 | 3. Run the bot using the following command: 41 | 42 | ``` 43 | poetry run python fullbot.py 44 | ``` 45 | 46 | The codebase includes other examples, such as a simple chat-only bot that uses OpenAI's API to respond to messages (`chatbot.py`). Documentation for these examples is still a work in progress. 47 | 48 | 49 | ### Running tests 50 | ``` 51 | poetry run pytest 52 | ``` 53 | 54 | ## Built-in commands 55 | 56 | Cliobot comes with a set of built-in commands that you can use out of the box. You can also easily add your own! 57 | 58 | ### /image 59 | 60 | Generates an image from a text prompt. 61 | 62 | Built-in implementations: DALL-E 3, any image model hosted on Replicate.com. 63 | 64 | ### /describe [WIP] 65 | 66 | Describe an image using text. 67 | 68 | Built-in implementations: OpenAI GPT4V, Ollama (Llava, etc), any image to text model hosted on Replicate.com. 69 | 70 | ### /transcribe 71 | 72 | Transcribes an audio file into text. 73 | 74 | Built-in implementations: OpenAI Whisper-1 75 | 76 | ### /ask 77 | 78 | Ask a question to an LLM agent. This doesn't take any conversation context. 79 | 80 | Built-in implementations: GPT-4 or any model supported by [Ollama](https://github.com/jmorganca/ollama) running in 81 | server mode, any LLM hosted on Replicate.com. 82 | 83 | ### /chat [WIP] 84 | 85 | Chat with an LLM agent, including a backlog of context 86 | 87 | ## Command syntax 88 | 89 | Cliobot uses a simple prompt parsing system (common across apps such as Midjourney & others). It's based on the 90 | following format: 91 | 92 | ``` 93 | / ? [-- ]+ 94 | ``` 95 | 96 | Each command handler is defined as a pydantic model, and the parameters are automatically parsed and validated. 97 | 98 | When a certain command requires multiple inputs, such as image, the bot will ask for them in sequence, then run the 99 | command after you provide all the inputs. 100 | 101 | An example of a command using the default dalle3 image generation command would be as follows: 102 | 103 | ``` 104 | /image a giant hamster in space --size 1024x1024 --model dalle3 105 | ``` 106 | 107 | ## Installing 108 | 109 | Running a bot locally is simple: 110 | 111 | - Clone this repo 112 | - Setup the python env 113 | - Rename `config.example.yml` to `config.yml` and set the appropriate variables you want. 114 | - Install all dependencies with: 115 | 116 | ``` 117 | python -m venv create venv 118 | source venv/bin/activate 119 | pip install -r requirements.txt 120 | ``` 121 | 122 | ### Running Cliobot on Telegram 123 | 124 | The bare minimum you'll need is an API Token for a Telegram bot. Please refer to 125 | the [official documentation](https://core.telegram.org/bots/tutorial#obtain-your-bot-token) for how to obtain an API 126 | token. It should look like this: `4839574812:AAFD39kkdpWt3ywyRZergyOLMaJhac60qc` 127 | 128 | Once you get a token, change your `config.yml` to include the following session: 129 | 130 | ``` 131 | bot: 132 | platform: telegram 133 | token: "1234567890:ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" 134 | ``` 135 | 136 | Then, run the bot using the following command: 137 | 138 | ``` 139 | source venv/bin/activate 140 | python app.py 141 | ``` 142 | 143 | ## Running with Docker 144 | 145 | If you don't have Python on your system or just prefer to keep things simple, you can run Cliobot using Docker too: 146 | 147 | ``` 148 | docker build -t cliobot . 149 | docker run -it --rm -v $(pwd)/data:/content/data -v $(pwd)/.env:/content/.env -v $(pwd)/config.yml:/content/config.yml cliobot 150 | ``` 151 | 152 | ## Using Automatic1111 WebUI as a backend 153 | 154 | You can plug in [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and use it as a backend for image generation! To do so, you'll need to set the following variables on your config.yml: 155 | 156 | ``` 157 | webui: 158 | endpoint: http://localhost:7860 159 | auth: user:pass 160 | ``` 161 | 162 | Notice you'll need to start webui with the `--api` flag. The `auth` field is optional (you can leave it blank if you don't use API authentication). For more information on how to use the API, please refer to the [official documentation](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API). 163 | 164 | ### Supported operations 165 | 166 | You can use any Stable Diffusion model that's installed along webui with the `/image` command. The following is an example using all the supported parameters: 167 | 168 | 169 | ``` 170 | /image a hamster in space --negative cartoon, drawing, illustration --model sdxl1.0 --steps 20 --sampler 'DPM++ 2M SDE' --cfg 7 --seed 1234 --steps 50 --width 1024 --height 1024 --batchcount 1 --batchsize 4 171 | ``` 172 | 173 | 174 | ## Configuring OpenAI 175 | 176 | To use OpenAI models (gpt, dalle3, whisper, etc), include the following in your config.yml: 177 | 178 | ``` 179 | openai: 180 | endpoints: 181 | - api_key: sk-.... 182 | api_type: open_ai 183 | base_url: https://api.openai.com/v1/ 184 | 185 | - api_key: xxx 186 | api_type: azure 187 | api_version: 2023-10-01-preview 188 | base_url: https://xxx.openai.azure.com 189 | model: gpt4 190 | kind: gpt-4 191 | 192 | - api_key: xxx 193 | api_type: azure 194 | api_version: 2 195 | base_url: https://xxx.openai.azure.com 196 | model: embeddings 197 | kind: embeddings 198 | 199 | - api_key: xxx 200 | api_type: azure 201 | api_version: 2023-12-01-preview 202 | base_url: https://xxx.openai.azure.com 203 | model: dalle3 204 | kind: dall-e-3 205 | 206 | - api_key: xxx 207 | api_type: azure 208 | api_version: 2023-12-01-preview 209 | base_url: https://xxx.openai.azure.com 210 | model: whisper1 211 | kind: whisper-1 212 | ``` 213 | 214 | Notice that for Azure deployments, you'll need to set one entry per model 215 | kind (`dall-e-3`, `whisper-1`, `embeddings`, `gpt-4`). The API key can be the same for all of them. 216 | 217 | ## Configuring Ollama 218 | 219 | In order to use any LLM via Ollama, simply include the following in your config.yml: 220 | 221 | ``` 222 | ollama: 223 | endpoint: http://localhost:11434 224 | models: 225 | - llama2 226 | ``` 227 | 228 | Each model on the `models` list will be exposed as a model on the bot. You can then use it by using the `/ask` command: 229 | 230 | ``` 231 | /ask what's the meaning of life? --model llama2 232 | ``` 233 | 234 | ## Configuring Replicate 235 | 236 | You can use any model hosted on [Replicate](https://replicate.com/) by mapping it out on your config.yml. The mapping is 237 | a bit more involved than other models, since you need to map out each parameter. Here's a complete example using SDXL 238 | hosted on Replicate: 239 | 240 | ``` 241 | replicate: 242 | api_token: xxx 243 | endpoints: 244 | - model: 'sdxl' 245 | kind: 'image' 246 | version: 'stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b' 247 | params: 248 | prompt: 249 | kind: str 250 | required: true 251 | negative_prompt: 252 | alias: no 253 | kind: str 254 | width: 255 | kind: int 256 | default: 1024 257 | height: 258 | kind: int 259 | default: 1024 260 | num_outputs: 261 | alias: num 262 | kind: int 263 | default: 1 264 | num_inference_steps: 265 | alias: steps 266 | kind: int 267 | default: 25 268 | guidance_scale: 269 | alias: cfg 270 | kind: float 271 | default: 7.5 272 | prompt_strength: 273 | alias: ps 274 | kind: float 275 | default: 0.8 276 | seed: 277 | kind: int 278 | apply_watermark: 279 | alias: watermark 280 | kind: bool 281 | default: true 282 | scheduler: 283 | kind: str 284 | default: 'KarrasDPM' 285 | refine: 286 | kind: str 287 | alias: refiner 288 | default: 'no_refiner' 289 | value_map: 290 | no: no_refiner 291 | expert: expert_ensemble_refiner 292 | base: base_image_refiner 293 | refine_steps: 294 | kind: int 295 | alias: rs 296 | ``` 297 | 298 | With the above config, you'll be able to generate images using the following command: 299 | 300 | ``` 301 | /image photo of a giant hamster in space --model sdxl --no illustration, cartoon, drawing --width 1280 --num 4 --steps 50 --rs 8 --refiner expert 302 | ``` 303 | 304 | Notice the parameter names on your slash command will match the param name on the config, _or_ an optional `alias`. This 305 | allows you to use shorter parameter names on your commands (eg typing out `--no` instead of `--negative_prompt`). 306 | 307 | ## Built-in extensions 308 | 309 | These are all deactivated by default, but easily enabled: 310 | 311 | - Sentry.io support for error reporting/tracking 312 | - Automatic message translation using Google Translate API 313 | - Utilization metrics using MixPanel 314 | - S3 for file storage 315 | 316 | ## Features 317 | 318 | - OpenAI API support for DALL-E, GPT-3, GPT-4 and Whisper, including Azure support and multiple API keys 319 | - Ollama support for any LLM model (including image to text) 320 | - Support for any model hosted on Replicate.com 321 | - Multiuser support 322 | - File storage support (local & S3) 323 | - Automatic message translation using Google Translate API 324 | - Persistent preferences to reduce repetitive prompt parameters 325 | 326 | ## Running on K8s 327 | 328 | TODO 329 | 330 | ## Writing plugins 331 | 332 | TODO 333 | 334 | ## Planned features 335 | 336 | - Discord integration 337 | - Whatsapp integration 338 | - Stable Diffusion 339 | - StableHorde processing 340 | 341 | ## TODO 342 | 343 | - RAG mode 344 | - chat history 345 | - Finish the LLM mode 346 | - save generated images to storage 347 | - save uploads 348 | - i18n support 349 | - img2txt commands 350 | - llama implementation 351 | 352 | -------------------------------------------------------------------------------- /clio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/herval/cliobot/c7b2155e548f33ab87216263b765796f5338e50e/clio.jpg -------------------------------------------------------------------------------- /cliobot/bots/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import queue 4 | import threading 5 | import traceback 6 | from sys import exc_info 7 | from typing import Callable 8 | 9 | from cliobot.cache import InMemoryCache 10 | from cliobot.errors import BaseErrorHandler 11 | from cliobot.metrics import BaseMetrics 12 | 13 | 14 | class Message: 15 | def __init__(self, 16 | message_id, 17 | user_id, 18 | chat_id, 19 | user, 20 | reply_to_message=None, 21 | reply_to_message_id=None, 22 | text=None, 23 | image=None, 24 | audio=None, 25 | voice=None, 26 | video=None, 27 | bot_id=None, 28 | metadata=None, # metadata is transient and survives the current request only 29 | is_forward=False, 30 | ): 31 | if metadata is None: 32 | metadata = {} 33 | 34 | self.user = user 35 | self.bot_id = bot_id 36 | self.message_id = message_id 37 | self.user_id = user_id 38 | self.chat_id = chat_id 39 | self.reply_to_message_id = reply_to_message_id 40 | self.text = text 41 | self.reply_to_message = reply_to_message 42 | self.image = image 43 | self.video = video 44 | self.audio = audio 45 | self.metadata = metadata 46 | self.is_forward = is_forward 47 | self.voice = voice 48 | 49 | if self.reply_to_message and not self.reply_to_message_id: 50 | self.reply_to_message_id = self.reply_to_message.message_id 51 | 52 | def translate(self, translator): 53 | if self.text is not None and translator: 54 | self.text = translator.translate(self.text) or self.text 55 | 56 | def full_text(self): 57 | return self.text 58 | 59 | def __str__(self): 60 | return f"Message({self.message_id}, {self.chat_id}, {self.user_id}, {self.text})" 61 | 62 | def __repr__(self): 63 | return self.__str__() 64 | 65 | 66 | class User: 67 | def __init__(self, username, phone, full_name, language): 68 | self.username = username 69 | self.phone = phone 70 | self.full_name = full_name 71 | self.language = language 72 | 73 | 74 | # context = the "short term memory" of the bot. It survives across requests until cleared 75 | # preferences = the "long term memory" of the bot. It survives until a user logs off 76 | class Session: 77 | def __init__(self, user_id, chat_id, context, preferences): 78 | self.user_id = user_id 79 | self.chat_id = chat_id 80 | self.preferences = {} 81 | self.context = context 82 | self.preferences = preferences 83 | 84 | def pop(self, key): 85 | if key in self.context: 86 | return self.context.pop(key) 87 | return None 88 | 89 | def __str__(self): 90 | return f"Session({self.context}, {self.preferences})" 91 | 92 | def __repr__(self): 93 | return self.__str__() 94 | 95 | def set(self, key, value): 96 | print('setting', key, value) 97 | self.context[key] = value 98 | 99 | def get(self, key, default=None, include_preferences=True): 100 | res = self.context.get(key, None) 101 | 102 | if not res and include_preferences: 103 | res = self.preferences.get(key, default) 104 | 105 | return res or default 106 | 107 | def clear(self, clear_user=False): 108 | newc = {} 109 | # for x in self.context.keys(): 110 | # if x in ['temp_image', 'temp_audio', 111 | # 'temp_video'] and not clear_user: # keep unless we're clearing the user 112 | # newc[x] = self.context[x] 113 | self.context = newc 114 | 115 | def to_dict(self, include_preferences=True): 116 | res = {} 117 | 118 | if include_preferences: 119 | for k, v in self.preferences.items(): 120 | res[k] = v 121 | 122 | for k, v in self.context.items(): 123 | if k != 'buffer': 124 | res[k] = v 125 | 126 | return res 127 | 128 | def images(self) -> dict[str, str]: 129 | return { 130 | k: v for k, v in self.context.items() 131 | if k.endswith('_image') and v is not None 132 | } 133 | 134 | def audios(self) -> dict[str, str]: 135 | return { 136 | k: v for k, v in self.context.items() 137 | if k.endswith('_audio') and v is not None 138 | } 139 | 140 | def set_preference(self, key, val): 141 | self.preferences[key] = val 142 | 143 | 144 | class CachedSession(Session): 145 | def __init__(self, chat_session, chat_id): 146 | super().__init__(user_id=chat_session.get('external_user_id'), 147 | chat_id=chat_id, 148 | context=chat_session.get('context'), 149 | preferences=chat_session.get('preferences'), 150 | ) 151 | self.dirty = False 152 | 153 | def pop(self, key): 154 | if key in self.context: 155 | self.dirty = True 156 | return super().pop(key) 157 | 158 | def persist(self, db): 159 | if self.dirty: # commit changes 160 | db.set_chat_context(self.user_id, self.context, self.preferences) 161 | self.dirty = False 162 | 163 | def set_preference(self, key, val): 164 | if self.preferences.get(key) != val: 165 | self.dirty = True 166 | super().set_preference(key, val) 167 | 168 | def set(self, key, value): 169 | if self.context.get(key) != value: 170 | self.dirty = True 171 | super().set(key, value) 172 | 173 | def clear(self, clear_user=False): 174 | if len(self.context) > 0: 175 | self.dirty = True 176 | super().clear(clear_user) 177 | 178 | @classmethod 179 | def from_cache(cls, db, user_id, chat_id): 180 | data = db.create_or_get_chat_session(user_id) 181 | return cls( 182 | chat_session=data, 183 | chat_id=chat_id, 184 | ) 185 | 186 | class MessageHandler: 187 | """ 188 | base class for message handlers 189 | """ 190 | def __init__(self): 191 | self.running = True 192 | self.sender_loop = None 193 | 194 | async def process(self, message: Message, session: CachedSession, bot): 195 | raise NotImplementedError 196 | 197 | def listen(self, bot): 198 | self.sender_loop = asyncio.new_event_loop() 199 | 200 | asyncio.set_event_loop(self.sender_loop) 201 | self.sender_loop.run_until_complete(self._poll(bot)) 202 | 203 | def stop(self): 204 | self.running = False 205 | self.sender_loop.stop() 206 | 207 | 208 | async def _poll(self, bot): 209 | while self.running: 210 | try: 211 | message = bot.internal_queue.get() 212 | await self._handle_message(message, bot) 213 | except (KeyboardInterrupt, SystemExit): 214 | print("Shutting down...") 215 | self.running = False 216 | return 217 | except Exception: 218 | traceback.print_exc() 219 | bot.metrics.capture_exception(exc_info(), 'anonymous') 220 | finally: 221 | bot.internal_queue.task_done() 222 | 223 | async def _handle_message(self, message: Message, bot): 224 | print('on_message', message.__str__()) 225 | session = CachedSession.from_cache( 226 | db=bot.db, 227 | user_id=message.user_id, 228 | chat_id=message.chat_id) 229 | 230 | bot.metrics.error_handler.set_context({ 231 | "id": session.user_id, 232 | }) 233 | 234 | try: 235 | bot.db.save_message( 236 | user_id=message.user_id, 237 | chat_id=message.chat_id, 238 | text=message.text or '', 239 | external_id=message.message_id, 240 | image=message.image, 241 | video=message.video, 242 | audio=message.audio, 243 | voice=message.voice, 244 | is_forward=message.is_forward, 245 | ) 246 | except Exception as e: 247 | bot.metrics.capture_exception(e, session.user_id) 248 | 249 | if session.user_id is None: 250 | session = bot.db.create_or_get_chat_session(message.user_id) 251 | print(session) 252 | session.chat_id = message.chat_id 253 | session.user_id = session.get('external_user_id', None) 254 | 255 | if message.reply_to_message_id and not message.reply_to_message: 256 | print("Loading reply...") 257 | message.reply_to_message = await bot.messaging_service.get_message(message.reply_to_message_id) 258 | 259 | message.translate(bot.translator) 260 | 261 | await self.process(message, session, bot) 262 | 263 | 264 | class MessagingService: 265 | 266 | async def initialize(self): 267 | raise NotImplementedError() 268 | 269 | async def get_file(self, file_id) -> (str, bytes): 270 | raise NotImplementedError() 271 | 272 | async def get_message(self, message_id): 273 | raise NotImplementedError() 274 | 275 | def supports_editing_media(self): 276 | return True # true by default 277 | 278 | async def get_file_info(self, file_id): 279 | raise NotImplementedError() 280 | 281 | async def edit_message_media(self, message_id, chat_id, media, text=None, reply_buttons=None): 282 | raise NotImplementedError() 283 | 284 | async def edit_message(self, message_id, chat_id, text, context=None, reply_buttons=None): 285 | raise NotImplementedError() 286 | 287 | async def send_message(self, text, chat_id, context=None, reply_to_message_id=None, reply_buttons=None, 288 | buttons=None): 289 | raise NotImplementedError() 290 | 291 | async def delete_message(self, message_id, chat_id): 292 | raise NotImplementedError() 293 | 294 | async def send_media(self, chat_id, media, text, reply_to_message_id=None, context=None, reply_buttons=None, 295 | buttons=None): 296 | raise NotImplementedError() 297 | 298 | 299 | 300 | 301 | class BaseBot: 302 | 303 | def __init__(self, 304 | handler_fn: Callable[[], MessageHandler], 305 | messaging_service: MessagingService, 306 | db, 307 | storage=None, 308 | internal_queue=None, 309 | bot_id=None, 310 | bot_language='en', 311 | cache=None, 312 | translator=None, 313 | metrics=None, 314 | ): 315 | self.messaging_service = messaging_service 316 | self.internal_queue = internal_queue or queue.Queue() 317 | self.translator = translator 318 | self.db = db 319 | self.storage = storage 320 | self.bot_id = bot_id 321 | self.bot = None 322 | self.bot_language = bot_language 323 | self.cache = cache or InMemoryCache() 324 | self.metrics = metrics or BaseMetrics(BaseErrorHandler()) 325 | self.models = {} 326 | self.handler_fn = handler_fn 327 | 328 | self.senders = [handler_fn() for _ in range(int(os.cpu_count()))] 329 | 330 | self.threads = [ 331 | threading.Thread(target=handler.listen, args=(self,), daemon=True) for handler in self.senders 332 | ] 333 | 334 | async def initialize(self): 335 | raise NotImplementedError() 336 | 337 | def start(self): 338 | raise NotImplementedError() 339 | 340 | def listen(self): 341 | # initialize the bot commands list and stuff 342 | loop = asyncio.new_event_loop() 343 | loop.run_until_complete(self.initialize()) 344 | loop.close() 345 | 346 | # start everything 347 | [t.start() for t in self.threads] 348 | print("Bot ready") 349 | self.start() 350 | print("blowing things up, stay calm...") 351 | [s.stop() for s in self.senders] 352 | 353 | async def enqueue(self, update): 354 | self.internal_queue.put(update) 355 | -------------------------------------------------------------------------------- /cliobot/bots/command_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import traceback 3 | from sys import exc_info 4 | from typing import Optional 5 | 6 | from cliobot.bots import Message, CachedSession, MessageHandler 7 | from cliobot.commands import BaseCommand 8 | 9 | class CommandHandler(MessageHandler): 10 | """ 11 | Simple handler that just executes commands using a slash command syntax 12 | """ 13 | 14 | def __init__(self, fallback_commands: dict, commands): 15 | super().__init__() 16 | self.fallback_commands = fallback_commands 17 | self.commands = commands 18 | self.command_handlers = {c.command: c for c in commands} 19 | self.reply_handlers = [c.command for c in commands if c.reply_only] 20 | 21 | def infer_command(self, update, session) -> Optional[BaseCommand]: 22 | if 'command' in update.metadata: # callback data takes precedence 23 | return self.command_handlers[update.metadata.get('command')] 24 | 25 | if update.text is None: 26 | return None 27 | 28 | text = update.text 29 | txt = text.split(" ")[0].lower() 30 | if txt.startswith('/'): 31 | txt = txt[1:] 32 | 33 | if txt in self.command_handlers: 34 | return self.command_handlers[txt] 35 | 36 | if update.reply_to_message_id and txt in self.reply_handlers: 37 | return self.command_handlers[txt] 38 | 39 | if session.get('command'): 40 | return self.command_handlers[session.get('command')] 41 | 42 | return None 43 | 44 | async def exec(self, command: BaseCommand, update: Message, session: CachedSession, bot): 45 | try: 46 | if await command.process(update, session, bot): 47 | session.clear() 48 | except Exception as e: 49 | traceback.print_exc() 50 | bot.metrics.capture_exception(exc_info(), session.user_id) 51 | finally: 52 | session.persist(bot.db) 53 | 54 | async def process(self, message: Message, session: CachedSession, bot): 55 | inf = self.infer_command(message, session) 56 | if inf is not None: # handle as a command input 57 | bot.metrics.send_event( 58 | event="user_command", 59 | user_id=session.user_id, 60 | params={ 61 | 'command': inf.command, 62 | 'chat_id': message.chat_id, 63 | } 64 | ) 65 | await self.exec(inf, message, session, bot) # command handled, all good 66 | else: 67 | bot.metrics.send_event( 68 | event="user_message", 69 | user_id=session.user_id, 70 | params={ 71 | 'chat_id': message.chat_id, 72 | } 73 | ) 74 | 75 | if message.audio and 'audio' in self.fallback_commands: 76 | fallback = self.fallback_commands['audio'] 77 | elif message.video and 'video' in self.fallback_commands: 78 | fallback = self.fallback_commands['video'] 79 | elif message.voice and 'voice' in self.fallback_commands: 80 | fallback = self.fallback_commands['voice'] 81 | elif message.image and 'image' in self.fallback_commands: 82 | fallback = self.fallback_commands['image'] 83 | elif message.text and 'text' in self.fallback_commands: 84 | fallback = self.fallback_commands['text'] 85 | else: 86 | fallback = None 87 | 88 | if fallback in self.command_handlers: 89 | message.text = f'/{fallback} {message.text}'.strip() 90 | await self.exec(self.command_handlers[fallback], message, session, bot) 91 | -------------------------------------------------------------------------------- /cliobot/bots/telegram_bot.py: -------------------------------------------------------------------------------- 1 | import contextvars 2 | from pathlib import Path 3 | 4 | import httpcore 5 | from retry import retry 6 | from telegram import BotCommand, InputMediaPhoto, InputMediaDocument, InlineKeyboardButton, InlineKeyboardMarkup, \ 7 | ReplyKeyboardMarkup, KeyboardButton 8 | from telegram._bot import Bot 9 | from telegram.error import BadRequest, TimedOut, Forbidden 10 | from telegram.ext import ApplicationBuilder, ConversationHandler, MessageHandler, CallbackQueryHandler, \ 11 | filters 12 | 13 | from cliobot.bots import Message, User, MessagingService, BaseBot 14 | from cliobot.bots.command_handler import CommandHandler 15 | from cliobot.errors import TransientFailure, UserBlocked, UnknownError, MessageNoLongerExists, MessageNotModifiable 16 | from cliobot.utils import flatten 17 | 18 | 19 | def telegram_bot_id(apikey): 20 | return apikey.split(':')[0] 21 | 22 | 23 | def convert_exceptions(func): 24 | async def wrapper(*args, **kwargs): 25 | try: 26 | return await func(*args, **kwargs) 27 | except httpcore.ConnectTimeout as to: 28 | # capture_exception(to) 29 | raise TransientFailure(to) 30 | except TimedOut as to: 31 | raise TransientFailure(to) 32 | except Forbidden as f: 33 | if 'bot was blocked by the user' in f.message: 34 | raise UserBlocked(f) 35 | else: 36 | raise UnknownError(f) 37 | except BadRequest as _e: 38 | if _e.message == 'Chat not found': 39 | print("Chat not found, ignoring (could be using a different bot instance?)") 40 | elif _e.message in [ 41 | 'Message to delete not found', 42 | 'Message to edit not found']: 43 | raise MessageNoLongerExists(_e) 44 | elif _e.message in [ 45 | 'There is no media in the message to edit', 46 | 'Replied message not found', 47 | 'Message can\'t be edited']: 48 | raise MessageNotModifiable(_e) 49 | elif (_e.__str__() in 50 | [ 51 | 'Message is not modified: specified new message content and reply markup are exactly the same as a current content and reply markup of the message', 52 | 'Canceled by new editmessagemedia request', 53 | ]): 54 | print(_e.__str__()) # happens, just ignore it 55 | raise TransientFailure(_e) 56 | else: 57 | raise _e 58 | 59 | return wrapper 60 | 61 | 62 | def convert_media(media): 63 | if 'image' in media: 64 | path = media['image'] 65 | if path.startswith('/'): 66 | path = Path(path) 67 | 68 | return InputMediaPhoto( 69 | media=media['image'], 70 | filename=media.get('filename', None), 71 | caption=media.get('text', None)) 72 | elif 'attachment' in media: 73 | return InputMediaDocument( 74 | media=media['attachment'], 75 | filename=media.get('filename', None), 76 | caption=media.get('text', None), 77 | thumbnail=media.get('thumbnail', media['attachment']), 78 | ) 79 | else: 80 | return None 81 | 82 | 83 | def parse_callback_string(callback_data): 84 | op, *rest = callback_data.split(':') 85 | if op in ['upvote', 'downvote']: 86 | return { 87 | 'command': op, 88 | 'job_id': rest[0], 89 | } 90 | 91 | if op == 'select': 92 | return { 93 | 'command': op, 94 | 'job_id': rest[0], 95 | 'index': rest[1], 96 | } 97 | 98 | if op == 'shuffle': 99 | return { 100 | 'command': op, 101 | 'job_id': rest[0], 102 | 'index': rest[1], 103 | } 104 | 105 | if op == 'reroll_job': 106 | return { 107 | 'command': op, 108 | 'job_id': rest[0], 109 | } 110 | 111 | if op == 'retry': 112 | return { 113 | 'command': op, 114 | 'job_id': rest[0], 115 | } 116 | 117 | return {} 118 | 119 | 120 | threadlocal_bot = contextvars.ContextVar("bot_instance", default=None) 121 | 122 | 123 | class TelegramMessagingService(MessagingService): 124 | def __init__(self, apikey, db): 125 | self.apikey = apikey 126 | self.bot_id = telegram_bot_id(apikey) 127 | self.db = db 128 | 129 | async def initialize(self) -> Bot: 130 | if not threadlocal_bot.get(): 131 | threadlocal_bot.set(Bot(self.apikey)) 132 | 133 | if not threadlocal_bot.get()._initialized: 134 | await threadlocal_bot.get().initialize() 135 | 136 | return threadlocal_bot.get() 137 | 138 | @convert_exceptions 139 | @retry(TimedOut, tries=2, delay=0.5) 140 | async def get_file(self, file_id): 141 | bot = await self.initialize() 142 | file = await bot.get_file(file_id) 143 | bytesdata = await file.download_as_bytearray() 144 | 145 | return file.file_path, bytesdata 146 | 147 | @convert_exceptions 148 | @retry(TimedOut, tries=2, delay=0.5) 149 | async def get_file_info(self, file_id): 150 | bot = await self.initialize() 151 | info = await bot.get_file(file_id) 152 | return { 153 | 'file_path': info.file_path, 154 | } 155 | 156 | @convert_exceptions 157 | @retry(TimedOut, tries=2, delay=0.5) 158 | async def edit_message_media(self, message_id, chat_id, media, text=None, reply_buttons=None): 159 | bot = await self.initialize() 160 | await bot.edit_message_media( 161 | chat_id=chat_id, 162 | message_id=message_id, 163 | media=convert_media(media), 164 | reply_markup=reply_markup(reply_buttons) 165 | ) 166 | 167 | if text: 168 | await bot.edit_message_caption( 169 | chat_id=chat_id, 170 | message_id=message_id, 171 | caption=text, 172 | ) 173 | 174 | @convert_exceptions 175 | @retry(TimedOut, tries=2, delay=0.5) 176 | async def edit_message(self, message_id, chat_id, text, session=None, reply_buttons=None): 177 | bot = await self.initialize() 178 | return await bot.edit_message_text( 179 | chat_id=chat_id, 180 | message_id=message_id, 181 | text=text or '', 182 | ) 183 | 184 | @convert_exceptions 185 | @retry(TimedOut, tries=2, delay=0.5) 186 | async def send_media(self, chat_id, media, reply_to_message_id=None, session=None, text=None, reply_buttons=None, 187 | buttons=None): 188 | bot = await self.initialize() 189 | 190 | if buttons: 191 | for b in buttons: 192 | if b['kind'] == 'url': 193 | text += f"\n\n{b['text']}: {b['url']}" 194 | 195 | res = await bot.send_photo( 196 | chat_id=chat_id, 197 | photo=media['image'], 198 | caption=text, 199 | reply_to_message_id=reply_to_message_id, 200 | reply_markup=reply_markup(reply_buttons) or buttons_markup(buttons), 201 | ) 202 | 203 | self.db.save_message( 204 | user_id=self.bot_id, 205 | chat_id=chat_id, 206 | text=text or '', 207 | external_id=res.id, 208 | image=res.photo[-1].file_id, 209 | is_forward=False, 210 | ) 211 | 212 | return res 213 | 214 | @convert_exceptions 215 | @retry(TimedOut, tries=2, delay=0.5) 216 | async def delete_message(self, message_id, chat_id): 217 | bot = await self.initialize() 218 | return await bot.delete_message( 219 | chat_id=int(chat_id), 220 | message_id=int(message_id), 221 | ) 222 | 223 | @convert_exceptions 224 | @retry((TimedOut, TransientFailure), tries=2, delay=0.5) 225 | async def send_message(self, text, chat_id, session=None, reply_to_message_id=None, reply_buttons=None, 226 | buttons=None): 227 | bot = await self.initialize() 228 | 229 | if buttons: 230 | for b in flatten(buttons): 231 | if b['kind'] == 'url': 232 | text += f"\n\n{b['text']}: {b['url']}" 233 | 234 | res = await bot.send_message( 235 | chat_id=chat_id, 236 | text=text or '', 237 | reply_to_message_id=reply_to_message_id, 238 | reply_markup=reply_markup(reply_buttons) or buttons_markup(buttons), 239 | ) 240 | 241 | self.db.save_message( 242 | user_id=self.bot_id, 243 | chat_id=chat_id, 244 | text=text or '', 245 | external_id=res.id, 246 | is_forward=False, 247 | ) 248 | 249 | return res 250 | 251 | 252 | def buttons_markup(buttons): 253 | if buttons is not None: 254 | rows = [] 255 | for line in buttons: 256 | items = [] 257 | for e in line: 258 | if e.get('kind', None) == 'login': 259 | items.append(KeyboardButton( 260 | text=e['text'], 261 | request_contact=True, 262 | )) 263 | elif e.get('kind', None) == 'url': 264 | pass 265 | # skip 266 | # items.append(KeyboardButton( 267 | # text=e['text'], 268 | # web_app=WebAppInfo( 269 | # url=e['url'], 270 | # ) 271 | # )) 272 | 273 | rows.append(items) 274 | 275 | return ReplyKeyboardMarkup( 276 | resize_keyboard=True, 277 | one_time_keyboard=True, 278 | keyboard=rows 279 | ) 280 | else: 281 | return None 282 | 283 | 284 | def reply_markup(reply_buttons): 285 | if reply_buttons is not None: 286 | rows = [] 287 | for line in reply_buttons: 288 | items = [] 289 | for e in line: 290 | if 'url' in e: 291 | items.append(InlineKeyboardButton( 292 | text=e['text'], 293 | url=e['url'], 294 | )) 295 | if 'callback' in e: 296 | items.append(InlineKeyboardButton( 297 | text=e['text'], 298 | callback_data=e['callback'], 299 | )) 300 | rows.append(items) 301 | 302 | return InlineKeyboardMarkup( 303 | inline_keyboard=rows 304 | ) 305 | else: 306 | return None 307 | 308 | 309 | class TelegramBot(BaseBot): 310 | def __init__(self, 311 | apikey, 312 | db, 313 | **kwargs 314 | ): 315 | self.apikey = apikey 316 | self.app = ApplicationBuilder().token(apikey).build() 317 | 318 | super().__init__( 319 | db=db, 320 | bot_id=telegram_bot_id(apikey), 321 | messaging_service=TelegramMessagingService( 322 | apikey=apikey, 323 | db=db, 324 | ), 325 | **kwargs, 326 | ) 327 | 328 | def handler_adapter(self, command): 329 | async def wrapper(message, context): 330 | msg = await self._parse_message( 331 | message=message.effective_message, 332 | user=message.effective_user, 333 | chat=message.effective_chat, 334 | context=context, 335 | callback_query=message.callback_query) 336 | print(message, context, msg) 337 | 338 | await self.enqueue(msg) 339 | 340 | return wrapper 341 | 342 | async def initialize(self): 343 | self.app.add_handler( 344 | ConversationHandler( 345 | entry_points=[ 346 | CallbackQueryHandler(self.handler_adapter('message_handler')), 347 | MessageHandler(filters.ALL, 348 | self.handler_adapter('message_handler')), 349 | ], 350 | states={}, 351 | fallbacks=[], 352 | )) 353 | 354 | bot = Bot(self.apikey) 355 | # sharing a bot between threads blows things up 356 | await bot.initialize() 357 | 358 | if isinstance(self.senders[0], CommandHandler): 359 | await bot.set_my_commands( 360 | commands=[ 361 | BotCommand( 362 | c.command, 363 | c.description, 364 | ) for c in self.senders[0].commands 365 | ]) 366 | 367 | def start(self): 368 | self.app.run_polling() 369 | 370 | async def _parse_message(self, message, user, chat, context, callback_query) -> Message: 371 | meta = {} 372 | 373 | # text 374 | if message is not None: 375 | txt = message.text or message.caption 376 | else: 377 | txt = None 378 | if callback_query is not None: 379 | txt = callback_query.data 380 | meta = parse_callback_string(txt) 381 | if context.args is not None and len(context.args) > 0: 382 | txt = ' '.join(context.args) 383 | 384 | audio = None 385 | image = None 386 | reply = None 387 | voice = None 388 | video = None 389 | if message.reply_to_message is not None: 390 | m = message.reply_to_message 391 | reply = await self._parse_message( 392 | message=m, 393 | user=m.from_user, 394 | chat=m.chat, 395 | context=context, 396 | callback_query=None) 397 | 398 | # if command is in reply to a photo, use that photo 399 | if message.reply_to_message.photo is not None and len( 400 | message.reply_to_message.photo) > 0: # a photo 401 | image = message.reply_to_message.photo[-1].file_id 402 | elif message.reply_to_message.document is not None and message.reply_to_message.document.mime_type.startswith( 403 | 'image'): 404 | image = message.reply_to_message.document.file_id # a document 405 | 406 | if message.reply_to_message.audio is not None: 407 | audio = message.reply_to_message.audio.file_id 408 | 409 | if message.reply_to_message.video is not None: 410 | video = message.reply_to_message.video.file_id 411 | elif message.photo is not None and len( 412 | message.photo) > 0: # uploaded a new image 413 | image = message.photo[-1].file_id 414 | elif message.document is not None and message.document.mime_type.startswith( 415 | 'image'): 416 | image = message.document.file_id 417 | 418 | if message.audio is not None: 419 | audio = message.audio.file_id 420 | 421 | if message.video is not None: 422 | video = message.video.file_id 423 | 424 | if message.voice is not None: 425 | voice = message.voice.file_id 426 | 427 | phone = message.contact.phone_number if message.contact and message.contact.user_id == user.id else None 428 | return Message( 429 | user=User( 430 | username=user.username, 431 | phone=phone, 432 | full_name=f'{user.first_name} {user.last_name}' if user.first_name and user.last_name else None, 433 | language=user.language_code, 434 | ), 435 | message_id=message.id, 436 | user_id=user.id, 437 | chat_id=chat.id, 438 | reply_to_message_id=message.reply_to_message.id if message.reply_to_message is not None else None, 439 | reply_to_message=reply, 440 | text=txt or '', 441 | image=image, 442 | audio=audio, 443 | voice=voice, 444 | video=video, 445 | metadata=meta, 446 | is_forward=message.forward_from is not None, 447 | ) 448 | -------------------------------------------------------------------------------- /cliobot/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pydantic import BaseModel, ValidationError 4 | 5 | 6 | class BasePrompt(BaseModel): 7 | command: str 8 | prompt: Optional[str] 9 | model: str = None 10 | 11 | 12 | async def send_error_message_image(messaging_service, text, message): 13 | # try: 14 | # await messaging_service.delete_message( 15 | # chat_id=message.chat_id, 16 | # message_id=message.message_id, 17 | # ) 18 | # except Exception as e: 19 | # print(e) 20 | await send_error_message(messaging_service, text, message.chat_id) 21 | 22 | 23 | async def send_error_message(messaging_service, text, chat_id): 24 | await messaging_service.send_message( 25 | chat_id=chat_id, 26 | text=f"🚨 {text}", 27 | ) 28 | 29 | 30 | def to_params(message, session) -> dict: 31 | context = session.context 32 | preferences = session.preferences 33 | 34 | # Split the input string into tokens 35 | tokens = message.text.split() 36 | 37 | if message.audio: 38 | context.set('audio', message.audio) 39 | 40 | if message.voice: 41 | context.set('voice', message.voice) 42 | 43 | if message.video: 44 | context.set('video', message.video) 45 | 46 | if message.image: 47 | context.set('image', message.image) 48 | 49 | command = None 50 | if len(tokens) > 0: 51 | # Extract the command 52 | command = tokens[0] 53 | if command.startswith('/'): 54 | command = command[1:] 55 | 56 | prompt = None 57 | if len(tokens) > 1: 58 | # Extract the prompt (the part after the command) 59 | prompt = ' '.join(tokens[1:]) 60 | if ' --' in prompt: 61 | prompt = prompt[:prompt.index(' --')] 62 | 63 | # Initialize the Pydantic model 64 | params = { 65 | 'command': command, 66 | } 67 | 68 | if prompt is not None: 69 | params['prompt'] = prompt 70 | 71 | for k, v in context.items(): 72 | params[k] = v 73 | 74 | # Iterate through the tokens to find key-value pairs (e.g., --bla abc) 75 | for i in range(1, len(tokens), 1): 76 | value = None 77 | if i + 1 < len(tokens) and tokens[i].startswith('--'): 78 | key = tokens[i][2:] 79 | 80 | while i + 1 < len(tokens) and not tokens[i + 1].startswith('--'): 81 | if value: 82 | value += ' ' 83 | else: 84 | value = '' 85 | value += tokens[i + 1] 86 | i += 1 87 | 88 | params[key] = value 89 | 90 | for k, v in preferences.items(): 91 | if k not in params: 92 | params[k] = v 93 | 94 | return params 95 | 96 | 97 | async def notify_errors(exc, messaging_service, chat_id, reply_to_message_id=None): 98 | if isinstance(exc, ValidationError): 99 | errors = [] 100 | for e in exc.errors(): 101 | errors.append(f'{e["loc"][0]}: {e["msg"]}') 102 | await messaging_service.send_message( 103 | text='Whoops!\n' + '\n'.join(errors), 104 | chat_id=chat_id, 105 | reply_to_message_id=reply_to_message_id, 106 | ) 107 | else: 108 | raise exc 109 | 110 | 111 | class BaseCommand: 112 | def __init__(self, command, name, description, examples, reply_only=False, prompt_class=BasePrompt): 113 | self.command = command 114 | self.name = name 115 | self.description = description 116 | self.examples = examples 117 | self.reply_only = reply_only 118 | self.prompt_class = prompt_class 119 | 120 | async def process(self, message, context, bot) -> bool: 121 | """ 122 | parses message and returns the right model to handle it, or None if the message is not a valid command 123 | """ 124 | params = to_params(message, context) 125 | if self.prompt_class: 126 | try: 127 | parsed = self.prompt_class(**params) 128 | except ValidationError as e: 129 | await notify_errors(e, bot.messaging_service, message.chat_id, message.message_id) 130 | return False 131 | else: 132 | parsed = params 133 | 134 | try: 135 | return await self.run(parsed, message, context, bot) 136 | except Exception as e: 137 | await notify_errors(e, bot.messaging_service, message.chat_id, message.message_id) 138 | return False 139 | 140 | async def run(self, parsed, message, context, bot) -> bool: 141 | """ 142 | execute the command and return True if the command was completely handled, or False if the command was either 143 | not handled or needs more data (eg. a file upload is pending). 144 | 145 | After a command is handled, the current context is cleared. 146 | """ 147 | raise NotImplementedError() 148 | 149 | 150 | class ImageUrl(BaseModel): 151 | url: str 152 | prompt: str = None 153 | 154 | 155 | class GenerationResults(BaseModel): 156 | texts: List[str] = None 157 | images: List[ImageUrl] = None 158 | 159 | 160 | # for practicality reasons... 161 | MODEL_ALIASES = { 162 | 'dalle3': 'dall-e-3', 163 | 'whisper1': 'whisper-1', 164 | 'gpt4': 'gpt-4', 165 | 'gpt4v': 'gpt-4-vision-preview', 166 | } 167 | 168 | 169 | class Model: 170 | """ 171 | A model is a class that contains a prompt class and a generate function that takes a prompt and returns a result 172 | """ 173 | 174 | def __init__(self, prompt_class): 175 | self.prompt_class = prompt_class 176 | 177 | async def generate(self, parsed) -> GenerationResults: 178 | raise NotImplementedError() 179 | 180 | 181 | class ModelBackedCommand(BaseCommand): 182 | def __init__(self, command, name, description, examples, models, default_model=None, reply_only=False): 183 | super().__init__(command, name, description, examples, reply_only) 184 | self.models = models 185 | self.default_model = default_model 186 | 187 | async def run_model(self, parsed, model, message, session, bot) -> bool: 188 | """ 189 | execute the command and return True if the command was completely handled, or False if the command was either 190 | not handled or needs more data (eg. a file upload is pending). 191 | 192 | After a command is handled, the current context is cleared. 193 | """ 194 | raise NotImplementedError() 195 | 196 | async def process(self, message, session, bot) -> bool: 197 | """ 198 | parses message and returns the right model to handle it, or None if the message is not a valid command 199 | """ 200 | 201 | params = to_params(message, session) 202 | model_name = params.get('model', self.default_model) 203 | 204 | if model_name is None: 205 | model = next(iter(self.models.values())) 206 | else: 207 | model = self.models.get(model_name, None) 208 | if model is None and model_name in MODEL_ALIASES: 209 | model = self.models.get(MODEL_ALIASES[model_name], None) 210 | 211 | if model is None and self.default_model: 212 | model = self.models.get(self.default_model, None) 213 | 214 | if model is None: 215 | await send_error_message(bot.messaging_service, 216 | f"Model {model_name} not found", 217 | message.chat_id) 218 | return False 219 | 220 | if model.prompt_class: 221 | try: 222 | parsed = model.prompt_class(**params) 223 | if parsed.model is None: 224 | parsed.model = params.get('model', self.default_model) 225 | except ValidationError as e: 226 | await notify_errors(e, bot.messaging_service, message.chat_id, message.message_id) 227 | return False 228 | else: 229 | parsed = params 230 | 231 | try: 232 | return await self.run_model(parsed, model, message, session, bot) 233 | except Exception as e: 234 | await notify_errors(e, bot.messaging_service, message.chat_id, message.message_id) 235 | return False 236 | -------------------------------------------------------------------------------- /cliobot/commands/audio.py: -------------------------------------------------------------------------------- 1 | from cliobot.commands import ModelBackedCommand 2 | from cliobot.db.utils import cached_get_file 3 | 4 | 5 | class Transcribe(ModelBackedCommand): 6 | def __init__(self, models, default_model): 7 | super().__init__( 8 | command='transcribe', 9 | name="transcribe", 10 | description="Transcribes an audio file", 11 | examples=[ 12 | "/transcribe --audio ", 13 | "upload or forward an audio file with /transcribe" 14 | ], 15 | models=models, 16 | default_model=default_model, 17 | ) 18 | 19 | async def run_model(self, parsed, model, message, session, bot): 20 | msg = await bot.messaging_service.send_message( 21 | text='Transcribing...', 22 | chat_id=message.chat_id, 23 | reply_to_message_id=message.message_id, 24 | ) 25 | 26 | parsed.audio = await cached_get_file( 27 | session=session, 28 | file_id=parsed.audio, 29 | bot=bot, 30 | ) 31 | 32 | res = await model.generate(parsed) 33 | await bot.messaging_service.delete_message( 34 | message_id=msg.message_id, 35 | chat_id=message.chat_id, 36 | ) 37 | for r in res.texts: 38 | await bot.messaging_service.send_message( 39 | text=r, 40 | chat_id=message.chat_id, 41 | reply_to_message_id=message.message_id, 42 | ) 43 | -------------------------------------------------------------------------------- /cliobot/commands/help.py: -------------------------------------------------------------------------------- 1 | from i18n import t 2 | 3 | from cliobot.commands import BaseCommand 4 | from cliobot.utils import locale 5 | 6 | 7 | def help_commands_str(context, commands): 8 | return '\n'.join([f'/{k.command} - {k.description}' for k in commands]) 9 | 10 | 11 | class Help(BaseCommand): 12 | 13 | def __init__(self, commands): 14 | super().__init__( 15 | command='help', 16 | name="help", 17 | description="Shows this help message", 18 | examples=[ 19 | "/help", 20 | ], 21 | ) 22 | self.commands = commands 23 | 24 | async def process(self, message, context, bot): 25 | return await bot.messaging_service.send_message( 26 | reply_to_message_id=message.message_id, 27 | chat_id=message.chat_id, 28 | text=t('instructions.help', 29 | locale=locale(context), 30 | commands=help_commands_str(context, self.commands)), 31 | reply_buttons=[ 32 | [ 33 | { 34 | 'text': t('buttons.read_docs', locale=locale(context)), 35 | 'url': f"https://github.com/herval/cliobot", 36 | 'inline_mode': False, 37 | } 38 | ] 39 | ] 40 | ) 41 | -------------------------------------------------------------------------------- /cliobot/commands/images.py: -------------------------------------------------------------------------------- 1 | from cliobot.commands import send_error_message_image, ModelBackedCommand 2 | from cliobot.db.utils import upload_asset, cached_get_file 3 | from cliobot.utils import abs_path 4 | 5 | 6 | class TextToImage(ModelBackedCommand): 7 | def __init__(self, models, default_model): 8 | super().__init__( 9 | command='image', 10 | name="image", 11 | description="Create an image from a text prompt", 12 | examples=[ 13 | "/image a hamster astronaut floating in space", 14 | ], 15 | models=models, 16 | default_model=default_model, 17 | ) 18 | 19 | async def run_model(self, parsed, model, message, session, bot): 20 | msg = await bot.messaging_service.send_media( 21 | text="Generating image, please wait...", 22 | chat_id=message.chat_id, 23 | media={ 24 | 'image': abs_path('working.jpg'), 25 | }, 26 | reply_to_message_id=message.message_id, 27 | ) 28 | 29 | try: 30 | res = await model.generate(parsed) 31 | images = res.images 32 | for r in images: 33 | upload_asset( 34 | session=session, 35 | local_path=r.url, 36 | db=bot.db, 37 | storage=bot.storage, 38 | folder='outputs', 39 | ) 40 | 41 | if len(images) == 1: 42 | await bot.messaging_service.edit_message_media( 43 | chat_id=msg.chat_id, 44 | message_id=msg.id, 45 | media={ 46 | 'image': images[0].url 47 | }, 48 | text=images[0].prompt, 49 | ) 50 | elif len(images) > 1: 51 | for r in images: 52 | await bot.messaging_service.send_media( 53 | chat_id=msg.chat_id, 54 | media={ 55 | 'image': r.url 56 | }, 57 | text=r.prompt, 58 | ) 59 | await bot.messaging_service.delete_message( 60 | message_id=msg.id, 61 | chat_id=msg.chat_id, 62 | ) 63 | except Exception as e: 64 | await send_error_message_image(bot.messaging_service, e.__str__(), message) 65 | 66 | 67 | class DescribeImage(ModelBackedCommand): 68 | def __init__(self, models, default_model): 69 | super().__init__( 70 | command='describe', 71 | name="describe_image", 72 | description="Describes an image", 73 | examples=[ 74 | "/describe", 75 | ], 76 | models=models, 77 | default_model=default_model, 78 | ) 79 | 80 | async def run_model(self, parsed, model, message, session, bot): 81 | if isinstance(parsed, dict): # unparsed prompt... 82 | image_id = parsed.get('image', None) 83 | else: 84 | image_id = parsed.image 85 | 86 | image = await cached_get_file( 87 | session=session, 88 | file_id=image_id, 89 | bot=bot, 90 | ) 91 | if isinstance(parsed, dict): 92 | parsed['image'] = image 93 | else: 94 | parsed.image = image 95 | 96 | res = await model.generate(parsed) 97 | for r in res.texts: 98 | await bot.messaging_service.send_message( 99 | text=r, 100 | chat_id=message.chat_id, 101 | reply_to_message_id=message.message_id, 102 | ) 103 | -------------------------------------------------------------------------------- /cliobot/commands/session.py: -------------------------------------------------------------------------------- 1 | from i18n import t 2 | 3 | from cliobot.commands import BaseCommand 4 | from cliobot.utils import locale 5 | 6 | 7 | # context manipulation commands 8 | 9 | class ListModels(BaseCommand): 10 | def __init__(self, txt2img, transcribe, describe, ask): 11 | super().__init__( 12 | command='models', 13 | name="list_models", 14 | description="Lists all available models", 15 | examples=[ 16 | "/models", 17 | ], 18 | prompt_class=None, 19 | ) 20 | self.models = { 21 | 'Text to Image': txt2img, 22 | 'Transcribe': transcribe, 23 | 'Describe': describe, 24 | 'Ask': ask, 25 | } 26 | 27 | async def run(self, parsed, message, session, bot): 28 | models = [] 29 | 30 | for k, v in self.models.items(): 31 | if len(v) > 0: 32 | models.append(k + ':') # TODO add provider to desc (webui, openai, etc) 33 | models.extend([f'- {k}' for k, v in v.items()]) 34 | models.append('\n') 35 | 36 | return await bot.messaging_service.send_message( 37 | text='Available models:\n' + '\n'.join(models), 38 | chat_id=message.chat_id, 39 | ) 40 | 41 | class ClearContext(BaseCommand): 42 | def __init__(self): 43 | super().__init__( 44 | command='clear', 45 | name="clear_context", 46 | description="Clears the context of the current chat", 47 | examples=[ 48 | "/clear", 49 | ], 50 | prompt_class=None, 51 | ) 52 | 53 | async def run(self, parsed, message, session, bot): 54 | session.clear() 55 | return await bot.messaging_service.send_message( 56 | text=t('result.clear_context', locale=locale(session)), 57 | chat_id=message.chat_id, 58 | ) 59 | 60 | 61 | class PrintContext(BaseCommand): 62 | def __init__(self): 63 | super().__init__( 64 | command='context', 65 | name="print_context", 66 | description="Prints the context of the current chat", 67 | examples=[ 68 | "/print", 69 | ], 70 | prompt_class=None, 71 | ) 72 | 73 | async def run(self, parsed, message, session, bot): 74 | ctx = '\n'.join([f'- {k}: {v}' for k, v in session.context.items() if v is not None]) 75 | if len(ctx) == 0: 76 | msg = 'Nothing in context' 77 | else: 78 | msg = t('results.current_context', context=ctx, locale=locale(session)) 79 | 80 | return await bot.messaging_service.send_message( 81 | text=msg, 82 | chat_id=message.chat_id, 83 | ) 84 | 85 | 86 | class SetPreference(BaseCommand): 87 | def __init__(self): 88 | super().__init__( 89 | command='set', 90 | name="set_preference", 91 | description="Sets a preference for the current chat. Preferences have the same name as the parameters you want to hardcode (eg /set model dalle3 will always pick the dalle3 model by default, unless you override it on your prompt).", 92 | examples=[ 93 | "/set preference_name preference_value", 94 | ], 95 | ) 96 | 97 | async def run(self, parsed, message, session, bot) -> bool: 98 | attr, val = parsed.prompt.split(' ', 1) 99 | 100 | session.set_preference(attr, val) 101 | 102 | return await bot.messaging_service.send_message( 103 | text=t('results.preference_set', attr=attr, value=val, locale=locale(session)), 104 | chat_id=message.chat_id, 105 | ) 106 | 107 | 108 | class ListPreferences(BaseCommand): 109 | def __init__(self): 110 | super().__init__( 111 | command='preferences', 112 | name="list_preferences", 113 | description="Lists all preferences for the current chat", 114 | examples=[ 115 | "/list", 116 | ], 117 | prompt_class=None, 118 | ) 119 | 120 | async def run(self, parsed, message, session, bot) -> bool: 121 | prefs = '\n'.join([f'- {k}: {v}' for k, v in session.preferences.items()]) 122 | if len(prefs) == 0: 123 | msg = 'No preferences set' 124 | else: 125 | msg = t('results.current_preferences', preferences=prefs, locale=locale(session)) 126 | 127 | return await bot.messaging_service.send_message( 128 | text=msg, 129 | chat_id=message.chat_id, 130 | ) -------------------------------------------------------------------------------- /cliobot/commands/text.py: -------------------------------------------------------------------------------- 1 | from cliobot.commands import ModelBackedCommand 2 | 3 | 4 | class Ask(ModelBackedCommand): 5 | def __init__(self, models, default_model): 6 | super().__init__( 7 | command='ask', 8 | name="ask", 9 | description="Ask a question using a language model", 10 | examples=[ 11 | "/ask what's the meaning of life?", 12 | ], 13 | models=models, 14 | default_model=default_model, 15 | ) 16 | 17 | async def run_model(self, parsed, model, message, session, bot) -> bool: 18 | res = await model.generate(parsed) 19 | 20 | for r in res.texts: 21 | await bot.messaging_service.send_message( 22 | text=r, 23 | chat_id=message.chat_id, 24 | reply_to_message_id=message.message_id, 25 | ) 26 | 27 | return True 28 | -------------------------------------------------------------------------------- /cliobot/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import queue 3 | import re 4 | 5 | import i18n 6 | import yaml 7 | from dotenv import load_dotenv 8 | 9 | from cliobot.bots import BaseBot 10 | from cliobot.bots.command_handler import CommandHandler 11 | from cliobot.cache import InMemoryCache 12 | from cliobot.commands.audio import Transcribe 13 | from cliobot.commands.help import Help 14 | from cliobot.commands.images import TextToImage, DescribeImage 15 | from cliobot.commands.session import SetPreference, ListPreferences, ClearContext, PrintContext, ListModels 16 | from cliobot.commands.text import Ask 17 | from cliobot.errors import BaseErrorHandler 18 | from cliobot.metrics import BaseMetrics 19 | from cliobot.translator import NullTranslator 20 | from cliobot.utils import abs_path 21 | 22 | 23 | def substitute_env_vars(data): 24 | if isinstance(data, dict): 25 | for key, value in data.items(): 26 | data[key] = substitute_env_vars(value) 27 | 28 | if isinstance(data, list): 29 | for i, value in enumerate(data): 30 | data[i] = substitute_env_vars(value) 31 | 32 | elif isinstance(data, str): 33 | # Find all occurrences of $ENV_VAR 34 | matches = re.findall(r'\$([A-Za-z_][A-Za-z0-9_]*)', data) 35 | 36 | # Replace each occurrence with the corresponding environment variable value 37 | for match in matches: 38 | env_var_value = os.environ.get(match, f"${match}") 39 | data = data.replace(f"${match}", env_var_value) 40 | 41 | return data 42 | 43 | 44 | def load_config(filename): 45 | load_dotenv( 46 | dotenv_path=abs_path('.env'), 47 | verbose=True, 48 | ) 49 | 50 | with open(abs_path(filename)) as f: 51 | config = yaml.safe_load(f) 52 | 53 | return substitute_env_vars(config) 54 | 55 | 56 | 57 | class ConfigLoader: 58 | # prepares a Cliobot based on a config.yml file 59 | 60 | def __init__(self, configPath, tmpFolder=None): 61 | config = load_config(configPath) 62 | self.config = config 63 | self.tmp_folder = tmpFolder or abs_path("tmp") 64 | 65 | def build(self) -> BaseBot: 66 | os.makedirs(self.tmp_folder, exist_ok=True) 67 | 68 | storage_driver = self.config['storage']['driver'] 69 | if storage_driver == 'local': 70 | from cliobot.storage import LocalStorage 71 | storage = LocalStorage(self.config['storage']['folder']) 72 | elif storage_driver == 's3': 73 | from cliobot.storage.s3 import S3Storage 74 | storage = S3Storage( 75 | access_key=self.config['storage']['s3']['access_key'], 76 | secret=self.config['storage']['s3']['secret_key'], 77 | bucket=self.config['storage']['s3']['bucket'], 78 | region=self.config['storage']['s3']['region'], 79 | ) 80 | else: 81 | raise Exception('unsupported storage driver:', storage_driver) 82 | 83 | db_driver = self.config['db']['driver'] 84 | if db_driver == 'sqlite3': 85 | from cliobot.db.sqlite import SqliteDb 86 | db = SqliteDb( 87 | file=self.config['db'].get('file', abs_path('clibot.db')) 88 | ) 89 | elif db_driver == 'inmemory': 90 | from cliobot.db import InMemoryDb 91 | db = InMemoryDb() 92 | else: 93 | raise Exception('unsupported db driver:', db_driver) 94 | 95 | error_handler = BaseErrorHandler() 96 | metrics = BaseMetrics(error_handler) 97 | 98 | commands = [ 99 | ClearContext(), 100 | PrintContext(), 101 | SetPreference(), 102 | ListPreferences(), 103 | ] 104 | 105 | txt2img_models: dict = {} 106 | transcribe_models: dict = {} 107 | describe_models: dict = {} 108 | ask_models: dict = {} 109 | 110 | if self.config.get('replicate', None): 111 | print("**** Using Replicate API ****") 112 | from cliobot.replicate.client import ReplicateEndpoint 113 | 114 | models = self.config['replicate']['endpoints'] 115 | for v in models: 116 | cli = ReplicateEndpoint( 117 | v['kind'], 118 | self.config['replicate']['api_token'], 119 | v['version'], 120 | v['params'], 121 | ) 122 | if v['kind'] == 'describe': 123 | describe_models[v['model']] = cli 124 | elif v['kind'] == 'transcribe': 125 | transcribe_models[v['model']] = cli 126 | elif v['kind'] == 'image': 127 | txt2img_models[v['model']] = cli 128 | elif v['kind'] == 'ask': 129 | ask_models[v['model']] = cli 130 | 131 | if self.config.get('ollama', None): 132 | from cliobot.ollama.client import OllamaText 133 | 134 | print("**** Using Ollama API ****") 135 | endpoint = self.config['ollama']['endpoint'] 136 | for v in self.config['ollama']['models']: 137 | m = OllamaText( 138 | endpoint=endpoint, 139 | ) 140 | h = None 141 | 142 | if v['kind'] == 'ask': 143 | h = ask_models 144 | elif v['kind'] == 'describe': 145 | h = describe_models 146 | elif v['kind'] == 'transcribe': 147 | h = transcribe_models 148 | elif v['kind'] == 'image': 149 | h = txt2img_models 150 | 151 | if h: 152 | h[v['model']] = m 153 | 154 | if self.config.get('openai', None): 155 | print("**** Using OpenAI API ****") 156 | from cliobot.openai.client import OpenAIClient, GPTPrompt, Whisper1, Dalle3, Gpt4Vision 157 | 158 | openai_client = OpenAIClient( 159 | endpoints=self.config['openai']['endpoints'], 160 | metrics=metrics, 161 | ) 162 | 163 | models = self.config['openai']['models'] 164 | if 'dall-e-3' in models: 165 | txt2img_models['dall-e-3'] = Dalle3(openai_client) 166 | 167 | if 'whisper-1' in models: 168 | transcribe_models['whisper-1'] = Whisper1(openai_client) 169 | 170 | if 'gpt-4' in models: 171 | ask_models['gpt-4'] = GPTPrompt(openai_client) 172 | 173 | if 'gpt-4-vision-preview' in models: 174 | describe_models['gpt-4-vision-preview'] = Gpt4Vision(openai_client) 175 | 176 | if self.config.get('webui', None): 177 | from cliobot.webui.client import WebuiClient, Txt2img 178 | 179 | print("**** Using Auto1111 WebUI API ****") 180 | client = WebuiClient( 181 | self.config['webui']['endpoint'], 182 | self.config['webui'].get('auth', None), 183 | temp_dir=self.tmp_folder, 184 | ) 185 | 186 | # get all models on boot 187 | ms = client.get_models() 188 | for m in ms: 189 | print("MODEL:", m['model_name'], 'loaded from webui') 190 | txt2img_models[m['model_name']] = Txt2img(m['model_name'], client) 191 | 192 | 193 | defaults = self.config.get('default_models', {}) 194 | 195 | if len(txt2img_models) > 0: 196 | commands.append(TextToImage(txt2img_models, defaults.get('image', None))) 197 | 198 | if len(transcribe_models) > 0: 199 | commands.append(Transcribe(transcribe_models, defaults.get('transcribe', None))) 200 | 201 | if len(describe_models) > 0: 202 | commands.append(DescribeImage(describe_models, defaults.get('describe', None))) 203 | 204 | if len(ask_models) > 0: 205 | commands.append(Ask(ask_models, defaults.get('ask', None))) 206 | 207 | commands.append( 208 | ListModels( 209 | txt2img_models, 210 | transcribe_models, 211 | describe_models, 212 | ask_models, 213 | )) 214 | commands.append(Help(commands)) 215 | 216 | self.internal_queue = queue.Queue() 217 | 218 | i18n.load_path.append(abs_path('i18n')) 219 | i18n.set('filename_format', '{locale}.{format}') 220 | 221 | cache = InMemoryCache() 222 | db = db 223 | 224 | translator = NullTranslator() 225 | 226 | if self.config['mode'] == 'command': 227 | handler = lambda: CommandHandler( 228 | fallback_commands=self.config['fallback_commands'], 229 | commands=commands, 230 | ) 231 | else: 232 | raise Exception('unsupported mode:', self.config['mode']) 233 | 234 | plat = self.config['bot']['platform'] 235 | if plat == 'telegram': 236 | from cliobot.bots.telegram_bot import TelegramBot 237 | 238 | apikey = self.config['bot']['token'] 239 | 240 | return TelegramBot( 241 | internal_queue=self.internal_queue, 242 | db=db, 243 | storage=storage, 244 | translator=translator, 245 | apikey=apikey, 246 | bot_language='en', 247 | cache=cache, 248 | metrics=metrics, 249 | handler_fn=handler, 250 | ) 251 | else: 252 | raise Exception('unsupported platform:', plat) 253 | -------------------------------------------------------------------------------- /cliobot/db/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class Database: 5 | 6 | def set_chat_context(self, user_id, context, preferences): 7 | raise NotImplementedError() 8 | 9 | def create_or_get_chat_session(self, user_id): 10 | raise NotImplementedError() 11 | 12 | def save_message(self, 13 | user_id, chat_id, text, external_id, 14 | image=None, 15 | audio=None, 16 | voice=None, 17 | video=None, 18 | is_forward=False, context=None): 19 | raise NotImplementedError() 20 | 21 | def get_asset(self, external_id, user_id, chat_id) -> Optional[dict]: 22 | raise NotImplementedError() 23 | 24 | def save_asset(self, external_id, user_id, chat_id, storage_path) -> dict: 25 | raise NotImplementedError() 26 | 27 | 28 | -------------------------------------------------------------------------------- /cliobot/db/inmemory.py: -------------------------------------------------------------------------------- 1 | from cliobot.bots import Session 2 | from cliobot.db import Database 3 | 4 | 5 | class InMemoryDb(Database): 6 | def __init__(self): 7 | print("**** Keeping state in memory only ****") 8 | 9 | self.jobs = {} 10 | self.profiles = {} 11 | self.messages = {} 12 | self.chats = {} 13 | 14 | def update_job(self, job_id, fields): 15 | pass 16 | 17 | def get_model(self, slug, kind): 18 | return None 19 | 20 | def set_chat_context(self, user_id, context, preferences): 21 | self.chats[user_id] = context 22 | 23 | def create_or_get_chat_session(self, user_id): 24 | return Session( 25 | user_id=user_id, 26 | chat_id='1', 27 | context={}, 28 | preferences={}, 29 | ) 30 | 31 | def save_message(self, 32 | user_id, 33 | chat_id, 34 | text, 35 | external_id, 36 | image=None, 37 | audio=None, 38 | voice=None, 39 | video=None, 40 | is_forward=False, 41 | context=None): 42 | pass 43 | -------------------------------------------------------------------------------- /cliobot/db/sqlite.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sqlite3 3 | from typing import Optional 4 | 5 | from cliobot.db import Database 6 | from cliobot.utils import abs_path 7 | 8 | 9 | class SqliteDb(Database): 10 | 11 | def __init__(self, file): 12 | self.conn = sqlite3.connect(file, check_same_thread=False) 13 | self.conn.row_factory = dict_factory 14 | self._create_tables() 15 | 16 | def create_or_get_chat_session(self, external_user_id): 17 | cur = self.conn.cursor() # TODO fix prefs 18 | cur.execute( 19 | "SELECT * FROM chat_sessions WHERE external_user_id = ?", 20 | (external_user_id) 21 | ) 22 | res = cur.fetchone() 23 | if res is None: 24 | cur.execute( 25 | "INSERT INTO chat_sessions (external_user_id) VALUES (?, ?)", 26 | (external_user_id) 27 | ) 28 | self.conn.commit() 29 | return self.create_or_get_chat_session(external_user_id) 30 | else: 31 | res['context'] = json.loads(res['context']) 32 | res['preferences'] = json.loads(res['preferences']) 33 | return res 34 | 35 | def save_message(self, 36 | user_id, 37 | chat_id, 38 | text, 39 | external_id, 40 | image=None, 41 | audio=None, 42 | voice=None, 43 | video=None, 44 | is_forward=False): 45 | cur = self.conn.cursor() 46 | cur.execute( 47 | "INSERT INTO chat_messages(external_id, external_user_id, external_chat_id, text, external_image_id, external_audio_id, external_voice_id, external_video_id, is_forward) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 48 | (external_id, user_id, chat_id, text, image, audio, voice, video, is_forward) 49 | ) 50 | self.conn.commit() 51 | 52 | def set_chat_context(self, user_id, context, preferences): 53 | cur = self.conn.cursor() 54 | cur.execute( 55 | "UPDATE chat_sessions SET context = ?, preferences = ? WHERE external_user_id = ?", 56 | (json.dumps(context), json.dumps(preferences), user_id) 57 | ) 58 | self.conn.commit() 59 | 60 | def get_asset(self, external_id, user_id, chat_id) -> Optional[dict]: 61 | cur = self.conn.cursor() 62 | cur.execute( 63 | "SELECT * FROM assets WHERE external_id = ? AND external_user_id = ? AND external_chat_id = ?", 64 | (external_id, user_id, chat_id) 65 | ) 66 | res = cur.fetchone() 67 | if res is None: 68 | return None 69 | 70 | return res 71 | 72 | def save_asset(self, external_id, user_id, chat_id, storage_path) -> (dict, bool): 73 | cur = self.conn.cursor() 74 | cur.execute( 75 | "INSERT INTO assets(external_id, external_user_id, external_chat_id, storage_path) VALUES (?, ?, ?, ?)", 76 | (external_id, user_id, chat_id, storage_path) 77 | ) 78 | self.conn.commit() 79 | return self.get_asset(external_id, user_id, chat_id) 80 | 81 | def _create_tables(self): 82 | with open(abs_path('schema.sql'), 'r') as f: 83 | self.conn.executescript(f.read()) 84 | 85 | 86 | def dict_factory(cursor, row): 87 | fields = [column[0] for column in cursor.description] 88 | return {key: value for key, value in zip(fields, row)} 89 | -------------------------------------------------------------------------------- /cliobot/db/utils.py: -------------------------------------------------------------------------------- 1 | import mimetypes 2 | import os 3 | from pathlib import Path 4 | 5 | import requests 6 | 7 | from cliobot.utils import md5_hash, abs_path, base64_to_bytes 8 | 9 | 10 | def asset_filename(folder, user_id, filename): 11 | return f"{folder}/{user_id}/{filename}" 12 | 13 | 14 | async def cached_get_file(file_id, bot, session) -> Path: 15 | """ 16 | get a file_id and return the local path to the file 17 | if the file is already cached, return the cached file 18 | if the file_id is a url or a local path, download it to the cache folder and return the local path 19 | 20 | :param file_id: 21 | :param bot: 22 | :param session: 23 | :return: 24 | """ 25 | info = await bot.messaging_service.get_file_info(file_id) 26 | filepath = info['file_path'] 27 | 28 | af = abs_path(asset_filename('cache', session.user_id, hashed_filename(filepath))) 29 | if os.path.exists(af): 30 | return Path(af) 31 | 32 | os.makedirs(os.path.dirname(af), exist_ok=True) 33 | try: 34 | data = get_data(file_id) 35 | except FileNotFoundError as e: 36 | _, data = await bot.messaging_service.get_file(file_id) 37 | 38 | with open(af, 'wb') as f: 39 | f.write(data) 40 | 41 | return Path(af) 42 | 43 | 44 | def get_data(filepath): 45 | if filepath.startswith('data:'): 46 | return base64_to_bytes(filepath) 47 | elif os.path.exists(filepath): 48 | with open(filepath, 'rb') as f: 49 | data = f.read() 50 | return data 51 | elif filepath.startswith('http'): 52 | return requests.get(filepath).content 53 | else: 54 | with open(filepath, 'rb') as f: 55 | return f.read() 56 | 57 | 58 | def hashed_filename(local_path): 59 | if local_path.startswith('data:'): 60 | ext = local_path.split(';')[0].split('/')[-1] 61 | return md5_hash(local_path) + '.' + ext 62 | 63 | ext = local_path.split('.')[-1] 64 | if '?' in ext: 65 | ext = ext.split('?')[0] 66 | return md5_hash(local_path) + '.' + ext 67 | 68 | 69 | def upload_asset( 70 | session, 71 | local_path, 72 | db, 73 | storage, 74 | folder, 75 | file_id=None): 76 | data = get_data(local_path) 77 | 78 | storage_path = storage.save_data( 79 | data, 80 | asset_filename(folder, session.user_id, hashed_filename(local_path)), 81 | mimetype=mimetypes.guess_type(local_path)[0], 82 | ) 83 | return db.save_asset( 84 | external_id=file_id or md5_hash(local_path), 85 | user_id=session.user_id, 86 | chat_id=session.chat_id, 87 | storage_path=storage_path, 88 | ) 89 | -------------------------------------------------------------------------------- /cliobot/errors/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BaseErrorHandler: 4 | 5 | def capture_exception(self, exception): 6 | print(exception) 7 | 8 | def set_context(self, data): 9 | print(data) 10 | 11 | 12 | class MessageNoLongerExists(Exception): # should mark the conversation as deleted 13 | pass 14 | 15 | 16 | class MessageNotModifiable(Exception): # should mark the conversation as noop 17 | pass 18 | 19 | 20 | class UserBlocked(Exception): # user blocked the bot or we blocked the user 21 | pass 22 | 23 | 24 | class UnknownError(Exception): # should ignore 25 | pass 26 | 27 | 28 | class TransientFailure(Exception): # should ignore 29 | pass 30 | -------------------------------------------------------------------------------- /cliobot/errors/sentry.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sentry_sdk 3 | from sentry_sdk.integrations.aiohttp import AioHttpIntegration 4 | 5 | from cliobot.errors import BaseErrorHandler 6 | 7 | 8 | class SentryHandler(BaseErrorHandler): 9 | 10 | def __init__(self, sentry_dsn): 11 | sentry_sdk.init( 12 | sentry_dsn, 13 | integrations=[ 14 | AioHttpIntegration(), 15 | ], 16 | 17 | # Set traces_sample_rate to 1.0 to capture 100% 18 | # of transactions for performance monitoring. 19 | # We recommend adjusting this value in production. 20 | traces_sample_rate=1.0, 21 | environment=os.environ.get("ENV", 'development'), 22 | ) 23 | 24 | def capture_exception(self, exception): 25 | sentry_sdk.capture_exception(exception) 26 | 27 | def set_context(self, data): 28 | with sentry_sdk.configure_scope() as scope: 29 | scope.set_user(data) -------------------------------------------------------------------------------- /cliobot/metrics.py: -------------------------------------------------------------------------------- 1 | class BaseMetrics: 2 | def capture_exception(self, exception, user_id='anonymous'): 3 | print(f"EXCEPTION: {exception} / {user_id}") 4 | 5 | def send_event(self, event, user_id='anonymous', params=None): 6 | print(f"EVENT: {event} / {user_id} / {params}") 7 | 8 | -------------------------------------------------------------------------------- /cliobot/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | class BaseMetrics: 2 | def __init__(self, error_handler): 3 | self.error_handler = error_handler 4 | 5 | def capture_exception(self, exception, user_id='anonymous'): 6 | self.error_handler.capture_exception(exception) 7 | 8 | def send_event(self, event, user_id='anonymous', params=None): 9 | print(f"EVENT: {event} / {user_id} / {params}") 10 | -------------------------------------------------------------------------------- /cliobot/metrics/mixpanel_metrics.py: -------------------------------------------------------------------------------- 1 | from mixpanel import Mixpanel 2 | from mixpanel_async import AsyncBufferedConsumer 3 | from retry import retry 4 | 5 | from cliobot.metrics import BaseMetrics 6 | 7 | 8 | class MixpanelMetrics(BaseMetrics): 9 | def __init__(self, key, error_handler): 10 | self.mp = None 11 | self.error_handler = error_handler 12 | if key != '': 13 | self.mp = Mixpanel(key, consumer=AsyncBufferedConsumer()) 14 | 15 | def capture_exception(self, exception, user_id='anonymous'): 16 | self.error_handler.capture_exception(exception) 17 | self.send_event('error', user_id, { 18 | 'error': exception.__str__() 19 | }) 20 | 21 | @retry(tries=2, delay=1) 22 | def send_event(self, event, user_id='anonymous', params=None): 23 | if params is None: 24 | params = {} 25 | 26 | try: 27 | if self.mp: 28 | self.mp.track(user_id, event, { 29 | **params, 30 | }) 31 | else: 32 | print(f"EVENT: {event} / {user_id} / {params}") 33 | except Exception as e: 34 | self.error_handler.capture_exception(e) 35 | -------------------------------------------------------------------------------- /cliobot/ollama/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | 4 | import requests 5 | 6 | from cliobot.commands import Model, GenerationResults, BasePrompt 7 | from cliobot.utils import decode_image 8 | 9 | 10 | class OllamaPrompt(BasePrompt): 11 | prompt: Optional[str] = "what's in this image?" 12 | image: Optional[str] 13 | 14 | class OllamaText(Model): 15 | def __init__(self, endpoint): 16 | super().__init__( 17 | prompt_class=OllamaPrompt, 18 | ) 19 | self.endpoint = endpoint 20 | 21 | async def generate(self, parsed) -> GenerationResults: 22 | params = { 23 | 'model': parsed.model, 24 | 'prompt': parsed.prompt, 25 | } 26 | 27 | if parsed.image: 28 | params['images'] = [decode_image(parsed.image)] 29 | if params['images'][0].startswith('data:image'): 30 | params['images'][0] = params['images'][0].split(',')[1] 31 | 32 | r = requests.post(f'{self.endpoint}/api/generate', 33 | json=params, 34 | stream=True) 35 | r.raise_for_status() 36 | 37 | response = '' 38 | for line in r.iter_lines(): 39 | body = json.loads(line) 40 | response_part = body.get('response', '') 41 | print(response_part, end='', flush=True) 42 | response += response_part 43 | 44 | if 'error' in body: 45 | raise Exception(body['error']) 46 | 47 | return GenerationResults( 48 | texts=[response.strip()] 49 | ) 50 | -------------------------------------------------------------------------------- /cliobot/openai/client.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | import openai 5 | from pydantic import Field 6 | 7 | from cliobot.commands import BasePrompt, Model, GenerationResults, ImageUrl 8 | from cliobot.utils import image_to_base64, open_image, decode_image 9 | 10 | VALID_DALLE3_SIZES = ['1024x1792', '1024x1024', '1792x1024'] 11 | DALLE3_RATIOS = [float(x[0]) / float(x[1]) for x in 12 | [x.split('x') for x in VALID_DALLE3_SIZES] 13 | ] 14 | 15 | 16 | # A set of commands using OpenAI's APIs 17 | class TranscribePrompt(BasePrompt): 18 | audio: str 19 | model: str = 'whisper-1' 20 | prompt: Optional[str] = "" 21 | 22 | 23 | class Whisper1(Model): 24 | def __init__(self, openai_client): 25 | self.openai_client = openai_client 26 | super().__init__( 27 | TranscribePrompt 28 | ) 29 | 30 | async def generate(self, parsed) -> GenerationResults: 31 | txt = self.openai_client.transcribe(parsed.audio) 32 | return GenerationResults(texts=[txt]) 33 | 34 | 35 | class GPTPrompt(Model): 36 | def __init__(self, openai_client): 37 | super().__init__( 38 | prompt_class=BasePrompt, 39 | ) 40 | self.openai_client = openai_client 41 | 42 | async def generate(self, parsed): 43 | res = self.openai_client.ask( 44 | parsed.prompt 45 | ) 46 | return GenerationResults(texts=[res]) 47 | 48 | 49 | class Dalle3Prompt(BasePrompt): 50 | size: str = Field(default='1024x1024', 51 | examples=VALID_DALLE3_SIZES) # TODO adjust size? 52 | 53 | 54 | class Dalle3(Model): 55 | def __init__(self, openai_client): 56 | super().__init__( 57 | prompt_class=Dalle3Prompt, 58 | ) 59 | self.openai_client = openai_client 60 | 61 | async def generate(self, parsed) -> GenerationResults: 62 | res = self.openai_client.dalle3_txt2img( 63 | prompt=parsed.prompt, 64 | num=1, 65 | size=parsed.size, 66 | ) 67 | 68 | return GenerationResults( 69 | images=[ 70 | ImageUrl( 71 | url=img.url, 72 | prompt=img.revised_prompt, 73 | ) 74 | for img in res 75 | ] 76 | ) 77 | 78 | 79 | class DescribePrompt(BasePrompt): 80 | image: str 81 | prompt: str = "what's this?" 82 | 83 | 84 | class Gpt4Vision(Model): 85 | def __init__(self, openai_client): 86 | super().__init__( 87 | prompt_class=DescribePrompt, 88 | ) 89 | self.openai_client = openai_client 90 | 91 | async def generate(self, parsed) -> GenerationResults: 92 | res = self.openai_client.img2text( 93 | prompt=parsed.prompt, 94 | image_url=parsed.image, 95 | ) 96 | 97 | return GenerationResults(texts=[res]) 98 | 99 | 100 | def dalle_size(size): 101 | if not size in VALID_DALLE3_SIZES: 102 | # convert to the nearest ratio 103 | ratio = float(size.split('x')[0]) / float(size.split('x')[1]) 104 | for idx, s in enumerate(DALLE3_RATIOS): 105 | if ratio <= s: 106 | return VALID_DALLE3_SIZES[idx] 107 | 108 | return size 109 | 110 | 111 | class OpenAIClient: 112 | # OpenAI wrapper that supports multiple regions and a mix of azure + openai apis 113 | 114 | def __init__(self, endpoints, metrics): 115 | self.metrics = metrics 116 | 117 | self.v1_configs = [v for v in endpoints if v['api_type'] == 'open_ai'] 118 | self.v1_clients = [ 119 | openai.OpenAI( 120 | api_key=v['api_key'], 121 | base_url=v['base_url'], 122 | ) for v in self.v1_configs] 123 | 124 | self.azure_configs = [v for v in endpoints if v['api_type'] == 'azure'] 125 | self.azure_clients = [ 126 | openai.AzureOpenAI( 127 | api_key=v['api_key'], 128 | azure_endpoint=v['base_url'], 129 | api_version=v['api_version'], 130 | ) for v in self.azure_configs] 131 | 132 | def transcribe(self, audio_file): 133 | if isinstance(audio_file, str): 134 | audio_file = Path(audio_file) 135 | 136 | client, model = self._get_client('whisper-1') 137 | res = client.audio.transcriptions.create( 138 | file=audio_file, 139 | model=model, 140 | ) 141 | 142 | return res.text 143 | 144 | def img2text(self, prompt, image_url, max_tokens=300) -> str: 145 | image_url = decode_image(image_url) 146 | 147 | client, model = self._get_client('gpt-4-vision-preview') 148 | 149 | response = client.chat.completions.create( 150 | model=model, 151 | messages=[ 152 | { 153 | "role": "user", 154 | "content": [ 155 | {"type": "text", "text": prompt}, 156 | { 157 | "type": "image_url", 158 | "image_url": { 159 | "url": image_url, 160 | }, 161 | }, 162 | ], 163 | } 164 | ], 165 | max_tokens=max_tokens, 166 | ) 167 | 168 | return response.choices[0].message.content 169 | 170 | def dalle3_txt2img(self, prompt, num, size): 171 | client, model = self._get_client('dall-e-3') 172 | res = client.images.generate( 173 | model=model, 174 | n=1, 175 | quality='hd', 176 | size=dalle_size(size), 177 | prompt=prompt, 178 | ) 179 | return res.data 180 | 181 | def ask(self, prompt, model_version='gpt-4'): 182 | client, model = self._get_client(model_version) 183 | res = client.chat.completions.create( 184 | model=model, 185 | messages=[ 186 | { 187 | 'role': 'user', 188 | 'content': prompt, 189 | } 190 | ] 191 | ) 192 | return res.choices[0].message.content 193 | 194 | def _get_client(self, model_kind): 195 | if len(self.azure_clients) > 0: 196 | for i, v in enumerate(self.azure_configs): 197 | if v['kind'] == model_kind: 198 | return self.azure_clients[i], v['model'] 199 | 200 | if len(self.v1_clients) > 0: 201 | return self.v1_clients[0], model_kind 202 | 203 | raise Exception("No OpenAI client available!") 204 | -------------------------------------------------------------------------------- /cliobot/replicate/client.py: -------------------------------------------------------------------------------- 1 | from types import AsyncGeneratorType 2 | 3 | from replicate.client import Client 4 | 5 | from cliobot.commands import Model, BasePrompt, GenerationResults, ImageUrl 6 | from cliobot.utils import decode_image 7 | 8 | 9 | class ReplicatePrompt(BasePrompt): 10 | image: str 11 | prompt: str 12 | 13 | 14 | class ReplicateEndpoint(Model): 15 | def __init__(self, kind, api_key, version, params): 16 | super().__init__( 17 | prompt_class=None, 18 | ) 19 | self.kind = kind 20 | self.version = version 21 | self.api_key = api_key 22 | self.params = params 23 | self.client = Client(api_key) 24 | 25 | async def generate(self, params) -> GenerationResults: 26 | input_args = {} 27 | for k, v in self.params.items(): 28 | default = v.get('default', None) 29 | param_name = v.get('alias', k) 30 | value_map = v.get('value_map', None) 31 | 32 | if v['kind'] == 'image': 33 | input_args[k] = decode_image(params.get(param_name, default)) 34 | elif v['kind'] == 'bool': 35 | input_args[k] = params.get(param_name, default) == 'true' 36 | else: 37 | input_args[k] = params.get(param_name, default) 38 | if value_map: 39 | input_args[k] = value_map.get(input_args[k], input_args[k]) 40 | 41 | if input_args[k] is None: 42 | del input_args[k] 43 | 44 | res = await self.client.async_run( 45 | self.version, 46 | input_args 47 | ) 48 | 49 | if isinstance(res, AsyncGeneratorType): 50 | result = [] 51 | async for item in res: 52 | result.append(item) 53 | res = result 54 | 55 | txt = '' 56 | imgs = [] 57 | 58 | for r in res: 59 | if self.kind == 'image': 60 | imgs.append( 61 | ImageUrl( 62 | url=r, 63 | prompt=params['prompt'], 64 | ) 65 | ) 66 | else: 67 | txt += r 68 | 69 | if len(txt) > 0: 70 | txts = [txt] 71 | else: 72 | txts = [] 73 | 74 | return GenerationResults( 75 | texts=txts, 76 | images=imgs, 77 | ) 78 | 79 | 80 | 81 | # def swap_face(self, source, target): 82 | # output = self.client.run( 83 | # "yan-ops/face_swap:a7d6a0118f021279b8966473f302b1d982fd3920426ebd334e8f64d5caf84418", 84 | # input={ 85 | # "det_thresh": 0.1, 86 | # "request_id": "", 87 | # "source_image": open(source, 'rb'), 88 | # "target_image": open(target, 'rb') 89 | # } 90 | # ) 91 | # 92 | # if output['status'] == 'failed': 93 | # raise Exception(output['msg']) 94 | # 95 | # return output['image'] 96 | # 97 | # def colorize(self, image_path): 98 | # output = self.client.run( 99 | # "arielreplicate/deoldify_image:0da600fab0c45a66211339f1c16b71345d22f26ef5fea3dca1bb90bb5711e950", 100 | # input={ 101 | # "input_image": open(image_path, "rb"), 102 | # 'model_name': 'Stable', 103 | # "render_factor": 35, 104 | # } 105 | # ) 106 | # 107 | # result = "" 108 | # for item in output: 109 | # result += item 110 | # 111 | # return result 112 | # 113 | # def gfpgan(self, image_path): 114 | # output = self.client.run( 115 | # "tencentarc/gfpgan:9283608cc6b7be6b65a8e44983db012355fde4132009bf99d976b2f0896856a3", 116 | # input={ 117 | # "img": open(image_path, "rb"), 118 | # 'scale': 4, 119 | # "version": "v1.4" 120 | # } 121 | # ) 122 | # 123 | # result = "" 124 | # for item in output: 125 | # result += item 126 | # 127 | # return result 128 | -------------------------------------------------------------------------------- /cliobot/storage/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | 5 | class LocalStorage: 6 | '''save image to local storage''' 7 | 8 | def __init__(self, folder="./"): 9 | print(f"**** Local Storage at: {folder} ****") 10 | self.folder = folder 11 | os.makedirs(folder, exist_ok=True) 12 | 13 | def save_data(self, data, path, mimetype): 14 | if path.startswith('/'): 15 | path = path[1:] 16 | 17 | os.makedirs(os.path.join(self.folder, os.path.dirname(path)), exist_ok=True) 18 | with open(os.path.join(self.folder, path), 'wb') as f: 19 | f.write(data) 20 | 21 | return path 22 | 23 | def full_path(self, path): 24 | return os.path.join(self.base_path(), path) 25 | 26 | def base_path(self): 27 | return self.folder 28 | 29 | def exists(self, path): 30 | return os.path.exists(os.path.join(self.folder, path)) 31 | 32 | def get_data(self, path) -> Optional[bytes]: 33 | with open(os.path.join(self.folder, path), 'rb') as f: 34 | return f.read() 35 | 36 | def stream_data(self, path, localpath): 37 | remote = os.path.join(self.folder, path) 38 | if not os.path.exists(remote): 39 | raise FileNotFoundError(remote) 40 | 41 | with open(remote, 'rb') as f: 42 | with open(localpath, 'wb') as f2: 43 | f2.write(f.read()) -------------------------------------------------------------------------------- /cliobot/storage/s3.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | 4 | class S3Storage: 5 | def __init__(self, access_key, secret, bucket, region): 6 | self.bucket = bucket 7 | self.s3 = boto3.client('s3', 8 | aws_access_key_id=access_key, 9 | aws_secret_access_key=secret, 10 | region_name=region) 11 | self.bucket = bucket 12 | print(f'Connected to S3') 13 | 14 | def full_path(self, path): 15 | return self.base_path() + path 16 | 17 | def upload_file(self, file, key): 18 | self.s3.upload_file(file, self.bucket, key) 19 | 20 | 21 | def base_path(self): 22 | return f'https://{self.bucket}.s3.amazonaws.com/' 23 | 24 | def save_data(self, data, path, mimetype): 25 | self.s3.put_object(Body=data, Bucket=self.bucket, Key=path, ContentType=mimetype) 26 | 27 | def get_data(self, path): 28 | return self.fetch_file(path)['Body'].read() 29 | 30 | def exists(self, path): 31 | try: 32 | self.s3.head_object(Bucket=self.bucket, Key=path) 33 | return True 34 | except ClientError as e: 35 | if e.response['Error']['Code'] == "404": 36 | return False 37 | else: 38 | raise e 39 | 40 | def get_data_text(self, path): 41 | return self.fetch_file(path)['Body'].read().decode('UTF-8') 42 | 43 | def fetch_file(self, key): 44 | return self.s3.get_object(Bucket=self.bucket, Key=key) 45 | 46 | def stream_data(self, path, localpath): 47 | try: 48 | return self.s3.download_file(self.bucket, path, localpath) 49 | except ClientError as e: 50 | if e.response['Error']['Code'] == "404": 51 | print("The object does not exist.") 52 | raise FileNotFoundError(path) 53 | raise e 54 | -------------------------------------------------------------------------------- /cliobot/translator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BaseTranslator: 4 | def translate(self, txt): 5 | raise NotImplementedError() 6 | 7 | class NullTranslator: 8 | def __init__(self): 9 | print("**** Auto translation disabled ****") 10 | 11 | def translate(self, txt): 12 | return txt -------------------------------------------------------------------------------- /cliobot/translator/google.py: -------------------------------------------------------------------------------- 1 | from cliobot.translator import BaseTranslator 2 | import deep_translator 3 | 4 | 5 | class Google(BaseTranslator): 6 | def __init__(self): 7 | self.translator = deep_translator.GoogleTranslator(source="auto", target="en") 8 | 9 | 10 | def translate(self, txt): 11 | return self.translator.translate(txt).text 12 | 13 | -------------------------------------------------------------------------------- /cliobot/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import io 4 | import os 5 | from io import BytesIO 6 | from pathlib import Path 7 | 8 | import requests 9 | from PIL import Image 10 | 11 | 12 | def md5_hash(txt): 13 | md5 = hashlib.md5() 14 | md5.update(txt.encode()) 15 | return md5.hexdigest() 16 | 17 | 18 | def abs_path(path): 19 | return os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", path)) 20 | 21 | 22 | def base64_to_bytes(base64_string): 23 | if base64_string.startswith('data:image/png;base64,'): 24 | base64_string = base64_string[len('data:image/png;base64,'):] 25 | return base64.b64decode(base64_string) 26 | 27 | 28 | def is_empty(txt): 29 | return txt is None or (isinstance(txt, str) and txt.strip() == '') or False 30 | 31 | 32 | def locale(session): 33 | if session.get('language', 'en') == 'br': 34 | return 'br' 35 | else: 36 | return 'en' 37 | 38 | 39 | def is_blank(hash, key): 40 | if key not in hash or is_empty(hash[key]): 41 | return True 42 | return False 43 | 44 | 45 | def get_or_default(hash, key, default): # ignore nones 46 | if hash: 47 | if key not in hash or hash[key] is None: 48 | return default 49 | return hash[key] 50 | return None 51 | 52 | 53 | def is_set(hash, key): 54 | return not is_blank(hash, key) 55 | 56 | 57 | def flatten(lst): 58 | flat_list = [] 59 | for item in lst: 60 | if isinstance(item, list): 61 | flat_list.extend(flatten(item)) 62 | else: 63 | flat_list.append(item) 64 | return flat_list 65 | 66 | 67 | def download_string(url): 68 | r = requests.get(url) 69 | if r.status_code == 200: 70 | return r.text 71 | 72 | 73 | def decode_image(file_or_url) -> str: 74 | """ 75 | convert a file to base64 if it's local or use the url if remote 76 | 77 | :param file_or_url: 78 | :return: 79 | """ 80 | if isinstance(file_or_url, Path): 81 | return image_to_base64(Image.open(file_or_url)) 82 | elif file_or_url.startswith('data:image/png;base64,'): 83 | return file_or_url 84 | elif file_or_url.startswith('/'): 85 | with open(file_or_url, "rb") as f: 86 | binary_data = f.read() 87 | return image_to_base64(Image.open(io.BytesIO(binary_data))) 88 | else: 89 | return file_or_url # full url 90 | 91 | 92 | def image_to_base64(image): 93 | buffered = BytesIO() 94 | image.save(buffered, format="PNG") 95 | d = base64.b64encode(buffered.getvalue()).decode('utf-8') 96 | if not d.startswith('data:image/png;base64,'): 97 | d = 'data:image/png;base64,' + d 98 | return d 99 | 100 | 101 | def open_image(r): 102 | if r.startswith('http'): 103 | return Image.open(BytesIO(download(r))) 104 | return Image.open(r) 105 | 106 | 107 | def download(url): 108 | r = requests.get(url) 109 | if r.status_code == 200: 110 | return r.content 111 | else: 112 | raise Exception(f"Failed to download {url} {r.status_code}") 113 | -------------------------------------------------------------------------------- /cliobot/webui/client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import os.path 4 | import random 5 | 6 | import requests 7 | from PIL import Image 8 | 9 | from cliobot.commands import Model, BasePrompt, GenerationResults 10 | from cliobot.utils import base64_to_bytes, abs_path 11 | 12 | 13 | def save_image(base64data, folder): 14 | image = Image.open( 15 | io.BytesIO( 16 | base64_to_bytes(base64data) 17 | )) 18 | filepath = abs_path(os.path.join( 19 | folder, 20 | f'{str(int(random.random() * 100000000))}.png')) 21 | image.save(filepath) 22 | return filepath 23 | 24 | 25 | class Txt2imgPrompt(BasePrompt): 26 | steps: int = 20 27 | sampler: str = 'DPM++ 2M SDE' 28 | width: int = 512 29 | height: int = 512 30 | batchcount: int = 1 31 | batchsize: int = 1 32 | cfg: int = 7 33 | seed: int = -1 34 | negative: str = '' 35 | 36 | 37 | class Txt2img(Model): 38 | def __init__(self, model, client): 39 | super().__init__( 40 | prompt_class=Txt2imgPrompt, 41 | ) 42 | self.client = client 43 | self.model = model 44 | 45 | async def generate(self, prompt): 46 | print("here!", prompt) 47 | return await self.client.txt2img(prompt) 48 | 49 | 50 | class WebuiClient: 51 | 52 | def __init__(self, endpoint, auth, temp_dir = 'tmp'): 53 | self.endpoint = endpoint 54 | self.auth = auth # TODO use 55 | self.temp_dir = temp_dir 56 | 57 | def get_models(self): 58 | return self._get(f'/sdapi/v1/sd-models') 59 | 60 | async def txt2img(self, parsed) -> GenerationResults: 61 | params = { 62 | 'cfg_scale': parsed.cfg, 63 | 'width': parsed.width, 64 | 'height': parsed.height, 65 | 'batch_size': parsed.batchsize, 66 | 'n_iter': parsed.batchcount, 67 | 'seed': parsed.seed, 68 | 'prompt': parsed.prompt, 69 | 'negative_prompt': parsed.negative, 70 | 'steps': parsed.steps, 71 | 'sampler_name': parsed.sampler, 72 | } # TODO previews 73 | 74 | r = self._post(f'/sdapi/v1/txt2img', params) 75 | 76 | imgs = [] 77 | 78 | for i in r['images']: 79 | path = save_image(i, self.temp_dir) 80 | imgs.append({ 81 | 'url': path, 82 | 'prompt': parsed.prompt, 83 | }) 84 | 85 | return GenerationResults( 86 | texts=[], 87 | images=imgs, 88 | ) 89 | 90 | # TODO 91 | # "styles": [ 92 | # "string" 93 | # ], 94 | # "subseed": -1, 95 | # "subseed_strength": 0, 96 | # "seed_resize_from_h": -1, 97 | # "seed_resize_from_w": -1, 98 | # "restore_faces": true, 99 | # "tiling": true, 100 | # "do_not_save_samples": false, 101 | # "do_not_save_grid": false, 102 | # "eta": 0, 103 | # "denoising_strength": 0, 104 | # "s_min_uncond": 0, 105 | # "s_churn": 0, 106 | # "s_tmax": 0, 107 | # "s_tmin": 0, 108 | # "s_noise": 0, 109 | # "override_settings": {}, 110 | # "override_settings_restore_afterwards": true, 111 | # "refiner_checkpoint": "string", 112 | # "refiner_switch_at": 0, 113 | # "disable_extra_networks": false, 114 | # "comments": {}, 115 | # "enable_hr": false, 116 | # "firstphase_width": 0, 117 | # "firstphase_height": 0, 118 | # "hr_scale": 2, 119 | # "hr_upscaler": "string", 120 | # "hr_second_pass_steps": 0, 121 | # "hr_resize_x": 0, 122 | # "hr_resize_y": 0, 123 | # "hr_checkpoint_name": "string", 124 | # "hr_sampler_name": "string", 125 | # "hr_prompt": "", 126 | # "hr_negative_prompt": "", 127 | # "script_name": "string", 128 | # "script_args": [], 129 | # "send_images": true, 130 | # "save_images": false, 131 | # "alwayson_scripts": {} 132 | # } 133 | 134 | def _get(self, path): 135 | args = {} 136 | if self.auth: 137 | encoded_credentials = base64.b64encode(self.auth.encode()).decode() 138 | args['headers'] = { 139 | 'Authorization': f'Basic {encoded_credentials}', 140 | } 141 | 142 | return requests.get(self.endpoint + path, **args).json() 143 | 144 | def _post(self, path, params): 145 | args = {} 146 | if self.auth: 147 | encoded_credentials = base64.b64encode(self.auth.encode()).decode() 148 | args['headers'] = { 149 | 'Authorization': f'Basic {encoded_credentials}', 150 | } 151 | 152 | return requests.post(self.endpoint + path, json=params, **args).json() 153 | -------------------------------------------------------------------------------- /config.example.yml: -------------------------------------------------------------------------------- 1 | env: development 2 | mode: command 3 | locale: en 4 | 5 | openai: 6 | models: 7 | - gpt-4 8 | - whisper-1 9 | - dall-e-3 10 | - gpt-4-vision 11 | 12 | endpoints: 13 | - api_key: $OPENAI_API_TOKEN 14 | api_type: open_ai 15 | base_url: https://api.openai.com/v1/ 16 | 17 | - api_key: $OPENAI_AZURE_API_KEY 18 | api_type: azure 19 | api_version: 2023-10-01-preview 20 | base_url: $OPENAI_AZURE_BASE_URL 21 | model: gpt4 22 | kind: gpt-4 23 | 24 | - api_key: $OPENAI_AZURE_API_KEY 25 | api_type: azure 26 | api_version: 2 27 | base_url: $OPENAI_AZURE_BASE_URL 28 | model: embeddings 29 | kind: embeddings 30 | 31 | - api_key: $OPENAI_AZURE_API_KEY 32 | api_type: azure 33 | api_version: 2023-12-01-preview 34 | base_url: $OPENAI_AZURE_BASE_URL 35 | model: dalle3 36 | kind: dall-e-3 37 | 38 | - api_key: $OPENAI_AZURE_API_KEY 39 | api_type: azure 40 | api_version: 2023-12-01-preview 41 | base_url: $OPENAI_AZURE_BASE_URL 42 | model: whisper1 43 | kind: whisper-1 44 | 45 | 46 | ollama: 47 | endpoint: http://localhost:11434 48 | models: 49 | - model: llama2 50 | kind: 'ask' 51 | - model: llava 52 | kind: 'describe' 53 | 54 | replicate: 55 | api_token: $REPLICATE_API_TOKEN 56 | endpoints: 57 | - model: 'llava13' 58 | kind: 'describe' 59 | version: "yorickvp/llava-13b:6bc1c7bb0d2a34e413301fee8f7cc728d2d4e75bfab186aa995f63292bda92fc" 60 | params: 61 | prompt: 62 | kind: str 63 | required: true 64 | image: 65 | kind: image 66 | required: true 67 | max_tokens: 68 | kind: int 69 | default: 1024 70 | - model: 'sdxl' 71 | kind: 'image' 72 | version: 'stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b' 73 | params: 74 | prompt: 75 | kind: str 76 | required: true 77 | negative_prompt: 78 | alias: no 79 | kind: str 80 | width: 81 | kind: int 82 | default: 1024 83 | height: 84 | kind: int 85 | default: 1024 86 | num_outputs: 87 | alias: num 88 | kind: int 89 | default: 1 90 | num_inference_steps: 91 | alias: steps 92 | kind: int 93 | default: 25 94 | guidance_scale: 95 | alias: cfg 96 | kind: float 97 | default: 7.5 98 | prompt_strength: 99 | alias: ps 100 | kind: float 101 | default: 0.8 102 | seed: 103 | kind: int 104 | apply_watermark: 105 | alias: watermark 106 | kind: bool 107 | default: true 108 | scheduler: 109 | kind: str 110 | default: 'KarrasDPM' 111 | refine: 112 | kind: str 113 | alias: refiner 114 | default: 'no_refiner' 115 | value_map: 116 | no: no_refiner 117 | expert: expert_ensemble_refiner 118 | base: base_image_refiner 119 | refine_steps: 120 | kind: int 121 | alias: rs 122 | 123 | bot: 124 | platform: telegram 125 | token: $TELEGRAM_TOKEN 126 | 127 | db: 128 | # driver: inmemory 129 | driver: sqlite3 130 | file: data/db.sqlite3 131 | 132 | storage: 133 | driver: local 134 | folder: data/ 135 | # driver: s3 136 | # bucket: my-bucket 137 | # zone: eu-west-1 138 | # access_key: my-key 139 | # secret_key: my-secret 140 | 141 | 142 | fallback_commands: 143 | audio: transcribe 144 | voice: transcribe 145 | text: ask 146 | 147 | 148 | default_models: 149 | ask: llama2 150 | transcribe: whisper-1 151 | image: dall-e-3 152 | describe: gpt-4-vision-preview -------------------------------------------------------------------------------- /config.full.yml: -------------------------------------------------------------------------------- 1 | env: development 2 | mode: command 3 | locale: en 4 | 5 | openai: 6 | models: 7 | - gpt-4 8 | - whisper-1 9 | - dall-e-3 10 | - gpt-4-vision-preview 11 | 12 | endpoints: 13 | - api_key: $OPENAI_API_TOKEN 14 | api_type: open_ai 15 | base_url: https://api.openai.com/v1/ 16 | 17 | - api_key: $OPENAI_AZURE_API_KEY_GPT35 18 | api_type: azure 19 | api_version: 2023-10-01-preview 20 | base_url: $OPENAI_AZURE_BASE_URL_GPT35 21 | model: gpt35turbo16k 22 | kind: gpt-3.5 23 | 24 | - api_key: $OPENAI_AZURE_API_KEY_GPT4 25 | api_type: azure 26 | api_version: 2023-10-01-preview 27 | base_url: $OPENAI_AZURE_BASE_URL_GPT4 28 | model: gpt4 29 | kind: gpt-4 30 | 31 | - api_key: $OPENAI_AZURE_API_KEY_EMBEDDINGS 32 | api_type: azure 33 | api_version: 2 34 | base_url: $OPENAI_AZURE_BASE_URL_EMBEDDINGS 35 | model: embeddings 36 | kind: embeddings 37 | 38 | - api_key: $OPENAI_AZURE_API_KEY_DALLE3 39 | api_type: azure 40 | api_version: 2023-12-01-preview 41 | base_url: $OPENAI_AZURE_BASE_URL_DALLE3 42 | model: dalle3 43 | kind: dall-e-3 44 | 45 | - api_key: $OPENAI_AZURE_API_KEY_WHISPER 46 | api_type: azure 47 | api_version: 2023-12-01-preview 48 | base_url: $OPENAI_AZURE_BASE_URL_WHISPER 49 | model: whisper1 50 | kind: whisper-1 51 | 52 | 53 | ollama: 54 | endpoint: http://localhost:11434 55 | models: 56 | - model: llama2 57 | kind: 'ask' 58 | - model: llava 59 | kind: 'describe' 60 | 61 | 62 | replicate: 63 | api_token: $REPLICATE_API_TOKEN 64 | endpoints: 65 | - model: 'llava13' 66 | kind: 'describe' 67 | version: "yorickvp/llava-13b:6bc1c7bb0d2a34e413301fee8f7cc728d2d4e75bfab186aa995f63292bda92fc" 68 | params: 69 | prompt: 70 | kind: str 71 | required: true 72 | default: "what's this?" 73 | image: 74 | kind: image 75 | required: true 76 | max_tokens: 77 | kind: int 78 | default: 1024 79 | - model: 'sdxl' 80 | kind: 'image' 81 | version: 'stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b' 82 | params: 83 | prompt: 84 | kind: str 85 | required: true 86 | negative_prompt: 87 | alias: no 88 | kind: str 89 | width: 90 | kind: int 91 | default: 1024 92 | height: 93 | kind: int 94 | default: 1024 95 | num_outputs: 96 | alias: num 97 | kind: int 98 | default: 1 99 | num_inference_steps: 100 | alias: steps 101 | kind: int 102 | default: 25 103 | guidance_scale: 104 | alias: cfg 105 | kind: float 106 | default: 7.5 107 | prompt_strength: 108 | alias: ps 109 | kind: float 110 | default: 0.8 111 | seed: 112 | kind: int 113 | apply_watermark: 114 | alias: watermark 115 | kind: bool 116 | default: true 117 | scheduler: 118 | kind: str 119 | default: 'KarrasDPM' 120 | refine: 121 | kind: str 122 | alias: refiner 123 | default: 'no_refiner' 124 | value_map: 125 | no: no_refiner 126 | expert: expert_ensemble_refiner 127 | base: base_image_refiner 128 | refine_steps: 129 | kind: int 130 | alias: rs 131 | 132 | bot: 133 | platform: telegram 134 | token: $TELEGRAM_TOKEN 135 | 136 | db: 137 | # driver: inmemory 138 | driver: sqlite3 139 | file: data/db.sqlite3 140 | 141 | storage: 142 | driver: local 143 | folder: data/ 144 | # driver: s3 145 | # bucket: my-bucket 146 | # zone: eu-west-1 147 | # access_key: my-key 148 | # secret_key: my-secret 149 | 150 | 151 | fallback_commands: 152 | audio: transcribe 153 | voice: transcribe 154 | text: ask 155 | 156 | 157 | default_models: 158 | ask: llama2 159 | transcribe: whisper-1 160 | image: dall-e-3 161 | describe: gpt-4-vision-preview -------------------------------------------------------------------------------- /fullbot.py: -------------------------------------------------------------------------------- 1 | from cliobot.config import ConfigLoader 2 | 3 | if __name__ == '__main__': 4 | bot = ConfigLoader('config.yml', 'tmp').build() 5 | 6 | bot.listen() 7 | -------------------------------------------------------------------------------- /i18n/en.yml: -------------------------------------------------------------------------------- 1 | en: 2 | instructions: 3 | help: "Please refer to the documentation for a list of commands and how to use them, or use one of the following commands:\n\n%{commands}" 4 | buttons: 5 | read_docs: Read the Documentation 6 | results: 7 | preference_set: "Preference %{attr} set to %{value}" 8 | current_preferences: "Preferences:\n%{preferences}" 9 | current_context: "Current context:\n%{context}" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cliobot" 3 | version = "0.1.0" 4 | description = "a modular bot platform for generative AI agents" 5 | authors = ["herval "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | httpcore = "^1.0.4" 12 | aiohttp = "^3.9.3" 13 | aiodns = "^3.1.1" 14 | python-i18n = {extras = ["yaml"], version = "^0.3.9"} 15 | pillow = "10.1.0" 16 | pydantic = "^2.6.4" 17 | requests = "^2.31.0" 18 | pyyaml = "^6.0.1" 19 | python-dotenv = "^1.0.1" 20 | retry = "^0.9.2" 21 | deep-translator = "^1.11.4" 22 | 23 | 24 | [tool.poetry.group.openai.dependencies] 25 | openai = "1.3.9" 26 | 27 | 28 | [tool.poetry.group.replicate.dependencies] 29 | replicate = "^0.25.0" 30 | 31 | 32 | [tool.poetry.group.sentry.dependencies] 33 | sentry-sdk = "^1.42.0" 34 | 35 | 36 | [tool.poetry.group.mixpanel.dependencies] 37 | mixpanel = "^4.10.1" 38 | mixpanel-py-async = "^0.3.0" 39 | 40 | 41 | [tool.poetry.group.aws.dependencies] 42 | boto3 = "^1.34.66" 43 | 44 | 45 | [tool.poetry.group.llama.dependencies] 46 | llama-index = "^0.10.20" 47 | 48 | 49 | [tool.poetry.group.telegram.dependencies] 50 | python-telegram-bot = "20.7" 51 | 52 | 53 | [tool.poetry.group.redis.dependencies] 54 | redis = "^5.0.3" 55 | 56 | 57 | [tool.poetry.group.dev.dependencies] 58 | pytest = "^8.1.1" 59 | 60 | [build-system] 61 | requires = ["poetry-core"] 62 | build-backend = "poetry.core.masonry.api" 63 | -------------------------------------------------------------------------------- /schema.sql: -------------------------------------------------------------------------------- 1 | create table if not exists assets 2 | ( 3 | id integer primary key autoincrement, 4 | created_at timestamp with time zone default current_timestamp not null, 5 | external_user_id text not null, 6 | external_id text not null, 7 | external_chat_id text not null, 8 | storage_path text not null 9 | ); 10 | 11 | create table if not exists chat_messages 12 | ( 13 | id integer primary key autoincrement, 14 | created_at timestamp with time zone default current_timestamp not null, 15 | text text not null, 16 | external_id text not null, 17 | external_user_id text, 18 | external_chat_id text, 19 | external_image_id text, 20 | external_audio_id text, 21 | external_voice_id text, 22 | external_video_id text, 23 | is_forward boolean default false 24 | ); 25 | 26 | create table if not exists jobs 27 | ( 28 | id integer primary key autoincrement, 29 | created_at timestamp with time zone default current_timestamp not null, 30 | status text default 'created' not null, 31 | params jsonb not null, 32 | chat_session_id integer not null references chat_sessions, 33 | external_id text, 34 | outputs jsonb, 35 | public boolean default false not null, 36 | nsfw boolean default false not null, 37 | deleted_at timestamp with time zone, 38 | external_status text 39 | ); 40 | 41 | create table if not exists chat_sessions 42 | ( 43 | id integer primary key autoincrement, 44 | created_at timestamp with time zone default current_timestamp not null, 45 | logged_in_at timestamp with time zone, 46 | external_user_id text unique, 47 | context jsonb default '{}', 48 | preferences jsonb default '{}' 49 | ); 50 | -------------------------------------------------------------------------------- /test/res/hello.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/herval/cliobot/c7b2155e548f33ab87216263b765796f5338e50e/test/res/hello.mp3 -------------------------------------------------------------------------------- /test/res/sandwich.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/herval/cliobot/c7b2155e548f33ab87216263b765796f5338e50e/test/res/sandwich.jpg -------------------------------------------------------------------------------- /test/test_command_handler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock 3 | 4 | from pydantic_core._pydantic_core import ValidationError 5 | 6 | from cliobot.bots import MessagingService, Message, Session 7 | from cliobot.commands import to_params, notify_errors 8 | 9 | 10 | def msg(txt): 11 | return Message( 12 | text=txt, 13 | user_id='123', 14 | chat_id='456', 15 | message_id='789', 16 | user={}, 17 | ) 18 | 19 | 20 | def session(ctx, prefs): 21 | return Session( 22 | user_id='123', 23 | chat_id='456', 24 | context=ctx, 25 | preferences=prefs, 26 | ) 27 | 28 | 29 | class TestCommandHandler(unittest.IsolatedAsyncioTestCase): 30 | 31 | async def test_parse_message(self): 32 | valid_res = to_params( 33 | msg("/test hello world ! --bla abc def"), 34 | session({}, {}), 35 | ) 36 | 37 | print(valid_res) 38 | self.assertEqual(valid_res['command'], 'test') 39 | self.assertEqual(valid_res['prompt'], 'hello world !') 40 | self.assertEqual(valid_res['bla'], 'abc def') 41 | self.assertIsNone(valid_res.get('ble')) 42 | 43 | valid_res = to_params(msg("/test hello world --ble 3"), session({'bla': 'abc def'}, {'bla': 'ignored'})) 44 | print(valid_res) 45 | self.assertEqual(valid_res['command'], 'test') 46 | self.assertEqual(valid_res['prompt'], 'hello world') 47 | self.assertEqual(valid_res['bla'], 'abc def') 48 | self.assertEqual(valid_res['ble'], '3') 49 | 50 | valid_res = to_params(msg("/test hello world --ble 3"), session({}, {'bla': 'abc def'})) 51 | print(valid_res) 52 | self.assertEqual(valid_res['command'], 'test') 53 | self.assertEqual(valid_res['prompt'], 'hello world') 54 | self.assertEqual(valid_res['bla'], 'abc def') 55 | self.assertEqual(valid_res['ble'], '3') 56 | 57 | async def test_notify_errors(self): 58 | ms = mock.Mock(MessagingService) 59 | err = mock.Mock(ValidationError) 60 | err.errors.return_value = [ 61 | {'loc': ('bla',), 'msg': 'foo'} 62 | ] 63 | 64 | await notify_errors( 65 | err, 66 | ms, 67 | 123) 68 | 69 | ms.send_message.assert_called_once_with(text='Whoops!\nbla: foo', chat_id=123, reply_to_message_id=None) 70 | err.errors.assert_called_once() 71 | -------------------------------------------------------------------------------- /test/test_ollama_client.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from cliobot.config import load_config 4 | from cliobot.ollama.client import OllamaText, OllamaPrompt 5 | from cliobot.utils import abs_path 6 | 7 | 8 | class TestOllamaClient(unittest.IsolatedAsyncioTestCase): 9 | 10 | def setUp(self): 11 | config = load_config('config.yml') 12 | self.text_client = OllamaText( 13 | config['ollama']['endpoint'], 14 | ) 15 | 16 | async def test_text(self): 17 | res = await self.text_client.generate( 18 | OllamaPrompt( 19 | command='', 20 | prompt='Hello there', 21 | model='llama2', 22 | ) 23 | ) 24 | print(res) 25 | self.assertIsNot(res.texts, []) 26 | 27 | 28 | async def test_describe(self): 29 | res = await self.text_client.generate( 30 | OllamaPrompt( 31 | command='', 32 | prompt='whats this?', 33 | model='llava', 34 | image=abs_path('test/res/sandwich.jpg'), 35 | ) 36 | ) 37 | print(res) 38 | self.assertIsNot(res.texts, []) 39 | -------------------------------------------------------------------------------- /test/test_openai_client.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from cliobot.config import load_config 4 | from cliobot.metrics import BaseMetrics 5 | from cliobot.openai.client import OpenAIClient 6 | from cliobot.utils import abs_path 7 | 8 | 9 | class TestOpenAIClient(unittest.TestCase): 10 | 11 | def setUp(self): 12 | config = load_config('config.yml') 13 | self.openai_client = OpenAIClient( 14 | [config['openai']['endpoints'][0]], 15 | BaseMetrics(None), 16 | ) 17 | 18 | self.azure_client = OpenAIClient( 19 | config['openai']['endpoints'][1:], 20 | BaseMetrics(None), 21 | ) 22 | 23 | def test_transcribe(self): 24 | res = self.openai_client.transcribe(abs_path('test/res/hello.mp3')) 25 | print(res) 26 | self.assertEqual(res, 'Hello there') 27 | 28 | res = self.azure_client.transcribe(abs_path('test/res/hello.mp3')) 29 | print(res) 30 | self.assertEqual(res, 'Hello there') 31 | 32 | def test_dalle_txt2img(self): 33 | res = self.azure_client.dalle3_txt2img( 34 | 'a blue coffee cup on top of a red table, besides a white plate', 35 | 2, 36 | '512x512') # size gets fixed automatically 37 | print(res) 38 | assert len(res) == 1 # always 1 39 | 40 | res = self.openai_client.dalle3_txt2img( 41 | 'a blue coffee cup on top of a red table, besides a white plate', 42 | 2, 43 | '512x512') # size gets fixed automatically 44 | print(res) 45 | assert len(res) == 1 # always 1 46 | 47 | def test_ask(self): 48 | # res = self.openai_client.ask('What is the meaning of life?') 49 | # print(res) 50 | # assert len(res) > 0 51 | 52 | res = self.azure_client.ask('What is the meaning of life?') 53 | print(res) 54 | assert len(res) > 0 55 | 56 | def test_img2txt(self): 57 | res = self.openai_client.img2text( 58 | "what's in this image?", 59 | 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRQX0YDlVeH53k9oST-dmEt-5w5IQwdxu7BhywRS2Q9cg&s' 60 | ) 61 | print(res) 62 | self.assertIsNotNone(res) 63 | 64 | res = self.openai_client.img2text( 65 | "what's in this image?", 66 | abs_path('test/res/sandwich.jpg') 67 | ) 68 | print(res) 69 | self.assertIsNotNone(res) 70 | -------------------------------------------------------------------------------- /test/test_replicate_client.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from cliobot.config import load_config 4 | from cliobot.replicate.client import ReplicateEndpoint 5 | from cliobot.utils import abs_path 6 | 7 | 8 | class TestReplicateClient(unittest.IsolatedAsyncioTestCase): 9 | 10 | def setUp(self): 11 | config = load_config('config.yml') 12 | self.describe_client = ReplicateEndpoint( 13 | config['replicate']['endpoints'][0]['kind'], 14 | config['replicate']['api_token'], 15 | config['replicate']['endpoints'][0]['version'], 16 | config['replicate']['endpoints'][0]['params'], 17 | ) 18 | self.txt2img_client = ReplicateEndpoint( 19 | config['replicate']['endpoints'][1]['kind'], 20 | config['replicate']['api_token'], 21 | config['replicate']['endpoints'][1]['version'], 22 | config['replicate']['endpoints'][1]['params'], 23 | ) 24 | 25 | async def test_describe(self): 26 | res = await self.describe_client.generate( 27 | { 28 | 'prompt': 'whats this?', 29 | 'image': abs_path('test/res/sandwich.jpg'), 30 | } 31 | ) 32 | print(res) 33 | self.assertIsNot(res.texts, []) 34 | self.assertIs(len(res.images), 0) 35 | 36 | async def test_txt2img(self): 37 | res = await self.txt2img_client.generate( 38 | { 39 | 'prompt': 'a blue coffee cup on top of a red table, besides a white plate', 40 | 'width': 1280, 41 | 'no': 'photography, realistic' 42 | } 43 | ) 44 | print(res) 45 | self.assertIs(len(res.texts), 0) 46 | self.assertIsNot(res.images, []) 47 | -------------------------------------------------------------------------------- /test/test_webui_client.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from cliobot.config import load_config 4 | from cliobot.utils import abs_path 5 | from cliobot.webui.client import WebuiClient, Txt2imgPrompt, save_image 6 | 7 | 8 | def save_images(imgs): 9 | for img in imgs: 10 | save_image(img.url, abs_path('tmp')) 11 | 12 | 13 | class TestWebuiClient(unittest.IsolatedAsyncioTestCase): 14 | def setUp(self): 15 | config = load_config('config.yml') 16 | self.client = WebuiClient( 17 | config['webui']['endpoint'], 18 | config['webui'].get('auth', None), 19 | ) 20 | 21 | async def test_txt2img(self): 22 | res = await self.client.txt2img( 23 | Txt2imgPrompt( 24 | command='image', 25 | prompt='a banana', 26 | model='sdxl', 27 | steps=5, 28 | batchsize=2, 29 | ) 30 | ) 31 | save_images(res.images) 32 | self.assertIs(len(res.texts), 0) 33 | self.assertIs(len(res.images), 2) 34 | -------------------------------------------------------------------------------- /working.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/herval/cliobot/c7b2155e548f33ab87216263b765796f5338e50e/working.jpg --------------------------------------------------------------------------------