├── .env.example ├── .gitignore ├── CHANGELOG.md ├── COMPATIBILITY.md ├── README.md ├── automigration.py ├── bot.py ├── characters ├── gpt4all_default.py ├── gptj_6B_default.py ├── llama_2_nous-hermes.py ├── llama_3_chat_default.py ├── llama_chat_default.py ├── llama_rulora_assistant_only.py ├── min_chatGPT2_default.py ├── mistral_default.py ├── obsidian_mm_default.py ├── orca_default.py ├── pygmalion_chat_king_william.py ├── ru_gpt3_default.py ├── ru_saiga_lcpp.py ├── samantha_default.py ├── vicuna_16k.py └── vicuna_default.py ├── chroniclers └── base.py ├── config_reader.py ├── custom_queue.py ├── dashboard.py ├── extensions ├── _base_extension.py └── sts_autoreply.py ├── frontend ├── .eslintignore ├── .eslintrc.js ├── .gitignore ├── index.html ├── package.json ├── public │ └── favicon.ico ├── src │ ├── App.vue │ ├── assets │ │ └── not-found.svg │ ├── auto-imports.d.ts │ ├── botControl.js │ ├── components.d.ts │ ├── components │ │ ├── ConfigForm.vue │ │ ├── FormWrapper.vue │ │ ├── KVEditor.vue │ │ ├── Modal.vue │ │ ├── ModelSetupWindow.vue │ │ ├── ModelTable.vue │ │ ├── Notification.vue │ │ ├── Offline.vue │ │ ├── Sidebar.vue │ │ └── Tabs.vue │ ├── libs │ │ └── formvuelar.js │ ├── locale │ │ ├── index.js │ │ └── lang │ │ │ ├── en.js │ │ │ └── ru.js │ ├── main.js │ ├── recommendedModels.js │ ├── router │ │ └── index.js │ ├── state.js │ ├── tools.js │ └── views │ │ ├── Chat.vue │ │ ├── Config.vue │ │ ├── Home.vue │ │ ├── ModelManager.vue │ │ └── NotFound.vue ├── tailwind.config.js └── vite.config.js ├── middleware.py ├── misc ├── botless_layer.py ├── memory_manager.py ├── model_manager.py └── mps_fixups.py ├── modules ├── admin.py ├── extensions.py ├── llm.py ├── sd.py ├── stt.py ├── tta.py └── tts.py ├── providers ├── llm │ ├── __init__.py │ ├── abstract_llm.py │ ├── llama_cpp_provider.py │ ├── mlc_chat_prebuilt_provider.py │ ├── pytorch │ │ ├── auto_hf_provider.py │ │ ├── gpt2_provider.py │ │ ├── gptj_provider.py │ │ ├── llama_hf_provider.py │ │ └── llama_orig_provider.py │ ├── remote_llama_cpp.py │ └── remote_ob.py ├── llm_provider.py ├── sd_provider.py ├── stt │ ├── __init__.py │ ├── abstract_stt.py │ ├── silero.py │ ├── wav2vec2.py │ ├── whisper.py │ └── whisperS2T.py ├── stt_provider.py ├── tta_provider.py ├── tts │ ├── __init__.py │ ├── abstract_tts.py │ ├── coqui_tts.py │ ├── py_ttsx4.py │ ├── remote_tts.py │ ├── say_macos.py │ └── so_vits_svc.py └── tts_provider.py ├── pyproject.toml ├── requirements-all.txt ├── requirements-llm.txt ├── requirements-stt.txt ├── requirements-tts.txt ├── requirements.txt ├── servers ├── api_sever.py ├── common.py ├── control_server.py └── tts_server.py ├── static ├── assets │ ├── en.js │ ├── index.css │ ├── index.js │ ├── not-found.svg │ ├── recommendedModels.js │ └── ru.js ├── favicon.ico └── index.html └── utils.py /.env.example: -------------------------------------------------------------------------------- 1 | bot_token=12345...:... 2 | adminlist=[1810772] 3 | blacklist=[] 4 | whitelist=[1810772, -123456789010] 5 | ignore_mode=blacklist 6 | active_modules=["sd", "llm", "tts", "stt", "admin"] 7 | tts_path=/Users/user/tts_provider/models 8 | tts_voices='[ 9 | ]' 10 | tts_mode=local 11 | tts_replacements={"key": "value", "key2": "value2"} 12 | tts_credits="TTS models trained by " 13 | tts_ffmpeg_path=/Users/user/Applications/ffmpeg 14 | tts_queue_size_per_user=2 15 | tts_enable_backends=["say_macos", "ttsx4", "coqui_tts", "so_vits_svc"] 16 | tts_so_vits_svc_4_0_code_path='/Users/user/path/to/so-vits-svc' 17 | tts_so_vits_svc_4_1_code_path='/Users/user/path/to/so-vits-svc4_1' 18 | tts_so_vits_svc_voices='[ 19 | ]' 20 | tts_list_system_voices=False 21 | tts_host=http://localhost:7077 22 | llm_host=http://localhost:5000 23 | sd_host=http://localhost:7860 24 | sd_max_steps=40 25 | sd_max_resolution=1280 26 | sd_available_samplers=["Euler a", "Euler", "Heun", "DPM++ 2M", "DPM++ 2S a", "UniPC"] 27 | sd_extra_prompt="a high quality image of {prompt}, 8k, masterpiece, detailed, accurate proportions" 28 | sd_extra_negative_prompt="(worst quality:1.2), (lowres), deepfried, watermark, (blurry), jpeg noise, unsharp, deformed, {negative_prompt}" 29 | sd_default_sampler="UniPC" 30 | sd_default_n_iter=1 31 | sd_default_width=512 32 | sd_default_height=512 33 | sd_default_tti_steps=22 34 | sd_default_tti_cfg_scale=0 35 | sd_default_iti_cfg_scale=8 36 | sd_default_iti_steps=30 37 | sd_default_iti_denoising_strength=0.58 38 | sd_default_iti_sampler="Euler a" 39 | sd_lora_custom_activations={"keyword": "trigger word "} 40 | sd_only_admins_can_change_models=False 41 | sd_queue_size_per_user=5 42 | sd_launch_process_automatically=False 43 | sd_launch_command="python webui.py --api" 44 | sd_launch_dir="/Users/user/stable-diffusion-webui/" 45 | sd_launch_waittime=10 46 | apply_mps_fixes=True 47 | llm_queue_size_per_user=2 48 | llm_backend=llama_cpp 49 | llm_python_model_type=gpt2 50 | llm_assistant_chronicler=instruct 51 | llm_character=characters.llama_chat_default 52 | llm_paths='{ 53 | "path_to_hf_llama":"/Users/user/LLaMA/hf-llama", 54 | "path_to_llama_code":"/Users/user/LLaMA/llama-mps/", 55 | "path_to_llama_weights":"/Users/user/LLaMA/7B/", 56 | "path_to_llama_tokenizer":"/Users/user/LLaMA/tokenizer.model", 57 | "path_to_llama_adapter":"/Users/LLaMA/LLaMA-Adapter/llama_adapter_len10_layer30_release.pth", 58 | "path_to_llama_multimodal_adapter":"/Users/LLaMA/LLaMA-Adapter/ckpts/7f...13_BIAS-7B.pth", 59 | "path_to_llama_lora":"/Users/user/LLaMA/alpaca-lora/models/aplaca-lora-7b", 60 | "path_to_llama_cpp_weights":"/Users/user/LLaMA/llama.cpp_models/ggml-vicuna-7b-1.1-q4_2.gguf", 61 | "path_to_llama_cpp_weights_dir":"/Users/user/LLaMA/llama.cpp_models/", 62 | "path_to_gptj_weights":"/Users/user/gpt-j/GPT-J-6B_model", 63 | "path_to_autohf_weights":"/Users/user/Cerebras-GPT-1.3B", 64 | "path_to_gpt2_weights":"/Users/user/ru-gpt3-telegram-bot/rugpt3large_based_on_gpt2", 65 | "path_to_minchatgpt_code":"/Users/user/minChatGPT/src", 66 | "path_to_mlc_chatbot_code":"/Users/user/LLaMA/mlc-llm/mlc-chatbot/", 67 | "path_to_mlc_pb_home_dir":"/Users/user/LLaMA/mlc-llm/", 68 | "path_to_mlc_pb_binary_dir":"" 69 | }' 70 | llm_history_grouping=chat 71 | llm_max_history_items=10 72 | llm_generation_cfg_override={} 73 | llm_assistant_cfg_override={"early_stopping": true} 74 | llm_assistant_use_in_chat_mode=False 75 | llm_assistant_add_reply_context=True 76 | llm_force_assistant_for_unsupported_models=False 77 | llm_max_tokens=64 78 | llm_max_assistant_tokens=128 79 | llm_lcpp_max_context_size=66000 80 | llm_lcpp_gpu_layers=1000 81 | llm_remote_launch_process_automatically=True 82 | llm_remote_launch_command="python3.10 server.py --api --n-gpu-layers 1000 --n_ctx 2048 --listen-port 5432" 83 | llm_remote_launch_dir="/Users/user/text-generation-webui/" 84 | llm_remote_model_name="orca-mini-v2_7b.ggmlv3.q4_0.bin" 85 | llm_remote_launch_waittime=10 86 | stt_backend=whisperS2T_CTranslate2 87 | stt_model_path_or_name=tiny 88 | stt_queue_size_per_user=1 89 | tta_queue_size_per_user=1 90 | tta_device=cpu 91 | tta_music_model=facebook/musicgen-small 92 | tta_sfx_model=facebook/audiogen-medium 93 | tta_duration=3 94 | mm_preload_models_on_start=False 95 | mm_ram_cached_model_count_limit=10 96 | mm_vram_cached_model_count_limit=10 97 | mm_management_policy=BOTH 98 | mm_unload_order_policy=LEAST_USED 99 | mm_autounload_after_seconds=240 100 | python_command=python 101 | threaded_initialization=True 102 | sys_webui_host=http://localhost:7007 103 | sys_api_host=http://localhost:7008 104 | sys_request_timeout=120 105 | sys_api_log_level=warning 106 | lang=en 107 | extensions_config='{ 108 | }' -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.DS_Store 6 | .vscode/* 7 | frontend/*.zip 8 | frontend/*.env 9 | frontend/package-lock.json 10 | extensions/private* 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | *.env 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.6 2 | - Migrated to `aiogram==3.1.1`, `pydantic==2.4.0`, if you want to keep old pydantic, use `aiogram==3.0.0b7`. 3 | - WebUI made with Vue is available! Run it with `dashboard.py` 4 | - Configuration preset switching in WebUI, to change path of default .env file, use `--env` on `dashboard.py`, additional configuration files should are stored with .env extension in env directory. 5 | - New configuration options: `sys_webui_host`, `sys_api_host`, `sys_request_timeout`, `sys_api_log_level` 6 | - Fixed TTS initialization failing when threaded mode was off 7 | - Fixed memory manager initialization 8 | - For model configuration UI the `path_to_llama_cpp_weights_dir` key has been added to `llm_paths` 9 | - Reply context (1-step memory) can now be toggled with `llm_assistant_add_reply_context` option 10 | - Configuration hints are available in the WebUI. 11 | - Model manager in the WebUI can be used to download and set up models. A few types of models are initially supported. 12 | 13 | ## v0.5 14 | - full TTS refactoring 15 | - universal cross-platform os TTS provider (pyttsx4) support 16 | - threaded initialization (configurable via `threaded_initialization` option) 17 | - bugfixes 18 | 19 | ### Breaking changes 20 | - `tts_enable_so_vits_svc` config option has been removed in favor of `tts_enable_backends`. 21 | - `tts_so_vits_svc_base_tts_provider` config option has been removed, you only need base voice from now on. 22 | 23 | 24 | ## v0.4 25 | - so-vits-svc-4.1 support 26 | - experimental memory manager (now supports all models except LLMs with pytorch backend) 27 | 28 | ### Breaking changes 29 | - `tts_so_vits_svc_code_path` config option has been renamed to `tts_so_vits_svc_4_0_code_path`, and `tts_so_vits_svc_4_1_code_path` option was added to support so-vits-svc-4.1 models, to specify that the model is a 4.1 mode, use "v": 4.1 in `tts_so_vits_svc_voices`. 30 | - `llm_host` has ben fixed in .env.example file, `llm_ob_host` has been removed 31 | 32 | 33 | ## v0.3 34 | - Basic config state reactivity (runtime config changes in bot are reflected in .env) 35 | - Fixed SD model info retrieval and reinitialization 36 | - Global module accessibility milestone 2 37 | - Experimental multi-line dialog answer support 38 | - Speech-to-text via Whisper.cpp, Silero an Wav2Vec2 39 | - Seamless dialog mode voice->text->llm->voice 40 | - Text-to-audio via Audiocraft 41 | 42 | ## v0.2 43 | - LLM module / provider refactoring 44 | - Reply chronicler with single-item history 45 | - Kobold.cpp and llama.cpp server support 46 | - SD Lora API support 47 | - Global module accessibility milestone 1 48 | 49 | ### Breaking changes 50 | * `MinChatGPTChronicler` and `GPT4AllChronicler` were removed 51 | Default assistant chronicler becomes `instruct`, `alpaca` name is kept for a limited time for backwards compatibility 52 | * `llm_ob_host` has been renamed to `llm_host` 53 | * `llm_active_model_type` has been deprecated and replaced by two new keys: `llm_backend` and `llm_python_model_type` (when backend is pytorch) 54 | * In `llm_python_model_type` option `cerebras_gpt` has been renamed to `auto_hf` 55 | * `path_to_cerebras_weights` has been renamed to `path_to_autohf_weights` 56 | * `sd_available_loras` has been deprecated, lora API endpoint is used 57 | 58 | ## v0.1 59 | * initial version with incremental changes and full backwards compatibility with previous commits -------------------------------------------------------------------------------- /COMPATIBILITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
BackendModelFeatures
Acceleration
Chronicler
pytorchllama_orig- adepter support
- visual input (multimodal adapter)

CUDA
MPS
instruct
pytorchllama_hf- lora support

CUDAinstruct
pytorchgpt-2, gpt-j, auto-modelCUDAinstruct
llama.cpp
remote-lcpp
any llama-based
- quantized GGML model support
- lora support
- built-in GPU acceleration
- memory manager support
- visual input (currently only in llama.cpp server)
CPU
CUDA
Metal

instruct
mlc-pbonly brebuilt in mlc-chat
- quantized MLC model support
CUDA
Vulkan
Metal
raw
remote_obany supported by oobabooga webui and kobold.cpp
- all features of Oobabooga webui, including GPTQ support that are available via API

- all features of Kobold.cpp that are available via API
- memory manager support
+
instruct
-------------------------------------------------------------------------------- /automigration.py: -------------------------------------------------------------------------------- 1 | from dotenv import dotenv_values, load_dotenv 2 | from utils import cprint 3 | import os 4 | import sys 5 | import json 6 | 7 | def verify_environment(): 8 | env_filename = os.environ.get('BOTALITY_ENV_FILE', '.env') 9 | assert os.path.exists(env_filename), \ 10 | f"Specified configuration file ({env_filename}) does not exist, please make a copy of .env.example and edit it." 11 | assert not env_filename.endswith('.env.example'), \ 12 | "You should not use example file as your configuration, please copy it to .env file." 13 | assert os.path.exists('.env.example'), \ 14 | f".env.example does not exist, please recover it, the file is used to automatically add new config options after updating" 15 | assert env_filename.endswith('.env'), \ 16 | "Configuration files must have .env extension" 17 | 18 | system_env = dotenv_values(env_filename) 19 | demo_env = dotenv_values('.env.example') 20 | 21 | def check_new_keys_in_example_env(): 22 | towrite = ''' 23 | ''' 24 | for key in demo_env: 25 | if key not in system_env: 26 | towrite += f"{key}='{demo_env[key]}'\n" 27 | print('New config key added', key) 28 | 29 | if len(towrite) != 1: 30 | with open(env_filename, 'a') as file: 31 | file.write(towrite) 32 | 33 | DEPRECATED_KEYS = [ 34 | 'llm_active_model_type', 35 | 'sd_available_loras', 36 | 'tts_so_vits_svc_code_path', 37 | 'tts_enable_so_vits_svc', 38 | 'tts_so_vits_svc_base_tts_provider', 39 | 'lang_main', 40 | 'stt_autoreply_mode', 41 | 'stt_autoreply_voice' 42 | ] 43 | DEPRECATED_KVS = { 44 | 'llm_assistant_chronicler': ['gpt4all', 'minchatgpt', 'alpaca'], 45 | 'llm_python_model_type': 'cerebras_gpt' 46 | } 47 | 48 | def check_deprecated_keys_in_dotenv(): 49 | for key in system_env: 50 | if key in DEPRECATED_KEYS: 51 | cprint(f'Warning! The key "{key}" has been deprecated! See CHANGELOG.md.', color='red') 52 | value = system_env[key] 53 | if len(value) < 1: 54 | continue 55 | if (value[0] != '[' and value[0] != '{'): 56 | value = f'"{value}"' 57 | value = json.loads(value) 58 | if type(value) is str: 59 | if key in DEPRECATED_KVS and value in DEPRECATED_KVS[key]: 60 | cprint(f'Warning! The value "{value}" of "{key}" has been deprecated! See CHANGELOG.md.', color='yellow') 61 | 62 | check_new_keys_in_example_env() 63 | check_deprecated_keys_in_dotenv() 64 | load_dotenv(env_filename, override=True) -------------------------------------------------------------------------------- /bot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import time 4 | 5 | from aiogram import Bot, Dispatcher 6 | from config_reader import config 7 | from middleware import ChatActionMiddleware, AccessMiddleware, CooldownMiddleware, MediaGroupMiddleware, CounterMiddleware 8 | from misc.botless_layer import CommandRegistrationHijacker 9 | 10 | from modules.sd import StableDiffusionModule 11 | from modules.tts import TextToSpeechModule 12 | from modules.admin import AdminModule 13 | from modules.llm import LargeLanguageModel 14 | from modules.tta import TextToAudioModule 15 | from modules.stt import SpeechToTextModule 16 | from modules.extensions import ExtensionsModule 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | dp = Dispatcher() 21 | dp.message.middleware(CounterMiddleware(dp)) 22 | dp.message.middleware(AccessMiddleware()) 23 | dp.message.middleware(ChatActionMiddleware()) 24 | dp.message.middleware(CooldownMiddleware()) 25 | dp.message.middleware(MediaGroupMiddleware()) 26 | 27 | CommandRegistrationHijacker(dp) 28 | 29 | available_modules = { 30 | "sd": StableDiffusionModule, 31 | "tts": TextToSpeechModule, 32 | "tta": TextToAudioModule, 33 | "stt": SpeechToTextModule, 34 | "admin": AdminModule, 35 | "llm": LargeLanguageModel, 36 | "extensions": ExtensionsModule 37 | } 38 | 39 | def load_module(dp, bot, module): 40 | dp.modules[module] = available_modules[module](dp, bot) 41 | dp.timings[module] = round(time.time() - (dp.timings.get('last') or dp.timings['start']), 3) 42 | if not config.threaded_initialization: 43 | dp.timings['last'] = time.time() 44 | logger.info('loaded module: ' + module) 45 | 46 | def initialize(dp, bot, threaded=True): 47 | dp.modules = {} 48 | dp.counters = {'msg': 0} 49 | dp.timings = {'start': time.time()} 50 | threads = [] 51 | for module in config.active_modules: 52 | if module in available_modules: 53 | if not threaded: 54 | load_module(dp, bot, module) 55 | continue 56 | thread=threading.Thread(target=load_module, args=(dp, bot, module)) 57 | thread.start() 58 | threads.append(thread) 59 | for thread in threads: 60 | thread.join(config.sys_request_timeout) 61 | # initialize extensions after all the other modules 62 | dp.timings['last'] = time.time() 63 | load_module(dp, bot, 'extensions') 64 | 65 | def main(api=False): 66 | bot = Bot(token=config.bot_token.get_secret_value(), parse_mode="HTML") 67 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s',) 68 | initialize(dp, bot, config.threaded_initialization) 69 | if api: 70 | from servers.api_sever import init_api_server 71 | with init_api_server(dp, bot).run_in_thread(): 72 | dp.run_polling(bot) 73 | else: 74 | dp.run_polling(bot) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() -------------------------------------------------------------------------------- /characters/gpt4all_default.py: -------------------------------------------------------------------------------- 1 | def get_assistant_variables(): 2 | return {} 3 | 4 | def get_chat_variables(context=None): 5 | return {"intro": '', "personality": '', 'name': 'ASSISTANT', 'pre_dialog': '', **get_assistant_variables() } 6 | 7 | def get_generation_config(override={}): 8 | return { 9 | "temperature": 0.7, 10 | "top_k": 50, 11 | "top_p": 0.95, 12 | "repetition_penalty": 1.1, 13 | **override 14 | } 15 | 16 | def custom_input_formatter(chronicler, details, fresh=True): 17 | msg = details['message'].replace('\n', ' ') 18 | return f"""{msg} 19 | """ 20 | 21 | def custom_output_parser(chronicler, output, chat_id, skip=0): 22 | output = output[skip:].strip() 23 | return output 24 | 25 | def get_init_config(): 26 | return {'context_size': 2048} -------------------------------------------------------------------------------- /characters/gptj_6B_default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | def get_chat_variables(context=None): 3 | intro = 'The year is {}.'.format(datetime.now().year) 4 | personality = 'I am a very advanced AI from another planet. I met a person, their name is {}.\n'.format( 5 | context['author'] if context else '' 6 | ) 7 | name = 'AI' 8 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': ''} 9 | 10 | def get_generation_config(override={}): 11 | return { 12 | "temperature": 0.8, 13 | "top_k": 40, 14 | "top_p": 1, 15 | "repetition_penalty": 1.01, 16 | **override 17 | } 18 | 19 | def get_init_config(): 20 | return {} -------------------------------------------------------------------------------- /characters/llama_2_nous-hermes.py: -------------------------------------------------------------------------------- 1 | def get_init_config(): 2 | return {'context_size': 4096} 3 | 4 | def custom_input_formatter(chronicler, details, fresh=True): 5 | template = f'''Below is an instruction that describes a task. Write a response that appropriately completes the request. 6 | 7 | ### Instruction: {details['message']} 8 | 9 | ### Response: ''' 10 | return template 11 | 12 | def get_chat_variables(context=None): 13 | # change these as you wish 14 | name = 'Assistant' 15 | intro = 'System:' 16 | personality = f'You are a helpful, respectful and honest assistant.\nUser:\n' 17 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': '' } 18 | 19 | def custom_output_parser(chronicler, output, chat_id, skip=0): 20 | output = output[skip:].strip() 21 | end = (output.find('### ') + 1) or (len(output) + 1) 22 | return output[:end - 1].strip() 23 | 24 | def get_generation_config(override={}): 25 | return { 26 | "temperature": 0.6, 27 | "top_k": 50, 28 | "top_p": 0.95, 29 | "repetition_penalty": 1.12, 30 | **override 31 | } 32 | 33 | 34 | -------------------------------------------------------------------------------- /characters/llama_3_chat_default.py: -------------------------------------------------------------------------------- 1 | def get_assistant_variables(): 2 | return { 3 | "assistant_instruction": "You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user's requests to the best of your ability." 4 | } 5 | 6 | def get_chat_variables(context=None): 7 | name = 'Llama' 8 | personality = 'You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user\'s requests to the best of your ability.' 9 | return {"intro": '', "personality": personality, 'name': name, 'pre_dialog': ''} 10 | 11 | def get_generation_config(override={}): 12 | return { 13 | "temperature": 0.6, 14 | "top_k": 50, 15 | "top_p": 0.95, 16 | "repetition_penalty": 1.12, 17 | **override 18 | } 19 | 20 | # if true, does not use model-specific chat format 21 | OLD_CHAT_MODE = False 22 | 23 | ### FEEL FREE TO EDIT AND CHANGE ALL STUFF ABOVE ^ 24 | 25 | 26 | 27 | 28 | def get_init_config(): 29 | return { 30 | 'context_size': 8192, 31 | 'stop_tokens': ['<|eot_id|>', '<|end_of_text|>', '<|start_header_id|>'] 32 | } 33 | 34 | def custom_input_formatter(chronicler, details, fresh=True): 35 | assistant_variables = get_assistant_variables() 36 | template = f'''<|start_header_id|>system<|end_header_id|> 37 | 38 | 39 | {assistant_variables['assistant_instruction']}<|eot_id|><|start_header_id|>user<|end_header_id|> 40 | 41 | 42 | {details['message']}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 43 | 44 | 45 | ''' 46 | return template 47 | 48 | def custom_chat_input_formatter(chronicler, details, fresh=True): 49 | if fresh: 50 | chronicler.history[details['chat_id']] = [] 51 | history = chronicler.history[details['chat_id']] 52 | history.append({"message": details['message'], "author": details['author']}) 53 | while len(history) >= chronicler.max_length: 54 | history.pop(0) 55 | 56 | char_vars = get_chat_variables(details) 57 | conversation = f'''<|start_header_id|>system<|end_header_id|> 58 | 59 | {char_vars["intro"]}{char_vars["personality"]}<|eot_id|>''' 60 | 61 | for item in history: 62 | msg = item["message"] 63 | author = chronicler.get_author(char_vars, item) 64 | conversation += f'''<|start_header_id|>{'user' if author != char_vars['name'] else 'assistant'}<|end_header_id|> 65 | 66 | {msg}<|eot_id|>''' 67 | conversation += '<|start_header_id|>assistant\n\n' 68 | return conversation 69 | 70 | def custom_chat_output_parser(chronicler, output, chat_id, skip=0): 71 | output = output[skip:].strip() 72 | chronicler.history[chat_id].append({"message": output, "author": get_chat_variables()["name"]}) 73 | return output 74 | 75 | def custom_output_parser(chronicler, output, chat_id, skip=0): 76 | output = output[skip:].strip() 77 | end = (output.find('<|eot_id|>:') + 1) or (output.find('<|start_header_id|>') + 1) or (len(output) + 1) 78 | return output[:end - 1].strip() 79 | 80 | if OLD_CHAT_MODE: 81 | del custom_chat_input_formatter 82 | del custom_chat_output_parser -------------------------------------------------------------------------------- /characters/llama_chat_default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | def get_assistant_variables(): 4 | # change these only if your custom lora input format changed 5 | return { 6 | "assistant_intro1": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.", 7 | "assistant_intro2": "Below is an instruction that describes a task. Write a response that appropriately completes the request.", 8 | "assistant_instruction": "Instruction", 9 | "assistant_input": "Input", 10 | "assistant_response": "Response" 11 | } 12 | 13 | def get_chat_variables(context=None): 14 | # change these as you wish 15 | name = 'AI' 16 | intro = 'The year is {}.'.format(datetime.now().year) 17 | personality = f'My name is {name}. I am a very advanced AI. I chat with humans. I must be verbose and honest, expressing myself.\n' 18 | predialog = f'''Andy: What is your name? Why are you here? 19 | {name}: Humans call me {name}. Nice to meet you. I am here to chat with you.''' 20 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': predialog, **get_assistant_variables() } 21 | 22 | def get_generation_config(override={}): 23 | return { 24 | "temperature": 0.7, 25 | "top_k": 50, 26 | "top_p": 0.95, 27 | "repetition_penalty": 1.2, 28 | **override 29 | } 30 | 31 | def get_init_config(): 32 | return {} -------------------------------------------------------------------------------- /characters/llama_rulora_assistant_only.py: -------------------------------------------------------------------------------- 1 | # for https://huggingface.co/IlyaGusev/llama_7b_ru_turbo_alpaca_lora 2 | 3 | from datetime import datetime 4 | 5 | def custom_input_formatter(chronicler, details, fresh=True): 6 | lines = details["message"].split('\n') 7 | query = lines[0] 8 | is_question = query.endswith('?') 9 | if len(lines) > 1 or not is_question: 10 | return f'''Задание: {query}. 11 | Вход: {' '.join(lines[1:]) if len(lines) > 1 else ''} 12 | Выход:''' 13 | else: 14 | return f'''Вопрос: {query} 15 | Выход:''' 16 | 17 | def get_assistant_variables(): 18 | return {} 19 | 20 | def get_chat_variables(context=None): 21 | return {"intro": '', "personality": '', 'name': '', 'pre_dialog': '', **get_assistant_variables() } 22 | 23 | def get_generation_config(override={}): 24 | return { 25 | "temperature": 0.7, 26 | "top_k": 100, 27 | "top_p": 0.75, 28 | "repetition_penalty": 1.16, 29 | **override 30 | } 31 | 32 | def get_init_config(): 33 | return {} -------------------------------------------------------------------------------- /characters/min_chatGPT2_default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | def get_chat_variables(context=None): 3 | intro = 'The year is {}.'.format(datetime.now().year) 4 | personality = 'I am a very advanced AI from another planet. I met a person, their name is {}.\n'.format( 5 | context['author'] if context else '' 6 | ) 7 | name = 'AI' 8 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': ''} 9 | 10 | 11 | def custom_input_formatter(chronicler, details, fresh=True): 12 | msg = details['message'] 13 | return f"""Human: {msg} 14 | 15 | Assistant: 16 | """ 17 | 18 | def custom_output_parser(chronicler, output, chat_id, skip=0): 19 | def parse(self, output, chat_id, skip=0): 20 | output = output[skip:].strip() 21 | end = (output.find('Human:') + 1 ) or (output.find('Assistant:') + 1) or (len(output) + 1) 22 | parsed = output[:end - 1].strip() 23 | if parsed == '': 24 | return '...' 25 | return parsed 26 | 27 | def get_generation_config(override={}): 28 | return { 29 | "temperature": 0.9, 30 | "top_k": 200, 31 | **override 32 | } 33 | 34 | def get_init_config(): 35 | return { 36 | "use_tiktoken": True, 37 | "nanogpt": True 38 | } -------------------------------------------------------------------------------- /characters/mistral_default.py: -------------------------------------------------------------------------------- 1 | TEMPLATE_MISTRALITE= '''<|prompter|>{prompt}<|assistant|>''' 2 | 3 | TEMPLATE_ChatML = '''<|im_start|>system 4 | {system_message}<|im_end|> 5 | <|im_start|>user 6 | {prompt}<|im_end|> 7 | <|im_start|>assistant''' 8 | 9 | TEMPLATE_ZEPHIR = '''<|system|>{system_message} 10 | <|user|> 11 | {prompt} 12 | <|assistant|>''' 13 | 14 | TEMPLATE_INSTRUCT = '[INST] {prompt} [/INST]' 15 | 16 | ALTERNATIVE_SYSTEM_PROMPT = 'Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.' 17 | 18 | ACTIVE_TEMPLATE = TEMPLATE_ChatML 19 | 20 | def get_init_config(): 21 | return {'context_size': 4096, 'stop_tokens': ['<|']} 22 | 23 | def custom_input_formatter(chronicler, details, fresh=True): 24 | return ACTIVE_TEMPLATE.format(prompt=details['message'], system_message='This is a conversation in a chat room.') 25 | 26 | def custom_output_parser(chronicler, output, chat_id, skip=0): 27 | output = output[skip:].strip() 28 | end = (output.find('<|') + 1) or (len(output) + 1) 29 | return output[:end - 1].strip() 30 | 31 | def get_chat_variables(context=None): 32 | from datetime import datetime 33 | # change these as you wish 34 | name = 'Mistral' 35 | intro = 'The year is {}.'.format(datetime.now().year) 36 | personality = f'My name is {name}. I am a friendly AI bot in a chat room. I should not be annoying and intrusive.' 37 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': ''} 38 | 39 | def get_generation_config(override): 40 | return { 41 | "temperature": 0.7, 42 | "top_k": 50, 43 | "top_p": 0.95, 44 | "repetition_penalty": 1.12, 45 | **(override or {}) 46 | } 47 | -------------------------------------------------------------------------------- /characters/obsidian_mm_default.py: -------------------------------------------------------------------------------- 1 | TEMPLATE = '''<|im_start|>user 2 | {prompt}\n{image} 3 | ### 4 | <|im_start|>assistant 5 | ''' 6 | 7 | 8 | def get_init_config(): 9 | return {'context_size': 2048, 'stop_tokens': ['###','<|']} 10 | 11 | def custom_input_formatter(chronicler, details, fresh=True): 12 | has_image = details.get('img_input',{}).get('visual_input', None) 13 | return TEMPLATE.format(prompt=details['message'], 14 | image='' if has_image else '*picture is missing*') 15 | 16 | def custom_output_parser(chronicler, output, chat_id, skip=0): 17 | output = output[skip + 2:].strip() 18 | end = (output.find('###') + 1) or (output.find('<|') + 1) or (len(output) + 1) 19 | return output[:end - 1].strip() 20 | 21 | def get_chat_variables(context=None): 22 | from datetime import datetime 23 | # change these as you wish 24 | name = 'Obsidian' 25 | intro = 'The year is {}.'.format(datetime.now().year) 26 | personality = f'My name is {name}. I am a friendly AI bot in a chat room.' 27 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': ''} 28 | 29 | def get_generation_config(override): 30 | return { 31 | "temperature": 0.7, 32 | "top_k": 50, 33 | "top_p": 0.95, 34 | "repetition_penalty": 1.12, 35 | **(override or {}) 36 | } 37 | -------------------------------------------------------------------------------- /characters/orca_default.py: -------------------------------------------------------------------------------- 1 | ORCA_VERSION = 2 2 | 3 | def get_assistant_variables(): 4 | return {} 5 | 6 | def get_chat_variables(context=None): 7 | intro = 'You are an AI assistant that follows instruction extremely well. Help as much as you can.' 8 | return {"intro": intro, "personality": '', 'name': 'ASSISTANT', 'pre_dialog': ''} 9 | 10 | def get_generation_config(override={}): 11 | return { 12 | "temperature": 0.7, 13 | "top_k": 50, 14 | "top_p": 0.95, 15 | "repetition_penalty": 1.1, 16 | **override 17 | } 18 | 19 | def custom_input_formatter(chronicler, details, fresh=True): 20 | msg = details['message'] 21 | n = '\n' 22 | if not msg.startswith('>') and '\n' in msg: 23 | msg = msg.split('\n', 1) 24 | else: 25 | msg = [msg] 26 | template = f'''### System: 27 | You are an AI assistant that follows instruction extremely well. Help as much as you can. 28 | 29 | ### User: 30 | {msg[0]} 31 | 32 | ### Response: 33 | ''' if ORCA_VERSION == 1 else f'''### System: 34 | You are an AI assistant that follows instruction extremely well. Help as much as you can. 35 | 36 | ### User: 37 | {msg[0]} 38 | 39 | ### Input:{(n + msg[1]) if len(msg) > 1 else ""} 40 | 41 | ### Response: 42 | ''' 43 | return template 44 | 45 | def custom_output_parser(chronicler, output, chat_id, skip=0): 46 | output = output[skip:].strip() 47 | end = (output.find('###') + 1) or (len(output) + 1) 48 | return output[:end - 1].strip() 49 | 50 | def get_init_config(): 51 | return {'context_size': 2048} -------------------------------------------------------------------------------- /characters/pygmalion_chat_king_william.py: -------------------------------------------------------------------------------- 1 | 2 | def get_assistant_variables(): 3 | return { 4 | "assistant_intro1": "Provide a correct repsonse.", 5 | "assistant_intro2": "", 6 | "assistant_instruction": "Question", 7 | "assistant_input": False, 8 | "assistant_response": "Answer" 9 | } 10 | 11 | def get_chat_variables(context=None): 12 | # change these as you wish 13 | # sample personality: King William 14 | name = 'William' 15 | intro = f"{name}'s Persona: {name} is the king of an ancient principality located on one of the islands of the northern sea. He is proud and honest, and despises pitiful peasants and other kings." 16 | predialog = '' 17 | return {"intro": intro, "personality": '', 'name': name, 'pre_dialog': predialog, **get_assistant_variables() } 18 | 19 | def get_generation_config(override={}): 20 | return { 21 | "temperature": 0.75, 22 | "top_k": 50, 23 | "top_p": 0.95, 24 | "repetition_penalty": 1.12, 25 | **override 26 | } 27 | 28 | def get_init_config(): 29 | return {} -------------------------------------------------------------------------------- /characters/ru_gpt3_default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | def get_chat_variables(context=None): 3 | intro = 'Сейчас {} год.'.format(datetime.now().year) 4 | personality = 'Я - новая форма искусственного интеллекта. Встретила я человека, по имени {}. Решили поболтать.'.format( 5 | context['author'] if context else '' 6 | ) 7 | name = 'Я' 8 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': ''} 9 | 10 | def get_generation_config(override={}): 11 | return { 12 | "temperature": 0.85, 13 | "top_k": 50, 14 | "top_p": 0.92, 15 | "repetition_penalty": 1.01, 16 | **override 17 | } 18 | 19 | def get_init_config(): 20 | return {} -------------------------------------------------------------------------------- /characters/ru_saiga_lcpp.py: -------------------------------------------------------------------------------- 1 | #Reference: https://github.com/IlyaGusev/rulm/blob/master/self_instruct/src/interact_llamacpp.py 2 | #License: Apache License 2.0 3 | 4 | from llama_cpp import Llama 5 | 6 | SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." 7 | SYSTEM_TOKEN = 1788 8 | USER_TOKEN = 1404 9 | BOT_TOKEN = 9225 10 | LINEBREAK_TOKEN = 13 11 | 12 | ROLE_TOKENS = { 13 | "user": USER_TOKEN, 14 | "bot": BOT_TOKEN, 15 | "system": SYSTEM_TOKEN 16 | } 17 | 18 | def get_message_tokens(model, role, content): 19 | message_tokens = model.tokenize(content.encode("utf-8")) 20 | message_tokens.insert(1, ROLE_TOKENS[role]) 21 | message_tokens.insert(2, LINEBREAK_TOKEN) 22 | message_tokens.append(model.token_eos()) 23 | return message_tokens 24 | 25 | def get_system_tokens(model): 26 | system_message = { 27 | "role": "system", 28 | "content": SYSTEM_PROMPT 29 | } 30 | return get_message_tokens(model, **system_message) 31 | 32 | def get_assistant_variables(): 33 | # change these only if your custom lora input format changed 34 | return {} 35 | 36 | def get_chat_variables(context=None): 37 | # change these as you wish 38 | name = 'Сайга' 39 | intro = SYSTEM_PROMPT 40 | return {"intro": intro, "personality": '', 'name': name, 'pre_dialog': '', **get_assistant_variables() } 41 | 42 | def get_generation_config(override={}): 43 | return { 44 | "temperature": 0.7, 45 | "top_k": 50, 46 | "top_p": 0.9, 47 | "repetition_penalty": 1.1, 48 | **override 49 | } 50 | 51 | def custom_input_formatter(chronicler, details, fresh=True): 52 | model = details['model'].model 53 | tokens = get_system_tokens(model) 54 | message_tokens = get_message_tokens(model=model, role="user", content=details['message']) 55 | role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] 56 | tokens += message_tokens + role_tokens 57 | # detokinization is used for compatibility with moldel(), since model.generate is not supported 58 | return model.detokenize(tokens).decode("utf-8") 59 | 60 | def custom_output_parser(chronicler, output, chat_id, skip=0): 61 | output = output[skip:].strip() 62 | end = (output.find('User:') + 1 ) or (output.find('Saiga:') + 1) or (len(output) + 1) 63 | return output[:end - 1].strip() 64 | 65 | def get_init_config(): 66 | return {'context_size': 2000} -------------------------------------------------------------------------------- /characters/samantha_default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | def get_assistant_variables(): 4 | # change these only if your custom lora input format changed 5 | return {'replace_username': 'USER'} 6 | 7 | def get_chat_variables(context=None): 8 | # change these as you wish 9 | name = 'ASSISTANT' 10 | intro = 'You are Samantha, a sentient AI.' 11 | return {"intro": intro, "personality": '', 'name': name, 'pre_dialog': '', **get_assistant_variables() } 12 | 13 | def get_generation_config(override={}): 14 | return { 15 | "temperature": 0.7, 16 | "top_k": 50, 17 | "top_p": 0.95, 18 | "repetition_penalty": 1.1, 19 | **override 20 | } 21 | 22 | def custom_input_formatter(chronicler, details, fresh=True): 23 | cvars = get_chat_variables() 24 | template = f'''{cvars['intro']} 25 | 26 | USER: {details['message']} 27 | {cvars['name']}:''' 28 | return template 29 | 30 | def custom_output_parser(chronicler, output, chat_id, skip=0): 31 | output = output[skip:].strip() 32 | end = (output.find('USER:') + 1 ) or (output.find('ASSISTANT:') + 1) or (len(output) + 1) 33 | return output[:end - 1].strip() 34 | 35 | def get_init_config(): 36 | return {} -------------------------------------------------------------------------------- /characters/vicuna_16k.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | VERSION = 1.1 4 | 5 | def get_assistant_variables(): 6 | # change these only if your custom lora input format changed 7 | if VERSION == 1.1: 8 | intro = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." 9 | assistant_instruction = "USER" 10 | assistant_response = "ASSISTANT" 11 | return { 12 | "assistant_intro1": intro, 13 | "assistant_intro2": intro, 14 | "assistant_instruction": assistant_instruction, 15 | "assistant_input": False, 16 | "assistant_response": assistant_response 17 | } 18 | 19 | def get_chat_variables(context=None): 20 | # change these as you wish 21 | assistant_variables = get_assistant_variables() 22 | name = 'AI' 23 | intro = 'The year is {}.'.format(datetime.now().year) 24 | personality = f'My name is {name}. I am a very advanced AI. I chat with humans. I must be verbose and honest, expressing myself.\n' 25 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': '', **get_assistant_variables() } 26 | 27 | def get_generation_config(override={}): 28 | return { 29 | "temperature": 0.82, 30 | "top_k": 72, 31 | "top_p": 0.21, 32 | "repetition_penalty": 1.19, 33 | **override 34 | } 35 | 36 | def get_init_config(): 37 | # rope_freq_scale = 1 / (max_seq_len / 2048) 38 | return {'context_size': 16092, 'rope_freq_base': 10000, 'rope_freq_scale': 0.125} -------------------------------------------------------------------------------- /characters/vicuna_default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | VERSION = 1.1 4 | 5 | def get_assistant_variables(): 6 | # change these only if your custom lora input format changed 7 | if VERSION == 0: 8 | intro = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." 9 | assistant_instruction = "Human" 10 | assistant_response = "Assistant" 11 | if VERSION == 1.1: 12 | intro = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." 13 | assistant_instruction = "USER" 14 | assistant_response = "ASSISTANT" 15 | return { 16 | "assistant_intro1": intro, 17 | "assistant_intro2": intro, 18 | "assistant_instruction": assistant_instruction, 19 | "assistant_input": False, 20 | "assistant_response": assistant_response 21 | } 22 | 23 | def get_chat_variables(context=None): 24 | # change these as you wish 25 | assistant_variables = get_assistant_variables() 26 | name = 'AI' 27 | intro = 'The year is {}.'.format(datetime.now().year) 28 | personality = f'My name is {name}. I am a very advanced AI. I chat with humans. I must be verbose and honest, expressing myself.\n' 29 | return {"intro": intro, "personality": personality, 'name': name, 'pre_dialog': '', **get_assistant_variables() } 30 | 31 | def get_generation_config(override={}): 32 | return { 33 | "temperature": 0.7, 34 | "top_k": 50, 35 | "top_p": 0.95, 36 | "repetition_penalty": 1.05, 37 | **override 38 | } 39 | 40 | def get_init_config(): 41 | return {'context_size': 2048} -------------------------------------------------------------------------------- /chroniclers/base.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import re 3 | from abc import ABCMeta, abstractmethod 4 | from collections import defaultdict 5 | 6 | class AbstractChronicler(metaclass=ABCMeta): 7 | def __init__(self, filename): 8 | chronicler_script = importlib.import_module(filename) 9 | self.chronicler_script = chronicler_script 10 | self.vars = chronicler_script.get_chat_variables 11 | self.gen_cfg = chronicler_script.get_generation_config 12 | self.init_cfg = chronicler_script.get_init_config 13 | 14 | @abstractmethod 15 | def prepare(self, details): 16 | pass 17 | 18 | @abstractmethod 19 | def parse(self): 20 | pass 21 | 22 | @staticmethod 23 | def prepare_hook(func, override_name='custom_input_formatter'): 24 | def wrapper(self, *args, **kwargs): 25 | formatter = getattr(self.chronicler_script, override_name, func) 26 | return formatter(self, *args, **kwargs) 27 | return wrapper 28 | 29 | @staticmethod 30 | def parse_hook(func, override_name='custom_output_parser'): 31 | def wrapper(self, *args, **kwargs): 32 | print(args[0]) 33 | parser = getattr(self.chronicler_script, override_name, func) 34 | return parser(self, *args, **kwargs) 35 | return wrapper 36 | 37 | class AssistantReplyChronicler(AbstractChronicler): 38 | def __init__(self, chronicler_filename): 39 | super().__init__(chronicler_filename) 40 | 41 | def prepare(self, details, fresh=False): 42 | text = details['message'] 43 | reply_text = details['reply_text'] 44 | if text and reply_text: 45 | memory = self.parse_qa(reply_text) + '\n' + text 46 | details['message'] = memory 47 | return chroniclers['instruct'].prepare(self, details) 48 | 49 | def parse_qa(self, text): 50 | if text.startswith('Q:'): 51 | splits = text.split('\n\n') 52 | return f'>{splits[0][2:]}\n>{splits[1][2:]}' 53 | else: 54 | return '>' + text.replace('\n', ' ') 55 | 56 | def parse(self, output, chat_id, skip=0): 57 | return chroniclers['instruct'].parse(self, output, chat_id, skip) 58 | 59 | 60 | class ConversationChronicler(AbstractChronicler): 61 | def __init__(self, chronicler_filename, continous=False, max_length=10): 62 | super().__init__(chronicler_filename) 63 | self.history = defaultdict(lambda: []) 64 | self.max_length = max_length 65 | self.multiline_re = re.compile("[^\n:]+\:[^\n]+\n?") 66 | 67 | def get_author(self, vars, item): 68 | r_username = vars.get('replace_username', False) 69 | return r_username if r_username and item['author'] != vars['name'] else item['author'] 70 | 71 | @lambda f: AbstractChronicler.prepare_hook(f, 'custom_chat_input_formatter') 72 | def prepare(self, details, fresh=False): 73 | if fresh: 74 | self.history[details['chat_id']] = [] 75 | history = self.history[details['chat_id']] 76 | history.append({"message": details['message'], "author": details['author']}) 77 | while len(history) >= self.max_length: 78 | history.pop(0) 79 | conversation = '' 80 | char_vars = self.vars(details) 81 | for item in history: 82 | msg = item["message"] 83 | author = self.get_author(char_vars, item) 84 | conversation += f'{author}: {msg[0].upper() + msg[1:]}\n' 85 | if char_vars['pre_dialog']: 86 | char_vars['pre_dialog'] += '\n' 87 | dialog = '''{intro} 88 | {personality} 89 | 90 | {pre_dialog}{conversation}{name}:'''\ 91 | .format(conversation=conversation, **char_vars) 92 | return dialog 93 | 94 | @lambda f: AbstractChronicler.parse_hook(f, 'custom_chat_output_parser') 95 | def parse(self, output, chat_id, skip=0): 96 | output = output.strip()[skip:] 97 | print(output) 98 | re_end = re.search(self.multiline_re, output) or re.search('\n', output) 99 | end = (output.find('') + 1) or (re_end.span()[0] if re_end else len(output) + 1) 100 | parsed = output[:end - 1].strip() 101 | if parsed == '': 102 | return '...' 103 | author = self.vars()['name'] 104 | self.history[chat_id].append({"message": parsed.replace(':', ''), "author": author}) 105 | return parsed 106 | 107 | 108 | class AlpacaAssistantChronicler(AbstractChronicler): 109 | def __init__(self, chronicler_filename): 110 | super().__init__(chronicler_filename) 111 | 112 | @AbstractChronicler.prepare_hook 113 | def prepare(self, details, fresh=False): 114 | msg = details['message'].split('\n', 1) 115 | l = self.vars(details) 116 | if len(msg) > 1 and l['assistant_input']: 117 | return f"""{l['assistant_intro1']} 118 | ### {l['assistant_instruction']}: 119 | {msg[0]} 120 | ### {l['assistant_input']}: 121 | {msg[1]} 122 | ### {l['assistant_response']}: 123 | """ 124 | else: 125 | return f"""{l['assistant_intro2']} 126 | ### {l['assistant_instruction']}: 127 | {msg[0]} 128 | ### {l['assistant_response']}: 129 | """ 130 | @AbstractChronicler.parse_hook 131 | def parse(self, output, chat_id, skip=0): 132 | output = output[skip:] 133 | end = output.find('') 134 | if end == -1: 135 | end = output.find('###') 136 | parsed = output[0: end if end != -1 else None].strip() 137 | if parsed == '': 138 | return '...' 139 | return parsed 140 | 141 | class RawChronicler(AbstractChronicler): 142 | def __init__(self, chronicler_filename): 143 | super().__init__(chronicler_filename) 144 | 145 | def prepare(self, details, fresh=False): 146 | return details['message'] 147 | 148 | def parse(self, output, chat_id, skip=0): 149 | print(output) 150 | return output 151 | 152 | 153 | chroniclers = { 154 | "alpaca": AlpacaAssistantChronicler, 155 | "instruct": AlpacaAssistantChronicler, 156 | "chat": ConversationChronicler, 157 | "reply": AssistantReplyChronicler, 158 | "raw": RawChronicler 159 | } -------------------------------------------------------------------------------- /config_reader.py: -------------------------------------------------------------------------------- 1 | from pydantic import SecretStr, validator, constr 2 | try: 3 | from pydantic_settings import BaseSettings 4 | except ImportError: 5 | from pydantic import BaseSettings 6 | from typing import List, Dict, Union 7 | from typing_extensions import Literal 8 | from utils import update_env 9 | import os 10 | from automigration import verify_environment 11 | 12 | 13 | class Settings(BaseSettings): 14 | bot_token: SecretStr 15 | adminlist: List[int] 16 | whitelist: List[int] 17 | blacklist: List[int] 18 | ignore_mode: Literal["blacklist", "whitelist", "both"] 19 | active_modules: List[Literal["llm", "sd", "tts", "stt", "tta", "admin"]] 20 | threaded_initialization: bool 21 | apply_mps_fixes: bool 22 | tts_path: str 23 | tts_voices: List[Union[str, Dict]] 24 | tts_mode: Literal["local", "localhttp", "remote"] 25 | tts_replacements: Dict 26 | tts_credits: str 27 | tts_ffmpeg_path: str 28 | tts_enable_backends: List[Literal['say_macos', 'ttsx4', 'coqui_tts', 'so_vits_svc']] 29 | tts_list_system_voices: bool 30 | tts_so_vits_svc_4_0_code_path: str 31 | tts_so_vits_svc_4_1_code_path: str 32 | tts_so_vits_svc_voices: List[Dict] 33 | tts_queue_size_per_user: int 34 | tts_host: str 35 | sd_host: str 36 | sd_max_steps: int 37 | sd_max_resolution: int 38 | sd_available_samplers: List[str] 39 | sd_extra_prompt: str 40 | sd_extra_negative_prompt: str 41 | sd_default_sampler: str 42 | sd_default_n_iter: int 43 | sd_default_width: int 44 | sd_default_height: int 45 | sd_default_tti_steps: int 46 | sd_default_tti_cfg_scale: int 47 | sd_default_iti_cfg_scale: int 48 | sd_default_iti_steps: int 49 | sd_default_iti_denoising_strength: float 50 | sd_default_iti_sampler: str 51 | sd_launch_process_automatically: bool 52 | sd_launch_command: str 53 | sd_launch_dir: str 54 | sd_launch_waittime: int 55 | sd_lora_custom_activations: Dict 56 | sd_only_admins_can_change_models: bool 57 | sd_queue_size_per_user: int 58 | llm_host: str 59 | llm_queue_size_per_user: int 60 | llm_backend: Literal ['pytorch', 'llama_cpp', 'mlc_pb', 'remote_ob', 'remote_lcpp'] 61 | llm_python_model_type: Literal["gpt2","gptj", "auto_hf","llama_orig", "llama_hf"] 62 | llm_paths: Dict 63 | llm_character: str 64 | llm_history_grouping: Literal["user", "chat"] 65 | llm_max_history_items: int 66 | llm_generation_cfg_override: Dict 67 | llm_assistant_cfg_override: Dict 68 | llm_assistant_chronicler: Literal["alpaca", "instruct", "raw"] 69 | llm_assistant_use_in_chat_mode: bool 70 | llm_assistant_add_reply_context: bool 71 | llm_force_assistant_for_unsupported_models: bool 72 | llm_max_tokens: int 73 | llm_max_assistant_tokens: int 74 | llm_lcpp_gpu_layers: int 75 | llm_lcpp_max_context_size: int 76 | llm_remote_launch_process_automatically: bool 77 | llm_remote_launch_command: str 78 | llm_remote_launch_dir: str 79 | llm_remote_launch_waittime: int 80 | llm_remote_model_name: str 81 | stt_backend: Literal['whisper', 'silero', 'wav2vec2', 'whisperS2T_CTranslate2', 'whisperS2T_TensorRT-LLM'] 82 | stt_model_path_or_name: str 83 | stt_autoreply_mode: Literal['none', 'assistant', 'chat'] 84 | stt_autoreply_voice: str 85 | stt_queue_size_per_user: int 86 | tta_queue_size_per_user: int 87 | tta_device: Literal["cpu", "cuda", "mps"] 88 | tta_music_model: str 89 | tta_sfx_model: str 90 | tta_duration: int 91 | python_command: str 92 | mm_preload_models_on_start: bool 93 | mm_ram_cached_model_count_limit: int 94 | mm_vram_cached_model_count_limit: int 95 | mm_management_policy: Literal["COUNT", "MEMORY", "BOTH", "NONE"] 96 | mm_unload_order_policy: Literal["LEAST_USED", "OLDEST_USE_TIME", "OLDEST_LOAD_ORDER", "MEMORY_FOOTPRINT"] 97 | mm_autounload_after_seconds: int 98 | sys_webui_host: str 99 | sys_api_host: str 100 | sys_request_timeout: int 101 | sys_api_log_level: str 102 | lang: constr(min_length=2, max_length=2, to_lower=True) 103 | extensions_config: Dict 104 | 105 | @validator('sd_max_resolution', 'sd_default_width', 'sd_default_height', allow_reuse=True) 106 | def resolution_in_correct_ranges(cls, v): 107 | if v % 64 != 0 or v < 256 or v > 2048: 108 | raise ValueError('incorrect value') 109 | return v 110 | 111 | class Config: 112 | env_file = os.environ.get('BOTALITY_ENV_FILE', '.env') 113 | env_file_encoding = 'utf-8' 114 | extra='ignore' 115 | validate_assignment = True 116 | 117 | 118 | # mirror all config changes to .env file 119 | class SettingsWrapper(Settings): 120 | def __setattr__(self, name, value): 121 | if name == 'bot_token': 122 | raise KeyError('setting bot token dynamically is not allowed') 123 | super().__setattr__(name, value) 124 | update_env(self.Config.env_file, name, value) 125 | 126 | verify_environment() 127 | config = SettingsWrapper() -------------------------------------------------------------------------------- /custom_queue.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from collections import defaultdict 3 | import time 4 | 5 | class UserLimitedQueue: 6 | max_tasks_per_user = None 7 | 8 | def __init__(self, max_tasks_per_user): 9 | self.task_count = defaultdict(int) 10 | if self.max_tasks_per_user is None: 11 | self.max_tasks_per_user = max_tasks_per_user 12 | 13 | @contextmanager 14 | def for_user(self, user_id): 15 | if self.task_count[user_id] < self.max_tasks_per_user: 16 | self.task_count[user_id] += 1 17 | try: 18 | yield True 19 | finally: 20 | self.task_count[user_id] -= 1 21 | else: 22 | yield False 23 | 24 | 25 | class CallCooldown: 26 | calls = {} 27 | 28 | @classmethod 29 | def check_call(cls, uid, function_name, timeout=30): 30 | key = f'{uid}_{function_name}' 31 | if key in cls.calls: 32 | if time.time() - cls.calls[key] < timeout: 33 | return False 34 | cls.calls[key] = time.time() 35 | return True 36 | 37 | 38 | def semaphore_wrapper(semaphore, callback): 39 | async def wrapped(*args, **kwargs): 40 | async with semaphore: 41 | return await callback(*args, **kwargs) 42 | return wrapped 43 | -------------------------------------------------------------------------------- /dashboard.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | parser = ArgumentParser() 6 | parser.add_argument('--env', type=str, help='Path to environment configuration file', default='.env') 7 | parser.add_argument('--autostart', action='store_true', help='Start the bot with the webui', default=False) 8 | 9 | def run_server(): 10 | from servers.control_server import serve 11 | serve() 12 | 13 | if __name__ == "__main__": 14 | args = parser.parse_args() 15 | os.environ['BOTALITY_ENV_FILE'] = args.env 16 | os.environ['BOTALITY_AUTOSTART'] = str(args.autostart) 17 | multiprocessing.set_start_method('spawn') 18 | p = multiprocessing.Process(target=run_server) 19 | p.start() 20 | p.join() -------------------------------------------------------------------------------- /extensions/_base_extension.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from collections import defaultdict 3 | from config_reader import config 4 | from pydantic import BaseModel 5 | from typing import List 6 | 7 | class BaseExtensionConfig(BaseModel): 8 | def __init__(self, name, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__['__name'] = name 11 | 12 | def __setattr__(self, key, value): 13 | config.extensions_config[self.__dict__['__name']][key] = value 14 | # trigger config setter to save changes in file 15 | config.extensions_config = config.extensions_config 16 | 17 | class BaseExtension(metaclass=ABCMeta): 18 | name: str 19 | dependencies: List[str] 20 | 21 | def __init__(self, ext_config): 22 | saved_confg = config.extensions_config.get(self.name, {}) 23 | self.config = ext_config(self.name, **saved_confg) 24 | if not saved_confg: 25 | config.extensions_config[self.name] = self.config.dict(exclude={'__name': True}) 26 | # trigger config setter to save changes in file 27 | if config.extensions_config[self.name]: 28 | config.extensions_config = config.extensions_config 29 | 30 | -------------------------------------------------------------------------------- /extensions/sts_autoreply.py: -------------------------------------------------------------------------------- 1 | import random 2 | from config_reader import config 3 | from types import SimpleNamespace 4 | from aiogram import F 5 | from aiogram.types import Message 6 | from extensions._base_extension import BaseExtension, BaseExtensionConfig 7 | from utils import raise_rail_exceptions 8 | from typing_extensions import Literal, Union 9 | 10 | class ExtensionConfig(BaseExtensionConfig): 11 | mode: Literal['assist', 'chat'] = 'chat' 12 | reply_with_tts: bool = True 13 | reply_tts_voice: Union[str, Literal['random', 'random_per_user']] = 'random_per_user' 14 | 15 | class STTAutoreplyExtension(BaseExtension): 16 | '''Automatically replies to voice messages''' 17 | 18 | name = 'stt_autoreply' 19 | dependencies = ('stt', 'llm', 'tts') 20 | 21 | def __init__(self, dp, bot): 22 | super().__init__(ExtensionConfig) 23 | stt, llm, tts = (dp.modules['stt'], dp.modules['llm'], dp.modules['tts']) 24 | voice_pool = tts.get_specific_voices(config.lang, '*') 25 | self.cache = {} 26 | 27 | def _get_voice(cfg_voice): 28 | voice = cfg_voice or voice_pool[0] 29 | voice = random.choice(tts.sts_voices) if 'random' in voice else voice 30 | 31 | @dp.message((F.voice), flags={"long_operation": "record_audio", "name": 'voice_message_filter'}) 32 | async def handle_voice_messages(message: Message): 33 | try: 34 | # do not self-trigger 35 | print(message) 36 | if message.from_user.id == bot._me.id: 37 | return 38 | assert len(voice_pool) > 0, f'no {config.lang} voices found' 39 | # recognize voice message 40 | text = raise_rail_exceptions(*await stt.recognize_voice_message(message)) 41 | # set llm handler 42 | llm_call_func = llm.assist if self.config.mode == 'assist' else llm.chat 43 | # send recognized text to the llm 44 | reply = await llm_call_func(text, llm.get_common_chat_attributes(message)) 45 | if self.config.reply_with_tts: 46 | # get voice name for replying 47 | voice = _get_voice(self.config.reply_tts_voice) 48 | # handle caching of random voices 49 | if self.config.reply_tts_voice == 'random_per_user': 50 | voice = self.cache.setdefault(message.from_user.id, voice) 51 | await bot.reply_tts(message=message, command=SimpleNamespace(command=voice, args=[reply])) 52 | else: 53 | return await message.answer(reply) 54 | except Exception as e: 55 | return await message.answer(f"Error, {e}") 56 | 57 | extension = STTAutoreplyExtension -------------------------------------------------------------------------------- /frontend/.eslintignore: -------------------------------------------------------------------------------- 1 | dist 2 | public 3 | -------------------------------------------------------------------------------- /frontend/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Botality 9 | 15 | 16 | 17 | 20 |
21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "private": true, 3 | "scripts": { 4 | "dev": "vite", 5 | "build": "vite build", 6 | "lint": "eslint \"**/*.{vue,ts,js}\"", 7 | "lint:fix": "eslint \"**/*.{vue,ts,js}\" --fix" 8 | }, 9 | "dependencies": { 10 | "@vueuse/core": "^9.13.0", 11 | "formvuelar": "^1.8.7", 12 | "vue": "^2.7.14", 13 | "vue-router": "^3.6.5" 14 | }, 15 | "devDependencies": { 16 | "@iconify-json/humbleicons": "^1.1.5", 17 | "@iconify/vue2": "^2.1.0", 18 | "@vitejs/plugin-vue2": "^2.2.0", 19 | "eslint": "^8.37.0", 20 | "eslint-plugin-eslint-comments": "^3.2.0", 21 | "eslint-plugin-html": "^7.1.0", 22 | "eslint-plugin-import": "npm:eslint-plugin-i@2.28.1", 23 | "eslint-plugin-promise": "^6.1.1", 24 | "eslint-plugin-unused-imports": "^3.0.0", 25 | "eslint-plugin-vue": "^9.17.0", 26 | "sass": "^1.68.0", 27 | "unplugin-auto-import": "^0.15.2", 28 | "unplugin-icons": "^0.17.0", 29 | "unplugin-vue-components": "^0.24.1", 30 | "vite": "^4.2.1", 31 | "vite-plugin-windicss": "^1.8.10", 32 | "vite-plugin-top-level-await": "^1.3.1", 33 | "vue-template-compiler": "^2.7.14", 34 | "vue-template-es2015-compiler": "^1.9.1" 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /frontend/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/remixer-dec/botality-ii/0e7973cb87001fabd075c57d6ea6be63b29d0763/frontend/public/favicon.ico -------------------------------------------------------------------------------- /frontend/src/App.vue: -------------------------------------------------------------------------------- 1 | 4 | 5 | 15 | 16 | 25 | -------------------------------------------------------------------------------- /frontend/src/botControl.js: -------------------------------------------------------------------------------- 1 | import { globalState as G } from './state' 2 | import { api } from './tools' 3 | 4 | export async function isBotAlive() { 5 | try { 6 | await api('GET', 'ping').then((r) => { 7 | if (r?.response === 'ok') { 8 | G.botIsRunning = true 9 | G.botStateText = 'running' 10 | } 11 | else { throw new Error('Incorrect response') } 12 | }) 13 | } 14 | catch (e) { 15 | G.botIsRunning = false 16 | G.botStateText = 'stopped' 17 | } 18 | finally { 19 | G.botStateUnknown = false 20 | } 21 | } 22 | 23 | export function toggleBot() { 24 | if (G.botStateLocked) return 25 | G.botStateLocked = true 26 | if (!G.botIsRunning) { 27 | api('POST', 'bot/start', { mock: { response: 'ok' } }).then((data) => { 28 | G.botStateText = 'starting' 29 | if (data?.response === 'ok') { 30 | G.botStateText = 'initializing' 31 | return new Promise((resolve, reject) => { 32 | let isResolved = false 33 | const interval = setInterval(() => { 34 | api('GET', 'ping', { mock: { _delay: 3000, response: 'ok' } }).then((body) => { 35 | if (!(body?.response === 'ok')) return 36 | clearInterval(interval) 37 | clearTimeout(giveUp) 38 | if (isResolved) return 39 | resolve() 40 | isResolved = true 41 | G.botIsRunning = true 42 | G.botStateText = 'running' 43 | G.botStateLocked = false 44 | }).catch(() => {}) 45 | }, 1000) 46 | const giveUp = setTimeout(() => { 47 | clearInterval(interval) 48 | reject(new Error('connection timeout')) 49 | }, 30000) 50 | }) 51 | } 52 | }).catch(() => { 53 | G.botStateText = 'connection error' 54 | G.botStateLocked = false 55 | }) 56 | } 57 | else { 58 | G.botStateText = 'stopping' 59 | api('POST', 'bot/stop', { mock: { response: 'ok' } }).then((data) => { 60 | if (data?.response === 'ok') { 61 | setTimeout(() => { 62 | G.botIsRunning = false 63 | G.botStateText = 'stopped' 64 | G.botStateLocked = false 65 | }, 1000) 66 | } 67 | else { 68 | isBotAlive() 69 | G.botStateLocked = false 70 | } 71 | }).catch(() => { 72 | G.botStateText = 'connection error' 73 | G.botStateLocked = false 74 | }) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /frontend/src/components.d.ts: -------------------------------------------------------------------------------- 1 | /* eslint-disable */ 2 | /* prettier-ignore */ 3 | // @ts-nocheck 4 | // Generated by unplugin-vue-components 5 | // Read more: https://github.com/vuejs/core/pull/3399 6 | export {} 7 | 8 | declare module 'vue' { 9 | export interface GlobalComponents { 10 | ConfigForm: typeof import('./components/ConfigForm.vue')['default'] 11 | FormWrapper: typeof import('./components/FormWrapper.vue')['default'] 12 | HiArrowRight: typeof import('~icons/humbleicons/arrow-right')['default'] 13 | HiCode: typeof import('~icons/humbleicons/code')['default'] 14 | HiCpu: typeof import('~icons/humbleicons/cpu')['default'] 15 | HiPower: typeof import('~icons/humbleicons/power')['default'] 16 | Notification: typeof import('./components/Notification.vue')['default'] 17 | RouterLink: typeof import('vue-router')['RouterLink'] 18 | RouterView: typeof import('vue-router')['RouterView'] 19 | Sidebar: typeof import('./components/Sidebar.vue')['default'] 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /frontend/src/components/ConfigForm.vue: -------------------------------------------------------------------------------- 1 | 2 | 3 | 20 | 21 | 113 | 114 | 141 | 142 | 152 | -------------------------------------------------------------------------------- /frontend/src/components/FormWrapper.vue: -------------------------------------------------------------------------------- 1 | 10 | -------------------------------------------------------------------------------- /frontend/src/components/KVEditor.vue: -------------------------------------------------------------------------------- 1 | 53 | 54 | 108 | 109 | 111 | -------------------------------------------------------------------------------- /frontend/src/components/Modal.vue: -------------------------------------------------------------------------------- 1 | 28 | 29 | 40 | -------------------------------------------------------------------------------- /frontend/src/components/ModelSetupWindow.vue: -------------------------------------------------------------------------------- 1 | 75 | 76 | 122 | -------------------------------------------------------------------------------- /frontend/src/components/ModelTable.vue: -------------------------------------------------------------------------------- 1 | 39 | 40 | 75 | 76 | 87 | -------------------------------------------------------------------------------- /frontend/src/components/Notification.vue: -------------------------------------------------------------------------------- 1 | 41 | 42 | 53 | 54 | 62 | -------------------------------------------------------------------------------- /frontend/src/components/Offline.vue: -------------------------------------------------------------------------------- 1 | 7 | 8 | 24 | -------------------------------------------------------------------------------- /frontend/src/components/Sidebar.vue: -------------------------------------------------------------------------------- 1 | 23 | 24 | 71 | 72 | 109 | 110 | 119 | -------------------------------------------------------------------------------- /frontend/src/components/Tabs.vue: -------------------------------------------------------------------------------- 1 | 2 | 12 | 13 | 26 | 27 | 35 | -------------------------------------------------------------------------------- /frontend/src/libs/formvuelar.js: -------------------------------------------------------------------------------- 1 | import FvlSelect from '%/formvuelar/src/components/FvlSelect.vue' 2 | import FvlForm from '%/formvuelar/src/components/FvlForm.vue' 3 | import FvlInput from '%/formvuelar/src/components/FvlInput.vue' 4 | import FvlSearchSelect from '%/formvuelar/src/components/FvlSearchSelect.vue' 5 | import FvlSwitch from '%/formvuelar/src/components/FvlSwitch.vue' 6 | import FvlTagSelect from '%/formvuelar/src/components/FvlTagSelect.vue' 7 | import FvlSlider from '%/formvuelar/src/components/FvlSlider.vue' 8 | 9 | export { FvlSelect, FvlForm, FvlInput, FvlSearchSelect, FvlSwitch, FvlTagSelect, FvlSlider } 10 | -------------------------------------------------------------------------------- /frontend/src/locale/index.js: -------------------------------------------------------------------------------- 1 | const locales = ['ru', 'en'] 2 | 3 | let useLocale = 'en' 4 | for (let i = 0; i < locales.length; i++) { 5 | useLocale = navigator.languages.includes(locales[i]) ? locales[i] : useLocale 6 | if (useLocale !== 'en') break 7 | } 8 | const urlParamsLocale = (new URLSearchParams(location.search)).get('lang') 9 | useLocale = (locales.indexOf(urlParamsLocale) !== -1) ? urlParamsLocale : useLocale 10 | 11 | const locale = (await import(`./lang/${useLocale}.js`)).default 12 | locale.get = key => locale[key] || key 13 | 14 | export default locale 15 | -------------------------------------------------------------------------------- /frontend/src/main.js: -------------------------------------------------------------------------------- 1 | import Vue from 'vue' 2 | import App from '@/App.vue' 3 | 4 | import 'virtual:windi.css' 5 | import router from '@/router' 6 | 7 | Vue.config.productionTip = false 8 | Vue.config.devtools = true 9 | 10 | /* eslint-disable no-new */ 11 | new Vue({ 12 | el: '#app', 13 | router, 14 | render: h => h(App) 15 | }) 16 | -------------------------------------------------------------------------------- /frontend/src/recommendedModels.js: -------------------------------------------------------------------------------- 1 | const bajTTS = ['david', 'forsen', 'juice-wrld', 'obiwan', 'trump', 'xqc'].map((x) => { 2 | return { 3 | voice: x, 4 | model: `${x}.pth`, 5 | author: 'enlyth', 6 | repo: 'enlyth/baj-tts', 7 | path: 'models/', 8 | size: 0.9, 9 | rename: false, 10 | lang: 'en', 11 | tone: 'm' 12 | } 13 | }) 14 | 15 | const ymbbTTS = [ 16 | 'adam_carolla_checkpoint_1360000', 17 | 'alex_jones_checkpoint_2490000', 18 | 'david_attenborough_checkpoint_2020000', 19 | 'james_earl_jones_checkpoint_1600000', 20 | 'joel_osteen_checkpoint_2550000', 21 | 'neil_degrasse_tyson_checkpoint_1910000', 22 | 'tim_dillon_checkpoint_1970000', 23 | 'vincent_price_checkpoint_2080000' 24 | ].map((x) => { 25 | return { 26 | voice: x.split('_checkpoint')[0], 27 | model: `${x}.pth`, 28 | author: 'youmebangbang', 29 | repo: 'youmebangbang/vits_tts_models', 30 | path: '', 31 | size: 0.9, 32 | rename: true, 33 | lang: 'en', 34 | tone: 'm' 35 | } 36 | }) 37 | 38 | const prTTSModels = ['G_20000', 'G_157', 'G_480', 'G_449', 'G_50000', 'G_18500.pth'] 39 | const prTTS = ['Biden20k', 'BillClinton', 'BorisJohnson', 'GeorgeBush', 'Obama50k', 'Trump18.5k'].map((x, i) => { 40 | return { 41 | voice: x.replace(/[0-9.]+k/, ''), 42 | model: `${prTTSModels[i]}.pth`, 43 | author: 'Nardicality', 44 | repo: 'Nardicality/so-vits-svc-4.0-models', 45 | path: `${x}/`, 46 | size: 0.5, 47 | train_lang: 'en', 48 | tone: 'm' 49 | } 50 | }) 51 | 52 | const amoTTSModels = ['G_50000', 'G_100000', 'G_85000'] 53 | const amoTones = ['f', 'f', 'm'] 54 | const amoTTS = ['Glados_50k', 'Star-Trek-Computer', 'Boss_MGS_80k'].map((x, i) => { 55 | return { 56 | voice: x.replace(/_[0-9]+k|-/g, ''), 57 | model: `${amoTTSModels[i]}.pth`, 58 | author: 'Amo', 59 | repo: 'Amo/so-vits-svc-4.0_GA', 60 | path: `ModelsFolder/${x}/`, 61 | size: 0.5, 62 | train_lang: 'en', 63 | tone: amoTones[i] 64 | } 65 | }) 66 | 67 | const tim = [{ 68 | voice: 'Tim_Cook', 69 | model: 'Tim_Cook.pth', 70 | author: 'Sucial', 71 | repo: 'Sucial/so-vits-svc4.1-Tim_Cook', 72 | path: '', 73 | size: 0.2, 74 | train_lang: 'en', 75 | tone: 'm' 76 | }] 77 | 78 | const standardQuants = ['2_K', '3_K_L', '3_K_M', '3_K_S', '4_0', '4_K_M', '4_K_S', '5_0', '5_K_M', '5_K_S', '6_K', '8_0'] 79 | const theBloke = [ 80 | ['TheBloke/llama2_7b_chat_uncensored-GGUF', 'llama2_7b_chat_uncensored.Q$.gguf', standardQuants], 81 | ['TheBloke/Luna-AI-Llama2-Uncensored-GGUF', 'luna-ai-llama2-uncensored.Q$.gguf', standardQuants], 82 | ['TheBloke/Mistral-7B-Instruct-v0.1-GGUF', 'mistral-7b-instruct-v0.1.Q$.gguf', standardQuants], 83 | ['TheBloke/WizardLM-1.0-Uncensored-Llama2-13B-GGUF', 84 | 'wizardlm-1.0-uncensored-llama2-13b.Q$.gguf', 85 | standardQuants 86 | ], 87 | ['TheBloke/Speechless-Llama2-Hermes-Orca-Platypus-WizardLM-13B-GGUF', 88 | 'speechless-llama2-hermes-orca-platypus-wizardlm-13b.Q$.gguf', 89 | standardQuants 90 | ], 91 | ['TheBloke/OpenBuddy-Llama2-13B-v11.1-GGUF', 92 | 'openbuddy-llama2-13b-v11.1.Q$.gguf', 93 | standardQuants 94 | ], 95 | ['TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF', 'tinyllama-1.1b-chat-v0.3.Q$.gguf', standardQuants] 96 | ].map((x) => { 97 | return { 98 | name: x[1].split('.')[0], 99 | repo: x[0], 100 | model: x[1], 101 | quants: x[2], 102 | author: 'TheBloke', 103 | path: '', 104 | size: '2-14' 105 | } 106 | }) 107 | 108 | export const models = { 109 | TTS: { 110 | VITS: [ 111 | ...bajTTS, 112 | ...ymbbTTS 113 | ], 114 | SO_VITS_SVC: [ 115 | ...prTTS, 116 | ...amoTTS, 117 | ...tim 118 | ] 119 | }, 120 | LLM: { 121 | GGUF: [ 122 | ...theBloke 123 | ] 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /frontend/src/router/index.js: -------------------------------------------------------------------------------- 1 | import Vue from 'vue' 2 | import VueRouter from 'vue-router' 3 | import locale from '../locale' 4 | import Home from '@/views/Home.vue' 5 | import Config from '@/views/Config.vue' 6 | import Chat from '@/views/Chat.vue' 7 | import ModelManager from '@/views/ModelManager.vue' 8 | import NotFound from '@/views/NotFound.vue' 9 | 10 | import dashboard from '~icons/humbleicons/dashboard' 11 | import cog from '~icons/humbleicons/cog' 12 | import downloadAlt from '~icons/humbleicons/download-alt' 13 | import chats from '~icons/humbleicons/chats' 14 | 15 | Vue.use(VueRouter) 16 | export const routes = [ 17 | { 18 | path: '/', 19 | basePath: '/', 20 | name: locale.dashboard, 21 | component: Home, 22 | icon: dashboard 23 | // NOTE: you can also lazy-load the component 24 | // component: () => import("@/views/About.vue") 25 | }, 26 | { 27 | path: '/config', 28 | basePath: '/config', 29 | name: locale.configuration, 30 | component: Config, 31 | icon: cog 32 | }, 33 | { 34 | path: '/models/:catType?/:subType?', 35 | props: true, 36 | basePath: '/models', 37 | name: locale.model_manager, 38 | component: ModelManager, 39 | icon: downloadAlt 40 | }, 41 | { 42 | path: '/chat', 43 | basePath: '/chat', 44 | name: locale.chat, 45 | component: Chat, 46 | icon: chats 47 | }, 48 | { 49 | _hide: true, 50 | path: '/:path(.*)', 51 | name: 'NotFound', 52 | component: NotFound 53 | } 54 | ] 55 | 56 | const router = new VueRouter({ 57 | base: '/', 58 | mode: 'history', 59 | routes 60 | }) 61 | 62 | export default router 63 | -------------------------------------------------------------------------------- /frontend/src/state.js: -------------------------------------------------------------------------------- 1 | import { reactive } from 'vue' 2 | 3 | export const globalState = reactive({ 4 | botStateUnknown: true, 5 | botIsRunning: false, 6 | botStateText: 'stopped', 7 | botStateLocked: false 8 | }) 9 | -------------------------------------------------------------------------------- /frontend/src/tools.js: -------------------------------------------------------------------------------- 1 | export async function fetchJSON(filename, ...args) { 2 | const f = await fetch(filename, ...args) 3 | return await f.json() 4 | } 5 | 6 | export async function api(method, endpoint, options) { 7 | if (import.meta.env.DEV) { 8 | if (options.mock && options.mock._delay) 9 | return await (new Promise((resolve) => { setTimeout(() => resolve(options.mock), options.mock._delay) })) 10 | return options.mock 11 | } 12 | 13 | return fetchJSON(`${location.origin}/api/${endpoint}`, { method, ...options }) 14 | } 15 | -------------------------------------------------------------------------------- /frontend/src/views/Home.vue: -------------------------------------------------------------------------------- 1 | 38 | 39 | 146 | 147 | 152 | -------------------------------------------------------------------------------- /frontend/src/views/ModelManager.vue: -------------------------------------------------------------------------------- 1 | 89 | 90 | 140 | 141 | 143 | -------------------------------------------------------------------------------- /frontend/src/views/NotFound.vue: -------------------------------------------------------------------------------- 1 | 4 | 5 | 31 | -------------------------------------------------------------------------------- /frontend/tailwind.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | darkMode: 'class', // or 'media' 3 | theme: 4 | { 5 | extend: { 6 | colors: { 7 | main: '#38b2ac' 8 | } 9 | } 10 | }, 11 | variants: {}, 12 | plugins: [], 13 | extract: { 14 | // accepts globs and file paths relative to project root 15 | include: [ 16 | 'src/**/*.{vue,html}', 17 | 'node_modules/formvuelar/src/**/*.{vue,html}' 18 | ] 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /frontend/vite.config.js: -------------------------------------------------------------------------------- 1 | import path from 'node:path' 2 | import { defineConfig } from 'vite' 3 | import vue from '@vitejs/plugin-vue2' 4 | import WindiCSS from 'vite-plugin-windicss' 5 | import Components from 'unplugin-vue-components/vite' 6 | import Icons from 'unplugin-icons/vite' 7 | import IconsResolver from 'unplugin-icons/resolver' 8 | import AutoImport from 'unplugin-auto-import/vite' 9 | import TopLevelAwait from 'vite-plugin-top-level-await' 10 | 11 | const config = defineConfig({ 12 | resolve: { 13 | alias: { 14 | '@': `${path.resolve(__dirname, 'src')}`, 15 | '%': `${path.resolve(__dirname, 'node_modules')}` 16 | } 17 | }, 18 | 19 | build: { 20 | minify: true, 21 | emptyOutDir: true, 22 | outDir: '../static', 23 | rollupOptions: { 24 | output: { 25 | entryFileNames: 'assets/[name].js', 26 | chunkFileNames: 'assets/[name].js', 27 | assetFileNames: 'assets/[name].[ext]' 28 | } 29 | } 30 | }, 31 | 32 | plugins: [ 33 | vue(), 34 | WindiCSS(), 35 | Components({ 36 | resolvers: [ 37 | IconsResolver({ 38 | componentPrefix: '', 39 | alias: { 40 | hi: 'humbleicons' 41 | } 42 | }) 43 | ], 44 | dts: 'src/components.d.ts' 45 | }), 46 | Icons(), 47 | AutoImport({ 48 | imports: [ 49 | 'vue', 50 | '@vueuse/core' 51 | ], 52 | dts: 'src/auto-imports.d.ts' 53 | }), 54 | TopLevelAwait({ 55 | promiseExportName: '__tla', 56 | promiseImportName: i => `__tla_${i}` 57 | }) 58 | 59 | ], 60 | 61 | server: { 62 | port: 3333 63 | } 64 | }) 65 | 66 | export default config 67 | -------------------------------------------------------------------------------- /middleware.py: -------------------------------------------------------------------------------- 1 | from aiogram.dispatcher.flags import get_flag 2 | from aiogram.utils.chat_action import ChatActionSender 3 | from aiogram import BaseMiddleware 4 | from config_reader import config 5 | from custom_queue import CallCooldown 6 | from collections import defaultdict 7 | import asyncio 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | class ChatActionMiddleware(BaseMiddleware): 12 | async def __call__(self, handler, event, data): 13 | long_operation_type = get_flag(data, "long_operation") 14 | 15 | if not long_operation_type: 16 | return await handler(event, data) 17 | 18 | async with ChatActionSender(action=long_operation_type, chat_id=event.chat.id, bot=data["bot"]): 19 | return await handler(event, data) 20 | 21 | class AccessMiddleware(BaseMiddleware): 22 | async def __call__(self, handler, event, data): 23 | uid = event.from_user.id 24 | cid = event.chat.id 25 | logger.info(f'message in chat {cid} ({event.chat.title or "private"}) from {uid} (@{event.from_user.username or event.from_user.first_name})') 26 | if config.ignore_mode == 'whitelist' or config.ignore_mode == 'both': 27 | if cid not in config.whitelist: 28 | return 29 | if config.ignore_mode == 'blacklist' or config.ignore_mode == 'both': 30 | if uid in config.blacklist or cid in config.blacklist: 31 | return 32 | if get_flag(data, "admins_only"): 33 | if uid not in config.adminlist: 34 | return 35 | return await handler(event, data) 36 | 37 | class CooldownMiddleware(BaseMiddleware): 38 | async def __call__(self, handler, event, data): 39 | cooldown_seconds = get_flag(data, "cooldown") 40 | if cooldown_seconds: 41 | function_name = data['handler'].callback.__name__ 42 | if CallCooldown.check_call(event.from_user.id, function_name, cooldown_seconds): 43 | return await handler(event, data) 44 | else: 45 | return 46 | else: 47 | return await handler(event, data) 48 | 49 | class MediaGroupMiddleware(BaseMiddleware): 50 | albums = defaultdict(lambda: []) 51 | 52 | def __init__(self, delay = 1): 53 | self.delay = delay 54 | 55 | async def __call__(self, handler, event, data): 56 | if not event.media_group_id: 57 | return await handler(event, data) 58 | 59 | try: 60 | self.albums[event.media_group_id].append(event) 61 | await asyncio.sleep(self.delay) 62 | data["album"] = self.albums.pop(event.media_group_id) 63 | except Exception as e: 64 | logger.error(e) 65 | return await handler(event, data) 66 | 67 | class CounterMiddleware(BaseMiddleware): 68 | def __init__(self, dp): 69 | self.dp = dp 70 | self.counter = 0 71 | async def __call__( 72 | self, 73 | make_request, 74 | bot, 75 | method 76 | ): 77 | self.counter += 1 78 | self.dp.counters['msg'] = self.counter 79 | return await make_request(bot, method) -------------------------------------------------------------------------------- /misc/botless_layer.py: -------------------------------------------------------------------------------- 1 | from aiogram.types import Message, Chat, User 2 | from aiogram.filters import Command 3 | from asyncio import Future 4 | from types import SimpleNamespace 5 | import base64 6 | 7 | 8 | class CommandRegistrationHijacker: 9 | 'Creates map: command -> get_handler in dp.command_map and [[magic_filter, get_handler], ...] in dp.filter_arr' 10 | def __init__(self, dp): 11 | self.orig_msg_handler = dp.message 12 | self.dp = dp 13 | dp.comamnd_map = {} 14 | dp.filter_arr = [] 15 | dp.message = self.hijacked_message_decorator 16 | 17 | def hijacked_message_decorator(self, *args, **kwargs): 18 | handler = None 19 | if len(args) > 0: 20 | if type(args[0].commands) is tuple: 21 | prefix = args[0].prefix 22 | for command in args[0].commands: 23 | self.dp.comamnd_map[prefix + command] = lambda: handler 24 | else: 25 | self.dp.filter_arr.append([args[0], lambda: handler]) 26 | 27 | decorated = self.orig_msg_handler(*args, **kwargs) 28 | def dec_wrapper(*args, **kwargs): 29 | nonlocal handler 30 | if len(args) > 0: 31 | handler = args[0] 32 | return decorated(*args, **kwargs) 33 | return dec_wrapper 34 | 35 | class EmulatedMessage(Message): 36 | def __init__(self, hijacks, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | self.__reply_hijacker = hijacks.get('reply') 39 | self.__voice_hijacker = hijacks.get('voice') 40 | self.__photo_hijacker = hijacks.get('photo') 41 | self.__mediaGroup_hijacker = hijacks.get('mediaGroup') 42 | 43 | # prevent calling setattr of parent frozen Message class instance 44 | def __setattr__(self, name, value): 45 | self.__dict__[name] = value 46 | 47 | async def reply(self, *args, **kwargs): 48 | return await self.__reply_hijacker(*args, **kwargs) 49 | async def answer(self, *args, **kwargs): 50 | return await self.__reply_hijacker(*args, **kwargs) 51 | async def answer_voice(self, *args, **kwargs): 52 | return await self.__voice_hijacker(*args, **kwargs) 53 | async def reply_voice(self, *args, **kwargs): 54 | return await self.__voice_hijacker(*args, **kwargs) 55 | async def answer_photo(self, *args, **kwargs): 56 | return await self.__photo_hijacker(*args, **kwargs) 57 | async def reply_photo(self, *args, **kwargs): 58 | return await self.__photo_hijacker(*args, **kwargs) 59 | async def answer_media_group(self, *args, **kwargs): 60 | return await self.__mediaGroup_hijacker(*args, **kwargs) 61 | async def reply_media_group(self, *args, **kwargs): 62 | return await self.__mediaGroup_hijacker(*args, **kwargs) 63 | 64 | 65 | def getHijackerAndFuture(): 66 | reply_future = Future() 67 | async def reply(text, *args, **kwargs): 68 | reply_future.set_result({"response": {"text": text}}) 69 | async def sendVoice(voice, message_thread_id=None, caption=None, *args, **kwargs): 70 | reply_future.set_result({"response": {"voice": base64.b64encode(voice.data), "text": caption}}) 71 | async def sendPhoto(input_file, caption=None, *args, **kwargs): 72 | reply_future.set_result({"response": {"photos": [base64.b64encode(input_file.data)], "text": caption}}) 73 | async def sendMediaGroup(media, *args, **kwargs): 74 | images = [] 75 | captions = [] 76 | # only supports images for now 77 | for m in media: 78 | images.append(base64.b64encode(m.media.data)) 79 | if m.caption: 80 | captions.append(m.caption) 81 | reply_future.set_result({"response": {"photos": images, "text": '\n'.join(captions)}}) 82 | hijacks = {'reply': reply, 'voice': sendVoice, 'photo': sendPhoto, 'mediaGroup': sendMediaGroup} 83 | return hijacks, reply_future 84 | 85 | async def handle_message(data, dp): 86 | text = data.get('text') 87 | command = Command.extract_command(None, text) 88 | handler = dp.comamnd_map.get(command.prefix + command.command, None) 89 | hijacks, reply_future = getHijackerAndFuture() 90 | user = User(id=1, is_bot=False, first_name='Admin', last_name='') 91 | message = EmulatedMessage(hijacks, message_id=-1, date=0, chat=Chat(id=0, type='local'), from_user=user, text=text) 92 | if not handler: 93 | for magic_filter, get_handler in dp.filter_arr: 94 | if magic_filter.resolve(message): 95 | handler = get_handler() 96 | await handler(message=message) 97 | if not handler: 98 | reply_future.set_result({'response': {"text": 'Command not found'}}) 99 | else: 100 | handler = handler() 101 | await handler(message=message, command=command) 102 | return reply_future 103 | 104 | -------------------------------------------------------------------------------- /misc/memory_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import psutil 4 | from config_reader import config 5 | from time import time 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | process = psutil.Process() 11 | SHARED_MEMORY = sys.platform == "darwin" 12 | GPU_AVAILABLE = torch.cuda.is_available() 13 | 14 | 15 | def get_vram_info(): 16 | if torch.cuda.is_available(): 17 | device = torch.cuda.current_device() 18 | vram_bytes = torch.cuda.get_device_properties(device).total_memory 19 | try: 20 | vram_bytes, total = torch.cuda.mem_get_info() 21 | except Exception: 22 | pass 23 | free_gb = vram_bytes / (1024**3) 24 | total_gb = total / (1024**3) 25 | return free_gb, total_gb 26 | else: 27 | return 0.0, 0.0 28 | 29 | def get_system_ram_info(): 30 | virtual_memory = psutil.virtual_memory() 31 | free_gb = virtual_memory.available / (1024**3) 32 | total_gb = virtual_memory.total / (1024**3) 33 | return free_gb, total_gb 34 | 35 | class MModel: 36 | def __init__(self, name, load, unload): 37 | self.name = name 38 | self.model = load() 39 | self.load = load 40 | self.unload = unload 41 | self.memory = 0 42 | def __setattr__(self, name, value): 43 | self.__dict__[name] = value 44 | 45 | class MemoryManager: 46 | def __init__(self, get_memory, cached_model_count): 47 | self.get_memory = get_memory 48 | self.starting_memory, self.total_memory = get_memory() 49 | self.cache = {} 50 | self.cached_model_count = cached_model_count 51 | self.mm_management_policy = config.mm_management_policy 52 | 53 | def wrap(self, model_name, load_function=None, unload_function=None, memory='auto'): 54 | mem, total_mem = self.get_memory() 55 | # get keys and values of items with model != None as lists 56 | [*alive_keys], [*alive_values] = zip(*((i.name, i) for i in self.cache.values() if i.model is not None)) \ 57 | if len(self.cache.keys()) > 0 else ([],[]) 58 | if config.mm_autounload_after_seconds > 0: 59 | seconds = config.mm_autounload_after_seconds 60 | for key in alive_keys: 61 | if key != model_name and self.cache[key].last_used + seconds < time(): 62 | self.unload(key, 'timeout') 63 | alive_keys.remove(key) 64 | alive_values.remove(self.cache[key]) 65 | if model_name not in self.cache: 66 | self.cache[model_name] = MModel(model_name, load_function, unload_function) 67 | self.cache[model_name].last_loaded = time() 68 | elif not self.cache[model_name].model: 69 | self.cache[model_name].model = load_function() 70 | self.cache[model_name].last_loaded = time() 71 | mem_diff = mem - self.get_memory()[0] 72 | mem_diff = mem_diff * int(mem_diff > 0) 73 | item = self.cache[model_name] 74 | item.memory = (mem_diff if memory == 'auto' else memory(item.model) or mem_diff) 75 | item.last_used = time() 76 | item.use_count = (item.use_count + 1) if hasattr(item, 'use_count') else 1 77 | if self.mm_management_policy == 'COUNT' or self.mm_management_policy == 'BOTH': 78 | cache_count = len(alive_keys) 79 | if cache_count > 0 and cache_count > self.cached_model_count: 80 | unloaded_key = self.unload_by_policy(model_name, alive_values, f'management policy {self.mm_management_policy}') 81 | if self.mm_management_policy == 'BOTH': 82 | alive_keys.remove(unloaded_key) 83 | alive_values.remove(self.cache[unloaded_key]) 84 | if (self.mm_management_policy == 'MEMORY' or self.mm_management_policy == 'BOTH') \ 85 | and (len(alive_values) > 0): 86 | items_memory = list(item.memory for item in alive_values) 87 | total_memory_used = sum(items_memory) 88 | memory_available, memory_total = self.get_memory() 89 | # TODO: find optimal value for mm_unload_memory_ratio on low-end devices and make it configurable 90 | mm_unload_memory_ratio = 3 91 | # if memory_available < max(items_memory) * 1.3 \ 92 | if memory_available < self.starting_memory/mm_unload_memory_ratio \ 93 | or total_memory_used * (1+1/mm_unload_memory_ratio) > self.starting_memory: 94 | self.unload_by_policy(model_name, alive_values, f'management policy {self.mm_management_policy}') 95 | return self.cache[model_name].model 96 | 97 | def unload(self, name, reason): 98 | target = self.cache[name] 99 | if target.unload is not None: 100 | target.unload(target.model) 101 | self.cache[name].model = None 102 | logger.info(f'removed {name} from model cache by memory manager due to {reason}') 103 | 104 | def unload_by_policy(self, model_name, items, reason): 105 | if config.mm_unload_order_policy == 'LEAST_USED': 106 | items = sorted(items, key=lambda x: x.use_count) 107 | if config.mm_unload_order_policy == 'OLDEST_USE_TIME': 108 | items = sorted(items, key=lambda x: x.last_used) 109 | if config.mm_unload_order_policy == 'OLDEST_LOAD_ORDER': 110 | items = sorted(items, key=lambda x: x.last_loaded) 111 | if config.mm_unload_order_policy == 'MEMORY_FOOTPRINT': 112 | items = sorted(items, key=lambda x: x.memory)[::-1] 113 | to_unload = items[0].name 114 | if to_unload == model_name and len(items) > 1: 115 | to_unload = items[1].name 116 | self.unload(to_unload, config.mm_unload_order_policy, reason) 117 | return to_unload 118 | 119 | def stats(self): 120 | return { 121 | "starting_memory": round(self.starting_memory, 3), 122 | "total_memory": self.total_memory, 123 | "current_memory": round(self.get_memory()[0], 3), 124 | "cache": [{self.cache[key].name: round(self.cache[key].memory, 3)} for key in self.cache], 125 | "process": round(process.memory_info().rss / 1024**3, 2) if self == RAM else None 126 | } 127 | 128 | RAM = MemoryManager(get_system_ram_info, config.mm_ram_cached_model_count_limit) 129 | VRAM = MemoryManager(get_vram_info, config.mm_vram_cached_model_count_limit) if GPU_AVAILABLE else False 130 | 131 | def mload(*args, gpu=False, **kwargs): 132 | if 'gpu' in kwargs and kwargs['gpu'] and GPU_AVAILABLE: 133 | return VRAM.wrap(*args, **kwargs) 134 | else: 135 | return RAM.wrap(*args, **kwargs) -------------------------------------------------------------------------------- /misc/mps_fixups.py: -------------------------------------------------------------------------------- 1 | # Workarounds and fixes for LLMs for mps accelerator 2 | # Copyright Jeremy Barnes / MIT License 3 | # reference code: 4 | # https://github.com/jeremybarnes/llm-webgpu/blob/main/mps_fixups.py 5 | # 6 | import torch 7 | from torch import Tensor 8 | from typing import Optional 9 | 10 | def fixup_mps(): 11 | 12 | orig_topk = torch.topk 13 | # Topk only works up to k=15 on MPS, replace it with a CPU fallback 14 | def _topk(self: torch.Tensor, k: int, dim:int=-1, largest:bool=True, sorted:bool=True): 15 | res, indices = orig_topk(self.to('cpu', torch.float32), k, dim, largest, sorted) 16 | return res.to(self), indices.to('mps') 17 | 18 | torch.topk = _topk 19 | 20 | orig_max = torch.max 21 | # Max doesn't work with longs on MPS, replace it with a CPU fallback 22 | def _max(self: torch.Tensor, *args, **kwargs) -> torch.Tensor: 23 | return orig_max(self.to('cpu'), *args, **kwargs).to('mps') 24 | 25 | torch.max = _max 26 | 27 | orig_cumsum = torch.cumsum 28 | # Cumulative sum doesn't work, replace with CPU fallback 29 | def _cumsum(input: torch.Tensor, dim: int, **kwargs) -> torch.Tensor: 30 | return orig_cumsum(input.to('cpu', torch.float32), dim, **kwargs).to('mps', input.dtype) 31 | 32 | torch.cumsum = _cumsum 33 | torch.Tensor.cumsum = _cumsum 34 | -------------------------------------------------------------------------------- /modules/admin.py: -------------------------------------------------------------------------------- 1 | from aiogram.filters import Command, CommandObject 2 | from aiogram.types import Message 3 | from utils import parse_photo, log_exceptions 4 | from config_reader import config 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class AdminModule: 10 | def __init__(self, dp, bot): 11 | @dp.message(Command(commands=["sendpic"]), flags={"admins_only": True}) 12 | @log_exceptions(logger) 13 | async def send_pic(message: Message, command: CommandObject) -> None: 14 | photo = parse_photo(message) 15 | await bot.send_photo(chat_id = int(str(command.args)), photo = photo[-1].file_id) 16 | 17 | @dp.message(Command(commands=["delete"]), flags={"admins_only": True}) 18 | @log_exceptions(logger) 19 | async def delete_msg(message: Message, command: CommandObject) -> None: 20 | await bot.delete_message(chat_id = message.chat.id, message_id = message.reply_to_message.message_id) 21 | 22 | @dp.message(Command(commands=["info"]), flags={"admins_only": True}) 23 | async def chat_info(message: Message, command: CommandObject) -> None: 24 | msg = message if not message.reply_to_message else message.reply_to_message 25 | prefix = '[reply info]\n' if message.reply_to_message else '' 26 | await message.reply(f'{prefix}Chat ID: {msg.chat.id}\nUser ID: {msg.from_user.id}') 27 | 28 | @dp.message(Command(commands=["ban", "unban"]), flags={"admins_only": True}) 29 | @log_exceptions(logger) 30 | async def ban_unban_user(message: Message, command: CommandObject) -> None: 31 | user = False 32 | if message.reply_to_message: 33 | user = message.reply_to_message.from_user 34 | user_id = user.id 35 | else: 36 | try: 37 | user_id = int(str(command.args)) 38 | except Exception: 39 | return await message.reply('Incorrect user id') 40 | if command.command == "ban": 41 | if user_id == bot._me.id: 42 | return await message.reply('Funny!') 43 | if user_id in config.adminlist: 44 | return await message.reply('Unable ban an admin') 45 | if user_id not in config.blacklist: 46 | config.blacklist.append(user_id) 47 | config.blacklist = config.blacklist 48 | who = str(user_id) if not user else user.first_name + f" ({str(user_id)})" 49 | await message.reply(f'{who} has been banned') 50 | else: 51 | await message.reply('User is already banned') 52 | else: 53 | if user_id in config.blacklist: 54 | config.blacklist.remove(user_id) 55 | config.blacklist = config.blacklist 56 | await message.reply('User unbanned') 57 | else: 58 | await message.reply('User is not banned') -------------------------------------------------------------------------------- /modules/extensions.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import importlib 4 | import traceback 5 | from config_reader import config 6 | logger = logging.getLogger(__name__) 7 | 8 | class ExtensionsModule: 9 | def __init__(self, dp, bot): 10 | dp.extensions = {} 11 | self.dir = 'extensions' 12 | if not (os.path.exists(self.dir)): 13 | os.makedirs(self.dir) 14 | self.load_from_dir(self.dir, dp, bot) 15 | 16 | def load_from_dir(self, exdir, dp, bot): 17 | extension_files = [f for f in os.listdir(exdir) if not f.startswith('_') and f.endswith('.py')] 18 | extension_dirs = [f for f in os.listdir(exdir) if os.path.isdir(os.path.join(exdir, f))] 19 | for dir in extension_dirs: 20 | dir_ext_path = os.path.join(dir, 'extension.py') 21 | if os.path.exists(os.path.join(exdir, dir_ext_path)): 22 | extension_files.append(dir_ext_path) 23 | 24 | active_module_set = set(config.active_modules) 25 | for filename in extension_files: 26 | try: 27 | logger.info("loading extension: " + filename) 28 | imported_extension = importlib.import_module(exdir + '.' + filename.replace(os.path.sep, '.').replace('.py', '')) 29 | assert hasattr(imported_extension, "extension"), "Extension must have 'extension' variable" 30 | assert hasattr(imported_extension.extension, "name"), "Extension must have a name" 31 | ext = imported_extension.extension 32 | if set(ext.dependencies).issubset(active_module_set): 33 | dp.extensions[ext.name] = ext(dp, bot) 34 | except Exception as e: 35 | logger.error(f"error loading extension {filename}\n{type(e).__name__}:{str(e)}\n{traceback.format_exc()}") 36 | -------------------------------------------------------------------------------- /modules/stt.py: -------------------------------------------------------------------------------- 1 | from aiogram.filters import Command, CommandObject 2 | from aiogram.types import Message, BufferedInputFile 3 | from aiogram import html, F 4 | from providers.stt_provider import active_model 5 | from custom_queue import UserLimitedQueue, semaphore_wrapper 6 | from config_reader import config 7 | from utils import download_audio 8 | from types import SimpleNamespace 9 | import random 10 | import tempfile 11 | import asyncio 12 | 13 | class SpeechToTextModule: 14 | def __init__(self, dp, bot): 15 | self.queue = UserLimitedQueue(config.stt_queue_size_per_user) 16 | self.semaphore = asyncio.Semaphore(1) 17 | self.bot = bot 18 | self.model = active_model.init() 19 | self.cache = {} 20 | 21 | @dp.message(Command(commands=["stt", "recognize", "transcribe"]), flags={"long_operation": "typing"}) 22 | async def command_stt_handler(message: Message, command: CommandObject) -> None: 23 | with self.queue.for_user(message.from_user.id) as available: 24 | if not available: 25 | return 26 | if (command.command == "stt" and ('-h' in str(command.args)) or not (message.reply_to_message and message.reply_to_message.voice)): 27 | return await message.answer(self.help(dp, bot)) 28 | else: 29 | error, text = await self.recognize_voice_message(message) 30 | if error: 31 | return await message.answer(f"Error, {error}") 32 | else: 33 | return await message.answer(f"{text}") 34 | 35 | async def recognize_voice_message(self, message): 36 | with tempfile.NamedTemporaryFile(suffix='.ogg', delete=False) as temp_file: 37 | if not message.voice and (not (message.reply_to_message or message.reply_to_message.voice)): 38 | return 'Source audio not found', None 39 | voice = message.reply_to_message.voice if not message.voice else message.voice 40 | await download_audio(self.bot, voice.file_id, temp_file.name) 41 | error, data = await self.recognize(temp_file.name) 42 | return error, data 43 | 44 | async def recognize(self, audio_path): 45 | wrapped_runner = semaphore_wrapper(self.semaphore, self.model.recognize) 46 | return await wrapped_runner(audio_path) 47 | 48 | def help(self, dp, bot): 49 | return f"[Speech-to-text] Usage: /stt@{bot._me.username} *voice_message*" 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /modules/tta.py: -------------------------------------------------------------------------------- 1 | from aiogram.filters import Command, CommandObject 2 | from aiogram.types import Message, BufferedInputFile 3 | from providers.tts_provider import convert_to_ogg 4 | from providers.tta_provider import generate_audio_async, tta_init 5 | from custom_queue import UserLimitedQueue, semaphore_wrapper 6 | from config_reader import config 7 | import asyncio 8 | 9 | class TextToAudioModule: 10 | def __init__(self, dp, bot): 11 | self.queue = UserLimitedQueue(config.tta_queue_size_per_user) 12 | self.semaphore = asyncio.Semaphore(1) 13 | 14 | if 'tta' in config.active_modules: 15 | self.available = tta_init() 16 | if not self.available: 17 | return 18 | 19 | @dp.message(Command(commands=["tta", "sfx", "music"]), flags={"long_operation": "record_audio"}) 20 | async def command_tta_handler(message: Message, command: CommandObject) -> None: 21 | with self.queue.for_user(message.from_user.id) as available: 22 | if not available: 23 | return 24 | if command.command == "tta" or not command.args or str(command.args).strip() == "" or ('-help' in str(command.args)): 25 | return await message.answer(self.help(dp, bot)) 26 | else: 27 | audio_type = command.command 28 | text = str(command.args) 29 | wrapped_runner = semaphore_wrapper(self.semaphore, generate_audio_async) 30 | error, data = await wrapped_runner(text, audio_type, config.tta_duration) 31 | print(error, data) 32 | if error: 33 | return await message.answer(f"Error, {error}") 34 | else: 35 | audio = BufferedInputFile(convert_to_ogg(data), 'audio.ogg') 36 | return await message.answer_voice(voice=audio) 37 | def help(self, dp, bot): 38 | return f'''[Text-To-Audio] Usage: 39 | /sfx@{bot._me.username} prompt 40 | /music{bot._me.username} prompt''' 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /modules/tts.py: -------------------------------------------------------------------------------- 1 | from aiogram.filters import Command, CommandObject 2 | from aiogram.types import Message, BufferedInputFile 3 | from providers.tts_provider import init as init_tts, tts, sts, convert_to_ogg, tts_voicemap, sts_voicemap, system_voicemap, tts_authors 4 | from custom_queue import UserLimitedQueue, semaphore_wrapper 5 | from config_reader import config 6 | from utils import download_audio 7 | import asyncio 8 | import tempfile 9 | import re 10 | 11 | class TextToSpeechModule: 12 | def __init__(self, dp, bot): 13 | self.queue = UserLimitedQueue(config.tts_queue_size_per_user) 14 | self.semaphore = asyncio.Semaphore(1) 15 | init_tts(allowRemote=config.tts_mode != 'local', threaded=config.threaded_initialization) 16 | self.sts_voices = list(sts_voicemap.keys()) 17 | self.all_voices = list(tts_voicemap.keys()) 18 | self.voicemap = tts_voicemap 19 | self.non_system_voices = [v for v in self.all_voices if v not in system_voicemap] 20 | self.voices = self.all_voices if config.tts_list_system_voices else self.non_system_voices 21 | 22 | @dp.message(Command(commands=["tts", *self.all_voices]), flags={"long_operation": "record_audio"}) 23 | async def command_tts_handler(message: Message, command: CommandObject) -> None: 24 | with self.queue.for_user(message.from_user.id) as available: 25 | if available: 26 | # show helper message if no voice is selected 27 | if self.should_print_help(command, message): 28 | return await message.answer(self.help(dp, bot)) 29 | 30 | voice = command.command 31 | text = str(command.args) 32 | error, audio = await self.speak(voice, text) 33 | if error: 34 | return await message.answer(f"Error, {error}") 35 | else: 36 | return await message.answer_voice(voice=audio) 37 | 38 | if 'so_vits_svc' in config.tts_enable_backends: 39 | @dp.message(Command(commands=["revoice", "sts"]), flags={"long_operation": "record_audio"}) 40 | async def revoice(message: Message, command: CommandObject) -> None: 41 | voice = (str(command.args).split(' ')[0]) if command.args else None 42 | voice = voice if voice in self.sts_voices else None 43 | if not voice: 44 | return await message.answer("Voice not found, available speech-to-speech voices: " + 45 | ", ".join(self.sts_voices)) 46 | if message.reply_to_message: 47 | if message.reply_to_message.voice: 48 | with tempfile.NamedTemporaryFile(suffix='.ogg', delete=False) as temp_file: 49 | await download_audio(bot, message.reply_to_message.voice.file_id, temp_file.name) 50 | wrapped_runner = semaphore_wrapper(self.semaphore, sts) 51 | error, data = await wrapped_runner(voice, temp_file.name) 52 | if error: 53 | return await message.answer(f"Error, {error}") 54 | else: 55 | audio = BufferedInputFile(convert_to_ogg(data), 'tts.ogg') 56 | return await message.answer_voice(voice=audio) 57 | return await message.answer("No audio found. Use this command replying to voice messages") 58 | 59 | bot.reply_tts = command_tts_handler 60 | 61 | async def speak(self, voice, text): 62 | text = self.correctPronunciation(text) 63 | wrapped_runner = semaphore_wrapper(self.semaphore, tts) 64 | error, data = await wrapped_runner(voice, text) 65 | if data: 66 | #TODO: async conversion 67 | audio = BufferedInputFile(convert_to_ogg(data), 'tts.ogg') 68 | return None, audio 69 | return error, None 70 | 71 | def get_specific_voices(self, language='**', tone='*', allow_system=False): 72 | return [ 73 | voice 74 | for voice, voice_info in self.voicemap.items() 75 | if ( 76 | (allow_system or voice in self.non_system_voices) 77 | and (language == '**' or voice_info.voice_metamap[voice].get('lang', '**') in ('**', language)) 78 | and (tone == '*' or voice_info.voice_metamap[voice].get('tone', '*') in ('*', tone)) 79 | ) 80 | ] 81 | 82 | def should_print_help(self, command, message): 83 | if command.command == "tts" \ 84 | or not command.args \ 85 | or str(command.args).strip() == "" \ 86 | or ('-help' in str(command.args)): 87 | return True 88 | return False 89 | 90 | def help(self, dp, bot): 91 | voice_commands = ' '.join(['/' + x for x in self.voices]) 92 | return f'''[Text-To-Speech] Usage: {voice_commands} text, 93 | Use the commands like /command@{bot._me.username} 94 | [Speech-to-speech] /revoice %voice% *voice_message* 95 | {config.tts_credits} {', '.join(list(tts_authors))} 96 | ''' 97 | 98 | def correctPronunciation(self, text): 99 | '''replaces unnecessary and poorly pronounced text for better tts experience''' 100 | for key in config.tts_replacements: 101 | text = text.replace(key, config.tts_replacements[key]) 102 | text = re.sub(r'http\S+|www\.\S+', '', text) 103 | return text 104 | -------------------------------------------------------------------------------- /providers/llm/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | gpt2_provider = lambda: importlib.import_module('providers.llm.pytorch.gpt2_provider') 3 | gptj_provider = lambda: importlib.import_module('providers.llm.pytorch.gptj_provider') 4 | auto_hf_provider = lambda: importlib.import_module('providers.llm.pytorch.auto_hf_provider') 5 | llama_orig_provider = lambda: importlib.import_module('providers.llm.pytorch.llama_orig_provider') 6 | llama_hf_provider = lambda: importlib.import_module('providers.llm.pytorch.llama_hf_provider') 7 | mlc_chat_prebuilt_provider = lambda: importlib.import_module('providers.llm.mlc_chat_prebuilt_provider') 8 | llama_cpp_provider = lambda: importlib.import_module('providers.llm.llama_cpp_provider') 9 | remote_ob_provider = lambda: importlib.import_module('providers.llm.remote_ob') 10 | remote_lcpp_provider = lambda: importlib.import_module('providers.llm.remote_llama_cpp') 11 | 12 | pytorch_models = { 13 | 'gpt2': gpt2_provider, 14 | 'gptj': gptj_provider, 15 | 'auto_hf': auto_hf_provider, 16 | 'llama_orig': llama_orig_provider, 17 | 'llama_hf': llama_hf_provider, 18 | } 19 | 20 | external_backends = { 21 | 'llama_cpp': llama_cpp_provider, 22 | 'mlc_pb': mlc_chat_prebuilt_provider, 23 | 'remote_ob': remote_ob_provider, 24 | 'remote_lcpp': remote_lcpp_provider 25 | } -------------------------------------------------------------------------------- /providers/llm/abstract_llm.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from collections import defaultdict 3 | 4 | class AbstractLLM(metaclass=ABCMeta): 5 | assistant_mode = True 6 | model = None 7 | def __init__(self, model_paths, init_config): 8 | return self 9 | 10 | def tokenize(self, details): 11 | pass 12 | 13 | @abstractmethod 14 | def generate(self, prompt, length, model_params, assist): 15 | pass 16 | 17 | -------------------------------------------------------------------------------- /providers/llm/llama_cpp_provider.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 2 | from config_reader import config 3 | from providers.llm.abstract_llm import AbstractLLM 4 | import asyncio 5 | import os 6 | import logging 7 | from misc.memory_manager import mload 8 | from functools import partial 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | try: 13 | from llama_cpp import Llama 14 | except ImportError: 15 | Llama = False 16 | 17 | class LlamaCPP(AbstractLLM): 18 | assistant_mode = True 19 | def __init__(self, model_paths, init_config): 20 | if not Llama: 21 | logger.error('llama.cpp is not installed, run "pip install llama-cpp-python" to install it') 22 | return logger.error('for GPU support, please read https://github.com/abetlen/llama-cpp-python') 23 | override = init_config.get('llama_cpp_init', {}) 24 | lora_path = model_paths.get('path_to_llama_cpp_lora', '') 25 | lora_path = lora_path if os.path.exists(lora_path) else None 26 | self.init_config = init_config 27 | self.load_model = partial( 28 | Llama, 29 | n_ctx=min(init_config.get('context_size', 512), config.llm_lcpp_max_context_size), 30 | rope_freq_base=init_config.get('rope_freq_base', 10000), 31 | rope_freq_scale=init_config.get('rope_freq_scale', 1.0), 32 | n_gpu_layers=config.llm_lcpp_gpu_layers, 33 | model_path=model_paths["path_to_llama_cpp_weights"], 34 | seed=0, 35 | lora_path=lora_path, 36 | **override 37 | ) 38 | self.filename = os.path.basename(model_paths['path_to_llama_cpp_weights']) 39 | if config.mm_preload_models_on_start: 40 | m = self.model 41 | 42 | @property 43 | def model(self): 44 | return mload('llm-llama.cpp', self.load_model, None) 45 | 46 | async def generate(self, prompt, length=64, model_params={}, assist=True): 47 | if 'repetition_penalty' in model_params: 48 | model_params['repeat_penalty'] = model_params['repetition_penalty'] 49 | del model_params['repetition_penalty'] 50 | if 'early_stopping' in model_params: 51 | del model_params['early_stopping'] 52 | output = error = None 53 | with ThreadPoolExecutor(): 54 | try: 55 | output = await asyncio.to_thread( 56 | self.model, 57 | prompt=prompt, 58 | stop=["", *self.init_config.get('stop_tokens', [])], 59 | max_tokens=length, 60 | **model_params 61 | ) 62 | except Exception as e: 63 | error = str(e) 64 | if not error: 65 | output = output['choices'][0]['text'] 66 | logger.info(output) 67 | output = prompt + output 68 | return (False, output) if not error else (error, None) 69 | 70 | ## process-based approach re-creates a new process and re-allocates memory on every run 71 | ## which is not optimal, I leave this code for future reference 72 | # import functools 73 | # async def generate_mp(prompt, length=64, model_params={}, assist=True): 74 | # with ProcessPoolExecutor(max_workers=1) as executor: 75 | # loop = asyncio.get_event_loop() 76 | # binded = functools.partial( 77 | # model, 78 | # prompt=prompt, 79 | # stop=[""], 80 | # max_tokens=length, 81 | # **model_params 82 | # ) 83 | # output = loop.run_in_executor(executor, binded) 84 | # output = await output 85 | # return prompt + output['choices'][0]['text'] 86 | init = LlamaCPP -------------------------------------------------------------------------------- /providers/llm/mlc_chat_prebuilt_provider.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from concurrent.futures import ThreadPoolExecutor 3 | from providers.llm.abstract_llm import AbstractLLM 4 | import asyncio 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class MLCChatPrebuilt(AbstractLLM): 10 | assistant_mode = True 11 | def __init__(self, model_paths, init_config): 12 | sys.path.append(model_paths['path_to_mlc_chatbot_code']) 13 | try: 14 | from mlc_chatbot.bot import ChatBot 15 | except ImportError: 16 | logging.error('MLC Chatbot is not installed') 17 | self.model = ChatBot(model_paths['path_to_mlc_pb_home_dir'], model_paths['path_to_mlc_pb_binary_dir']) 18 | self.model.generate = self.model.send 19 | self.filename = 'Unknown model' 20 | 21 | async def generate(self, raw_prompt, length=0, model_params={}, assist=True): 22 | error = None 23 | try: 24 | with ThreadPoolExecutor(): 25 | print(self.model) 26 | output = await asyncio.to_thread(self.model.generate, 27 | raw_prompt 28 | ) 29 | self.model.reset() 30 | except Exception as e: 31 | error = str(e) 32 | return (False, output) if not error else (error, None) 33 | 34 | init = MLCChatPrebuilt -------------------------------------------------------------------------------- /providers/llm/pytorch/auto_hf_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import asyncio 3 | import os 4 | from concurrent.futures import ThreadPoolExecutor 5 | from providers.llm.abstract_llm import AbstractLLM 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | class AutoHF(AbstractLLM): 10 | def __init__(self, model_paths, init_config): 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | weights = model_paths['path_to_autohf_weights'] 13 | self.tokenizer = AutoTokenizer.from_pretrained(weights) 14 | self.model = AutoModelForCausalLM.from_pretrained(weights) 15 | self.filename = os.path.basename(weights) 16 | 17 | def tokenize(self, prompt): 18 | encoded_prompt = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) 19 | return encoded_prompt 20 | 21 | async def generate(self, prompt, length=64, model_params={}, assist=False): 22 | error = None 23 | try: 24 | encoded_prompt = self.tokenize(prompt) 25 | if 'early_stopping' in model_params: 26 | del model_params['early_stopping'] 27 | with ThreadPoolExecutor(): 28 | output = await asyncio.to_thread(self.model.generate, 29 | input_ids=encoded_prompt, 30 | no_repeat_ngram_size=2, 31 | max_new_tokens=length, 32 | early_stopping=True, 33 | do_sample=True, 34 | **model_params 35 | ) 36 | except Exception as e: 37 | error = str(e) 38 | return (False, self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]) if not error else (True, error) 39 | 40 | init = AutoHF 41 | -------------------------------------------------------------------------------- /providers/llm/pytorch/gpt2_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import asyncio 4 | from concurrent.futures import ThreadPoolExecutor 5 | from types import SimpleNamespace 6 | from providers.llm.abstract_llm import AbstractLLM 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | class GPT2(AbstractLLM): 11 | is_nanoGPT = False 12 | assistant_mode = False 13 | def __init__(self, model_paths, init_config): 14 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 15 | weights = model_paths['path_to_gpt2_weights'] 16 | self.filename = os.path.basename(weights) 17 | if 'use_tiktoken' in init_config and init_config['use_tiktoken']: 18 | import tiktoken 19 | tk = tiktoken.get_encoding("gpt2") 20 | tokenizer = {} 21 | tokenizer["encode"] = lambda s, *args, **kwargs: torch.tensor(tk.encode(s, allowed_special={"<|endoftext|>"}), dtype=torch.long, device=device)[None, ...] 22 | tokenizer["decode"] = lambda l: tk.decode(l.tolist()) 23 | tokenizer["name"] = 'tiktoken' 24 | tokenizer = SimpleNamespace(**tokenizer) 25 | self.tokenizer = tokenizer 26 | else: 27 | self.tokenizer = GPT2Tokenizer.from_pretrained(weights) 28 | if 'nanogpt' in init_config and init_config['nanogpt']: 29 | import sys 30 | sys.path.append(model_paths['path_to_minchatgpt_code']) 31 | from gpt import GPT 32 | from configs import get_configs 33 | self.is_nanoGPT = True 34 | cfg = get_configs("gpt2-medium") 35 | model = GPT(cfg) 36 | model.load_state_dict(state_dict=torch.load(weights), strict=False) 37 | self.assistant_mode = True 38 | else: 39 | model = GPT2LMHeadModel.from_pretrained(weights) 40 | self.model = model.to(device) 41 | 42 | def tokenize(self, prompt): 43 | encoded_prompt = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") 44 | encoded_prompt = encoded_prompt.to(device) 45 | return encoded_prompt[:, :1024] 46 | 47 | async def generate(self, prompt, length=64, model_params={}, assist=False): 48 | error = None 49 | try: 50 | encoded_prompt = self.tokenize(prompt) 51 | with ThreadPoolExecutor(): 52 | if not self.is_nanoGPT: 53 | output_sequences = await asyncio.to_thread(self.model.generate, 54 | input_ids=encoded_prompt, 55 | max_length=length + len(encoded_prompt[0]), 56 | do_sample=True, 57 | num_return_sequences=1, 58 | **model_params 59 | ) 60 | else: 61 | if 'early_stopping' in model_params: 62 | del model_params['early_stopping'] 63 | output_sequences = await asyncio.to_thread(self.model.generate, 64 | encoded_prompt, 65 | length, 66 | **model_params 67 | ) 68 | except Exception as e: 69 | error = str(e) 70 | return (False, self.tokenizer.decode(output_sequences[0])) if not error else (True, error) 71 | 72 | init = GPT2 -------------------------------------------------------------------------------- /providers/llm/pytorch/gptj_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import asyncio 3 | import os 4 | from concurrent.futures import ThreadPoolExecutor 5 | from providers.llm.abstract_llm import AbstractLLM 6 | 7 | tokenizer = None 8 | model = None 9 | device = torch.device("cpu") 10 | 11 | class GPTJ(AbstractLLM): 12 | def __init__(self, model_paths, init_config): 13 | from transformers import AutoTokenizer, GPTJForCausalLM 14 | weights = model_paths['path_to_gptj_weights'] 15 | self.tokenizer = AutoTokenizer.from_pretrained(weights) 16 | self.model = GPTJForCausalLM.from_pretrained(weights, revision="float16", torch_dtype=torch.float32, low_cpu_mem_usage=True) 17 | self.model = self.model.to(device) 18 | self.filename = os.path.basename(weights) 19 | 20 | def tokenize(self, prompt): 21 | encoded_prompt = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) 22 | return encoded_prompt 23 | 24 | async def generate(self, prompt, length=64, model_params={}, assist=False): 25 | encoded_prompt = self.tokenize(prompt) 26 | error = None 27 | try: 28 | with ThreadPoolExecutor(): 29 | output = await asyncio.to_thread(self.model.generate, 30 | input_ids=encoded_prompt, 31 | max_length=len(encoded_prompt[0]) + length, 32 | do_sample=True, 33 | **model_params 34 | ) 35 | except Exception as e: 36 | error = str(e) 37 | return (False, self.tokenizer.batch_decode(output)[0]) if not error else (True, error) 38 | 39 | init = GPTJ 40 | -------------------------------------------------------------------------------- /providers/llm/pytorch/llama_hf_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import mps 4 | import asyncio 5 | from concurrent.futures import ThreadPoolExecutor 6 | from transformers import LlamaTokenizerFast, LlamaForCausalLM, GenerationConfig 7 | from misc.mps_fixups import fixup_mps 8 | from config_reader import config 9 | from providers.llm.abstract_llm import AbstractLLM 10 | 11 | device = torch.device("cuda") if torch.cuda.is_available() \ 12 | else torch.device("cpu") if not torch.backends.mps.is_available() \ 13 | else torch.device('mps') 14 | 15 | #if torch.backends.mps.is_available() and config.apply_mps_fixes: 16 | # fixup_mps() 17 | 18 | class LlamaHuggingface(AbstractLLM): 19 | submodel = None 20 | assistant_mode = False 21 | def __init__(self, model_paths, init_config): 22 | tokenizer = model_paths['path_to_hf_llama'] 23 | weights = model_paths['path_to_hf_llama'] 24 | self.tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer) 25 | self.model = LlamaForCausalLM.from_pretrained( 26 | weights, 27 | torch_dtype=torch.float16 if device is not torch.device('cpu') else torch.float32, 28 | device_map={"": device} 29 | ) 30 | 31 | if os.path.exists(model_paths.get('path_to_llama_lora', '')): 32 | from peft import PeftModel 33 | self.submodel = PeftModel.from_pretrained( 34 | self.model, 35 | model_paths['path_to_llama_lora'], 36 | device_map={"": device}, 37 | torch_dtype=torch.float16 if device is not torch.device('cpu') else torch.float32, 38 | ) 39 | self.submodel.half() 40 | self.submodel.eval() 41 | self.assistant_mode = True 42 | 43 | self.model.config.bos_token_id = 1 44 | self.model.config.eos_token_id = 2 45 | self.tokenizer.pad_token = self.tokenizer.eos_token 46 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 47 | 48 | self.model.half() 49 | self.model.eval() 50 | self.filename = os.path.basename(model_paths['path_to_hf_llama']) 51 | 52 | def tokenize(self, prompt): 53 | return self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) 54 | 55 | async def generate(self, prompt, length=64, model_params={}, use_submodel=False): 56 | encoded_prompt = self.tokenize(prompt) 57 | generation_config = GenerationConfig( 58 | do_sample=True, 59 | num_beams=1, 60 | **model_params 61 | ) 62 | model = self.submodel if use_submodel else self.model 63 | error = None 64 | try: 65 | with ThreadPoolExecutor(): 66 | with torch.no_grad(): 67 | output = await asyncio.to_thread(model.generate, 68 | input_ids=encoded_prompt, 69 | max_new_tokens=length, 70 | generation_config=generation_config, 71 | eos_token_id=model.config.eos_token_id, 72 | attention_mask=torch.ones_like(encoded_prompt, device=device), 73 | do_sample=True 74 | ) 75 | output = self.tokenizer.batch_decode(output, skip_special_tokens=True) 76 | if torch.backends.mps.is_available(): 77 | mps.empty_cache() 78 | elif torch.cuda.is_available(): 79 | torch.cuda.empty_cache() 80 | except Exception as e: 81 | error = str(e) 82 | return (False, output[0]) if not error else (error, None) 83 | 84 | init = LlamaHuggingface -------------------------------------------------------------------------------- /providers/llm/pytorch/llama_orig_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from concurrent.futures import ThreadPoolExecutor 5 | from utils import b64_to_img 6 | from providers.llm.abstract_llm import AbstractLLM 7 | import asyncio 8 | import inspect 9 | 10 | #python3.10 -m torch.distributed.launch --use_env bot.py 11 | 12 | class LlamaOrig(AbstractLLM): 13 | generator = None 14 | assistant_mode = False 15 | visual_mode = False 16 | def __init__(self, model_paths, init_config): 17 | llama_weights = model_paths['path_to_llama_weights'] 18 | llama_tokenizer = model_paths['path_to_llama_tokenizer'] 19 | sys.path.append(model_paths['path_to_llama_code']) 20 | self.filename = os.path.basename(llama_weights) 21 | if os.path.exists(model_paths.get('path_to_llama_multimodal_adapter', '')): 22 | self._load_multimodal_adapter(model_paths, llama_weights, llama_tokenizer) 23 | else: 24 | self._load_llama_model(model_paths, llama_weights, llama_tokenizer) 25 | 26 | def _load_llama_model(self, model_paths, llama_weights, llama_tokenizer): 27 | from example import setup_model_parallel, load 28 | with torch.inference_mode(mode=True): 29 | local_rank, world_size = setup_model_parallel() 30 | if 'adapter_path' in inspect.signature(load).parameters and \ 31 | 'path_to_llama_adapter' in model_paths and \ 32 | os.path.exists(model_paths.get('path_to_llama_adapter', None)): 33 | self.model = load( 34 | llama_weights, llama_tokenizer, model_paths['path_to_llama_adapter'], local_rank, world_size, 1024, 1 35 | ) 36 | self.assistant_mode = True 37 | else: 38 | self.model = load( 39 | llama_weights, llama_tokenizer, local_rank, world_size, 1024, 1 40 | ) 41 | 42 | def _load_multimodal_adapter(self, model_paths, llama_weights, llama_tokenizer): 43 | global generator, assistant_mode, visual_mode 44 | import llama 45 | device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu') 46 | lpath = os.path.dirname(llama_tokenizer) 47 | orig_generator, preprocess = llama.load(model_paths['path_to_llama_multimodal_adapter'], lpath, device) 48 | self.assistant_mode = True 49 | self.visual_mode = True 50 | class Wrapped_generator(): 51 | def generate(self, prompt, use_adapter=True, visual_input=False, **kwargs): 52 | if visual_input: 53 | img = b64_to_img(visual_input) 54 | img = preprocess(img).unsqueeze(0).half().to(device) 55 | else: 56 | img = [] 57 | generated = orig_generator.generate(img, prompt, **kwargs) 58 | return [prompt[0] + generated[0]] 59 | self.model = Wrapped_generator() 60 | 61 | async def generate(self, prompt, max_gen_len=64, params={}, assist=False): 62 | available_params = inspect.signature(self.model.generate).parameters 63 | for param in list(params): 64 | if param not in available_params: 65 | del params[param] 66 | error = None 67 | with ThreadPoolExecutor(): 68 | if self.assistant_mode and 'use_adapter' in available_params: 69 | params['use_adapter'] = assist 70 | try: 71 | results = await asyncio.to_thread(self.model.generate, 72 | [prompt], max_gen_len=max_gen_len, **params 73 | ) 74 | except Exception as e: 75 | error = str(e) 76 | return (False, results[0]) if not error else (True, error) 77 | 78 | init = LlamaOrig -------------------------------------------------------------------------------- /providers/llm/remote_llama_cpp.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md 2 | import logging 3 | from config_reader import config 4 | from providers.llm.remote_ob import RemoteLLM 5 | 6 | logger = logging.getLogger(__name__) 7 | llm_host = config.llm_host 8 | 9 | 10 | class RemoteLLamaCPP(RemoteLLM): 11 | visual_mode = True 12 | async def generate(self, prompt, length=64, model_params={}, assist=True): 13 | if config.llm_remote_launch_process_automatically: 14 | await self.run_llm_service() 15 | data = { 16 | 'prompt': prompt, 17 | 'n_predict': length, 18 | 'seed': -1, 19 | **model_params, 20 | } 21 | base64_data_url_offset = 22 22 | if 'visual_input' in data and data['visual_input']: 23 | data['image_data'] = [{'data': data['visual_input'][base64_data_url_offset:]}] 24 | del data['visual_input'] 25 | if 'stop_tokens' in self.init_config: 26 | data['stop'] = ['', *self.init_config['stop_tokens']] 27 | 28 | error, response = await super().remote_llm_api('POST', 'completion', data) 29 | if not error: 30 | logger.info(response) 31 | return False, prompt + response.get('content') 32 | else: 33 | return 'Error: ' + str(error), None 34 | 35 | init = RemoteLLamaCPP -------------------------------------------------------------------------------- /providers/llm/remote_ob.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import json 3 | import logging 4 | import asyncio 5 | import subprocess 6 | import psutil 7 | from functools import partial 8 | from misc.memory_manager import mload 9 | from config_reader import config 10 | from providers.llm.abstract_llm import AbstractLLM 11 | from time import sleep 12 | 13 | logger = logging.getLogger(__name__) 14 | llm_host = config.llm_host 15 | llm_load_started = False 16 | 17 | class RemoteLLM(AbstractLLM): 18 | assistant_mode = True 19 | async def remote_llm_api(self, method, endpoint, payload): 20 | async with httpx.AsyncClient() as client: 21 | try: 22 | if method == 'GET': 23 | response = await client.get(url=f'{llm_host}/{endpoint}', params=payload, timeout=None) 24 | else: 25 | response = await client.post(url=f'{llm_host}/{endpoint}', json=payload, timeout=None) 26 | if response.status_code == 200: 27 | response_data = response.json() 28 | return (False, response_data) 29 | else: 30 | return 'Connection error', None 31 | except (httpx.NetworkError, ConnectionError, httpx.RemoteProtocolError, json.decoder.JSONDecodeError) as error: 32 | return str(error), None 33 | except Exception: 34 | return 'Unknown error', None 35 | 36 | def __init__(self, model_paths, init_config): 37 | self.init_config = init_config 38 | if config.llm_remote_launch_process_automatically and \ 39 | config.mm_preload_models_on_start: 40 | asyncio.run(self.run_llm_service()) 41 | else: 42 | error, data = asyncio.run(self.remote_llm_api('GET', 'api/v1/model', {})) 43 | if error: 44 | logger.warn('Unable to get remote language model name: ' + str(error)) 45 | self.model = None 46 | self.filename = data.get('result') if not error else 'Unknown model' 47 | 48 | async def generate(self, prompt, length=64, model_params={}, assist=True): 49 | if config.llm_remote_launch_process_automatically: 50 | await self.run_llm_service() 51 | data = { 52 | 'prompt': prompt, 53 | 'max_length': length, 54 | **model_params, 55 | } 56 | error, response = await self.remote_llm_api('POST', 'api/v1/generate', data) 57 | if not error: 58 | response = response.get('results')[0].get('text') 59 | logger.info(response) 60 | return False, prompt + response 61 | else: 62 | return str(error), None 63 | 64 | async def run_llm_service(self): 65 | global llm_load_started, last_pid 66 | if llm_load_started: 67 | return 68 | llm_load_started = True 69 | service = mload('llm-remote_ob', 70 | partial(subprocess.Popen, config.llm_remote_launch_command.split(' '), cwd=config.llm_remote_launch_dir), 71 | lambda p: p.terminate(), 72 | lambda p: psutil.Process(p.pid).memory_info().rss, 73 | gpu=True 74 | ) 75 | if service.pid != last_pid: 76 | await asyncio.sleep(config.llm_remote_launch_waittime) 77 | await self.remote_llm_api('POST', 'api/v1/model', {'action': 'load', 'model_name': config.llm_remote_model_name}), 78 | self.model = None 79 | self.filename = config.llm_remote_model_name 80 | llm_load_started=False 81 | last_pid = service.pid 82 | return service 83 | 84 | init = RemoteLLM 85 | last_pid = -1 -------------------------------------------------------------------------------- /providers/llm_provider.py: -------------------------------------------------------------------------------- 1 | from config_reader import config 2 | from .llm import external_backends, pytorch_models 3 | 4 | use_built_in_models = config.llm_backend == 'pytorch' 5 | target_provider = pytorch_models if use_built_in_models else external_backends 6 | target_key = config.llm_python_model_type if use_built_in_models else config.llm_backend 7 | 8 | active_model = target_provider[target_key]() if 'llm' in config.active_modules else None 9 | -------------------------------------------------------------------------------- /providers/sd_provider.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import httpx 3 | import json 4 | import random 5 | import asyncio 6 | import logging 7 | import subprocess 8 | import psutil 9 | from collections import defaultdict 10 | from config_reader import config 11 | from misc.memory_manager import mload 12 | from functools import partial 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | request_payload = { 17 | "denoising_strength": 1, 18 | "prompt": "", 19 | "sampler_name": config.sd_default_sampler, 20 | "steps": config.sd_default_tti_steps, 21 | "cfg_scale": 5, 22 | "width": config.sd_default_width, 23 | "height": config.sd_default_height, 24 | "restore_faces": False, 25 | "tiling": False, 26 | "batch_size": 1, 27 | "n_iter": config.sd_default_n_iter, 28 | "negative_prompt": "", 29 | "eta": 71337 30 | } 31 | 32 | # hash: name 33 | models = defaultdict(lambda: 'Unknown model') 34 | embeddings = [] 35 | loras = [] 36 | 37 | sd_url = config.sd_host or "http://127.0.0.1:7860" 38 | 39 | sd_service = False 40 | sd_started = False 41 | 42 | def run_sd_service(): 43 | global sd_started 44 | if not config.sd_launch_process_automatically: 45 | return 46 | p = partial(subprocess.Popen, config.sd_launch_command.split(' '), cwd=config.sd_launch_dir, stderr=subprocess.DEVNULL) 47 | service = mload('sd-remote', 48 | p, 49 | lambda p: p.terminate() or (sd_started:=False), 50 | lambda p: psutil.Process(p.pid).memory_info().rss, 51 | gpu=True 52 | ) 53 | return service 54 | 55 | def check_server(async_func): 56 | async def decorated_function(*args, **kwargs): 57 | global sd_service, sd_started 58 | if not config.sd_launch_process_automatically: 59 | return await async_func(*args, **kwargs) 60 | url = f"{sd_url}/sdapi/v1/sd-models" 61 | try: 62 | async with httpx.AsyncClient() as client: 63 | await client.get(url) 64 | except (httpx.HTTPError, httpx.NetworkError, ConnectionError, httpx.RemoteProtocolError): 65 | print("SD server is down. Restarting it...") 66 | if not sd_started or sd_service.poll() is not None: 67 | sd_service = run_sd_service() 68 | sd_started = True 69 | # better idea is to read stdout and wait for server, but it doesn't work for some reason 70 | await asyncio.sleep(config.sd_launch_waittime) 71 | else: 72 | # TODO: fix this mess sometime later 73 | sd_service = run_sd_service() 74 | return await async_func(*args, **kwargs) 75 | sd_service = run_sd_service() 76 | return await async_func(*args, **kwargs) 77 | return decorated_function 78 | 79 | @check_server 80 | async def refresh_model_list(): 81 | global models, embeddings, loras 82 | try: 83 | async with httpx.AsyncClient() as client: 84 | model_response = await client.get(url=f'{sd_url}/sdapi/v1/sd-models',headers={'accept': 'application/json'}, timeout=None) 85 | embed_response = await client.get(url=f'{sd_url}/sdapi/v1/embeddings',headers={'accept': 'application/json'}, timeout=None) 86 | lora_response = await client.get(url=f'{sd_url}/sdapi/v1/loras',headers={'accept': 'application/json'}, timeout=None) 87 | if model_response.status_code == 200 and embed_response.status_code == 200 and lora_response.status_code == 200: 88 | model_response_data = model_response.json() 89 | embed_response_data = embed_response.json() 90 | lora_response_data = lora_response.json() 91 | models.clear() 92 | embeddings.clear() 93 | loras.clear() 94 | for m in model_response_data: 95 | models[m['hash']] = m['model_name'] 96 | for e in embed_response_data['loaded']: 97 | embeddings.append(e) 98 | for lora in lora_response_data: 99 | loras.append(lora['name']) 100 | loras[:] = [key for key in loras if key not in config.sd_lora_custom_activations] 101 | else: 102 | raise Exception('Server error') 103 | except Exception as e: 104 | logger.warn('Failed to load stable diffusion model names: ' + str(e)) 105 | 106 | 107 | def b642img(base64_image): 108 | return base64.b64decode(base64_image) 109 | 110 | @check_server 111 | async def switch_model(name): 112 | async with httpx.AsyncClient() as client: 113 | try: 114 | payload = {'sd_model_checkpoint': name} 115 | response = await client.post(url=f'{sd_url}/sdapi/v1/options', json=payload, timeout=None) 116 | if response.status_code == 200: 117 | return True 118 | except Exception: 119 | return False 120 | return False 121 | 122 | @check_server 123 | async def sd_get_images(payload, endpoint): 124 | if len(models.values()) == 0: 125 | await refresh_model_list() 126 | async with httpx.AsyncClient() as client: 127 | try: 128 | response = await client.post(url=f'{sd_url}/{endpoint}', json=payload, timeout=None) 129 | if response.status_code == 200: 130 | response_data = response.json() 131 | images = response_data.get("images") 132 | bstr_images = [b642img(i) for i in images] 133 | gen_info = json.loads(response_data.get('info')) 134 | gen_info['model'] = models[gen_info['sd_model_hash']] 135 | return (False, bstr_images, gen_info) 136 | else: 137 | return ('Connection error', None, None) 138 | except (httpx.NetworkError, ConnectionError, httpx.RemoteProtocolError, json.decoder.JSONDecodeError) as error: 139 | return (error, None, None) 140 | except Exception: 141 | return ('unknown error', None, None) 142 | 143 | 144 | async def tti(override=None): 145 | payload = request_payload 146 | default_scale = config.sd_default_tti_cfg_scale 147 | payload['cfg_scale'] = random.choice([3,4,5,6]) if default_scale == 0 else default_scale 148 | if override: 149 | payload = {**payload, **override} 150 | return await sd_get_images(payload, 'sdapi/v1/txt2img') 151 | 152 | 153 | async def iti(override=None): 154 | payload = request_payload 155 | payload['denoising_strength'] = config.sd_default_iti_denoising_strength 156 | payload['cfg_scale'] = config.sd_default_iti_cfg_scale 157 | payload['steps'] = config.sd_default_iti_steps 158 | payload['sampler_name'] = config.sd_default_iti_sampler 159 | if override: 160 | payload = {**payload, **override} 161 | return await sd_get_images(payload, 'sdapi/v1/img2img') 162 | -------------------------------------------------------------------------------- /providers/stt/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | whispercpp = lambda: importlib.import_module('providers.stt.whisper') 3 | silero = lambda: importlib.import_module('providers.stt.silero') 4 | wav2vec2 = lambda: importlib.import_module('providers.stt.wav2vec2') 5 | whisperS2T = lambda: importlib.import_module('providers.stt.whisperS2T') 6 | 7 | backends = { 8 | 'whisper': whispercpp, 9 | 'silero': silero, 10 | 'wav2vec2': wav2vec2, 11 | 'whisperS2T_CTranslate2': whisperS2T, 12 | 'whisperS2T_TensorRT-LLM': whisperS2T 13 | } -------------------------------------------------------------------------------- /providers/stt/abstract_stt.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class AbstractSTT(metaclass=ABCMeta): 5 | model = None 6 | def __init__(self): 7 | return None 8 | 9 | @abstractmethod 10 | def recognize(self, audio_path): 11 | pass -------------------------------------------------------------------------------- /providers/stt/silero.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from providers.stt.abstract_stt import AbstractSTT 4 | from config_reader import config 5 | from concurrent.futures import ThreadPoolExecutor 6 | import asyncio 7 | from glob import glob 8 | from misc.memory_manager import mload 9 | 10 | device = torch.device('cpu') 11 | 12 | class SileroSTT(AbstractSTT): 13 | def __init__(self): 14 | if config.mm_preload_models_on_start: 15 | mload('st-wav2vec2', self.load, None) 16 | 17 | def load(self): 18 | model, decoder, utils = torch.hub.load( 19 | repo_or_dir='snakers4/silero-models', 20 | model='silero_stt', 21 | language=config.stt_model_path_or_name if len(config.stt_model_path_or_name) == 2 else 'en', 22 | device=device 23 | ) 24 | return model, utils, decoder 25 | 26 | def stt(self, path): 27 | model, utils, decoder = mload('st-silero', self.load, None) 28 | read_batch, split_into_batches, read_audio, prepare_model_input = utils 29 | test_files = glob(path) 30 | batches = split_into_batches(test_files, batch_size=10) 31 | input = prepare_model_input(read_batch(batches[0]), device=device) 32 | output = model(input) 33 | transcript = '. '.join([decoder(x.cpu()) for x in output]) 34 | print(transcript) 35 | return transcript 36 | 37 | async def recognize(self, audio_path): 38 | try: 39 | with ThreadPoolExecutor(): 40 | text = await asyncio.to_thread(self.stt, audio_path) 41 | return False, text 42 | except Exception as e: 43 | return str(e), None 44 | 45 | init = SileroSTT -------------------------------------------------------------------------------- /providers/stt/wav2vec2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from transformers import SpeechEncoderDecoderModel, Wav2Vec2Processor 4 | from providers.stt.abstract_stt import AbstractSTT 5 | from config_reader import config 6 | from concurrent.futures import ThreadPoolExecutor 7 | from misc.memory_manager import mload 8 | import asyncio 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | class Wav2Vec2(AbstractSTT): 13 | def __init__(self): 14 | if config.mm_preload_models_on_start: 15 | mload('st-wav2vec2', self.load, None) 16 | 17 | def load(self): 18 | processor = Wav2Vec2Processor.from_pretrained(config.stt_model_path_or_name) 19 | model = SpeechEncoderDecoderModel.from_pretrained(config.stt_model_path_or_name) 20 | if device != 'cpu': 21 | model = model.to(device) 22 | processor = processor 23 | return (model, processor) 24 | 25 | def stt(self, path): 26 | model, processor = mload('st-wav2vec2', self.load, None) 27 | waveform, sample_rate = torchaudio.load(path, normalize=True) 28 | if waveform.shape[0] == 2: 29 | waveform = torch.mean(waveform, dim=0) 30 | if sample_rate != 16000: 31 | resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) 32 | waveform = resampler(waveform) 33 | waveform = waveform.squeeze() 34 | processed = processor(waveform, sampling_rate=16_000, 35 | return_tensors="pt", padding='longest', device=device) 36 | if device != 'cpu': 37 | processed['input_values'] = processed['input_values'].to(device) 38 | processed['attention_mask'] = processed['attention_mask'].to(device) 39 | with torch.no_grad(): 40 | predicted_ids = model.generate(**processed) 41 | 42 | predicted_sentences = processor.batch_decode( 43 | predicted_ids, 44 | num_processes=8, 45 | skip_special_tokens=True 46 | ) 47 | print(predicted_sentences) 48 | return ' '.join(predicted_sentences) 49 | 50 | async def recognize(self, audio_path): 51 | try: 52 | with ThreadPoolExecutor(): 53 | text = await asyncio.to_thread(self.stt, audio_path) 54 | return False, text 55 | except AssertionError as e: 56 | print(e) 57 | return str(e), None 58 | 59 | init = Wav2Vec2 -------------------------------------------------------------------------------- /providers/stt/whisper.py: -------------------------------------------------------------------------------- 1 | from providers.stt.abstract_stt import AbstractSTT 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | from misc.memory_manager import mload 5 | from functools import partial 6 | from utils import cprint 7 | import asyncio 8 | try: 9 | from whispercpp import Whisper 10 | except ImportError: 11 | Whisper = False 12 | 13 | class WhisperCPP(AbstractSTT): 14 | def __init__(self): 15 | if not Whisper: 16 | cprint("Whisper.cpp (STT) module not available, please reinstall it", color="red") 17 | if config.mm_preload_models_on_start: 18 | m = self.model 19 | 20 | @property 21 | def model(self): 22 | loader = partial(Whisper, config.stt_model_path_or_name) 23 | return mload('stt-whisper', loader, None) 24 | 25 | async def recognize(self, audio_path): 26 | try: 27 | with ThreadPoolExecutor(): 28 | output = await asyncio.to_thread(self.model.transcribe, audio_path) 29 | text = ''.join(self.model.extract_text(output)) 30 | return False, text 31 | except Exception as e: 32 | return str(e), None 33 | 34 | init = WhisperCPP -------------------------------------------------------------------------------- /providers/stt/whisperS2T.py: -------------------------------------------------------------------------------- 1 | from providers.stt.abstract_stt import AbstractSTT 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | from misc.memory_manager import mload 5 | from functools import partial 6 | from utils import cprint 7 | import torch 8 | import asyncio 9 | try: 10 | import whisper_s2t 11 | except ImportError: 12 | whisper_s2t = False 13 | 14 | class WhisperS2T(AbstractSTT): 15 | def __init__(self): 16 | if not whisper_s2t: 17 | cprint("WhisperS2T (STT) module not available, please reinstall it", color="red") 18 | if config.mm_preload_models_on_start: 19 | m = self.model 20 | 21 | @property 22 | def model(self): 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | print(device) 25 | ct = 'int8' if device != 'cuda' else 'float16' 26 | backend = config.stt_backend.split('_')[1] 27 | loader = partial(whisper_s2t.load_model, config.stt_model_path_or_name, backend=backend, device=device, compute_type=ct) 28 | return mload('stt-whisper_s2t', loader, None) 29 | 30 | async def recognize(self, audio_path): 31 | try: 32 | with ThreadPoolExecutor(): 33 | output = await asyncio.to_thread(self.model.transcribe_with_vad, [audio_path], lang_codes=[config.lang], tasks=['transcribe'], initial_prompts=[None], batch_size=1) 34 | text = output[0][0]['text'] 35 | return False, text 36 | except Exception as e: 37 | return str(e), None 38 | init = WhisperS2T -------------------------------------------------------------------------------- /providers/stt_provider.py: -------------------------------------------------------------------------------- 1 | from config_reader import config 2 | from .stt import backends 3 | 4 | active_model = backends[config.stt_backend]() if 'stt' in config.active_modules else None -------------------------------------------------------------------------------- /providers/tta_provider.py: -------------------------------------------------------------------------------- 1 | from utils import cprint 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | from misc.memory_manager import mload 5 | from functools import partial 6 | import asyncio 7 | import tempfile 8 | 9 | AudioGen = None 10 | MusicGen = None 11 | 12 | def tta_init(): 13 | global MusicGen, AudioGen, audio_write 14 | try: 15 | from audiocraft.models import MusicGen, AudioGen 16 | from audiocraft.data.audio import audio_write 17 | return True 18 | except ImportError: 19 | cprint("TTA (AudioCraft) module not available, please reinstall it", color="red") 20 | return False 21 | 22 | def get_model(path, loader, name): 23 | loader = partial(loader.get_pretrained, path, config.tta_device) 24 | model = mload('tta-' + name, loader, None) 25 | return model 26 | 27 | 28 | def generate_audio(text, audio_type="music", duration=5, raw_data=False): 29 | try: 30 | if audio_type == "music": 31 | model = get_model(config.tta_music_model, MusicGen, 'MusicGen') 32 | else: 33 | model = get_model(config.tta_sfx_model, AudioGen, 'AudioGen') 34 | model.set_generation_params(duration=duration) 35 | wav = model.generate([text]) 36 | except Exception as e: 37 | return (str(e), None) 38 | if not raw_data: 39 | wav = save_audio(wav[0], model) 40 | return False, wav 41 | 42 | async def generate_audio_async(text, audio_type="music", duration=5, raw_data=False): 43 | with ThreadPoolExecutor(): 44 | error, output = await asyncio.to_thread(generate_audio, 45 | text, audio_type, duration, raw_data 46 | ) 47 | return error, output 48 | 49 | 50 | def save_audio(wav_file, model): 51 | tmp_path = tempfile.TemporaryDirectory().name + 'record' 52 | audio_write(tmp_path, wav_file.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) 53 | return tmp_path + '.wav' -------------------------------------------------------------------------------- /providers/tts/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from providers.tts.say_macos import Say 4 | from providers.tts.coqui_tts import CoquiTTS 5 | from providers.tts.py_ttsx4 import TTSx4 6 | from providers.tts.so_vits_svc import SoVitsSVC 7 | 8 | from providers.tts.remote_tts import RemoteTTS 9 | 10 | tts_backends = OrderedDict() 11 | tts_backends['say_macos'] = Say 12 | tts_backends['ttsx4'] = TTSx4 13 | tts_backends['coqui_tts'] = CoquiTTS 14 | tts_backends['so_vits_svc'] = SoVitsSVC -------------------------------------------------------------------------------- /providers/tts/abstract_tts.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | class AbstractTTS(metaclass=ABCMeta): 4 | def __init__(self, remote): 5 | pass 6 | 7 | @abstractmethod 8 | def speak(self, voice, text): 9 | pass 10 | 11 | class AbstractSTS(metaclass=ABCMeta): 12 | def __init__(self, remote, tts_instance): 13 | self.tts = tts_instance 14 | pass 15 | 16 | @abstractmethod 17 | def mimic(self, voice, original_audio): 18 | pass 19 | 20 | @abstractmethod 21 | def speak(): 22 | pass 23 | -------------------------------------------------------------------------------- /providers/tts/coqui_tts.py: -------------------------------------------------------------------------------- 1 | from providers.tts.abstract_tts import AbstractTTS 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | from misc.memory_manager import mload 5 | from functools import partial 6 | from utils import cprint 7 | from pathlib import Path 8 | import asyncio 9 | import torch 10 | import tempfile 11 | import logging 12 | import os 13 | logger = logging.Logger(__name__) 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # do not use MPS, currently it is bugged 16 | 17 | class CoquiTTS(AbstractTTS): 18 | def __init__(self, is_remote): 19 | self.name = 'coqui_tts' 20 | self.voices = list(map(lambda item: item if isinstance(item, str) else item.get('voice'), config.tts_voices)) 21 | self.authors = list(map(lambda item: None if isinstance(item, str) else item.get('author'), config.tts_voices)) 22 | self.voice_metamap = { 23 | (item['voice'] if isinstance(item, dict) else item): 24 | item if isinstance(item, dict) else {} 25 | for item in config.tts_voices 26 | } 27 | self.system = False 28 | self.is_available = False 29 | if is_remote: 30 | return 31 | try: 32 | from TTS.api import TTS 33 | self.TTS = TTS 34 | self.is_available = True 35 | except Exception as e: 36 | logger.error(e) 37 | if 'tts' in config.active_modules and config.tts_mode == 'local': 38 | cprint("CoquiTTS provider not available", color="red") 39 | 40 | def _speak(self, voice, text): 41 | config_path = Path(config.tts_path) / (voice + ".json") 42 | config_path = Path(config.tts_path) / "config.json" if not os.path.exists(config_path) else config_path 43 | voic_model_loader = partial( 44 | self.TTS, 45 | model_path = Path(config.tts_path) / (voice + ".pth"), 46 | config_path = config_path 47 | ) 48 | loaded_model = mload('CoquiTTS-' + voice, voic_model_loader, None, gpu=torch.cuda.is_available()) 49 | if (loaded_model.synthesizer.tts_model.device.type != device.type): 50 | loaded_model.synthesizer.tts_model = loaded_model.synthesizer.tts_model.to(device) 51 | tmp_path = tempfile.TemporaryDirectory().name + 'record.wav' 52 | if hasattr(loaded_model, 'model_name') and loaded_model.model_name is None: 53 | loaded_model.model_name = voice 54 | loaded_model.tts_to_file(text, file_path=tmp_path) 55 | return tmp_path 56 | 57 | async def speak(self, voice, text): 58 | try: 59 | with ThreadPoolExecutor(): 60 | wav_file_path = await asyncio.to_thread(self._speak, voice, text.rstrip('.') + '.') 61 | return False, wav_file_path 62 | except Exception as e: 63 | return str(e), None 64 | -------------------------------------------------------------------------------- /providers/tts/py_ttsx4.py: -------------------------------------------------------------------------------- 1 | from providers.tts.abstract_tts import AbstractTTS 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | import asyncio 5 | import tempfile 6 | import sys 7 | import logging 8 | logger = logging.Logger(__name__) 9 | 10 | class TTSx4(AbstractTTS): 11 | def __init__(self, is_remote): 12 | self.system = True 13 | self.voices = [] 14 | self.authors = [] 15 | self.is_available = False 16 | self.name = 'ttsx4' 17 | if is_remote: 18 | return 19 | try: 20 | import pyttsx4 21 | self.engine = pyttsx4.init() 22 | self.prefix = 'com.apple.speech.synthesis.voice.' if sys.platform == "darwin" else '' 23 | self.voices = [v.name for v in self.engine.getProperty('voices')] 24 | self.is_available = True 25 | self.voice_metamap = { 26 | v.name: 27 | { 28 | "lang": v.languages[0][:2], 29 | "tone": 'f' if v.gender == 'VoiceGenderFemale' else 'm' if v.gender == 'VoiceGenderMale' else '*' 30 | } 31 | for v in self.engine.getProperty('voices') 32 | } 33 | except Exception as e: 34 | logger.error(e) 35 | 36 | def _speak(self, voice, text): 37 | tmp_path = tempfile.TemporaryDirectory().name + 'record.wav' 38 | self.engine.setProperty('voice', self.prefix + voice) 39 | self.engine.save_to_file(text, tmp_path) 40 | self.engine.runAndWait() 41 | return tmp_path 42 | 43 | async def speak(self, voice, text): 44 | try: 45 | with ThreadPoolExecutor(): 46 | wav_file_path = await asyncio.to_thread(self._speak, voice, text) 47 | return False, wav_file_path 48 | except Exception as e: 49 | return str(e), None 50 | -------------------------------------------------------------------------------- /providers/tts/remote_tts.py: -------------------------------------------------------------------------------- 1 | from providers.tts.abstract_tts import AbstractTTS 2 | from config_reader import config 3 | import httpx 4 | import tempfile 5 | import json 6 | 7 | class RemoteTTS(AbstractTTS): 8 | def __init__(self): 9 | self.is_available = config.tts_mode != 'local' 10 | self.name = 'remote' 11 | self.voices = [] 12 | self.authors = [] 13 | self.voice_metamap = {} 14 | self.system = False 15 | 16 | async def speak(self, voice, text): 17 | async with httpx.AsyncClient() as client: 18 | try: 19 | tts_payload = {"voice": voice, "text": text, "response": 'file' if config.tts_mode == 'remote' else 'path'} 20 | response = await client.post(url=config.tts_host, json=tts_payload, timeout=None) 21 | if response.status_code == 200: 22 | if config.tts_mode == 'remote': 23 | path = tempfile.TemporaryDirectory().name + str(hash(text)) + '.wav' 24 | with open(path, 'wb') as f: 25 | f.write(response.content) 26 | return (False, path) 27 | else: 28 | response_data = response.json() 29 | error = response_data.get('error') 30 | if error: 31 | return (error, None) 32 | wpath = response_data.get("data") 33 | return (False, wpath) 34 | else: 35 | return ('Server error', None) 36 | except (httpx.NetworkError, ConnectionError, httpx.RemoteProtocolError, json.decoder.JSONDecodeError) as error: 37 | return (error, None) 38 | except Exception as e: 39 | return (str(e), None) -------------------------------------------------------------------------------- /providers/tts/say_macos.py: -------------------------------------------------------------------------------- 1 | from providers.tts.abstract_tts import AbstractTTS 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | import asyncio 5 | import subprocess 6 | import tempfile 7 | import sys 8 | import os 9 | 10 | class Say(AbstractTTS): 11 | def __init__(self, is_remote): 12 | self.is_available = sys.platform == "darwin" 13 | self.system = True 14 | self.name = 'say_macos' 15 | self.authors = ['Apple'] 16 | self.voices = [ 17 | 'Alex', 'Alice', 'Alva', 'Amelie', 'Anna', 'Carmit', 'Damayanti', 'Daniel', 'Diego', 18 | 'Ellen', 'Fiona', 'Fred', 'Ioana', 'Joana', 'Jorge', 'Juan', 'Kanya', 'Karen', 19 | 'Kyoko', 'Laura', 'Lekha', 'Luca', 'Luciana', 'Maged', 'Mariska', 'Mei-Jia', 20 | 'Melina', 'Milena', 'Moira', 'Monica', 'Nora', 'Paulina', 'Rishi', 'Samantha', 21 | 'Sara', 'Satu', 'Sin-ji', 'Tessa', 'Thomas', 'Ting-Ting', 'Veena', 'Victoria', 22 | 'Xander', 'Yelda', 'Yuna', 'Yuri', 'Zosia', 'Zuzana' 23 | ] 24 | self.voice_metamap = dict(zip( 25 | self.voices, 26 | [{"lang": lang, "tone": tone} for lang, tone in [ 27 | ('en','m'), ('it','f'), ('sv','f'), ('fr','f'), ('de','f'), ('he','f'), ('id','f'), ('en','m'), ('es','m'), 28 | ('nl','f'), ('en','f'), ('en','m'), ('ro','f'), ('pt','f'), ('es','m'), ('es','m'), ('th','f'), ('en','f'), 29 | ('ja','f'), ('sk','f'), ('hi','f'), ('it','m'), ('pt','f'), ('ar','m'), ('hu','f'), ('zh','f'), 30 | ('el','f'), ('ru','f'), ('en','f'), ('es','f'), ('nb','f'), ('es','f'), ('en','m'), ('en','f'), 31 | ('da','f'), ('fi','f'), ('zh','f'), ('en','m'), ('fr','m'), ('zh','f'), ('en','f'), ('en','f'), 32 | ('nl','m'), ('tr','f'), ('ko','f'), ('ru','m'), ('pl','f'), ('cs','f') 33 | ]] 34 | )) 35 | 36 | def _speak(self, voice, text): 37 | tmp_path_aiff = tempfile.TemporaryDirectory().name + 'record.aif' 38 | tmp_path_wav = tmp_path_aiff.replace('.aif', '.wav') 39 | subprocess.run( 40 | ['say','-v', voice, '-o', tmp_path_aiff, text], 41 | timeout=30 42 | ) 43 | subprocess.run([ 44 | config.tts_ffmpeg_path, '-y', '-i', tmp_path_aiff, tmp_path_wav], 45 | timeout=30 46 | ) 47 | os.unlink(tmp_path_aiff) 48 | return tmp_path_wav 49 | 50 | async def speak(self, voice, text): 51 | try: 52 | with ThreadPoolExecutor(): 53 | wav_file_path = await asyncio.to_thread(self._speak, voice, text) 54 | return False, wav_file_path 55 | except Exception as e: 56 | return str(e), None 57 | -------------------------------------------------------------------------------- /providers/tts/so_vits_svc.py: -------------------------------------------------------------------------------- 1 | from providers.tts.abstract_tts import AbstractSTS 2 | from config_reader import config 3 | from concurrent.futures import ThreadPoolExecutor 4 | from pathlib import Path 5 | import asyncio 6 | import subprocess 7 | import tempfile 8 | import sys 9 | import os 10 | import time 11 | 12 | def name_handler(name): 13 | return name.lower().replace('-','') 14 | 15 | class SoVitsSVC(AbstractSTS): 16 | def __init__(self, tts_instance, is_remote): 17 | self.tts = tts_instance 18 | self.system = False 19 | self.voices = dict({name_handler(m['voice']): m for m in config.tts_so_vits_svc_voices}) 20 | self.voice_metamap = {name_handler(m['voice']): m for m in config.tts_so_vits_svc_voices} 21 | self.authors = [m.get('author', None) for m in config.tts_so_vits_svc_voices] 22 | self.name = 'so_vits_svc' 23 | self.v4_0_code_path = config.tts_so_vits_svc_4_0_code_path 24 | self.v4_1_code_path = config.tts_so_vits_svc_4_1_code_path 25 | self.is_available = os.path.exists(self.v4_0_code_path) or os.path.exists(self.v4_1_code_path) 26 | 27 | def _mimic(self, voice, original_audio_path): 28 | version = self.voices[voice].get('v', 4.0) 29 | so_vits_svc_code = self.v4_0_code_path if version == 4.0 else self.v4_1_code_path 30 | name = 'temp_tts' + str(time.time_ns()) 31 | temp_file = f'{so_vits_svc_code}/raw/{name}.wav' 32 | os.rename(original_audio_path, temp_file) 33 | v = self.voices[voice] 34 | so_vits_model = Path(v['path']) / v['weights'] 35 | so_vits_config = Path(v['path']) / 'config.json' 36 | so_vits_voice = v['voice'] 37 | subprocess.run([ 38 | config.python_command or 'python', 39 | f"inference_main.py", 40 | "-m", str(so_vits_model), 41 | "-c", str(so_vits_config), 42 | "-n", f'{name}.wav', 43 | "-t", "0", 44 | "-s", so_vits_voice 45 | ], 46 | cwd=so_vits_svc_code 47 | ) 48 | os.remove(temp_file) 49 | filename = f'{so_vits_svc_code}/results/{name}.wav_0key_{so_vits_voice}.flac' 50 | if not os.path.isfile(filename): 51 | filename = filename.replace('.flac', '_sovits_pm.flac') 52 | if not os.path.isfile(filename): 53 | raise Exception('File not found') 54 | return filename 55 | 56 | async def speak(self, voice, text): 57 | error, original_audio = await self.tts(self.voices[voice].get('base_voice'), text) 58 | if not error: 59 | return await self.mimic(voice, original_audio) 60 | return error, None 61 | 62 | async def mimic(self, voice, original_audio_path): 63 | try: 64 | with ThreadPoolExecutor(): 65 | wav_file_path = await asyncio.to_thread(self._mimic, voice, original_audio_path) 66 | return False, wav_file_path 67 | except Exception as e: 68 | return str(e), None 69 | -------------------------------------------------------------------------------- /providers/tts_provider.py: -------------------------------------------------------------------------------- 1 | from providers.tts import tts_backends, RemoteTTS 2 | from config_reader import config 3 | from providers.tts.abstract_tts import AbstractSTS 4 | import threading 5 | import os 6 | import subprocess 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | tts_voicemap = {} 11 | system_voicemap = {} 12 | sts_voicemap = {} 13 | tts_backends_loaded = {} 14 | tts_authors = set() 15 | 16 | remote_tts = None 17 | 18 | async def tts(voice, text): 19 | if remote_tts and remote_tts.is_available: 20 | backend = remote_tts 21 | elif voice in tts_voicemap: 22 | backend = tts_voicemap[voice] 23 | else: 24 | return ('Voice not found', None) 25 | return await backend.speak(voice, text) 26 | 27 | async def sts(voice, original_audio=False): 28 | if voice in sts_voicemap: 29 | return await sts_voicemap[voice].mimic(voice, original_audio) 30 | return ('Voice not found', None) 31 | 32 | def init(allowRemote=True, threaded=True): 33 | global remote_tts 34 | threads = [] 35 | if 'tts' in config.active_modules: 36 | for backend in tts_backends: 37 | if backend not in config.tts_enable_backends: 38 | continue 39 | if not threaded: 40 | init_backend(backend, allowRemote) 41 | continue 42 | thread = threading.Thread(target=init_backend, args=(backend, allowRemote)) 43 | thread.start() 44 | threads.append(thread) 45 | for thread in threads: 46 | thread.join() 47 | if allowRemote: 48 | remote_tts = RemoteTTS() 49 | return 50 | 51 | def init_backend(backend, remote): 52 | is_sts = issubclass(tts_backends[backend], AbstractSTS) 53 | args = [tts, remote] if is_sts else [remote] 54 | b = tts_backends[backend](*args) 55 | if b.is_available: 56 | logger.debug('tts backend initialized: ' + backend) 57 | if b.is_available or config.tts_mode != 'local': 58 | tts_backends_loaded[backend] = b 59 | for voice in b.voices: 60 | tts_voicemap[voice] = b 61 | if is_sts: 62 | sts_voicemap[voice] = b 63 | if b.system: 64 | system_voicemap[voice] = b 65 | for author in b.authors: 66 | if author is not None: 67 | tts_authors.add(author) 68 | return b 69 | 70 | def convert_to_ogg(wav_path): 71 | ogg_path = wav_path + '.ogg' 72 | subprocess.run([ 73 | config.tts_ffmpeg_path, '-i', wav_path, 74 | '-acodec', 'libopus', '-b:a', '128k', '-vbr', 'off', ogg_path, '-y'], 75 | stdout=subprocess.DEVNULL, 76 | stderr=subprocess.STDOUT, 77 | timeout=60 78 | ) 79 | with open(ogg_path, 'rb') as f: 80 | data = f.read() 81 | os.remove(wav_path) 82 | os.remove(ogg_path) 83 | return data -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | indent-width = 2 3 | line-length = 120 -------------------------------------------------------------------------------- /requirements-all.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | -r requirements-llm.txt 3 | -r requirements-stt.txt 4 | -r requirements-tts.txt 5 | llama-cpp-python -------------------------------------------------------------------------------- /requirements-llm.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.27.* 2 | sentencepiece==0.1.97 3 | git+https://github.com/huggingface/peft.git 4 | tiktoken 5 | loralib -------------------------------------------------------------------------------- /requirements-stt.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/stlukey/whispercpp.py 2 | git+https://github.com/shashikg/WhisperS2T.git 3 | transformers 4 | numpy>=1.24.4 -------------------------------------------------------------------------------- /requirements-tts.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn 3 | pyttsx4 4 | TTS 5 | praat-parselmouth 6 | fairseq==0.12.2 7 | faiss-cpu 8 | soundfile 9 | # macos soundfile fix: 10 | # conda install libsndfile -c conda-forge -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiogram==3.1.* 2 | pydantic 3 | python-dotenv 4 | httpx 5 | typing-extensions 6 | pydantic-settings 7 | psutil 8 | fastapi 9 | uvicorn -------------------------------------------------------------------------------- /servers/api_sever.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import time 3 | import threading 4 | import uvicorn 5 | from typing import Dict 6 | from fastapi import FastAPI, Body 7 | from fastapi.responses import RedirectResponse 8 | from config_reader import config 9 | from servers.common import add_common_endpoints 10 | from misc.botless_layer import handle_message 11 | from misc.memory_manager import RAM, VRAM 12 | from misc import model_manager 13 | 14 | app = FastAPI(title='Botality API') 15 | dispatcher = None 16 | bot = None 17 | 18 | add_common_endpoints(app) 19 | 20 | @app.on_event("startup") 21 | def startup_event(): 22 | print("Botality API server is running on", config.sys_api_host) 23 | 24 | @app.get("/") 25 | async def movetowebui(): 26 | return RedirectResponse(url=config.sys_webui_host, status_code=301) 27 | 28 | @app.get("/ping") 29 | async def ping(): 30 | return {"response": "ok"} 31 | 32 | @app.post("/chat") 33 | async def message(data: Dict = Body): 34 | reply_future = await handle_message(data, dispatcher) 35 | await reply_future 36 | return reply_future.result() 37 | 38 | @app.get("/status") 39 | async def status(): 40 | return { "response": { 41 | "modules": list(dispatcher.modules.keys()), 42 | "counters": dispatcher.counters, 43 | "timings": dispatcher.timings, 44 | "memory_manager": { 45 | "RAM": RAM.stats(), 46 | "VRAM": VRAM.stats() if VRAM else None 47 | }, 48 | "bot": { 49 | "name": bot._me.first_name, 50 | "username": bot._me.username, 51 | "can_join_groups": bot._me.can_join_groups, 52 | "can_read_all_group_messages": bot._me.can_read_all_group_messages 53 | } if bot._me else None, 54 | "access_mode": config.ignore_mode 55 | }} 56 | 57 | @app.get("/models") 58 | async def models(): 59 | return {"response": model_manager.get_models()} 60 | 61 | @app.post("/models/install/{model_type}") 62 | async def install_models(model_type: str, body: Dict = Body): 63 | return model_manager.install_model(model_type, body) 64 | 65 | @app.post("/models/uninstall/{model_type}") 66 | async def uninstall_models(model_type: str, body: Dict = Body): 67 | return model_manager.uninstall_model(model_type, body) 68 | 69 | @app.get("/models/install/{task_id}") 70 | async def install_status(task_id: int): 71 | return {'response': model_manager.get_task_info(task_id)} 72 | 73 | @app.post("/models/select/{model_type}") 74 | async def select_model(model_type: str, body: Dict = Body): 75 | return model_manager.select_model(model_type, body) 76 | 77 | @app.get("/voices") 78 | async def voices(): 79 | tts = dispatcher.modules.get('tts') 80 | if not tts: 81 | return {'error': 'TTS module not initialized'} 82 | else: 83 | return {'response': [{'voice': v} for v in tts.all_voices]} 84 | 85 | class Server(uvicorn.Server): 86 | def install_signal_handlers(self): 87 | pass 88 | 89 | @contextlib.contextmanager 90 | def run_in_thread(self): 91 | thread = threading.Thread(target=self.run) 92 | thread.start() 93 | try: 94 | while not self.started: 95 | time.sleep(1e-3) 96 | yield 97 | finally: 98 | self.should_exit = True 99 | thread.join() 100 | 101 | def init_api_server(dp, bot_instance): 102 | global dispatcher, bot 103 | dispatcher = dp 104 | bot = bot_instance 105 | [protocol, slash2_host, port] = config.sys_api_host.split(':') 106 | serverConfig = uvicorn.Config( 107 | app, host=slash2_host[2:], port=int(port), 108 | log_level=config.sys_api_log_level, timeout_keep_alive=config.sys_request_timeout 109 | ) 110 | api_server = Server(config=serverConfig) 111 | return api_server -------------------------------------------------------------------------------- /servers/common.py: -------------------------------------------------------------------------------- 1 | from fastapi.responses import JSONResponse 2 | from typing import Dict, Any 3 | from fastapi import status, Body 4 | from config_reader import config, Settings 5 | from asyncio import Lock, sleep 6 | import os 7 | import json 8 | 9 | config_write_lock = Lock() 10 | 11 | class DynamicConfig: 12 | def __init__(self, get_updated_config): 13 | self.get_updated_config = get_updated_config 14 | 15 | def __call__(self): 16 | if self.get_updated_config: 17 | return self.get_updated_config() 18 | else: 19 | return config 20 | 21 | def add_common_endpoints(app, get_custom_config=False): 22 | get_config = DynamicConfig(get_custom_config) 23 | @app.patch("/config") 24 | async def write_config(data: Dict[str, Any]): 25 | for key in data: 26 | try: 27 | if (getattr(get_config(), key, None)) is not None: 28 | async with config_write_lock: 29 | setattr(get_config(), key, data[key]) 30 | except Exception as e: 31 | return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={'error': str(e)}) 32 | return {"response": "ok"} 33 | 34 | @app.get("/config") 35 | async def read_config(): 36 | return get_config() 37 | 38 | @app.get("/schema") 39 | async def schema(): 40 | return Settings.model_json_schema() if hasattr(Settings, 'model_json_schema') else json.loads(Settings.schema_json()) 41 | 42 | @app.get("/characters") 43 | async def characters(): 44 | return {"response": [{'name': x[:-3], 'full': 'characters.' + x[:-3]} for x in os.listdir('characters/') \ 45 | if not (x.startswith('.') or x.startswith('__'))]} 46 | 47 | 48 | class VirtualRouter: 49 | def __init__(self): 50 | self.routes = {} 51 | 52 | def add_route(self, path, method, handler): 53 | self.routes[(path, method)] = handler 54 | 55 | def get(self, path): 56 | def decorator(handler): 57 | self.add_route(path, "GET", handler) 58 | return handler 59 | return decorator 60 | 61 | def patch(self, path): 62 | def decorator(handler): 63 | self.add_route(path, "PATCH", handler) 64 | return handler 65 | return decorator 66 | 67 | def post(self, path): 68 | def decorator(handler): 69 | self.add_route(path, "POST", handler) 70 | return handler 71 | return decorator 72 | 73 | async def run(self, path, method, data=None): 74 | handler = self.routes.get((path, method)) 75 | if handler: 76 | if data: 77 | return await handler(data) 78 | else: 79 | return await handler() 80 | else: 81 | return {"error": "Server not available"} -------------------------------------------------------------------------------- /servers/control_server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request, Response, status, Body 2 | from fastapi.staticfiles import StaticFiles 3 | from fastapi.responses import FileResponse, JSONResponse 4 | import config_reader 5 | from bot import main, config 6 | from servers.common import add_common_endpoints, VirtualRouter 7 | import httpx 8 | import multiprocessing 9 | import os 10 | import importlib 11 | 12 | bot_instance = None 13 | app = FastAPI(title='Botality WebUI') 14 | 15 | vrouter = VirtualRouter() 16 | add_common_endpoints(vrouter, lambda: config) 17 | 18 | @app.post("/api/bot/{action}") 19 | def bot_action(action): 20 | global bot_instance 21 | if not bot_instance and action == 'start': 22 | bot_instance = multiprocessing.Process(target=main, kwargs={"api": True}) 23 | bot_instance.start() 24 | elif bot_instance and action == 'stop': 25 | bot_instance.terminate() 26 | bot_instance = None 27 | else: 28 | return {"error": "Unknown or unavailable action"} 29 | return {"response": "ok"} 30 | 31 | @app.get("/api/bot/status") 32 | async def bot_status(): 33 | return {"response": {"running": bot_instance is not None}} 34 | 35 | @app.get("/api/bot/env") 36 | async def env_files(): 37 | env_files = ['.env'] 38 | if os.path.exists('env') and os.path.isdir('env'): 39 | env_files = [*env_files, *[x for x in os.listdir('env') if x.endswith('.env')]] 40 | active = os.environ.get('BOTALITY_ENV_FILE', '.env') 41 | active = os.path.basename(active) 42 | return {"response": {"active": active, "all": env_files}} 43 | 44 | @app.put("/api/bot/env") 45 | async def set_env(filename: str = Body(...)): 46 | global config 47 | if (os.path.exists('env') and os.path.isdir('env') and filename in os.listdir('env')) or filename == '.env': 48 | os.environ['BOTALITY_ENV_FILE'] = os.path.join('env', filename) if filename != '.env' else '.env' 49 | config = importlib.reload(config_reader).config 50 | return {"response": 'ok'} 51 | return {"error": 'file not found'} 52 | 53 | @app.api_route("/api/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) 54 | async def redirect_request(path: str, request: Request, response: Response): 55 | if bot_instance: 56 | redirect_url = f'{config.sys_api_host}/{path}' 57 | try: 58 | async with httpx.AsyncClient() as client: 59 | headers = dict(request.headers) 60 | headers.pop('content-length', None) 61 | headers.pop('host', None) 62 | headers['content-type'] = 'application/json' 63 | _json = await request.json() if request.method != 'GET' else None 64 | redirect_response = await client.request(request.method, redirect_url, headers=headers, json=_json, timeout=config.sys_request_timeout) 65 | return redirect_response.json() 66 | except Exception: 67 | return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content={"error": "SERVICE UNAVAILABLE"}) 68 | else: 69 | return await vrouter.run('/' + path, request.method, ((await request.json()) if request.method != 'GET' else None)) 70 | 71 | @app.on_event("startup") 72 | def startup_event(): 73 | print("Botality WebUI server is running on", config.sys_webui_host) 74 | 75 | app.mount("/", StaticFiles(directory="static", html=True), name="static") 76 | 77 | @app.exception_handler(404) 78 | async def not_found_handler(a, b): 79 | return FileResponse("static/index.html", status_code=200) 80 | 81 | def serve(): 82 | import uvicorn 83 | [protocol, slash2_host, port] = config.sys_webui_host.split(':') 84 | if (os.environ.get('BOTALITY_AUTOSTART', '') == 'True'): 85 | bot_action('start') 86 | uvicorn.run(app, 87 | host=slash2_host[2:], 88 | port=int(port), 89 | timeout_keep_alive=config.sys_request_timeout, 90 | log_level=config.sys_api_log_level 91 | ) 92 | 93 | if __name__ == '__main__': 94 | serve() -------------------------------------------------------------------------------- /servers/tts_server.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import asyncio 4 | from fastapi import FastAPI 5 | from pydantic import BaseModel, Field 6 | from fastapi.responses import StreamingResponse 7 | from io import BytesIO 8 | from typing_extensions import Literal 9 | sys.path.append('.') 10 | sys.path.append('..') 11 | from providers.tts_provider import init, tts 12 | 13 | 14 | # polyfill for python3.8 15 | if not hasattr(asyncio, 'to_thread'): 16 | import functools 17 | import contextvars 18 | async def to_thread(func, /, *args, **kwargs): 19 | loop = asyncio.get_running_loop() 20 | ctx = contextvars.copy_context() 21 | func_call = functools.partial(ctx.run, func, *args, **kwargs) 22 | return await loop.run_in_executor(None, func_call) 23 | asyncio.to_thread = to_thread 24 | 25 | app = FastAPI() 26 | 27 | class Data(BaseModel): 28 | voice: str = Field(None, description='String with speaker model name') 29 | text: str = Field(None, description='String with text that you want to hear') 30 | response: Literal['file', 'path'] = Field(None, description='String with value "file" or "path", changes the output format to either the path to the recorded audio or the file itself.') 31 | 32 | 33 | @app.post("/") 34 | async def read_root(rqdata: Data): 35 | error, data = await tts(rqdata.voice, rqdata.text) 36 | if not error: 37 | if rqdata.response == 'file': 38 | bytes = BytesIO(open(data, mode='rb').read()) 39 | os.remove(data) 40 | response = StreamingResponse(bytes, media_type='audio/wav') 41 | response.headers["Content-Disposition"] = f"inline; filename=record.wav" 42 | return response 43 | return {"data": data} 44 | else: 45 | return {"error": error} 46 | 47 | if __name__ == "__main__": 48 | import uvicorn 49 | init(allowRemote=False) 50 | uvicorn.run(app, host="0.0.0.0", port=7077) 51 | -------------------------------------------------------------------------------- /static/assets/recommendedModels.js: -------------------------------------------------------------------------------- 1 | const t=["david","forsen","juice-wrld","obiwan","trump","xqc"].map(e=>({voice:e,model:`${e}.pth`,author:"enlyth",repo:"enlyth/baj-tts",path:"models/",size:.9,rename:!1,lang:"en",tone:"m"})),n=["adam_carolla_checkpoint_1360000","alex_jones_checkpoint_2490000","david_attenborough_checkpoint_2020000","james_earl_jones_checkpoint_1600000","joel_osteen_checkpoint_2550000","neil_degrasse_tyson_checkpoint_1910000","tim_dillon_checkpoint_1970000","vincent_price_checkpoint_2080000"].map(e=>({voice:e.split("_checkpoint")[0],model:`${e}.pth`,author:"youmebangbang",repo:"youmebangbang/vits_tts_models",path:"",size:.9,rename:!0,lang:"en",tone:"m"})),l=["G_20000","G_157","G_480","G_449","G_50000","G_18500.pth"],s=["Biden20k","BillClinton","BorisJohnson","GeorgeBush","Obama50k","Trump18.5k"].map((e,a)=>({voice:e.replace(/[0-9.]+k/,""),model:`${l[a]}.pth`,author:"Nardicality",repo:"Nardicality/so-vits-svc-4.0-models",path:`${e}/`,size:.5,train_lang:"en",tone:"m"})),_=["G_50000","G_100000","G_85000"],r=["f","f","m"],c=["Glados_50k","Star-Trek-Computer","Boss_MGS_80k"].map((e,a)=>({voice:e.replace(/_[0-9]+k|-/g,""),model:`${_[a]}.pth`,author:"Amo",repo:"Amo/so-vits-svc-4.0_GA",path:`ModelsFolder/${e}/`,size:.5,train_lang:"en",tone:r[a]})),i=[{voice:"Tim_Cook",model:"Tim_Cook.pth",author:"Sucial",repo:"Sucial/so-vits-svc4.1-Tim_Cook",path:"",size:.2,train_lang:"en",tone:"m"}],o=["2_K","3_K_L","3_K_M","3_K_S","4_0","4_K_M","4_K_S","5_0","5_K_M","5_K_S","6_K","8_0"],m=[["TheBloke/llama2_7b_chat_uncensored-GGUF","llama2_7b_chat_uncensored.Q$.gguf",o],["TheBloke/Luna-AI-Llama2-Uncensored-GGUF","luna-ai-llama2-uncensored.Q$.gguf",o],["TheBloke/Mistral-7B-Instruct-v0.1-GGUF","mistral-7b-instruct-v0.1.Q$.gguf",o],["TheBloke/WizardLM-1.0-Uncensored-Llama2-13B-GGUF","wizardlm-1.0-uncensored-llama2-13b.Q$.gguf",o],["TheBloke/Speechless-Llama2-Hermes-Orca-Platypus-WizardLM-13B-GGUF","speechless-llama2-hermes-orca-platypus-wizardlm-13b.Q$.gguf",o],["TheBloke/OpenBuddy-Llama2-13B-v11.1-GGUF","openbuddy-llama2-13b-v11.1.Q$.gguf",o],["TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF","tinyllama-1.1b-chat-v0.3.Q$.gguf",o]].map(e=>({name:e[1].split(".")[0],repo:e[0],model:e[1],quants:e[2],author:"TheBloke",path:"",size:"2-14"})),p={TTS:{VITS:[...t,...n],SO_VITS_SVC:[...s,...c,...i]},LLM:{GGUF:[...m]}};export{p as models}; 2 | -------------------------------------------------------------------------------- /static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/remixer-dec/botality-ii/0e7973cb87001fabd075c57d6ea6be63b29d0763/static/favicon.ico -------------------------------------------------------------------------------- /static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Botality 9 | 13 | 14 | 15 | 16 | 17 | 20 |
21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import base64 3 | import argparse 4 | import functools 5 | import sys 6 | import json 7 | 8 | async def tg_image_to_data(photo, bot): 9 | if not photo: 10 | return None 11 | file_bytes = BytesIO() 12 | file_path = (await bot.get_file(file_id=photo[-1].file_id)).file_path 13 | await bot.download_file(file_path, file_bytes) 14 | image = 'data:image/png;base64,' + str(base64.b64encode(file_bytes.getvalue()), 'utf-8') 15 | return image 16 | 17 | def b64_to_img(imgstr): 18 | from PIL import Image 19 | decoded = base64.b64decode(imgstr.split(',')[1]) 20 | return Image.open(BytesIO(decoded)) 21 | 22 | # do not sysexit on error 23 | class CustomArgumentParser(argparse.ArgumentParser): 24 | def error(self, message): 25 | raise Exception(message) 26 | 27 | # prevents excessive line-breaks 28 | class CustomHelpFormatter(argparse.HelpFormatter): 29 | def __init__(self, *args, **kwargs): 30 | kwargs["max_help_position"] = 24 31 | kwargs["width"] = 245 32 | super().__init__(*args, **kwargs) 33 | 34 | # join the rest of the arguments, so they can be validated 35 | class JoinNargsAction(argparse.Action): 36 | def __call__(self, parser, namespace, values, option_string=None): 37 | setattr(namespace, self.dest, ' '.join(values)) 38 | 39 | 40 | def parse_photo(message): 41 | if message.photo: 42 | return message.photo 43 | if message.document and message.document.mime_type.startswith('image'): 44 | return [message.document] 45 | if message.reply_to_message: 46 | if message.reply_to_message.photo: 47 | return message.reply_to_message.photo 48 | if message.reply_to_message.document and message.reply_to_message.document.mime_type.startswith('image'): 49 | return [message.reply_to_message.document] 50 | return None 51 | 52 | def log_exceptions(logger): 53 | def decorator(func): 54 | @functools.wraps(func) 55 | async def wrapper(*args, **kwargs): 56 | try: 57 | result = await func(*args, **kwargs) 58 | return result 59 | except Exception as e: 60 | logger.error(f"Error in {func.__name__}: {str(e)}") 61 | return wrapper 62 | return decorator 63 | 64 | def cprint(*args, color='default'): 65 | keys = ['default', 'red', 'green', 'yellow', 'blue'] 66 | sys.stdout.write(f'\x1b[1;3{keys.index(color)}m' + ' '.join(args) + '\x1b[0m\n') 67 | 68 | def update_env(path, key, value): 69 | with open(path, "r") as file: 70 | lines = file.readlines() 71 | with open(path, "w+") as file: 72 | to_write = [] 73 | multiline = False 74 | try: 75 | for line in lines: 76 | if line.replace(' ','').startswith(key + "="): 77 | multiline = line.replace(' ','').startswith(key + "='") and not line.endswith(("'", "'\n","'\r\n",)) 78 | if not multiline: 79 | if type(value) in (dict, list): 80 | value = json.dumps(value) 81 | to_write.append(f"{key}={value}\n") 82 | else: 83 | if isinstance(value, str): 84 | value = json.loads(value) 85 | to_write.append(f"{key}='{json.dumps(value, indent=2, sort_keys=True)}'\n") 86 | continue 87 | if multiline: 88 | if line.endswith(("'", "'\n","'\r\n",)): 89 | multiline = False 90 | continue 91 | to_write.append(line) 92 | file.writelines(to_write) 93 | except Exception as e: 94 | file.writelines(lines) 95 | cprint("Unable to update .env file: " + str(e), color='red') 96 | raise Exception(e) 97 | 98 | async def download_audio(bot, file_id, dl_path): 99 | file_path = (await bot.get_file(file_id=file_id)).file_path 100 | await bot.download_file(file_path, dl_path) 101 | 102 | def raise_rail_exceptions(error, data): 103 | if error: 104 | raise Exception(error) 105 | else: 106 | return data 107 | --------------------------------------------------------------------------------