├── .dockerignore ├── .env-sample ├── .flake8 ├── .gitattributes ├── .gitignore ├── .gitmodules ├── .idea ├── .gitignore ├── app.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── .vscode ├── launch.json ├── redis-xplorer.redis └── settings.json ├── LICENSE ├── app ├── __init__.py ├── auth │ └── admin.py ├── common │ ├── app_settings.py │ ├── app_settings_llama_cpp.py │ ├── config.py │ ├── constants.py │ └── lotties.py ├── contents │ ├── browsing_demo.gif │ ├── browsing_demo.png │ ├── code_demo.png │ ├── edit_title_demo.png │ ├── embed_file_demo.png │ ├── favicon.ico │ ├── llama_api.png │ ├── model_selection_demo.png │ ├── site_logo.png │ ├── state_of_the_union.txt │ ├── ui_demo.png │ └── upload_demo.gif ├── database │ ├── __init__.py │ ├── connection.py │ ├── crud │ │ ├── api_keys.py │ │ ├── api_whitelists.py │ │ ├── deprecated_chatgpt.py │ │ └── users.py │ └── schemas │ │ ├── __init__.py │ │ ├── auth.py │ │ └── deprecated_chatgpt.py ├── dependencies.py ├── deprecated_chatgpt.html ├── errors │ ├── api_exceptions.py │ └── chat_exceptions.py ├── middlewares │ ├── token_validator.py │ └── trusted_hosts.py ├── mixins │ └── enum.py ├── models │ ├── base_models.py │ ├── chat_commands.py │ ├── chat_models.py │ ├── completion_models.py │ ├── function_calling │ │ ├── base.py │ │ └── functions.py │ ├── llm_tokenizers.py │ └── llms.py ├── routers │ ├── __init__.py │ ├── auth.py │ ├── index.py │ ├── services.py │ ├── user_services.py │ ├── users.py │ ├── v1.py │ └── websocket.py ├── shared.py ├── utils │ ├── api │ │ ├── completion.py │ │ ├── duckduckgo.py │ │ ├── translate.py │ │ └── weather.py │ ├── auth │ │ ├── api_keys.py │ │ ├── register_validation.py │ │ └── token.py │ ├── chat │ │ ├── buffer.py │ │ ├── build_llama_shared_lib.py │ │ ├── chat_rooms.py │ │ ├── commands │ │ │ ├── browsing.py │ │ │ ├── core.py │ │ │ ├── llm_parameter.py │ │ │ ├── prompt.py │ │ │ ├── server.py │ │ │ ├── summarize.py │ │ │ ├── testing.py │ │ │ └── vectorstore.py │ │ ├── embeddings │ │ │ ├── __init__.py │ │ │ ├── sentence_encoder.py │ │ │ └── transformer.py │ │ ├── file_loader.py │ │ ├── managers │ │ │ ├── cache.py │ │ │ ├── message.py │ │ │ ├── stream.py │ │ │ ├── vectorstore.py │ │ │ └── websocket.py │ │ ├── messages │ │ │ ├── converter.py │ │ │ ├── handler.py │ │ │ └── turn_templates.py │ │ ├── text_generations │ │ │ ├── __init__.py │ │ │ ├── completion_api.py │ │ │ ├── converter.py │ │ │ ├── exllama.py │ │ │ ├── llama_cpp.py │ │ │ ├── path.py │ │ │ └── summarization.py │ │ └── tokens.py │ ├── colorama.py │ ├── date_utils.py │ ├── encoding_utils.py │ ├── errors.py │ ├── function_calling │ │ ├── callbacks │ │ │ ├── click_link.py │ │ │ ├── full_browsing.py │ │ │ ├── lite_browsing.py │ │ │ ├── translate.py │ │ │ └── vectorstore_search.py │ │ ├── parser.py │ │ ├── query.py │ │ ├── request.py │ │ └── token_count.py │ ├── huggingface.py │ ├── js_initializer.py │ ├── langchain │ │ ├── chat_llama_cpp.py │ │ ├── embeddings_api.py │ │ ├── qdrant_vectorstore.py │ │ ├── redis_vectorstore.py │ │ ├── structured_tool.py │ │ ├── token_text_splitter.py │ │ └── web_search.py │ ├── logger.py │ ├── module_reloader.py │ ├── params_utils.py │ ├── system.py │ ├── tests │ │ ├── random.py │ │ └── timeit.py │ └── types.py ├── viewmodels │ ├── admin.py │ └── status.py └── web │ ├── .last_build_id │ ├── assets │ ├── AssetManifest.bin │ ├── AssetManifest.json │ ├── FontManifest.json │ ├── NOTICES │ ├── assets │ │ ├── images │ │ │ ├── ai_profile.png │ │ │ ├── favicon.ico │ │ │ ├── openai_profile.png │ │ │ ├── site_logo.png │ │ │ ├── user_profile.png │ │ │ └── vicuna_profile.jpg │ │ ├── lotties │ │ │ ├── click.json │ │ │ ├── fail.json │ │ │ ├── file-upload.json │ │ │ ├── go-back.json │ │ │ ├── ok.json │ │ │ ├── read.json │ │ │ ├── scroll-down.json │ │ │ ├── search-doc.json │ │ │ ├── search-web.json │ │ │ └── translate.json │ │ └── svgs │ │ │ └── search-web.svg │ ├── fonts │ │ └── MaterialIcons-Regular.otf │ ├── packages │ │ └── cupertino_icons │ │ │ └── assets │ │ │ └── CupertinoIcons.ttf │ └── shaders │ │ └── ink_sparkle.frag │ ├── canvaskit │ ├── canvaskit.js │ ├── canvaskit.wasm │ ├── chromium │ │ ├── canvaskit.js │ │ └── canvaskit.wasm │ ├── skwasm.js │ ├── skwasm.wasm │ └── skwasm.worker.js │ ├── favicon.ico │ ├── flutter.js │ ├── flutter_service_worker.js │ ├── icons │ ├── Icon-192.png │ ├── Icon-512.png │ ├── Icon-maskable-192.png │ └── Icon-maskable-512.png │ ├── index.html │ ├── main.dart.js │ ├── manifest.json │ ├── site_logo.png │ └── version.json ├── build-web.bat ├── docker-compose-local.yaml ├── docker-compose-prod.yaml ├── dockerfile ├── init-git.sh ├── llama_models ├── ggml │ └── llama_cpp_models_here.txt └── gptq │ └── exllama_models_here.txt ├── main.py ├── maintools.py ├── notebook ├── function-calling-example.ipynb └── langchain-summarization.ipynb ├── pyproject.toml ├── qdrant ├── .lock └── meta.json ├── readme.md ├── requirements-exllama.txt ├── requirements-llama-cpp.txt ├── requirements-sentence-encoder.txt ├── requirements-transformer-embedding-cpu.txt ├── requirements-transformer-embedding-gpu.txt ├── requirements.txt ├── run_local.bat ├── tests ├── __init__.py ├── __pycache__ │ └── test_completion_api.cpython-311-pytest-7.4.0.pyc.8996 ├── conftest.py ├── test_api_service.py ├── test_auth.py ├── test_chat.py ├── test_completion_api.py ├── test_database.py ├── test_dynamic_llm_load.py ├── test_memory_deallocation.py ├── test_translate.py └── test_vectorstore.py └── update-git.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | .env 2 | llama_models/* -------------------------------------------------------------------------------- /.env-sample: -------------------------------------------------------------------------------- 1 | 2 | # API_ENV can be "local", "test", "prod" 3 | API_ENV="local" 4 | 5 | # Port for Docker. If you run this app without Docker, this will be ignored to 8001 6 | PORT=8000 7 | 8 | # Default LLM model for each chat. 9 | DEFAULT_LLM_MODEL="gpt_3_5_turbo" 10 | 11 | # Your MySQL DB info 12 | MYSQL_DATABASE="traffic" 13 | MYSQL_TEST_DATABASE="testing_db" 14 | MYSQL_ROOT_PASSWORD="YOUR_MYSQL_PASSWORD_HERE" 15 | MYSQL_USER="traffic_admin" 16 | MYSQL_PASSWORD="YOUR_DB_ADMIN_PASSWORD_HERE" 17 | 18 | # Your Redis DB info 19 | REDIS_DATABASE="0" 20 | REDIS_PASSWORD="YOUR_REDIS_PASSWORD_HERE" 21 | 22 | # Your JWT secret key 23 | JWT_SECRET="ANY_PASSWORD_FOR_JWT_TOKEN_GENERATION_HERE" 24 | 25 | # Your OpenAI API key 26 | OPENAI_API_KEY="sk-*************" 27 | 28 | # Chatbot settings 29 | # Summarize for chat: Do token summarization for message more than SUMMARIZATION_THRESHOLD 30 | SUMMARIZE_FOR_CHAT=True 31 | SUMMARIZATION_THRESHOLD=512 32 | 33 | # Embedding text will be chunked by EMBEDDING_TOKEN_CHUNK_SIZE with EMBEDDING_TOKEN_CHUNK_OVERLAP 34 | # overlap means how many tokens will be overlapped between each chunk. 35 | EMBEDDING_TOKEN_CHUNK_SIZE=512 36 | EMBEDDING_TOKEN_CHUNK_OVERLAP=128 37 | 38 | # The shared vector collection name. This will be shared for all users. 39 | QDRANT_COLLECTION="SharedCollection" 40 | 41 | # If you want to set prefix or suffix for all prompt to LLM, set these. 42 | GLOBAL_PREFIX="" 43 | GLOBAL_SUFFIX="" 44 | 45 | # If you want to use local embedding instead of OpenAI's Ada-002, 46 | # set LOCAL_EMBEDDING_MODEL as "intfloat/e5-large-v2" or other huggingface embedding model repo. 47 | # Warning: Local embedding needs a lot of computing resources!!! 48 | LOCAL_EMBEDDING_MODEL=None 49 | 50 | 51 | # Define these if you want to open production server with API_ENV="prod" 52 | HOST_IP="OPTIONAL_YOUR_IP_HERE e.g. 192.168.0.2" 53 | HOST_MAIN="OPTIONAL_YOUR_DOMAIN_HERE e.g. yourdomain.com, if you are running API_ENV as production, this will be needed for TLS certificate registration" 54 | HOST_SUB="OPTIONAL_YOUR_SUB_DOMAIN_HERE e.g. mobile.yourdomain.com" 55 | MY_EMAIL="OPTIONAL_YOUR_DOMAIN_HERE e.g. yourdomain.com, if you are running API_ENV as production, this will be needed for TLS certificate registration" 56 | 57 | # Not used. 58 | AWS_ACCESS_KEY="OPTIONAL_IF_YOU_NEED" 59 | AWS_SECRET_KEY="OPTIONAL_IF_YOU_NEED" 60 | AWS_AUTHORIZED_EMAIL="OPTIONAL_IF_YOU_NEED" 61 | SAMPLE_JWT_TOKEN="OPTIONAL_IF_YOU_NEED_FOR_TESTING e.g. Bearer XXXXX" 62 | SAMPLE_ACCESS_KEY="OPTIONAL_IF_YOU_NEED_FOR_TESTING" 63 | SAMPLE_SECRET_KEY="OPTIONAL_IF_YOU_NEED_FOR_TESTING" 64 | KAKAO_RESTAPI_TOKEN="OPTIONAL_IF_YOU_NEED e.g. Bearer XXXXX" 65 | WEATHERBIT_API_KEY="OPTIONAL_IF_YOU_NEED" 66 | NASA_API_KEY="OPTIONAL_IF_YOU_NEED" 67 | 68 | # For translation. If you don't need translation, you can ignore these. 69 | PAPAGO_CLIENT_ID="OPTIONAL_FOR_TRANSTLATION" 70 | PAPAGO_CLIENT_SECRET="OPTIONAL_FOR_TRANSTLATION" 71 | GOOGLE_CLOUD_PROJECT_ID="OPTIONAL_FOR_TRANSTLATION e.g. top-abcd-01234" 72 | GOOGLE_TRANSLATE_API_KEY ="OPTIONAL_FOR_TRANSTLATION" 73 | GOOGLE_TRANSLATE_OAUTH_ID="OPTIONAL_FOR_TRANSTLATION" 74 | GOOGLE_TRANSLATE_OAUTH_SECRET="OPTIONAL_FOR_TRANSTLATION" 75 | RAPIDAPI_KEY="OPTIONAL_FOR_TRANSLATION" 76 | CUSTOM_TRANSLATE_URL="OPTIONAL_FOR_TRANSLATION" 77 | 78 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ; ignore = W503, E501 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.js linguist-detectable=false 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | acme.json 3 | PRIVATE_* 4 | venv/ 5 | .cache/ 6 | *.pyc 7 | *.log 8 | llama_models/* 9 | !llama_models/ggml/llama_cpp_models_here.txt 10 | !llama_models/gptq/exllama_models_here.txt 11 | deprecated_* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "frontend"] 2 | path = frontend 3 | url = https://github.com/c0sogi/LLMChat-frontend 4 | [submodule "repositories/llama_cpp"] 5 | path = repositories/llama_cpp 6 | url = https://github.com/abetlen/llama-cpp-python 7 | [submodule "repositories/exllama"] 8 | path = repositories/exllama 9 | url = https://github.com/turboderp/exllama 10 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/app.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 34 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Flutter", 6 | "type": "dart", 7 | "request": "launch", 8 | "program": "frontend/lib/main.dart", 9 | "deviceId": "chrome" 10 | }, 11 | { 12 | "name": "Python: FastAPI", 13 | "type": "python", 14 | "request": "launch", 15 | "module": "main", 16 | "jinja": true, 17 | "justMyCode": true, 18 | "console": "integratedTerminal", 19 | "env": { 20 | "PYTHONPATH": "${workspaceFolder}" 21 | }, 22 | "args": [ 23 | "uvicorn", 24 | "main:app", 25 | "--reload" 26 | ] 27 | } 28 | ] 29 | } -------------------------------------------------------------------------------- /.vscode/redis-xplorer.redis: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/.vscode/redis-xplorer.redis -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests", 4 | "-s", 5 | "-v", 6 | "--capture=no" 7 | ], 8 | "python.testing.unittestEnabled": false, 9 | "python.testing.pytestEnabled": true, 10 | "files.exclude": { 11 | "**/.git": true, 12 | "**/.svn": true, 13 | "**/.hg": true, 14 | "**/CVS": true, 15 | "**/.DS_Store": true, 16 | "**/Thumbs.db": true, 17 | "**/__pycache__": true, 18 | "**/.pytest_cache": true, 19 | }, 20 | "python.analysis.typeCheckingMode": "basic", 21 | "python.linting.flake8Enabled": true, 22 | "python.linting.mypyEnabled": false, 23 | "python.linting.enabled": true, 24 | "files.associations": { 25 | "*.inc": "cpp" 26 | }, 27 | "python.linting.pylintEnabled": false, 28 | "pylint.path": [ 29 | "-" 30 | ], 31 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Seongbin Choi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/__init__.py -------------------------------------------------------------------------------- /app/auth/admin.py: -------------------------------------------------------------------------------- 1 | from starlette.requests import Request 2 | from starlette.responses import Response 3 | from starlette_admin.auth import AdminUser, AuthProvider 4 | from starlette_admin.exceptions import FormValidationError, LoginFailed 5 | from app.common.config import config 6 | 7 | 8 | class MyAuthProvider(AuthProvider): 9 | async def login( 10 | self, 11 | username: str, 12 | password: str, 13 | remember_me: bool, 14 | request: Request, 15 | response: Response, 16 | ) -> Response: 17 | if len(username) < 3: 18 | """Form data validation""" 19 | raise FormValidationError( 20 | {"username": "Ensure username has at least 03 characters"} 21 | ) 22 | 23 | if username == config.mysql_user and password == config.mysql_password: 24 | """Save `username` in session""" 25 | request.session.update({"username": username}) 26 | return response 27 | 28 | raise LoginFailed("Invalid username or password") 29 | 30 | async def is_authenticated(self, request) -> bool: 31 | if request.session.get("username", None) == config.mysql_user: 32 | """ 33 | Save current `user` object in the request state. Can be used later 34 | to restrict access to connected user. 35 | """ 36 | return True 37 | 38 | return False 39 | 40 | def get_admin_user(self, request: Request) -> AdminUser: 41 | return AdminUser(username=config.mysql_user) 42 | 43 | async def logout(self, request: Request, response: Response) -> Response: 44 | request.session.clear() 45 | return response 46 | -------------------------------------------------------------------------------- /app/common/app_settings_llama_cpp.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from multiprocessing import Process 3 | from os import kill 4 | from signal import SIGINT 5 | from threading import Event 6 | from urllib import parse 7 | 8 | import requests 9 | from fastapi import FastAPI 10 | from starlette.middleware.cors import CORSMiddleware 11 | 12 | from app.shared import Shared 13 | from app.utils.logger import ApiLogger 14 | 15 | from .config import Config 16 | 17 | 18 | def check_health(url: str) -> bool: 19 | """Check if the given url is available or not""" 20 | try: 21 | schema = parse.urlparse(url).scheme 22 | netloc = parse.urlparse(url).netloc 23 | if requests.get(f"{schema}://{netloc}/health").status_code != 200: 24 | return False 25 | return True 26 | except Exception: 27 | return False 28 | 29 | 30 | def start_llama_cpp_server(config: Config, shared: Shared): 31 | """Start Llama CPP server. if it is already running, terminate it first.""" 32 | 33 | if shared.process.is_alive(): 34 | ApiLogger.cwarning("Terminating existing Llama CPP server") 35 | shared.process.terminate() 36 | shared.process.join() 37 | 38 | if config.llama_server_port is None: 39 | raise NotImplementedError("Llama CPP server port is not set") 40 | 41 | ApiLogger.ccritical("Starting Llama CPP server") 42 | shared.process = Process( 43 | target=run_llama_cpp, args=(config.llama_server_port,), daemon=True 44 | ) 45 | shared.process.start() 46 | 47 | 48 | def shutdown_llama_cpp_server(shared: Shared): 49 | """Shutdown Llama CPP server.""" 50 | ApiLogger.ccritical("Shutting down Llama CPP server") 51 | if shared.process.is_alive() and shared.process.pid: 52 | kill(shared.process.pid, SIGINT) 53 | shared.process.join() 54 | 55 | 56 | def monitor_llama_cpp_server( 57 | config: Config, 58 | shared: Shared, 59 | ) -> None: 60 | """Monitors the Llama CPP server and handles server availability. 61 | 62 | Parameters: 63 | - `config: Config`: An object representing the server configuration. 64 | - `shared: Shared`: An object representing shared data.""" 65 | thread_sigterm: Event = shared.thread_terminate_signal 66 | if not config.llama_completion_url: 67 | return 68 | while True: 69 | if not check_health(config.llama_completion_url): 70 | if thread_sigterm.is_set(): 71 | break 72 | if config.is_llama_booting: 73 | continue 74 | ApiLogger.cerror("Llama CPP server is not available") 75 | config.is_llama_available = False 76 | config.is_llama_booting = True 77 | try: 78 | start_llama_cpp_server(config=config, shared=shared) 79 | except (ImportError, NotImplementedError): 80 | ApiLogger.cerror( 81 | "ImportError: Llama CPP server is not available" 82 | ) 83 | return 84 | except Exception: 85 | ApiLogger.cexception( 86 | "Unknown error: Llama CPP server is not available" 87 | ) 88 | config.is_llama_booting = False 89 | continue 90 | else: 91 | config.is_llama_booting = False 92 | config.is_llama_available = True 93 | shutdown_llama_cpp_server(shared) 94 | 95 | 96 | @asynccontextmanager 97 | async def lifespan_llama_cpp(app: FastAPI): 98 | ApiLogger.ccritical("🦙 Llama.cpp server is running") 99 | yield 100 | ApiLogger.ccritical("🦙 Shutting down llama.cpp server...") 101 | 102 | 103 | def create_app_llama_cpp(): 104 | from app.routers import v1 105 | 106 | new_app = FastAPI( 107 | title="🦙 llama.cpp Python API", 108 | version="0.0.1", 109 | lifespan=lifespan_llama_cpp, 110 | ) 111 | new_app.add_middleware( 112 | CORSMiddleware, 113 | allow_origins=["*"], 114 | allow_credentials=True, 115 | allow_methods=["*"], 116 | allow_headers=["*"], 117 | ) 118 | 119 | @new_app.get("/health") 120 | async def health(): 121 | return "ok" 122 | 123 | new_app.include_router(v1.router) 124 | return new_app 125 | 126 | 127 | def run_llama_cpp(port: int) -> None: 128 | from uvicorn import Config, Server 129 | 130 | from maintools import initialize_before_launch 131 | 132 | initialize_before_launch() 133 | 134 | Server( 135 | config=Config( 136 | create_app_llama_cpp(), 137 | host="0.0.0.0", 138 | port=port, 139 | log_level="warning", 140 | ) 141 | ).run() 142 | 143 | 144 | if __name__ == "__main__": 145 | run_llama_cpp(port=8002) 146 | -------------------------------------------------------------------------------- /app/common/lotties.py: -------------------------------------------------------------------------------- 1 | from app.mixins.enum import EnumMixin 2 | 3 | 4 | class Lotties(EnumMixin): 5 | CLICK = "lottie-click" 6 | READ = "lottie-read" 7 | SCROLL_DOWN = "lottie-scroll-down" 8 | GO_BACK = "lottie-go-back" 9 | SEARCH_WEB = "lottie-search-web" 10 | SEARCH_DOC = "lottie-search-doc" 11 | OK = "lottie-ok" 12 | FAIL = "lottie-fail" 13 | TRANSLATE = "lottie-translate" 14 | 15 | def format(self, contents: str, end: bool = True) -> str: 16 | return f"\n```{self.get_value(self)}\n{contents}" + ( 17 | "\n```\n" if end else "" 18 | ) 19 | 20 | 21 | if __name__ == "__main__": 22 | print(Lotties.CLICK.format("hello")) 23 | -------------------------------------------------------------------------------- /app/contents/browsing_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/browsing_demo.gif -------------------------------------------------------------------------------- /app/contents/browsing_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/browsing_demo.png -------------------------------------------------------------------------------- /app/contents/code_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/code_demo.png -------------------------------------------------------------------------------- /app/contents/edit_title_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/edit_title_demo.png -------------------------------------------------------------------------------- /app/contents/embed_file_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/embed_file_demo.png -------------------------------------------------------------------------------- /app/contents/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/favicon.ico -------------------------------------------------------------------------------- /app/contents/llama_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/llama_api.png -------------------------------------------------------------------------------- /app/contents/model_selection_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/model_selection_demo.png -------------------------------------------------------------------------------- /app/contents/site_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/site_logo.png -------------------------------------------------------------------------------- /app/contents/ui_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/ui_demo.png -------------------------------------------------------------------------------- /app/contents/upload_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/upload_demo.gif -------------------------------------------------------------------------------- /app/database/__init__.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import declarative_base 2 | from sqlalchemy.orm.decl_api import DeclarativeMeta 3 | from typing import TypeVar 4 | 5 | Base: DeclarativeMeta = declarative_base() 6 | TableGeneric = TypeVar("TableGeneric") 7 | -------------------------------------------------------------------------------- /app/database/crud/api_keys.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | from sqlalchemy import select, func, exists 3 | from app.models.base_models import AddApiKey 4 | from app.errors.api_exceptions import ( 5 | Responses_400, 6 | Responses_404, 7 | Responses_500, 8 | ) 9 | from app.common.config import MAX_API_KEY 10 | from app.database.connection import db 11 | from app.database.schemas.auth import ApiKeys, Users 12 | from app.utils.auth.api_keys import generate_new_api_key 13 | 14 | 15 | async def create_api_key( 16 | user_id: int, 17 | additional_key_info: AddApiKey, 18 | ) -> ApiKeys: 19 | if db.session is None: 20 | raise Responses_500.database_not_initialized 21 | async with db.session() as transaction: 22 | api_key_count_stmt = select(func.count(ApiKeys.id)).filter_by(user_id=user_id) 23 | api_key_count: int | None = await transaction.scalar(api_key_count_stmt) 24 | if api_key_count is not None and api_key_count >= MAX_API_KEY: 25 | raise Responses_400.max_key_count_exceed 26 | while True: 27 | new_api_key: ApiKeys = generate_new_api_key( 28 | user_id=user_id, additional_key_info=additional_key_info 29 | ) 30 | is_api_key_duplicate_stmt = select( 31 | exists().where(ApiKeys.access_key == new_api_key.access_key) 32 | ) 33 | is_api_key_duplicate: bool | None = await transaction.scalar( 34 | is_api_key_duplicate_stmt 35 | ) 36 | if not is_api_key_duplicate: 37 | break 38 | transaction.add(new_api_key) 39 | await transaction.commit() 40 | await transaction.refresh(new_api_key) 41 | return new_api_key 42 | 43 | 44 | async def get_api_keys(user_id: int) -> list[ApiKeys]: 45 | return await ApiKeys.fetchall_filtered_by(user_id=user_id) 46 | 47 | 48 | async def get_api_key_owner(access_key: str) -> Users: 49 | if db.session is None: 50 | raise Responses_500.database_not_initialized 51 | async with db.session() as transaction: 52 | matched_api_key: Optional[ApiKeys] = await transaction.scalar( 53 | select(ApiKeys).filter_by(access_key=access_key) 54 | ) 55 | if matched_api_key is None: 56 | raise Responses_404.not_found_access_key 57 | owner: Users = await Users.first_filtered_by(id=matched_api_key.user_id) 58 | if owner is None: 59 | raise Responses_404.not_found_user 60 | return owner 61 | 62 | 63 | async def get_api_key_and_owner(access_key: str) -> Tuple[ApiKeys, Users]: 64 | if db.session is None: 65 | raise Responses_500.database_not_initialized 66 | async with db.session() as transaction: 67 | matched_api_key: Optional[ApiKeys] = await transaction.scalar( 68 | select(ApiKeys).filter_by(access_key=access_key) 69 | ) 70 | if matched_api_key is None: 71 | raise Responses_404.not_found_access_key 72 | api_key_owner: Optional[Users] = await transaction.scalar( 73 | select(Users).filter_by(id=matched_api_key.user_id) 74 | ) 75 | if api_key_owner is None: 76 | raise Responses_404.not_found_user 77 | return matched_api_key, api_key_owner 78 | 79 | 80 | async def update_api_key( 81 | updated_key_info: dict, 82 | access_key_id: int, 83 | user_id: int, 84 | ) -> ApiKeys: 85 | if db.session is None: 86 | raise Responses_500.database_not_initialized 87 | async with db.session() as transaction: 88 | matched_api_key: Optional[ApiKeys] = await transaction.scalar( 89 | select(ApiKeys).filter_by(id=access_key_id, user_id=user_id) 90 | ) 91 | if matched_api_key is None: 92 | raise Responses_404.not_found_api_key 93 | matched_api_key.set_values_as(**updated_key_info) 94 | transaction.add(matched_api_key) 95 | await transaction.commit() 96 | await transaction.refresh(matched_api_key) 97 | return matched_api_key 98 | 99 | 100 | async def delete_api_key( 101 | access_key_id: int, 102 | access_key: str, 103 | user_id: int, 104 | ) -> None: 105 | if db.session is None: 106 | raise Responses_500.database_not_initialized 107 | async with db.session() as transaction: 108 | matched_api_key: Optional[ApiKeys] = await transaction.scalar( 109 | select(ApiKeys).filter_by( 110 | id=access_key_id, user_id=user_id, access_key=access_key 111 | ) 112 | ) 113 | if matched_api_key is None: 114 | raise Responses_404.not_found_api_key 115 | await transaction.delete(matched_api_key) 116 | await transaction.commit() 117 | -------------------------------------------------------------------------------- /app/database/crud/api_whitelists.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from sqlalchemy import select, func 3 | from app.errors.api_exceptions import ( 4 | Responses_400, 5 | Responses_404, 6 | Responses_500, 7 | ) 8 | from app.common.config import MAX_API_WHITELIST 9 | from app.database.connection import db 10 | from app.database.schemas.auth import ( 11 | ApiKeys, 12 | ApiWhiteLists, 13 | ) 14 | 15 | 16 | async def create_api_key_whitelist( 17 | ip_address: str, 18 | api_key_id: int, 19 | ) -> ApiWhiteLists: 20 | if db.session is None: 21 | raise Responses_500.database_not_initialized 22 | async with db.session() as transaction: 23 | whitelist_count_stmt = select(func.count(ApiWhiteLists.id)).filter_by(api_key_id=api_key_id) 24 | whitelist_count: int | None = await transaction.scalar(whitelist_count_stmt) 25 | if whitelist_count is not None and whitelist_count >= MAX_API_WHITELIST: 26 | raise Responses_400.max_whitekey_count_exceed 27 | ip_duplicated_whitelist_stmt = select(ApiWhiteLists).filter_by(api_key_id=api_key_id, ip_address=ip_address) 28 | ip_duplicated_whitelist = await transaction.scalar(ip_duplicated_whitelist_stmt) 29 | if ip_duplicated_whitelist is not None: 30 | return ip_duplicated_whitelist 31 | new_whitelist = ApiWhiteLists(api_key_id=api_key_id, ip_address=ip_address) 32 | transaction.add(new_whitelist) 33 | await transaction.commit() 34 | await transaction.refresh(new_whitelist) 35 | return new_whitelist 36 | 37 | 38 | async def get_api_key_whitelist(api_key_id: int) -> list[ApiWhiteLists]: 39 | return await ApiWhiteLists.fetchall_filtered_by(api_key_id=api_key_id) 40 | 41 | 42 | async def delete_api_key_whitelist( 43 | user_id: int, 44 | api_key_id: int, 45 | whitelist_id: int, 46 | ) -> None: 47 | if db.session is None: 48 | raise Responses_500.database_not_initialized 49 | async with db.session() as transaction: 50 | matched_api_key: Optional[ApiKeys] = await transaction.scalar( 51 | select(ApiKeys).filter_by(id=api_key_id, user_id=user_id) 52 | ) 53 | if matched_api_key is None: 54 | raise Responses_404.not_found_api_key 55 | matched_whitelist_stmt = select(ApiWhiteLists).filter_by(id=whitelist_id, api_key_id=api_key_id) 56 | matched_whitelist = await transaction.scalar(matched_whitelist_stmt) 57 | await transaction.delete(matched_whitelist) 58 | await transaction.commit() 59 | -------------------------------------------------------------------------------- /app/database/crud/deprecated_chatgpt.py: -------------------------------------------------------------------------------- 1 | # from typing import Optional 2 | 3 | # from sqlalchemy import select 4 | # from app.errors.api_exceptions import ( 5 | # Responses_404, 6 | # Responses_500, 7 | # ) 8 | # from app.database.connection import db 9 | # from app.database.schemas.auth import ( 10 | # ChatMessages, 11 | # ChatRooms, 12 | # GptPresets, 13 | # ) 14 | 15 | 16 | # async def create_chat_room( 17 | # chat_room_type: str, 18 | # name: str, 19 | # description: str | None, 20 | # user_id: int, 21 | # status: str = "active", 22 | # ) -> ChatRooms: 23 | # if db.session is None: 24 | # raise Responses_500.database_not_initialized 25 | # async with db.session() as transaction: 26 | # new_chat_room: ChatRooms = ChatRooms( 27 | # status=status, 28 | # chat_room_type=chat_room_type, 29 | # name=name, 30 | # description=description, 31 | # user_id=user_id, 32 | # ) 33 | # transaction.add(new_chat_room) 34 | # await transaction.commit() 35 | # await transaction.refresh(new_chat_room) 36 | # return new_chat_room 37 | 38 | 39 | # async def get_chat_all_rooms(user_id: int) -> list[ChatRooms]: 40 | # return await ChatRooms.fetchall_filtered_by(user_id=user_id) # type: ignore 41 | 42 | 43 | # async def create_chat_message( 44 | # role: str, 45 | # message: str, 46 | # chat_room_id: int, 47 | # user_id: int, 48 | # status: str = "active", 49 | # ) -> ChatMessages: 50 | # if db.session is None: 51 | # raise Responses_500.database_not_initialized 52 | # async with db.session() as transaction: 53 | # new_chat_message: ChatMessages = ChatMessages( 54 | # status=status, 55 | # role=role, 56 | # message=message, 57 | # user_id=user_id, 58 | # chat_room_id=chat_room_id, 59 | # ) 60 | # transaction.add(new_chat_message) 61 | # await transaction.commit() 62 | # await transaction.refresh(new_chat_message) 63 | # return new_chat_message 64 | 65 | 66 | # async def get_chat_all_messages( 67 | # chat_room_id: int, 68 | # user_id: int, 69 | # ) -> list[ChatMessages]: 70 | # return await ChatMessages.fetchall_filtered_by( 71 | # user_id=user_id, 72 | # chat_room_id=chat_room_id, 73 | # ) # type: ignore 74 | 75 | 76 | # async def create_gpt_preset( 77 | # user_id: int, 78 | # temperature: float, 79 | # top_p: float, 80 | # presence_penalty: float, 81 | # frequency_penalty: float, 82 | # ) -> GptPresets: 83 | # if db.session is None: 84 | # raise Responses_500.database_not_initialized 85 | # async with db.session() as transaction: 86 | # new_gpt_preset: GptPresets = GptPresets( 87 | # user_id=user_id, 88 | # temperature=temperature, 89 | # top_p=top_p, 90 | # presence_penalty=presence_penalty, 91 | # frequency_penalty=frequency_penalty, 92 | # ) 93 | # transaction.add(new_gpt_preset) 94 | # await transaction.commit() 95 | # await transaction.refresh(new_gpt_preset) 96 | # return new_gpt_preset 97 | 98 | 99 | # async def get_gpt_presets(user_id: int) -> list[GptPresets]: 100 | # return await GptPresets.fetchall_filtered_by(user_id=user_id) # type: ignore 101 | 102 | 103 | # async def get_gpt_preset(user_id: int, preset_id: int) -> GptPresets: 104 | # return await GptPresets.fetchone_filtered_by(user_id=user_id, id=preset_id) 105 | 106 | 107 | # async def update_gpt_preset( 108 | # user_id: int, 109 | # preset_id: int, 110 | # temperature: float, 111 | # top_p: float, 112 | # presence_penalty: float, 113 | # frequency_penalty: float, 114 | # status: str = "active", 115 | # ) -> GptPresets: 116 | # if db.session is None: 117 | # raise Responses_500.database_not_initialized 118 | # async with db.session() as transaction: 119 | # matched_preset: Optional[GptPresets] = await transaction.scalar( 120 | # select(GptPresets).filter_by(user_id=user_id, id=preset_id) 121 | # ) 122 | # if matched_preset is None: 123 | # raise Responses_404.not_found_preset 124 | # matched_preset.set_values_as( 125 | # temperature=temperature, 126 | # top_p=top_p, 127 | # presence_penalty=presence_penalty, 128 | # frequency_penalty=frequency_penalty, 129 | # status=status, 130 | # ) 131 | # transaction.add(matched_preset) 132 | # await transaction.commit() 133 | # await transaction.refresh(matched_preset) 134 | # return matched_preset 135 | 136 | 137 | # async def delete_gpt_preset(user_id: int, preset_id: int) -> None: 138 | # if db.session is None: 139 | # raise Responses_500.database_not_initialized 140 | # async with db.session() as transaction: 141 | # matched_preset: Optional[GptPresets] = await transaction.scalar( 142 | # select(GptPresets).filter_by(user_id=user_id, id=preset_id) 143 | # ) 144 | # if matched_preset is None: 145 | # raise Responses_404.not_found_preset 146 | # await transaction.delete(matched_preset) 147 | # await transaction.commit() 148 | # return 149 | -------------------------------------------------------------------------------- /app/database/crud/users.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import select, exists 2 | from app.database.connection import db 3 | from app.database.schemas.auth import ( 4 | Users, 5 | ApiKeys, 6 | ) 7 | 8 | 9 | async def is_email_exist(email: str) -> bool: 10 | return True if await db.scalar(select(exists().where(Users.email == email))) else False 11 | 12 | 13 | async def get_me(user_id: int): 14 | return await Users.first_filtered_by(id=user_id) 15 | 16 | 17 | async def is_valid_api_key(access_key: str) -> bool: 18 | return True if await db.scalar(select(exists().where(ApiKeys.access_key == access_key))) else False 19 | 20 | 21 | async def register_new_user( 22 | email: str, 23 | hashed_password: str, 24 | ip_address: str | None, 25 | ) -> Users: 26 | return ( 27 | await Users.add_one( 28 | autocommit=True, 29 | refresh=True, 30 | email=email, 31 | password=hashed_password, 32 | ip_address=ip_address, 33 | ) 34 | if ip_address 35 | else await Users.add_one( 36 | autocommit=True, 37 | refresh=True, 38 | email=email, 39 | password=hashed_password, 40 | ) 41 | ) 42 | 43 | 44 | async def find_matched_user(email: str) -> Users: 45 | return await Users.first_filtered_by(email=email) 46 | -------------------------------------------------------------------------------- /app/database/schemas/auth.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from sqlalchemy import ( 3 | String, 4 | Integer, 5 | Enum, 6 | Boolean, 7 | ForeignKey, 8 | ) 9 | from sqlalchemy.orm import ( 10 | relationship, 11 | Mapped, 12 | mapped_column, 13 | ) 14 | 15 | from app.viewmodels.status import ApiKeyStatus, UserStatus 16 | from .. import Base 17 | from . import TableMixin 18 | 19 | 20 | class Users(Base, TableMixin): 21 | __tablename__ = "users" 22 | status: Mapped[str] = mapped_column(Enum(UserStatus), default=UserStatus.active) 23 | email: Mapped[str] = mapped_column(String(length=50)) 24 | password: Mapped[str | None] = mapped_column(String(length=100)) 25 | name: Mapped[str | None] = mapped_column(String(length=20)) 26 | phone_number: Mapped[str | None] = mapped_column(String(length=20)) 27 | profile_img: Mapped[str | None] = mapped_column(String(length=100)) 28 | marketing_agree: Mapped[bool] = mapped_column(Boolean, default=True) 29 | api_keys: Mapped["ApiKeys"] = relationship( 30 | back_populates="users", cascade="all, delete-orphan", lazy=True 31 | ) 32 | # chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="users", cascade="all, delete-orphan", lazy=True) 33 | # chat_messages: Mapped["ChatMessages"] = relationship( 34 | # back_populates="users", cascade="all, delete-orphan", lazy=True 35 | # ) 36 | # gpt_presets: Mapped["GptPresets"] = relationship( 37 | # back_populates="users", cascade="all, delete-orphan", lazy=True, uselist=False 38 | # ) 39 | 40 | 41 | class ApiKeys(Base, TableMixin): 42 | __tablename__ = "api_keys" 43 | status: Mapped[str] = mapped_column(Enum(ApiKeyStatus), default=ApiKeyStatus.active) 44 | access_key: Mapped[str] = mapped_column(String(length=64), index=True, unique=True) 45 | secret_key: Mapped[str] = mapped_column(String(length=64)) 46 | user_memo: Mapped[str | None] = mapped_column(String(length=40)) 47 | is_whitelisted: Mapped[bool] = mapped_column(default=False) 48 | user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) 49 | users: Mapped["Users"] = relationship(back_populates="api_keys") 50 | whitelists: Mapped["ApiWhiteLists"] = relationship( 51 | backref="api_keys", cascade="all, delete-orphan" 52 | ) 53 | 54 | 55 | class ApiWhiteLists(Base, TableMixin): 56 | __tablename__ = "api_whitelists" 57 | api_key_id: Mapped[int] = mapped_column( 58 | Integer, ForeignKey("api_keys.id", ondelete="CASCADE") 59 | ) 60 | ip_address: Mapped[str] = mapped_column(String(length=64)) 61 | 62 | 63 | # class ChatRooms(Base, Mixin): 64 | # __tablename__ = "chat_rooms" 65 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True) 66 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active") 67 | # chat_room_type: Mapped[str] = mapped_column(String(length=20), index=True) 68 | # name: Mapped[str] = mapped_column(String(length=20)) 69 | # description: Mapped[str | None] = mapped_column(String(length=100)) 70 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) 71 | # users: Mapped["Users"] = relationship(back_populates="chat_rooms") 72 | # chat_messages: Mapped["ChatMessages"] = relationship(back_populates="chat_rooms", cascade="all, delete-orphan") 73 | 74 | 75 | # class ChatMessages(Base, Mixin): 76 | # __tablename__ = "chat_messages" 77 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True) 78 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active") 79 | # role: Mapped[str] = mapped_column(String(length=20), default="user") 80 | # message: Mapped[str] = mapped_column(Text) 81 | # chat_room_id: Mapped[int] = mapped_column(ForeignKey("chat_rooms.id", ondelete="CASCADE")) 82 | # chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="chat_messages") 83 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) 84 | # users: Mapped["Users"] = relationship(back_populates="chat_messages") 85 | 86 | 87 | # class GptPresets(Base, Mixin): 88 | # __tablename__ = "gpt_presets" 89 | # temperature: Mapped[float] = mapped_column(Float, default=0.9) 90 | # top_p: Mapped[float] = mapped_column(Float, default=1.0) 91 | # presence_penalty: Mapped[float] = mapped_column(Float, default=0) 92 | # frequency_penalty: Mapped[float] = mapped_column(Float, default=0) 93 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), unique=True) 94 | # users: Mapped["Users"] = relationship(back_populates="gpt_presets") 95 | -------------------------------------------------------------------------------- /app/database/schemas/deprecated_chatgpt.py: -------------------------------------------------------------------------------- 1 | # from sqlalchemy import ( 2 | # String, 3 | # Enum, 4 | # ForeignKey, 5 | # Float, 6 | # Text, 7 | # ) 8 | # from sqlalchemy.orm import ( 9 | # relationship, 10 | # Mapped, 11 | # mapped_column, 12 | # ) 13 | # from app.database.schemas.auth import Users 14 | # from .. import Base 15 | # from . import Mixin 16 | 17 | 18 | # class ChatRooms(Base, Mixin): 19 | # __tablename__ = "chat_rooms" 20 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True) 21 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active") 22 | # chat_room_type: Mapped[str] = mapped_column(String(length=20), index=True) 23 | # name: Mapped[str] = mapped_column(String(length=20)) 24 | # description: Mapped[str | None] = mapped_column(String(length=100)) 25 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) 26 | # users: Mapped["Users"] = relationship(back_populates="chat_rooms") 27 | # chat_messages: Mapped["ChatMessages"] = relationship(back_populates="chat_rooms", cascade="all, delete-orphan") 28 | 29 | 30 | # class ChatMessages(Base, Mixin): 31 | # __tablename__ = "chat_messages" 32 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True) 33 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active") 34 | # role: Mapped[str] = mapped_column(String(length=20), default="user") 35 | # message: Mapped[str] = mapped_column(Text) 36 | # chat_room_id: Mapped[int] = mapped_column(ForeignKey("chat_rooms.id")) 37 | # chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="chat_messages") 38 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) 39 | # users: Mapped["Users"] = relationship(back_populates="chat_messages") 40 | 41 | 42 | # class GptPresets(Base, Mixin): 43 | # __tablename__ = "gpt_presets" 44 | # temperature: Mapped[float] = mapped_column(Float, default=0.9) 45 | # top_p: Mapped[float] = mapped_column(Float, default=1.0) 46 | # presence_penalty: Mapped[float] = mapped_column(Float, default=0) 47 | # frequency_penalty: Mapped[float] = mapped_column(Float, default=0) 48 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True) 49 | # users: Mapped["Users"] = relationship(back_populates="gpt_presets") 50 | -------------------------------------------------------------------------------- /app/dependencies.py: -------------------------------------------------------------------------------- 1 | from fastapi import Header, Query 2 | from fastapi.security import APIKeyHeader 3 | 4 | 5 | def api_service_dependency(secret: str = Header(...), key: str = Query(...), timestamp: str = Query(...)): 6 | ... # do some validation or processing with the headers 7 | 8 | 9 | USER_DEPENDENCY = APIKeyHeader(name="Authorization", auto_error=False) 10 | -------------------------------------------------------------------------------- /app/errors/chat_exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class ChatException(Exception): # Base exception for chat 5 | def __init__(self, *, msg: str | None = None) -> None: 6 | self.msg = msg 7 | super().__init__() 8 | 9 | 10 | class ChatConnectionException(ChatException): 11 | def __init__(self, *, msg: str | None = None) -> None: 12 | self.msg = msg 13 | super().__init__(msg=msg) 14 | 15 | 16 | class ChatLengthException(ChatException): 17 | def __init__(self, *, msg: str | None = None) -> None: 18 | self.msg = msg 19 | super().__init__(msg=msg) 20 | 21 | 22 | class ChatContentFilterException(ChatException): 23 | def __init__(self, *, msg: str | None = None) -> None: 24 | self.msg = msg 25 | super().__init__(msg=msg) 26 | 27 | 28 | class ChatTooMuchTokenException(ChatException): 29 | def __init__(self, *, msg: str | None = None) -> None: 30 | self.msg = msg 31 | super().__init__(msg=msg) 32 | 33 | 34 | class ChatTextGenerationException(ChatException): 35 | def __init__(self, *, msg: str | None = None) -> None: 36 | self.msg = msg 37 | super().__init__(msg=msg) 38 | 39 | 40 | class ChatOtherException(ChatException): 41 | def __init__(self, *, msg: str | None = None) -> None: 42 | self.msg = msg 43 | super().__init__(msg=msg) 44 | 45 | 46 | class ChatModelNotImplementedException(ChatException): 47 | def __init__(self, *, msg: str | None = None) -> None: 48 | self.msg = msg 49 | super().__init__(msg=msg) 50 | 51 | 52 | class ChatBreakException(ChatException): 53 | def __init__(self, *, msg: str | None = None) -> None: 54 | self.msg = msg 55 | super().__init__(msg=msg) 56 | 57 | 58 | class ChatContinueException(ChatException): 59 | def __init__(self, *, msg: str | None = None) -> None: 60 | self.msg = msg 61 | super().__init__(msg=msg) 62 | 63 | 64 | class ChatInterruptedException(ChatException): 65 | def __init__(self, *, msg: str | None = None) -> None: 66 | self.msg = msg 67 | super().__init__(msg=msg) 68 | 69 | 70 | class ChatFunctionCallException(ChatException): 71 | """Raised when function is called.""" 72 | 73 | def __init__(self, *, func_name: str, func_kwargs: dict[str, Any]) -> None: 74 | self.func_name = func_name 75 | self.func_kwargs = func_kwargs 76 | super().__init__(msg=f"Function {func_name}({func_kwargs}) is called.") 77 | -------------------------------------------------------------------------------- /app/middlewares/trusted_hosts.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from starlette.datastructures import URL, Headers 4 | from starlette.responses import PlainTextResponse, RedirectResponse 5 | from starlette.types import ASGIApp, Receive, Scope, Send 6 | 7 | from app.errors.api_exceptions import Responses_500 8 | 9 | 10 | class TrustedHostMiddleware: 11 | def __init__( 12 | self, 13 | app: ASGIApp, 14 | allowed_hosts: Sequence[str] | None = None, 15 | except_path: Sequence[str] | None = None, 16 | www_redirect: bool = True, 17 | ): 18 | self.app: ASGIApp = app 19 | self.allowed_hosts: list = ( 20 | ["*"] if allowed_hosts is None else list(allowed_hosts) 21 | ) 22 | self.allow_any: bool = ( 23 | "*" in allowed_hosts if allowed_hosts is not None else True 24 | ) 25 | self.www_redirect: bool = www_redirect 26 | self.except_path: list = ( 27 | [] if except_path is None else list(except_path) 28 | ) 29 | if allowed_hosts is not None: 30 | for allowed_host in allowed_hosts: 31 | if "*" in allowed_host[1:]: 32 | raise Responses_500.middleware_exception 33 | if ( 34 | allowed_host.startswith("*") and allowed_host != "*" 35 | ) and not allowed_host.startswith("*."): 36 | raise Responses_500.middleware_exception 37 | 38 | async def __call__( 39 | self, scope: Scope, receive: Receive, send: Send 40 | ) -> None: 41 | if self.allow_any or scope["type"] not in ( 42 | "http", 43 | "websocket", 44 | ): # pragma: no cover 45 | await self.app(scope, receive, send) 46 | return 47 | 48 | headers = Headers(scope=scope) 49 | host = headers.get("host", "").split(":")[0] 50 | is_valid_host = False 51 | found_www_redirect = False 52 | for pattern in self.allowed_hosts: 53 | if ( 54 | host == pattern 55 | or pattern.startswith("*") 56 | and host.endswith(pattern[1:]) 57 | or URL(scope=scope).path in self.except_path 58 | ): 59 | is_valid_host = True 60 | break 61 | elif "www." + host == pattern: 62 | found_www_redirect = True 63 | 64 | if is_valid_host: 65 | await self.app(scope, receive, send) 66 | else: 67 | if found_www_redirect and self.www_redirect: 68 | url = URL(scope=scope) 69 | redirect_url = url.replace(netloc="www." + url.netloc) 70 | response = RedirectResponse(url=str(redirect_url)) 71 | else: 72 | response = PlainTextResponse( 73 | "Invalid host header", status_code=400 74 | ) 75 | await response(scope, receive, send) 76 | -------------------------------------------------------------------------------- /app/models/chat_commands.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from app.utils.chat.commands.browsing import BrowsingCommands 4 | from app.utils.chat.commands.core import CoreCommands 5 | from app.utils.chat.commands.llm_parameter import LLMParameterCommands 6 | from app.utils.chat.commands.prompt import PromptCommands 7 | from app.utils.chat.commands.server import ServerCommands 8 | from app.utils.chat.commands.summarize import SummarizeCommands 9 | from app.utils.chat.commands.testing import TestingCommands 10 | from app.utils.chat.commands.vectorstore import VectorstoreCommands 11 | 12 | from .chat_models import command_response 13 | 14 | 15 | class ChatCommandsMetaClass(type): 16 | """Metaclass for ChatCommands class. 17 | It is used to automatically create list of special commands. 18 | """ 19 | 20 | special_commands: list[str] 21 | 22 | def __init__(cls, name, bases, dct): 23 | super().__init__(name, bases, dct) 24 | cls.special_commands = [ 25 | callback_name 26 | for callback_name in dir(CoreCommands) 27 | if not callback_name.startswith("_") 28 | ] 29 | 30 | 31 | class ChatCommands( 32 | CoreCommands, 33 | VectorstoreCommands, 34 | PromptCommands, 35 | BrowsingCommands, 36 | LLMParameterCommands, 37 | ServerCommands, 38 | TestingCommands, 39 | SummarizeCommands, 40 | metaclass=ChatCommandsMetaClass, 41 | ): 42 | @classmethod 43 | def find_callback_with_command(cls, command: str) -> Callable: 44 | found_callback: Callable = getattr(cls, command) 45 | if found_callback is None or not callable(found_callback): 46 | found_callback = cls.not_existing_callback 47 | return found_callback 48 | 49 | @staticmethod 50 | @command_response.send_message_and_stop 51 | def not_existing_callback() -> str: # callback for not existing command 52 | return "Sorry, I don't know what you mean by..." 53 | -------------------------------------------------------------------------------- /app/models/completion_models.py: -------------------------------------------------------------------------------- 1 | from sys import version_info 2 | from typing import Dict, List, Optional, Literal, TypedDict 3 | 4 | # If python version >= 3.11, use the built-in NotRequired type. 5 | # Otherwise, import it from typing_extensi 6 | if version_info >= (3, 11): 7 | from typing import NotRequired # type: ignore 8 | else: 9 | from typing_extensions import NotRequired 10 | 11 | from .function_calling.base import JsonTypes 12 | 13 | 14 | class FunctionCallParsed(TypedDict): 15 | name: str 16 | arguments: NotRequired[Dict[str, JsonTypes]] 17 | 18 | 19 | class FunctionCallUnparsed(TypedDict): 20 | name: NotRequired[str] 21 | arguments: NotRequired[str] 22 | 23 | 24 | class EmbeddingUsage(TypedDict): 25 | prompt_tokens: int 26 | total_tokens: int 27 | 28 | 29 | class EmbeddingData(TypedDict): 30 | index: int 31 | object: str 32 | embedding: List[float] 33 | 34 | 35 | class Embedding(TypedDict): 36 | object: Literal["list"] 37 | model: str 38 | data: List[EmbeddingData] 39 | usage: EmbeddingUsage 40 | 41 | 42 | class CompletionLogprobs(TypedDict): 43 | text_offset: List[int] 44 | token_logprobs: List[Optional[float]] 45 | tokens: List[str] 46 | top_logprobs: List[Optional[Dict[str, float]]] 47 | 48 | 49 | class CompletionChoice(TypedDict): 50 | text: str 51 | index: int 52 | logprobs: Optional[CompletionLogprobs] 53 | finish_reason: Optional[str] 54 | 55 | 56 | class CompletionUsage(TypedDict): 57 | prompt_tokens: int 58 | completion_tokens: int 59 | total_tokens: int 60 | 61 | 62 | class CompletionChunk(TypedDict): 63 | id: str 64 | object: Literal["text_completion"] 65 | created: int 66 | model: str 67 | choices: List[CompletionChoice] 68 | 69 | 70 | class Completion(TypedDict): 71 | id: str 72 | object: Literal["text_completion"] 73 | created: int 74 | model: str 75 | choices: List[CompletionChoice] 76 | usage: CompletionUsage 77 | 78 | 79 | class ChatCompletionMessage(TypedDict): 80 | role: Literal["assistant", "user", "system"] 81 | content: str 82 | user: NotRequired[str] 83 | function_call: NotRequired[FunctionCallUnparsed] 84 | 85 | 86 | class ChatCompletionChoice(TypedDict): 87 | index: int 88 | message: ChatCompletionMessage 89 | finish_reason: Optional[str] 90 | 91 | 92 | class ChatCompletion(TypedDict): 93 | id: str 94 | object: Literal["chat.completion"] 95 | created: int 96 | model: str 97 | choices: List[ChatCompletionChoice] 98 | usage: CompletionUsage 99 | 100 | 101 | class ChatCompletionChunkDelta(TypedDict): 102 | role: NotRequired[Literal["assistant"]] 103 | content: NotRequired[str] 104 | function_call: NotRequired[FunctionCallUnparsed] 105 | 106 | 107 | class ChatCompletionChunkChoice(TypedDict): 108 | index: int 109 | delta: ChatCompletionChunkDelta 110 | finish_reason: Optional[str] 111 | 112 | 113 | class ChatCompletionChunk(TypedDict): 114 | id: str 115 | model: str 116 | object: Literal["chat.completion.chunk"] 117 | created: int 118 | choices: List[ChatCompletionChunkChoice] 119 | 120 | 121 | class ModelData(TypedDict): 122 | id: str 123 | object: Literal["model"] 124 | owned_by: str 125 | permissions: List[str] 126 | 127 | 128 | class ModelList(TypedDict): 129 | object: Literal["list"] 130 | data: List[ModelData] 131 | -------------------------------------------------------------------------------- /app/models/function_calling/base.py: -------------------------------------------------------------------------------- 1 | """Helper classes for wrapping functions in OpenAI's API""" 2 | 3 | from dataclasses import dataclass 4 | from sys import version_info 5 | from types import NoneType 6 | from typing import ( 7 | Any, 8 | Generic, 9 | Literal, 10 | Optional, 11 | Type, 12 | TypeVar, 13 | TypedDict, 14 | Union, 15 | ) 16 | 17 | # If python version >= 3.11, use the built-in NotRequired type. 18 | # Otherwise, import it from typing_extensi 19 | if version_info >= (3, 11): 20 | from typing import NotRequired # type: ignore 21 | else: 22 | from typing_extensions import NotRequired 23 | 24 | # The types that can be used in JSON 25 | JsonTypes = Union[int, float, str, bool, dict, list, None] 26 | 27 | ParamType = TypeVar("ParamType", bound=JsonTypes) 28 | ReturnType = TypeVar("ReturnType") 29 | 30 | 31 | class ParameterProperty(TypedDict): 32 | type: str 33 | description: NotRequired[str] 34 | enum: NotRequired[list[JsonTypes]] 35 | 36 | 37 | class ParameterDefinition(TypedDict): 38 | type: Literal["object"] 39 | properties: dict[str, ParameterProperty] 40 | required: NotRequired[list[str]] 41 | 42 | 43 | class FunctionProperty(TypedDict): 44 | name: str 45 | description: NotRequired[str] 46 | parameters: NotRequired[ParameterDefinition] 47 | 48 | 49 | @dataclass 50 | class FunctionCallParameter(Generic[ParamType]): 51 | """A class for wrapping function parameters in OpenAI's API""" 52 | 53 | name: str 54 | type: Type[ParamType] 55 | description: Optional[str] = None 56 | enum: Optional[list[ParamType]] = None 57 | 58 | def to_dict(self) -> dict[str, ParameterProperty]: 59 | """Returns a dictionary representation of the parameter""" 60 | parameter_property: ParameterProperty = { 61 | "type": self._get_json_type(self.type) 62 | } # type: ignore 63 | if self.description: 64 | parameter_property["description"] = self.description 65 | if self.enum: 66 | parameter_property["enum"] = self.enum # type: ignore 67 | return {self.name: parameter_property} 68 | 69 | @staticmethod 70 | def _get_json_type(python_type: Type[JsonTypes]) -> str: 71 | """Returns the JSON type for a given python type""" 72 | if python_type is int: 73 | return "integer" 74 | elif python_type is float: 75 | return "number" 76 | elif python_type is str: 77 | return "string" 78 | elif python_type is bool: 79 | return "boolean" 80 | elif python_type is dict: 81 | return "object" 82 | elif python_type is list: 83 | return "array" 84 | elif python_type is NoneType or python_type is None: 85 | return "null" 86 | else: 87 | raise ValueError( 88 | f"Invalid type {python_type} for JSON. " 89 | f"Permitted types are {JsonTypes}" 90 | ) 91 | 92 | 93 | @dataclass 94 | class FunctionCall: 95 | """A class for wrapping functions in OpenAI's API""" 96 | 97 | name: str 98 | description: Optional[str] = None 99 | parameters: Optional[list[FunctionCallParameter[Any]]] = None 100 | required: Optional[list[str]] = None 101 | 102 | def to_dict(self) -> FunctionProperty: 103 | """Returns a dictionary representation of the function""" 104 | function_property: FunctionProperty = FunctionProperty(name=self.name) # type: ignore 105 | if self.description: 106 | function_property["description"] = self.description 107 | if self.parameters: 108 | function_property["parameters"] = { 109 | "type": "object", 110 | "properties": { 111 | param.name: param.to_dict()[param.name] 112 | for param in self.parameters 113 | }, 114 | "required": [ 115 | param.name 116 | for param in self.parameters 117 | if param.name in (self.required or []) 118 | ], 119 | } 120 | return function_property 121 | 122 | 123 | if __name__ == "__main__": 124 | 125 | def test_callback(test: str) -> int: 126 | return len(test) 127 | 128 | param_1: FunctionCallParameter[str] = FunctionCallParameter( 129 | name="param_1", 130 | type=str, 131 | description="This is a test1", 132 | enum=["a", "b", "c", "d"], 133 | ) # `str` is the type of the parameter 134 | 135 | param_2: FunctionCallParameter[bool] = FunctionCallParameter( 136 | name="param_2", 137 | type=bool, 138 | description="This is a test2", 139 | enum=[True, False], 140 | ) # `bool` is the type of the parameter 141 | 142 | func: FunctionCall = FunctionCall( 143 | name="test_function", 144 | description="This is a test function", 145 | parameters=[param_1, param_2], 146 | required=["test"], 147 | ) # There's no type for the function 148 | -------------------------------------------------------------------------------- /app/routers/__init__.py: -------------------------------------------------------------------------------- 1 | from app.routers import auth, index, services, user_services, users, websocket 2 | 3 | __all__ = [ 4 | "auth", 5 | "index", 6 | "services", 7 | "user_services", 8 | "users", 9 | "websocket", 10 | ] 11 | -------------------------------------------------------------------------------- /app/routers/auth.py: -------------------------------------------------------------------------------- 1 | import bcrypt 2 | from fastapi import APIRouter, Response, Security, status 3 | from fastapi.requests import Request 4 | 5 | from app.common.config import TOKEN_EXPIRE_HOURS 6 | from app.database.crud.users import is_email_exist, register_new_user 7 | from app.database.schemas.auth import Users 8 | from app.dependencies import USER_DEPENDENCY 9 | from app.errors.api_exceptions import Responses_400, Responses_404 10 | from app.models.base_models import SnsType, Token, UserRegister, UserToken 11 | from app.utils.auth.register_validation import ( 12 | is_email_length_in_range, 13 | is_email_valid_format, 14 | is_password_length_in_range, 15 | ) 16 | from app.utils.auth.token import create_access_token, token_decode 17 | from app.utils.chat.managers.cache import CacheManager 18 | 19 | router = APIRouter(prefix="/auth") 20 | 21 | 22 | @router.post("/register/{sns_type}", status_code=201, response_model=Token) 23 | async def register( 24 | request: Request, 25 | response: Response, 26 | sns_type: SnsType, 27 | reg_info: UserRegister, 28 | ) -> Token: 29 | if sns_type == SnsType.EMAIL: 30 | if not (reg_info.email and reg_info.password): 31 | raise Responses_400.no_email_or_password 32 | 33 | if is_email_length_in_range(email=reg_info.email) is False: 34 | raise Responses_400.email_length_not_in_range 35 | 36 | if is_password_length_in_range(password=reg_info.password) is False: 37 | raise Responses_400.password_length_not_in_range 38 | 39 | if is_email_valid_format(email=reg_info.email) is False: 40 | raise Responses_400.invalid_email_format 41 | 42 | if await is_email_exist(email=reg_info.email): 43 | raise Responses_400.email_already_exists 44 | 45 | hashed_password: str = bcrypt.hashpw( 46 | password=reg_info.password.encode("utf-8"), 47 | salt=bcrypt.gensalt(), 48 | ).decode("utf-8") 49 | new_user: Users = await register_new_user( 50 | email=reg_info.email, 51 | hashed_password=hashed_password, 52 | ip_address=request.state.ip, 53 | ) 54 | data_to_be_tokenized: dict = UserToken.from_orm(new_user).dict( 55 | exclude={"password", "marketing_agree"} 56 | ) 57 | token: str = create_access_token( 58 | data=data_to_be_tokenized, expires_delta=TOKEN_EXPIRE_HOURS 59 | ) 60 | response.set_cookie( 61 | key="Authorization", 62 | value=f"Bearer {token}", 63 | max_age=TOKEN_EXPIRE_HOURS * 3600, 64 | secure=True, 65 | httponly=True, 66 | ) 67 | return Token(Authorization=f"Bearer {token}") 68 | raise Responses_400.not_supported_feature 69 | 70 | 71 | @router.delete("/register", status_code=status.HTTP_204_NO_CONTENT) 72 | async def unregister( 73 | authorization: str = Security(USER_DEPENDENCY), 74 | ): 75 | registered_user: UserToken = UserToken(**token_decode(authorization)) 76 | if registered_user.email is not None: 77 | await CacheManager.delete_user(user_id=registered_user.email) 78 | await Users.delete_filtered( 79 | Users.email == registered_user.email, autocommit=True 80 | ) 81 | 82 | 83 | @router.post("/login/{sns_type}", status_code=200, response_model=Token) 84 | async def login( 85 | response: Response, 86 | sns_type: SnsType, 87 | user_info: UserRegister, 88 | ) -> Token: 89 | if sns_type == SnsType.EMAIL: 90 | if not (user_info.email and user_info.password): 91 | raise Responses_400.no_email_or_password 92 | matched_user: Users = await Users.first_filtered_by( 93 | email=user_info.email 94 | ) 95 | if matched_user is None or matched_user.password is None: 96 | raise Responses_404.not_found_user 97 | if not bcrypt.checkpw( 98 | password=user_info.password.encode("utf-8"), 99 | hashed_password=matched_user.password.encode("utf-8"), 100 | ): 101 | raise Responses_404.not_found_user 102 | data_to_be_tokenized: dict = UserToken.from_orm(matched_user).dict( 103 | exclude={"password", "marketing_agree"} 104 | ) 105 | token: str = create_access_token( 106 | data=data_to_be_tokenized, expires_delta=TOKEN_EXPIRE_HOURS 107 | ) 108 | response.set_cookie( 109 | key="Authorization", 110 | value=f"Bearer {token}", 111 | max_age=TOKEN_EXPIRE_HOURS * 3600, 112 | secure=True, 113 | httponly=True, 114 | ) 115 | return Token(Authorization=f"Bearer {token}") 116 | else: 117 | raise Responses_400.not_supported_feature 118 | -------------------------------------------------------------------------------- /app/routers/index.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Request 2 | from fastapi.responses import FileResponse 3 | 4 | router = APIRouter() 5 | 6 | 7 | @router.get("/") 8 | async def index(): 9 | return FileResponse("app/web/index.html") 10 | 11 | 12 | @router.get("/favicon.ico", include_in_schema=False) 13 | async def favicon(): 14 | return FileResponse("app/contents/favicon.ico") 15 | 16 | 17 | @router.get("/test") 18 | async def test(request: Request): 19 | return {"username": request.session.get("username", None)} 20 | -------------------------------------------------------------------------------- /app/routers/user_services.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import APIRouter 4 | 5 | from app.models.base_models import MessageOk 6 | 7 | router = APIRouter(prefix="/user-services") 8 | 9 | 10 | @router.get("/", status_code=200) 11 | def test(query: Optional[str] = None): 12 | return MessageOk(message="Hello, World!" if query is None else query) 13 | -------------------------------------------------------------------------------- /app/routers/users.py: -------------------------------------------------------------------------------- 1 | import ipaddress 2 | 3 | from fastapi import APIRouter 4 | from starlette.requests import Request 5 | 6 | from app.database.crud import api_keys, api_whitelists, users 7 | from app.errors.api_exceptions import Responses_400 8 | from app.models.base_models import ( 9 | AddApiKey, 10 | CreateApiWhiteList, 11 | GetApiKey, 12 | GetApiKeyFirstTime, 13 | GetApiWhiteList, 14 | MessageOk, 15 | UserMe, 16 | ) 17 | 18 | router = APIRouter(prefix="/user") 19 | 20 | 21 | @router.get("/me", response_model=UserMe) 22 | async def get_me(request: Request): 23 | return await users.get_me(user_id=request.state.user.id) 24 | 25 | 26 | @router.put("/me") 27 | async def put_me(request: Request): 28 | ... 29 | 30 | 31 | @router.delete("/me") 32 | async def delete_me(request: Request): 33 | ... 34 | 35 | 36 | @router.get("/apikeys", response_model=list[GetApiKey]) 37 | async def get_api_keys(request: Request): 38 | return await api_keys.get_api_keys(user_id=request.state.user.id) 39 | 40 | 41 | @router.post("/apikeys", response_model=GetApiKeyFirstTime, status_code=201) 42 | async def create_api_key(request: Request, api_key_info: AddApiKey): 43 | return await api_keys.create_api_key( 44 | user_id=request.state.user.id, additional_key_info=api_key_info 45 | ) 46 | 47 | 48 | @router.put("/apikeys/{key_id}", response_model=GetApiKey) 49 | async def update_api_key( 50 | request: Request, 51 | api_key_id: int, 52 | api_key_info: AddApiKey, 53 | ): 54 | """ 55 | API KEY User Memo Update 56 | """ 57 | return await api_keys.update_api_key( 58 | updated_key_info=api_key_info.dict(), 59 | access_key_id=api_key_id, 60 | user_id=request.state.user.id, 61 | ) 62 | 63 | 64 | @router.delete("/apikeys/{key_id}") 65 | async def delete_api_key( 66 | request: Request, 67 | key_id: int, 68 | access_key: str, 69 | ): 70 | await api_keys.delete_api_key( 71 | access_key_id=key_id, 72 | access_key=access_key, 73 | user_id=request.state.user.id, 74 | ) 75 | return MessageOk() 76 | 77 | 78 | @router.get( 79 | "/apikeys/{key_id}/whitelists", response_model=list[GetApiWhiteList] 80 | ) 81 | async def get_api_keys_whitelist(api_key_id: int): 82 | return await api_whitelists.get_api_key_whitelist(api_key_id=api_key_id) 83 | 84 | 85 | @router.post("/apikeys/{key_id}/whitelists", response_model=GetApiWhiteList) 86 | async def create_api_keys_whitelist( 87 | api_key_id: int, 88 | ip: CreateApiWhiteList, 89 | ): 90 | ip_address: str = ip.ip_address 91 | try: 92 | ipaddress.ip_address(ip_address) 93 | except Exception as exception: 94 | raise Responses_400.invalid_ip( 95 | lazy_format={"ip": ip_address}, ex=exception 96 | ) 97 | return await api_whitelists.create_api_key_whitelist( 98 | ip_address=ip_address, api_key_id=api_key_id 99 | ) 100 | 101 | 102 | @router.delete("/apikeys/{key_id}/whitelists/{list_id}") 103 | async def delete_api_keys_whitelist( 104 | request: Request, 105 | api_key_id: int, 106 | whitelist_id: int, 107 | ): 108 | await api_whitelists.delete_api_key_whitelist( 109 | user_id=request.state.user.id, 110 | api_key_id=api_key_id, 111 | whitelist_id=whitelist_id, 112 | ) 113 | return MessageOk() 114 | -------------------------------------------------------------------------------- /app/routers/websocket.py: -------------------------------------------------------------------------------- 1 | from asyncio import sleep 2 | 3 | from fastapi import APIRouter, WebSocket 4 | 5 | from app.common.config import API_ENV, HOST_MAIN, OPENAI_API_KEY 6 | from app.database.crud import api_keys 7 | from app.database.schemas.auth import Users 8 | from app.errors.api_exceptions import Responses_400, Responses_401 9 | from app.utils.chat.managers.stream import ChatStreamManager 10 | from app.utils.chat.managers.websocket import SendToWebsocket 11 | from app.viewmodels.status import ApiKeyStatus, UserStatus 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.websocket("/chat/{api_key}") 17 | async def ws_chat(websocket: WebSocket, api_key: str): 18 | if OPENAI_API_KEY is None: 19 | raise Responses_400.not_supported_feature 20 | await websocket.accept() # accept websocket 21 | if api_key != OPENAI_API_KEY and not API_ENV == "test": 22 | _api_key, _user = await api_keys.get_api_key_and_owner( 23 | access_key=api_key 24 | ) 25 | if _user.status not in (UserStatus.active, UserStatus.admin): 26 | await SendToWebsocket.message( 27 | websocket=websocket, 28 | msg="Your account is not active", 29 | chat_room_id=" ", 30 | ) 31 | await sleep(60) 32 | raise Responses_401.not_authorized 33 | if _api_key.status is not ApiKeyStatus.active: 34 | await SendToWebsocket.message( 35 | websocket=websocket, 36 | msg="Your api key is not active", 37 | chat_room_id=" ", 38 | ) 39 | await sleep(60) 40 | raise Responses_401.not_authorized 41 | else: 42 | _user = Users(email=f"testaccount@{HOST_MAIN}") 43 | await ChatStreamManager.begin_chat( 44 | websocket=websocket, 45 | user=_user, 46 | ) 47 | -------------------------------------------------------------------------------- /app/utils/api/weather.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import orjson 3 | from typing import Any, Literal 4 | from fastapi import HTTPException 5 | 6 | 7 | async def fetch_weather_data( 8 | lat: float, 9 | lon: float, 10 | api_key: str, 11 | source: Literal["openweathermap", "weatherbit", "climacell"], 12 | ) -> Any: 13 | base_url = { 14 | "openweathermap": "https://api.openweathermap.org/data/2.5/weather", 15 | "weatherbit": "https://api.weatherbit.io/v2.0/current", 16 | "climacell": "https://api.climacell.co/v3/weather/realtime", 17 | }[source] 18 | query_params = { 19 | "openweathermap": f"lat={lat}&lon={lon}&appid={api_key}", 20 | "weatherbit": f"lat={lat}&lon={lon}&key={api_key}", 21 | "climacell": f"lat={lat}&lon={lon}&unit_system=metric&apikey={api_key}", 22 | }[source] 23 | 24 | async with httpx.AsyncClient() as client: 25 | response = await client.get(base_url + "?" + query_params) 26 | if response.status_code == 200: 27 | weather_data = orjson.loads(response.content) 28 | print("weather_data:", weather_data) 29 | return weather_data 30 | else: 31 | raise HTTPException( 32 | status_code=response.status_code, 33 | detail=f"Error fetching data from {source}", 34 | ) 35 | 36 | 37 | def get_temperature( 38 | weather_data: dict, 39 | source: Literal["openweathermap", "weatherbit", "climacell"], 40 | ): 41 | if source == "openweathermap": 42 | temp = weather_data["main"]["temp"] - 273.15 # Convert from Kelvin to Celsius 43 | elif source == "weatherbit": 44 | temp = weather_data["data"][0]["temp"] 45 | elif source == "climacell": 46 | temp = weather_data["temp"]["value"] 47 | else: 48 | temp = None 49 | 50 | return temp 51 | -------------------------------------------------------------------------------- /app/utils/auth/api_keys.py: -------------------------------------------------------------------------------- 1 | from secrets import choice 2 | from string import ascii_letters, digits 3 | from uuid import uuid4 4 | 5 | from app.database.schemas.auth import ApiKeys 6 | from app.models.base_models import AddApiKey 7 | 8 | 9 | def generate_new_api_key( 10 | user_id: int, additional_key_info: AddApiKey 11 | ) -> ApiKeys: 12 | alnums = ascii_letters + digits 13 | secret_key = "".join(choice(alnums) for _ in range(40)) 14 | uid = f"{str(uuid4())[:-12]}{str(uuid4())}" 15 | new_api_key = ApiKeys( 16 | secret_key=secret_key, 17 | user_id=user_id, 18 | access_key=uid, 19 | **additional_key_info.dict(), 20 | ) 21 | return new_api_key 22 | -------------------------------------------------------------------------------- /app/utils/auth/register_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | # Make a regular expression 4 | # for validating an Email 5 | EMAIL_REGEX: re.Pattern = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{1,7}\b") 6 | 7 | 8 | # Define a function for 9 | # for validating an Email 10 | def is_email_valid_format(email: str) -> bool: 11 | return True if EMAIL_REGEX.fullmatch(email) is not None else False 12 | 13 | 14 | def is_email_length_in_range(email: str) -> bool: 15 | return True if 6 <= len(email) <= 50 else False 16 | 17 | 18 | def is_password_length_in_range(password: str) -> bool: 19 | return True if 6 <= len(password) <= 100 else False 20 | 21 | 22 | # Driver Code 23 | if __name__ == "__main__": 24 | for email in ("ankitrai326@gmail.com", "my.ownsite@our-earth.org", "ankitrai326.com", "aa@a.a"): 25 | # calling run function 26 | print(is_email_valid_format(email)) 27 | -------------------------------------------------------------------------------- /app/utils/auth/token.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | from jwt import decode as jwt_decode 4 | from jwt import encode as jwt_encode 5 | from jwt.exceptions import DecodeError, ExpiredSignatureError 6 | 7 | from app.common.config import JWT_ALGORITHM, JWT_SECRET 8 | from app.errors.api_exceptions import Responses_401 9 | 10 | 11 | def create_access_token( 12 | *, data: dict, expires_delta: int | None = None 13 | ) -> str: 14 | to_encode: dict = data.copy() 15 | if expires_delta is not None and expires_delta != 0: 16 | to_encode.update( 17 | {"exp": datetime.utcnow() + timedelta(hours=expires_delta)} 18 | ) 19 | return jwt_encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) 20 | 21 | 22 | def token_decode(authorization: str) -> dict: 23 | if authorization is None: 24 | raise Responses_401.token_decode_failure 25 | try: 26 | authorization = authorization.replace("Bearer ", "") 27 | payload = jwt_decode( 28 | authorization, key=JWT_SECRET, algorithms=[JWT_ALGORITHM] 29 | ) 30 | except ExpiredSignatureError: 31 | raise Responses_401.token_expired 32 | except DecodeError: 33 | raise Responses_401.token_decode_failure 34 | return payload 35 | -------------------------------------------------------------------------------- /app/utils/chat/chat_rooms.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | from app.models.chat_models import UserChatContext 4 | 5 | from .buffer import BufferedUserContext 6 | from .managers.cache import CacheManager 7 | 8 | 9 | async def create_new_chat_room( 10 | user_id: str, 11 | new_chat_room_id: str | None = None, 12 | buffer: BufferedUserContext | None = None, 13 | ) -> UserChatContext: 14 | if buffer is not None: 15 | default: UserChatContext = UserChatContext.construct_default( 16 | user_id=user_id, 17 | chat_room_id=new_chat_room_id if new_chat_room_id else uuid4().hex, 18 | llm_model=buffer.current_llm_model, 19 | ) 20 | else: 21 | default: UserChatContext = UserChatContext.construct_default( 22 | user_id=user_id, 23 | chat_room_id=new_chat_room_id if new_chat_room_id else uuid4().hex, 24 | ) 25 | await CacheManager.create_context(user_chat_context=default) 26 | if buffer is not None: 27 | buffer.insert_context(user_chat_context=default) 28 | await buffer.change_context_to(index=0) 29 | return default 30 | 31 | 32 | async def delete_chat_room( 33 | chat_room_id_to_delete: str, 34 | buffer: BufferedUserContext, 35 | ) -> bool: 36 | await CacheManager.delete_chat_room( 37 | user_id=buffer.user_id, chat_room_id=chat_room_id_to_delete 38 | ) 39 | index: int | None = buffer.find_index_of_chatroom( 40 | chat_room_id=chat_room_id_to_delete 41 | ) 42 | if index is None: 43 | return False 44 | buffer.delete_context(index=index) 45 | if not buffer: 46 | await create_new_chat_room( 47 | user_id=buffer.user_id, 48 | buffer=buffer, 49 | ) 50 | if buffer.current_chat_room_id == chat_room_id_to_delete: 51 | await buffer.change_context_to(index=0) 52 | return True 53 | -------------------------------------------------------------------------------- /app/utils/chat/commands/browsing.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from app.models.chat_models import ResponseType 4 | from app.models.function_calling.functions import FunctionCalls 5 | from app.models.llms import OpenAIModel 6 | from app.utils.chat.buffer import BufferedUserContext 7 | from app.utils.chat.messages.handler import MessageHandler 8 | from app.utils.function_calling.query import aget_query_to_search 9 | 10 | 11 | class BrowsingCommands: 12 | @staticmethod 13 | async def browse( 14 | user_query: str, /, buffer: BufferedUserContext, **kwargs 15 | ) -> tuple[Optional[str], ResponseType]: 16 | """Query LLM with duckduckgo browse results\n 17 | /browse """ 18 | if user_query.startswith("/"): 19 | # User is trying to invoke another command. 20 | # Give control back to the command handler, 21 | # and let it handle the command. 22 | # e.g. `/browse /help` will invoke `/help` command 23 | return user_query, ResponseType.REPEAT_COMMAND 24 | 25 | # Save user query to buffer and database 26 | await MessageHandler.user(msg=user_query, buffer=buffer) 27 | 28 | if not isinstance(buffer.current_llm_model.value, OpenAIModel): 29 | # Non-OpenAI models can't invoke function call, 30 | # so we force function calling here 31 | query_to_search: str = await aget_query_to_search( 32 | buffer=buffer, 33 | query=user_query, 34 | function=FunctionCalls.get_function_call( 35 | FunctionCalls.web_search 36 | ), 37 | ) 38 | await MessageHandler.function_call( 39 | callback_name=FunctionCalls.web_search.__name__, 40 | callback_kwargs={"query_to_search": query_to_search}, 41 | buffer=buffer, 42 | ) 43 | else: 44 | # OpenAI models can invoke function call, 45 | # so let the AI decide whether to invoke function call 46 | function = FunctionCalls.get_function_call( 47 | FunctionCalls.web_search 48 | ) 49 | buffer.optional_info["functions"] = [function] 50 | buffer.optional_info["function_call"] = function 51 | await MessageHandler.ai(buffer=buffer) 52 | 53 | # End of command 54 | return None, ResponseType.DO_NOTHING 55 | -------------------------------------------------------------------------------- /app/utils/chat/commands/llm_parameter.py: -------------------------------------------------------------------------------- 1 | from app.models.chat_models import UserChatContext, command_response 2 | from app.utils.chat.managers.cache import CacheManager 3 | 4 | 5 | class LLMParameterCommands: 6 | @staticmethod 7 | @command_response.send_message_and_stop 8 | async def temp( 9 | temp_to_change: float, user_chat_context: UserChatContext 10 | ) -> str: # set temperature of ai 11 | """Set temperature of ai\n 12 | /temp """ 13 | try: 14 | assert ( 15 | 0 <= temp_to_change <= 2 16 | ) # assert temperature is between 0 and 2 17 | except AssertionError: # if temperature is not between 0 and 2 18 | return "Temperature must be between 0 and 2" 19 | else: 20 | previous_temperature: str = str( 21 | user_chat_context.user_chat_profile.temperature 22 | ) 23 | user_chat_context.user_chat_profile.temperature = temp_to_change 24 | await CacheManager.update_profile_and_model( 25 | user_chat_context 26 | ) # update user_chat_context 27 | return f"I've changed temperature from {previous_temperature} to {temp_to_change}." # return success msg 28 | 29 | @staticmethod 30 | @command_response.send_message_and_stop 31 | async def topp( 32 | top_p_to_change: float, user_chat_context: UserChatContext 33 | ) -> str: # set top_p of ai 34 | """Set top_p of ai\n 35 | /topp """ 36 | try: 37 | assert 0 <= top_p_to_change <= 1 # assert top_p is between 0 and 1 38 | except AssertionError: # if top_p is not between 0 and 1 39 | return "Top_p must be between 0 and 1." # return fail message 40 | else: 41 | previous_top_p: str = str( 42 | user_chat_context.user_chat_profile.top_p 43 | ) 44 | user_chat_context.user_chat_profile.top_p = ( 45 | top_p_to_change # set top_p 46 | ) 47 | await CacheManager.update_profile_and_model( 48 | user_chat_context 49 | ) # update user_chat_context 50 | return f"I've changed top_p from {previous_top_p} to {top_p_to_change}." # return success message 51 | -------------------------------------------------------------------------------- /app/utils/chat/commands/prompt.py: -------------------------------------------------------------------------------- 1 | from app.common.constants import SystemPrompts 2 | from app.models.base_models import MessageHistory 3 | from app.models.chat_models import ChatRoles, UserChatContext, command_response 4 | from app.utils.chat.managers.message import MessageManager 5 | 6 | 7 | class PromptCommands: 8 | @staticmethod 9 | @command_response.send_message_and_stop 10 | async def codex(user_chat_context: UserChatContext) -> str: 11 | """Let GPT act as CODEX("COding DEsign eXpert")\n 12 | /codex""" 13 | system_message = SystemPrompts.CODEX 14 | await MessageManager.clear_message_history_safely( 15 | user_chat_context=user_chat_context, role=ChatRoles.SYSTEM 16 | ) 17 | await MessageManager.add_message_history_safely( 18 | user_chat_context=user_chat_context, 19 | role=ChatRoles.SYSTEM, 20 | content=system_message, 21 | ) 22 | return "CODEX mode ON" 23 | 24 | @staticmethod 25 | @command_response.send_message_and_stop 26 | async def redx(user_chat_context: UserChatContext) -> str: 27 | """Let GPT reduce your message as much as possible\n 28 | /redx""" 29 | system_message = SystemPrompts.REDEX 30 | await MessageManager.clear_message_history_safely( 31 | user_chat_context=user_chat_context, role=ChatRoles.SYSTEM 32 | ) 33 | await MessageManager.add_message_history_safely( 34 | user_chat_context=user_chat_context, 35 | role=ChatRoles.SYSTEM, 36 | content=system_message, 37 | ) 38 | return "REDX mode ON" 39 | 40 | @staticmethod 41 | @command_response.send_message_and_stop 42 | async def system( 43 | system_message: str, /, user_chat_context: UserChatContext 44 | ) -> str: # add system message 45 | """Add system message\n 46 | /system """ 47 | await MessageManager.add_message_history_safely( 48 | user_chat_context=user_chat_context, 49 | content=system_message, 50 | role=ChatRoles.SYSTEM, 51 | ) 52 | return ( 53 | f"Added system message: {system_message}" # return success message 54 | ) 55 | 56 | @staticmethod 57 | @command_response.send_message_and_stop 58 | async def pop(role: str, user_chat_context: UserChatContext) -> str: 59 | """Pop last message (user or system or ai)\n 60 | /pop """ 61 | try: 62 | actual_role: ChatRoles = ChatRoles.get_static_member(role) 63 | except ValueError: 64 | return ( 65 | "Role must be one of user, system, ai" # return fail message 66 | ) 67 | last_message_history: MessageHistory | None = ( 68 | await MessageManager.pop_message_history_safely( 69 | user_chat_context=user_chat_context, 70 | role=actual_role, 71 | ) 72 | ) # type: ignore 73 | if last_message_history is None: # if last_message_history is None 74 | return f"There is no {role} message to pop." # return fail message 75 | return f"Pop {role} message: {last_message_history.content}" # return success message 76 | 77 | @staticmethod 78 | @command_response.send_message_and_stop 79 | async def set( 80 | role, new_message: str, /, user_chat_context: UserChatContext 81 | ) -> str: 82 | """Set last message (user or system or ai)\n 83 | /set """ 84 | try: 85 | actual_role: ChatRoles = ChatRoles.get_static_member(role) 86 | except ValueError: 87 | return ( 88 | "Role must be one of user, system, ai" # return fail message 89 | ) 90 | if ( 91 | await MessageManager.set_message_history_safely( 92 | user_chat_context=user_chat_context, 93 | role=actual_role, 94 | index=-1, 95 | new_content=new_message, 96 | ) 97 | is None 98 | ): # if set message history failed 99 | return f"There is no {role} message to set." 100 | return f"Set {role} message: {new_message}" # return success message 101 | -------------------------------------------------------------------------------- /app/utils/chat/commands/server.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | 3 | from app.common.app_settings_llama_cpp import ( 4 | shutdown_llama_cpp_server, 5 | start_llama_cpp_server, 6 | ) 7 | from app.common.config import config 8 | from app.models.chat_models import command_response 9 | from app.shared import Shared 10 | 11 | 12 | class ServerCommands: 13 | @staticmethod 14 | @command_response.send_message_and_stop 15 | async def free() -> str: 16 | """Free the process pool executor\n 17 | /free""" 18 | shared = Shared() 19 | shared.process_pool_executor.shutdown(wait=True) 20 | shared.process_pool_executor = ProcessPoolExecutor() 21 | return "Process pool executor freed!" 22 | 23 | @staticmethod 24 | @command_response.send_message_and_stop 25 | async def shutdownllamacpp() -> str: 26 | """Shutdown the llama cpp server\n 27 | /shutdownllamacpp""" 28 | shutdown_llama_cpp_server(shared=Shared()) 29 | return "Shutdown llama cpp server!" 30 | 31 | @staticmethod 32 | @command_response.send_message_and_stop 33 | async def startllamacpp() -> str: 34 | """Start the llama cpp server\n 35 | /startllamacpp""" 36 | start_llama_cpp_server(shared=Shared(), config=config) 37 | return "Started llama cpp server!" 38 | -------------------------------------------------------------------------------- /app/utils/chat/commands/summarize.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Optional 3 | 4 | from fastapi.concurrency import run_in_threadpool 5 | 6 | from app.models.chat_models import command_response 7 | from app.shared import Shared 8 | from app.utils.chat.buffer import BufferedUserContext 9 | from app.utils.chat.messages.converter import message_histories_to_str 10 | from app.utils.chat.text_generations.summarization import get_summarization 11 | 12 | 13 | class SummarizeCommands: 14 | @staticmethod 15 | @command_response.send_message_and_stop 16 | async def summarize( 17 | to_summarize: Optional[str], /, buffer: BufferedUserContext 18 | ) -> str: 19 | shared = Shared() 20 | if to_summarize is None: 21 | to_summarize = message_histories_to_str( 22 | user_chat_roles=buffer.current_user_chat_roles, 23 | user_message_histories=buffer.current_user_message_histories, 24 | ai_message_histories=buffer.current_ai_message_histories, 25 | system_message_histories=buffer.current_system_message_histories, 26 | ) 27 | to_summarize_tokens = buffer.current_user_chat_context.total_tokens 28 | start: float = time() 29 | summarized = await run_in_threadpool( 30 | get_summarization, 31 | to_summarize=to_summarize, 32 | to_summarize_tokens=to_summarize_tokens, 33 | ) 34 | original_tokens: int = len( 35 | shared.token_text_splitter._tokenizer.encode(to_summarize) 36 | ) 37 | summarized_tokens: int = len( 38 | shared.token_text_splitter._tokenizer.encode(summarized) 39 | ) 40 | end: float = time() 41 | return "\n".join( 42 | ( 43 | "# Summarization complete!", 44 | f"- original tokens: {original_tokens}", 45 | f"- summarized tokens: {summarized_tokens}", 46 | f"- summarization time: {end-start}s", 47 | "```", 48 | summarized, 49 | "```", 50 | ) 51 | ) 52 | -------------------------------------------------------------------------------- /app/utils/chat/commands/testing.py: -------------------------------------------------------------------------------- 1 | from fastapi import WebSocket 2 | 3 | from app.models.chat_models import ( 4 | ResponseType, 5 | UserChatContext, 6 | command_response, 7 | ) 8 | from app.models.llms import LLMModels 9 | from app.utils.chat.buffer import BufferedUserContext 10 | from app.utils.chat.managers.websocket import SendToWebsocket 11 | 12 | 13 | class TestingCommands: 14 | @staticmethod 15 | @command_response.send_message_and_stop 16 | def test( 17 | user_chat_context: UserChatContext, 18 | ) -> str: # test command showing user_chat_context 19 | """Test command showing user_chat_context\n 20 | /test""" 21 | return str(user_chat_context) 22 | 23 | @staticmethod 24 | @command_response.send_message_and_stop 25 | def ping() -> str: 26 | """Ping! Pong!\n 27 | /ping""" 28 | return "pong" 29 | 30 | @staticmethod 31 | @command_response.send_message_and_stop 32 | def echo(msg: str, /) -> str: 33 | """Echo your message\n 34 | /echo """ 35 | return msg 36 | 37 | @staticmethod 38 | @command_response.do_nothing 39 | async def sendtowebsocket( 40 | msg: str, /, websocket: WebSocket, user_chat_context: UserChatContext 41 | ) -> None: 42 | """Send all messages to websocket\n 43 | /sendtowebsocket""" 44 | await SendToWebsocket.message( 45 | websocket=websocket, 46 | msg=msg, 47 | chat_room_id=user_chat_context.chat_room_id, 48 | ) 49 | 50 | @staticmethod 51 | async def testchaining( 52 | chain_size: int, buffer: BufferedUserContext 53 | ) -> tuple[str, ResponseType]: 54 | """Test chains of commands\n 55 | /testchaining """ 56 | if chain_size <= 0: 57 | return "Chaining Complete!", ResponseType.SEND_MESSAGE_AND_STOP 58 | await SendToWebsocket.message( 59 | websocket=buffer.websocket, 60 | msg=f"Current Chaining: {chain_size}", 61 | chat_room_id=buffer.current_chat_room_id, 62 | ) 63 | return f"/testchaining {chain_size-1}", ResponseType.REPEAT_COMMAND 64 | 65 | @staticmethod 66 | @command_response.send_message_and_stop 67 | def codeblock(language, codes: str, /) -> str: 68 | """Send codeblock\n 69 | /codeblock """ 70 | return f"\n```{language.lower()}\n" + codes + "\n```\n" 71 | 72 | @staticmethod 73 | @command_response.send_message_and_stop 74 | def addmember(name: str, /, user_chat_context: UserChatContext) -> str: 75 | LLMModels.add_member(name, None) 76 | return f"Added {name} to members" 77 | -------------------------------------------------------------------------------- /app/utils/chat/commands/vectorstore.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from app.common.config import config 4 | from app.common.lotties import Lotties 5 | from app.models.chat_models import ResponseType, command_response 6 | from app.models.function_calling.functions import FunctionCalls 7 | from app.models.llms import OpenAIModel 8 | from app.utils.chat.buffer import BufferedUserContext 9 | from app.utils.chat.managers.vectorstore import VectorStoreManager 10 | from app.utils.chat.messages.handler import MessageHandler 11 | from app.utils.function_calling.query import aget_query_to_search 12 | from app.viewmodels.status import UserStatus 13 | 14 | 15 | class VectorstoreCommands: 16 | @staticmethod 17 | async def query( 18 | user_query: str, /, buffer: BufferedUserContext, **kwargs 19 | ) -> tuple[Optional[str], ResponseType]: 20 | """Query from redis vectorstore\n 21 | /query """ 22 | if user_query.startswith("/"): 23 | # User is trying to invoke another command. 24 | # Give control back to the command handler, 25 | # and let it handle the command. 26 | # e.g. `/query /help` will invoke `/help` command 27 | return user_query, ResponseType.REPEAT_COMMAND 28 | 29 | # Save user query to buffer and database 30 | await MessageHandler.user(msg=user_query, buffer=buffer) 31 | 32 | if not isinstance(buffer.current_llm_model.value, OpenAIModel): 33 | # Non-OpenAI models can't invoke function call, 34 | # so we force function calling here 35 | query_to_search: str = await aget_query_to_search( 36 | buffer=buffer, 37 | query=user_query, 38 | function=FunctionCalls.get_function_call( 39 | FunctionCalls.vectorstore_search 40 | ), 41 | ) 42 | await MessageHandler.function_call( 43 | callback_name=FunctionCalls.vectorstore_search.__name__, 44 | callback_kwargs={"query_to_search": query_to_search}, 45 | buffer=buffer, 46 | ) 47 | else: 48 | # OpenAI models can invoke function call, 49 | # so let the AI decide whether to invoke function call 50 | function = FunctionCalls.get_function_call( 51 | FunctionCalls.vectorstore_search 52 | ) 53 | buffer.optional_info["functions"] = [function] 54 | buffer.optional_info["function_call"] = function 55 | await MessageHandler.ai(buffer=buffer) 56 | 57 | # End of command 58 | return None, ResponseType.DO_NOTHING 59 | 60 | @staticmethod 61 | @command_response.send_message_and_stop 62 | async def embed(text_to_embed: str, /, buffer: BufferedUserContext) -> str: 63 | """Embed the text and save its vectors in the redis vectorstore.\n 64 | /embed """ 65 | await VectorStoreManager.create_documents( 66 | text=text_to_embed, collection_name=buffer.user_id 67 | ) 68 | return Lotties.OK.format("Embedding successful!") 69 | 70 | @staticmethod 71 | @command_response.send_message_and_stop 72 | async def share(text_to_embed: str, /) -> str: 73 | """Embed the text and save its vectors in the redis vectorstore. This index is shared for everyone.\n 74 | /share """ 75 | await VectorStoreManager.create_documents( 76 | text=text_to_embed, collection_name=config.shared_vectorestore_name 77 | ) 78 | return Lotties.OK.format( 79 | "Embedding successful!\nThis data will be shared for everyone." 80 | ) 81 | 82 | @staticmethod 83 | @command_response.send_message_and_stop 84 | async def drop(buffer: BufferedUserContext) -> str: 85 | """Drop the index from the redis vectorstore.\n 86 | /drop""" 87 | dropped_index: list[str] = [] 88 | if await VectorStoreManager.delete_collection( 89 | collection_name=buffer.user_id 90 | ): 91 | dropped_index.append(buffer.user_id) 92 | if ( 93 | buffer.user.status is UserStatus.admin 94 | and await VectorStoreManager.delete_collection( 95 | collection_name=config.shared_vectorestore_name, 96 | ) 97 | ): 98 | dropped_index.append(config.shared_vectorestore_name) 99 | if not dropped_index: 100 | return "No index dropped." 101 | return f"Index dropped: {', '.join(dropped_index)}" 102 | -------------------------------------------------------------------------------- /app/utils/chat/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | 5 | class BaseEmbeddingGenerator(ABC): 6 | @abstractmethod 7 | def __del__(self): 8 | """Clean up resources.""" 9 | ... 10 | 11 | @classmethod 12 | @abstractmethod 13 | def from_pretrained(cls, model_name: str) -> "BaseEmbeddingGenerator": 14 | """Load a pretrained model into RAM.""" 15 | return cls 16 | 17 | @abstractmethod 18 | def generate_embeddings( 19 | self, 20 | texts: list[str], 21 | **kwargs: Any, 22 | ) -> list[list[float]]: 23 | """Generate embeddings for a list of texts.""" 24 | ... 25 | 26 | @property 27 | @abstractmethod 28 | def model_name(self) -> str: 29 | """Identifier for the model used by this generator.""" 30 | ... 31 | -------------------------------------------------------------------------------- /app/utils/chat/embeddings/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable, Optional 2 | 3 | import numpy as np 4 | import tensorflow_hub as hub 5 | 6 | from app.utils.colorama import Fore 7 | from app.utils.logger import ApiLogger 8 | 9 | from . import BaseEmbeddingGenerator 10 | 11 | if TYPE_CHECKING: 12 | from tensorflow.python.framework.ops import Tensor 13 | 14 | 15 | class SentenceEncoderEmbeddingGenerator(BaseEmbeddingGenerator): 16 | """Generate embeddings using a sentence encoder model, 17 | automatically downloading the model from https://tfhub.dev/""" 18 | 19 | base_url: str = "https://tfhub.dev/google/" 20 | model: Optional[Callable[[list[str]], "Tensor"]] = None 21 | _model_name: Optional[str] = None 22 | 23 | def __del__(self) -> None: 24 | if self.model is not None: 25 | getattr(self.model, "__del__", lambda: None)() 26 | del self.model 27 | self.model = None 28 | print("🗑️ SentenceEncoderEmbedding deleted!") 29 | 30 | @classmethod 31 | def from_pretrained( 32 | cls, model_name: str 33 | ) -> "SentenceEncoderEmbeddingGenerator": 34 | self = cls() 35 | self._model_name = model_name 36 | url = f"{self.base_url.rstrip('/')}/{model_name.lstrip('/')}" 37 | self.model = hub.load(url) # type: ignore 38 | ApiLogger("||embeddings||").info( 39 | f"🤖 TFHub {Fore.WHITE}{model_name}{Fore.GREEN} loaded!", 40 | ) 41 | return self 42 | 43 | def generate_embeddings( 44 | self, 45 | texts: list[str], 46 | batch_size: int = 100, 47 | **kwargs, 48 | ) -> list[list[float]]: 49 | assert self.model is not None, "Please load the model first." 50 | embeddings: list["Tensor"] = [] 51 | for batch_idx in range(0, len(texts), batch_size): 52 | batch_texts = texts[batch_idx : (batch_idx + batch_size)] 53 | embeddings.append(self.model(batch_texts)) 54 | return np.vstack(embeddings).tolist() 55 | 56 | @property 57 | def model_name(self) -> str: 58 | return self._model_name or self.__class__.__name__ 59 | 60 | 61 | # class SemanticSearch: 62 | # use: Callable[[list[str]], "Tensor"] 63 | # nn: Optional[NearestNeighbors] = None 64 | # data: Optional[list[str]] = None 65 | # embeddings: Optional[np.ndarray] = None 66 | 67 | # def __init__(self): 68 | # self.use = hub.load( 69 | # "https://tfhub.dev/google/universal-sentence-encoder/4", 70 | # ) # type: ignore 71 | 72 | # def __call__(self, text: str, return_data: bool = True): 73 | # assert ( 74 | # self.nn is not None and self.data is not None 75 | # ), "Please fit the model first." 76 | # query_embedding: "Tensor" = self.use([text]) 77 | # neighbors: np.ndarray = self.nn.kneighbors( 78 | # query_embedding, 79 | # return_distance=False, 80 | # )[0] 81 | 82 | # if return_data: 83 | # return [self.data[i] for i in neighbors] 84 | # else: 85 | # return neighbors 86 | 87 | # def query( 88 | # self, 89 | # text: str, 90 | # return_data: bool = True, 91 | # ): 92 | # return self(text, return_data=return_data) 93 | 94 | # def fit(self, data: list[str], batch: int = 1000, n_neighbors: int = 5): 95 | # self.data = data 96 | # self.embeddings = self._get_text_embedding(data, batch=batch) 97 | # self.nn = NearestNeighbors( 98 | # n_neighbors=min( 99 | # n_neighbors, 100 | # len(self.embeddings), 101 | # ) 102 | # ) 103 | # self.nn.fit(self.embeddings) 104 | 105 | # def _get_text_embedding( 106 | # self, 107 | # texts: list[str], 108 | # batch: int = 1000, 109 | # ) -> np.ndarray: 110 | # embeddings: list["Tensor"] = [] 111 | # for batch_idx in range(0, len(texts), batch): 112 | # text_batch = texts[batch_idx : (batch_idx + batch)] 113 | # embeddings.append(self.use(text_batch)) 114 | # return np.vstack(embeddings) 115 | -------------------------------------------------------------------------------- /app/utils/chat/embeddings/transformer.py: -------------------------------------------------------------------------------- 1 | from gc import collect 2 | from typing import Optional 3 | 4 | from torch import Tensor, cuda 5 | from transformers.modeling_outputs import ( 6 | BaseModelOutputWithPoolingAndCrossAttentions, 7 | ) 8 | from transformers.modeling_utils import PreTrainedModel 9 | from transformers.models.auto.modeling_auto import AutoModel 10 | from transformers.models.auto.tokenization_auto import AutoTokenizer 11 | from transformers.models.t5.modeling_t5 import T5Model 12 | from transformers.tokenization_utils import PreTrainedTokenizer 13 | from transformers.tokenization_utils_base import BatchEncoding 14 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 15 | 16 | from app.utils.colorama import Fore 17 | from app.utils.logger import ApiLogger 18 | 19 | from . import BaseEmbeddingGenerator 20 | 21 | device = "cuda" if cuda.is_available() else "cpu" 22 | 23 | 24 | class TransformerEmbeddingGenerator(BaseEmbeddingGenerator): 25 | """Generate embeddings using a transformer model, 26 | automatically downloading the model from https://huggingface.co/""" 27 | 28 | model: Optional[PreTrainedModel] = None 29 | tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast] = None 30 | encoder: Optional[PreTrainedModel] = None 31 | _model_name: Optional[str] = None 32 | 33 | def __del__(self) -> None: 34 | if self.model is not None: 35 | getattr(self.model, "__del__", lambda: None)() 36 | self.model = None 37 | print("🗑️ TransformerEmbedding model deleted!") 38 | if self.tokenizer is not None: 39 | getattr(self.tokenizer, "__del__", lambda: None)() 40 | self.tokenizer = None 41 | print("🗑️ TransformerEmbedding tokenizer deleted!") 42 | if self.encoder is not None: 43 | getattr(self.encoder, "__del__", lambda: None)() 44 | self.encoder = None 45 | print("🗑️ TransformerEmbedding encoder deleted!") 46 | 47 | @classmethod 48 | def from_pretrained( 49 | cls, model_name: str 50 | ) -> "TransformerEmbeddingGenerator": 51 | self = cls() 52 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 53 | self._model_name = model_name 54 | ApiLogger("||embeddings||").info( 55 | f"🤖 Huggingface tokenizer {Fore.WHITE}{model_name}{Fore.GREEN} loaded!", 56 | ) 57 | 58 | self.model = AutoModel.from_pretrained(model_name) 59 | ApiLogger("||embeddings||").info( 60 | f"🤖 Huggingface model {Fore.WHITE}{model_name}{Fore.GREEN} loaded!", 61 | ) 62 | return self 63 | 64 | def generate_embeddings( 65 | self, 66 | texts: list[str], 67 | context_length: int = 512, 68 | batch_size: int = 3, 69 | **kwargs, 70 | ) -> list[list[float]]: 71 | embeddings: list[list[float]] = [] 72 | for batch_idx in range(0, len(texts), batch_size): 73 | batch_texts = texts[batch_idx : (batch_idx + batch_size)] 74 | batch_embeddings, _ = self._generate_embeddings_and_n_tokens( 75 | texts=batch_texts, context_length=context_length 76 | ) 77 | embeddings.extend(batch_embeddings) 78 | return embeddings 79 | 80 | def _generate_embeddings_and_n_tokens( 81 | self, 82 | texts: list[str], 83 | context_length: int = 512, 84 | ) -> tuple[list[list[float]], int]: 85 | assert self.model is not None and self.tokenizer is not None 86 | 87 | def average_pool( 88 | last_hidden_states: Tensor, attention_mask: Tensor 89 | ) -> Tensor: 90 | last_hidden = last_hidden_states.masked_fill( 91 | ~attention_mask[..., None].bool(), 0.0 92 | ) 93 | return ( 94 | last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 95 | ) 96 | 97 | # Tokenize the input texts 98 | batch_dict: BatchEncoding = self.tokenizer( 99 | texts, 100 | max_length=context_length, 101 | padding="longest", 102 | truncation=True, 103 | return_tensors="pt", 104 | ) 105 | if self.encoder is None: 106 | # Get the encoder from the model 107 | if isinstance(self.model, T5Model): 108 | self.encoder = self.model.get_encoder() 109 | else: 110 | self.encoder = self.model 111 | 112 | if device == "cuda": 113 | # Load the encoder into VRAM 114 | self.encoder = self.encoder.to(device) 115 | batch_dict = batch_dict.to(device) 116 | outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.encoder( 117 | **batch_dict 118 | ) 119 | embeddings, tokens = ( 120 | average_pool( 121 | last_hidden_states=outputs.last_hidden_state, 122 | attention_mask=batch_dict["attention_mask"], # type: ignore 123 | ).tolist(), 124 | sum([len(encoding) for encoding in batch_dict["input_ids"]]), # type: ignore 125 | ) 126 | del batch_dict 127 | del outputs 128 | if device == "cuda": 129 | # Deallocate output tensors from VRAM 130 | cuda.empty_cache() 131 | collect() 132 | return embeddings, tokens 133 | 134 | @property 135 | def model_name(self) -> str: 136 | return self._model_name or self.__class__.__name__ 137 | -------------------------------------------------------------------------------- /app/utils/chat/file_loader.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import IO, Any 3 | 4 | from langchain.document_loaders.unstructured import UnstructuredBaseLoader 5 | from langchain.docstore.document import Document 6 | from unstructured.partition.auto import partition 7 | 8 | 9 | class UnstructuredFileIOLoader(UnstructuredBaseLoader): 10 | """Loader that uses unstructured to load file IO objects.""" 11 | 12 | def __init__(self, file: IO, filename: str, mode: str = "single", **unstructured_kwargs: Any): 13 | """Initialize with file path.""" 14 | self.file = file 15 | self.filename = filename 16 | super().__init__(mode=mode, **unstructured_kwargs) 17 | 18 | def _get_elements(self) -> list: 19 | return partition(file=self.file, file_filename=self.filename, **self.unstructured_kwargs) 20 | 21 | def _get_metadata(self) -> dict: 22 | return {} 23 | 24 | 25 | def read_bytes_to_documents(file: bytes, filename: str) -> list[Document]: 26 | return UnstructuredFileIOLoader(file=io.BytesIO(file), strategy="fast", filename=filename).load() 27 | 28 | 29 | def read_bytes_to_text(file: bytes, filename: str) -> str: 30 | return "\n\n".join([doc.page_content for doc in read_bytes_to_documents(file=file, filename=filename)]) 31 | 32 | 33 | if __name__ == "__main__": 34 | with open(r"test.pdf", "rb") as f: 35 | file = f.read() 36 | text = read_bytes_to_text(file, "test.pdf") 37 | print(text) 38 | -------------------------------------------------------------------------------- /app/utils/chat/messages/turn_templates.py: -------------------------------------------------------------------------------- 1 | """This module contains functions to extract information from chat turn templates.""" 2 | 3 | from re import DOTALL, compile 4 | from typing import TYPE_CHECKING, Optional 5 | 6 | if TYPE_CHECKING: 7 | from langchain import PromptTemplate 8 | 9 | 10 | def shatter_chat_turn_prompt( 11 | *keys: str, chat_turn_prompt: "PromptTemplate" 12 | ) -> tuple[str, ...]: 13 | """Identify the chat turn template and return the shatter result. 14 | e.g. If template of chat_turn_prompt is "### {role}: {content} " 15 | and keys are "role" and "content", 16 | then the result will be ('### ', "{role}", ': ', "{content}", ' ').""" 17 | pattern: str = "(.*)" 18 | kwargs: dict[str, str] = {} 19 | for key in keys: 20 | kwargs[key] = "{" + key + "}" 21 | pattern += f"({kwargs[key]})(.*)" 22 | search_result = compile(pattern, flags=DOTALL).match( 23 | chat_turn_prompt.format(**kwargs) 24 | ) 25 | if search_result is None: 26 | raise ValueError( 27 | f"Invalid chat turn prompt: {chat_turn_prompt.format(**kwargs)}" 28 | ) 29 | return search_result.groups() 30 | 31 | 32 | def identify_end_of_string(*keys, chat_turn_prompt: "PromptTemplate") -> Optional[str]: 33 | """Identify the end of string in the chat turn prompt. 34 | e.g. If template of chat_turn_prompt is "### {role}: {content} " 35 | then the result will be "". 36 | If there is no end of string, then the result will be None.""" 37 | return shatter_chat_turn_prompt(*keys, chat_turn_prompt=chat_turn_prompt)[ 38 | -1 39 | ].strip() 40 | 41 | 42 | if __name__ == "__main__": 43 | from langchain import PromptTemplate 44 | 45 | input_variables = ["role", "content"] 46 | for template, template_format in ( 47 | ("### {role}: {content} ", "f-string"), 48 | ("### {{role}}: {{content}} ", "jinja2"), 49 | ): 50 | chat_turn_prompt = PromptTemplate( 51 | template=template, 52 | input_variables=input_variables, 53 | template_format=template_format, 54 | ) 55 | print( 56 | "Shattered String:", 57 | shatter_chat_turn_prompt( 58 | *input_variables, chat_turn_prompt=chat_turn_prompt 59 | ), 60 | ) # ('### ', '{role}', ': ', '{content}', ' ') 61 | print( 62 | "End-of-String:", 63 | identify_end_of_string(*input_variables, chat_turn_prompt=chat_turn_prompt), 64 | ) # 65 | -------------------------------------------------------------------------------- /app/utils/chat/text_generations/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import TYPE_CHECKING, Iterator 3 | 4 | from app.models.base_models import APIChatMessage, TextGenerationSettings 5 | from app.models.completion_models import ( 6 | ChatCompletion, 7 | ChatCompletionChunk, 8 | Completion, 9 | CompletionChunk, 10 | ) 11 | 12 | if TYPE_CHECKING: 13 | from app.models.llms import LLMModel 14 | 15 | 16 | class BaseCompletionGenerator(ABC): 17 | """Base class for all completion generators.""" 18 | 19 | user_role: str = "user" 20 | system_role: str = "system" 21 | 22 | user_input_role: str = "User" 23 | system_input_role: str = "System" 24 | 25 | ai_fallback_input_role: str = "Assistant" 26 | 27 | @abstractmethod 28 | def __del__(self): 29 | """Clean up resources.""" 30 | ... 31 | 32 | @classmethod 33 | @abstractmethod 34 | def from_pretrained( 35 | cls, llm_model: "LLMModel" 36 | ) -> "BaseCompletionGenerator": 37 | """Load a pretrained model into RAM.""" 38 | ... 39 | 40 | @abstractmethod 41 | def generate_completion( 42 | self, prompt: str, settings: TextGenerationSettings 43 | ) -> Completion: 44 | """Generate a completion for a given prompt.""" 45 | ... 46 | 47 | @abstractmethod 48 | def generate_completion_with_streaming( 49 | self, prompt: str, settings: TextGenerationSettings 50 | ) -> Iterator[CompletionChunk]: 51 | """Generate a completion for a given prompt, yielding chunks of text as they are generated.""" 52 | ... 53 | 54 | @abstractmethod 55 | def generate_chat_completion( 56 | self, messages: list[APIChatMessage], settings: TextGenerationSettings 57 | ) -> ChatCompletion: 58 | """Generate a completion for a given prompt.""" 59 | ... 60 | 61 | @abstractmethod 62 | def generate_chat_completion_with_streaming( 63 | self, messages: list[APIChatMessage], settings: TextGenerationSettings 64 | ) -> Iterator[ChatCompletionChunk]: 65 | """Generate a completion for a given prompt, yielding chunks of text as they are generated.""" 66 | ... 67 | 68 | @staticmethod 69 | def get_stop_strings(*roles: str) -> list[str]: 70 | """A helper method to generate stop strings for a given set of roles. 71 | Stop strings are required to stop text completion API from generating 72 | text that does not belong to the current chat turn. 73 | e.g. The common stop string is "### USER:", which can prevent ai from generating 74 | user's message itself.""" 75 | 76 | prompt_stop = set() 77 | for role in roles: 78 | avoids = ( 79 | f"{role}:", 80 | f"### {role}:", 81 | f"###{role}:", 82 | ) 83 | prompt_stop.update( 84 | avoids, 85 | map(str.capitalize, avoids), 86 | map(str.upper, avoids), 87 | map(str.lower, avoids), 88 | ) 89 | return list(prompt_stop) 90 | 91 | @classmethod 92 | def convert_messages_into_prompt( 93 | cls, messages: list[APIChatMessage], settings: TextGenerationSettings 94 | ) -> str: 95 | """A helper method to convert list of messages into one text prompt.""" 96 | ai_input_role: str = cls.ai_fallback_input_role 97 | chat_history: str = "" 98 | for message in messages: 99 | if message.role.lower() == cls.user_role: 100 | input_role = cls.user_input_role 101 | elif message.role.lower() == cls.system_role: 102 | input_role = cls.system_input_role 103 | else: 104 | input_role = ai_input_role = message.role 105 | chat_history += f"### {input_role}:{message.content}" 106 | 107 | prompt_stop: list[str] = cls.get_stop_strings( 108 | cls.user_input_role, cls.system_input_role, ai_input_role 109 | ) 110 | if isinstance(settings.stop, str): 111 | settings.stop = prompt_stop + [settings.stop] 112 | elif isinstance(settings.stop, list): 113 | settings.stop = prompt_stop + settings.stop 114 | else: 115 | settings.stop = prompt_stop 116 | return chat_history + f"### {ai_input_role}:" 117 | 118 | @staticmethod 119 | def is_possible_to_generate_stops( 120 | decoded_text: str, stops: list[str] 121 | ) -> bool: 122 | """A helper method to check if the decoded text contains any of the stop tokens.""" 123 | for stop in stops: 124 | if stop in decoded_text or any( 125 | [ 126 | decoded_text.endswith(stop[: i + 1]) 127 | for i in range(len(stop)) 128 | ] 129 | ): 130 | return True 131 | return False 132 | 133 | @property 134 | @abstractmethod 135 | def llm_model(self) -> "LLMModel": 136 | """The LLM model used by this generator.""" 137 | ... 138 | -------------------------------------------------------------------------------- /app/utils/chat/text_generations/path.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from app.common.config import BASE_DIR 5 | 6 | 7 | def resolve_model_path_to_posix( 8 | model_path: str, default_relative_directory: Optional[str] = None 9 | ): 10 | """Resolve a model path to a POSIX path, relative to the BASE_DIR.""" 11 | path = Path(model_path) 12 | parent_directory: Path = ( 13 | Path(BASE_DIR) / Path(default_relative_directory) 14 | if default_relative_directory is not None 15 | else Path(BASE_DIR) 16 | if Path.cwd() == path.parent.resolve() 17 | else path.parent.resolve() 18 | ) 19 | filename: str = path.name 20 | return (parent_directory / filename).as_posix() 21 | -------------------------------------------------------------------------------- /app/utils/chat/text_generations/summarization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from app.common.config import ChatConfig 4 | from app.shared import Shared 5 | from app.utils.logger import ApiLogger 6 | 7 | 8 | def get_summarization( 9 | to_summarize: str, 10 | to_summarize_tokens: Optional[int] = None, 11 | ) -> str: 12 | shared = Shared() 13 | if to_summarize_tokens is None: 14 | to_summarize_tokens = len( 15 | shared.token_text_splitter._tokenizer.encode(to_summarize) 16 | ) 17 | 18 | if to_summarize_tokens < ChatConfig.summarization_token_limit: 19 | summarize_chain = shared.stuff_summarize_chain 20 | else: 21 | summarize_chain = shared.map_reduce_summarize_chain 22 | result: str = summarize_chain.run( 23 | shared.token_text_splitter.create_documents( 24 | [to_summarize], 25 | tokens_per_chunk=ChatConfig.summarization_chunk_size, 26 | chunk_overlap=ChatConfig.summarization_token_overlap, 27 | ) 28 | ) 29 | ApiLogger("||get_summarization||").info(result) 30 | return result 31 | -------------------------------------------------------------------------------- /app/utils/colorama.py: -------------------------------------------------------------------------------- 1 | # Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. 2 | """ 3 | This module generates ANSI character codes to printing colors to terminals. 4 | See: http://en.wikipedia.org/wiki/ANSI_escape_code 5 | """ 6 | 7 | CSI = "\033[" 8 | OSC = "\033]" 9 | BEL = "\a" 10 | 11 | 12 | def code_to_chars(code): 13 | return CSI + str(code) + "m" 14 | 15 | 16 | def set_title(title): 17 | return OSC + "2;" + title + BEL 18 | 19 | 20 | def clear_screen(mode=2): 21 | return CSI + str(mode) + "J" 22 | 23 | 24 | def clear_line(mode=2): 25 | return CSI + str(mode) + "K" 26 | 27 | 28 | class AnsiCodes(object): 29 | def __init__(self): 30 | # the subclasses declare class attributes which are numbers. 31 | # Upon instantiation we define instance attributes, which are the same 32 | # as the class attributes but wrapped with the ANSI escape sequence 33 | for name in dir(self): 34 | if not name.startswith("_"): 35 | value = getattr(self, name) 36 | setattr(self, name, code_to_chars(value)) 37 | 38 | 39 | class AnsiCursor(object): 40 | def UP(self, n=1): 41 | return CSI + str(n) + "A" 42 | 43 | def DOWN(self, n=1): 44 | return CSI + str(n) + "B" 45 | 46 | def FORWARD(self, n=1): 47 | return CSI + str(n) + "C" 48 | 49 | def BACK(self, n=1): 50 | return CSI + str(n) + "D" 51 | 52 | def POS(self, x=1, y=1): 53 | return CSI + str(y) + ";" + str(x) + "H" 54 | 55 | 56 | class AnsiFore(AnsiCodes): 57 | BLACK = 30 58 | RED = 31 59 | GREEN = 32 60 | YELLOW = 33 61 | BLUE = 34 62 | MAGENTA = 35 63 | CYAN = 36 64 | WHITE = 37 65 | RESET = 39 66 | 67 | # These are fairly well supported, but not part of the standard. 68 | LIGHTBLACK_EX = 90 69 | LIGHTRED_EX = 91 70 | LIGHTGREEN_EX = 92 71 | LIGHTYELLOW_EX = 93 72 | LIGHTBLUE_EX = 94 73 | LIGHTMAGENTA_EX = 95 74 | LIGHTCYAN_EX = 96 75 | LIGHTWHITE_EX = 97 76 | 77 | 78 | class AnsiBack(AnsiCodes): 79 | BLACK = 40 80 | RED = 41 81 | GREEN = 42 82 | YELLOW = 43 83 | BLUE = 44 84 | MAGENTA = 45 85 | CYAN = 46 86 | WHITE = 47 87 | RESET = 49 88 | 89 | # These are fairly well supported, but not part of the standard. 90 | LIGHTBLACK_EX = 100 91 | LIGHTRED_EX = 101 92 | LIGHTGREEN_EX = 102 93 | LIGHTYELLOW_EX = 103 94 | LIGHTBLUE_EX = 104 95 | LIGHTMAGENTA_EX = 105 96 | LIGHTCYAN_EX = 106 97 | LIGHTWHITE_EX = 107 98 | 99 | 100 | class AnsiStyle(AnsiCodes): 101 | BRIGHT = 1 102 | DIM = 2 103 | NORMAL = 22 104 | RESET_ALL = 0 105 | 106 | 107 | Fore = AnsiFore() 108 | Back = AnsiBack() 109 | Style = AnsiStyle() 110 | Cursor = AnsiCursor() 111 | -------------------------------------------------------------------------------- /app/utils/date_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, date, timedelta 2 | from re import match 3 | 4 | 5 | class UTC: 6 | def __init__(self): 7 | self.utc_now = datetime.utcnow() 8 | 9 | @classmethod 10 | def now(cls, hour_diff: int = 0) -> datetime: 11 | return cls().utc_now + timedelta(hours=hour_diff) 12 | 13 | @classmethod 14 | def now_isoformat(cls, hour_diff: int = 0) -> str: 15 | return (cls().utc_now + timedelta(hours=hour_diff)).isoformat() + "Z" 16 | 17 | @classmethod 18 | def date(cls, hour_diff: int = 0) -> date: 19 | return cls.now(hour_diff=hour_diff).date() 20 | 21 | @classmethod 22 | def timestamp(cls, hour_diff: int = 0) -> int: 23 | return int(cls.now(hour_diff=hour_diff).strftime("%Y%m%d%H%M%S")) 24 | 25 | @classmethod 26 | def timestamp_to_datetime(cls, timestamp: int, hour_diff: int = 0) -> datetime: 27 | return datetime.strptime(str(timestamp), "%Y%m%d%H%M%S") + timedelta(hours=hour_diff) 28 | 29 | @classmethod 30 | def date_code(cls, hour_diff: int = 0) -> int: 31 | return int(cls.date(hour_diff=hour_diff).strftime("%Y%m%d")) 32 | 33 | @staticmethod 34 | def check_string_valid(string: str) -> bool: 35 | """Check if a string is in ISO 8601 UTC datetime format. 36 | e.g. 2023-05-22T05:08:29.087279Z" -> True 37 | 2023-05-22T05:08:29Z -> True 38 | 2023-05-22T05:08:29 -> False 39 | 2023/05/22T05:08:29.087279Z -> False""" 40 | regex = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z$" 41 | return True if match(regex, string) else False 42 | 43 | 44 | if __name__ == "__main__": 45 | # 예제 사용 예 46 | print(UTC.check_string_valid("2023-05-22T05:08:29.087279Z")) # True 47 | print(UTC.check_string_valid("2023-05-22T05:08:29Z")) # True 48 | print(UTC.check_string_valid("2023-05-22T05:08:29")) # False 49 | print(UTC.check_string_valid("2023/05/22T05:08:29.087279Z")) # False 50 | print(UTC.timestamp_to_datetime(10000101000000)) # Prefix timestamp 51 | print(UTC.timestamp_to_datetime(99991231235959)) # Suffix timestamp 52 | -------------------------------------------------------------------------------- /app/utils/encoding_utils.py: -------------------------------------------------------------------------------- 1 | from base64 import b64encode, urlsafe_b64encode 2 | from os import urandom 3 | from os.path import exists 4 | from re import findall 5 | from typing import Any 6 | 7 | import json 8 | from cryptography.fernet import Fernet, InvalidToken 9 | from cryptography.hazmat.primitives.hashes import SHA256, HashAlgorithm 10 | from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC 11 | 12 | 13 | class SecretConfigSetup: 14 | """ 15 | Setup for secret configs encryption and decryption in JSON format. 16 | 17 | - Example: 18 | password_from_environ = environ.get("SECRET_CONFIGS_PASSWORD", None) 19 | secret_config_setup = SecretConfigSetup( 20 | password=password_from_environ 21 | if password_from_environ is not None 22 | else input("Enter Passwords:"), 23 | json_file_name="secret_configs.json", 24 | ) 25 | 26 | 27 | @dataclass(frozen=True) 28 | class SecretConfig(metaclass=SingletonMetaClass): 29 | secret_config: dict = field(default_factory=secret_config_setup.initialize) 30 | """ 31 | 32 | algorithm: HashAlgorithm = SHA256() 33 | length: int = 32 34 | iterations: int = 100000 35 | 36 | def __init__(self, password: str, json_file_name: str) -> None: 37 | self.password = password 38 | self.json_file_name = json_file_name 39 | 40 | @classmethod 41 | def _get_kdf(cls, salt: bytes) -> PBKDF2HMAC: 42 | return PBKDF2HMAC( 43 | salt=salt, 44 | algorithm=cls.algorithm, 45 | length=cls.length, 46 | iterations=cls.iterations, 47 | ) 48 | 49 | def encrypt(self) -> None: 50 | # Derive an encryption key from the password and salt using PBKDF2 51 | salt = urandom(16) 52 | fernet = Fernet( 53 | urlsafe_b64encode(SecretConfigSetup._get_kdf(salt=salt).derive(key_material=self.password.encode())) 54 | ) 55 | with open(self.json_file_name, "r") as f: 56 | plain_data: Any = json.load(f) 57 | 58 | # Encrypt the string using the encryption key 59 | encrypted_data: bytes = fernet.encrypt(json.dumps(plain_data).encode()) 60 | 61 | # Save the encrypted data and salt to file 62 | with open(f"{self.json_file_name}.enc", "wb") as f: 63 | f.write(salt) 64 | f.write(encrypted_data) 65 | 66 | def decrypt(self) -> Any: 67 | # Load the encrypted data and salt from file 68 | with open(f"{self.json_file_name}.enc", "rb") as f: 69 | salt = f.read(16) 70 | fernet = Fernet( 71 | urlsafe_b64encode(SecretConfigSetup._get_kdf(salt=salt).derive(key_material=self.password.encode())) 72 | ) 73 | encrypted_data: bytes = f.read() 74 | decrypted_data: bytes = fernet.decrypt(encrypted_data) 75 | return json.loads(decrypted_data) 76 | 77 | def initialize(self) -> Any: 78 | if not exists(f"{self.json_file_name}.enc"): 79 | self.encrypt() 80 | 81 | while True: 82 | try: 83 | secret_config: Any = self.decrypt() 84 | except InvalidToken: 85 | self.password = input("Wrong password! Enter password again: ") 86 | else: 87 | return secret_config 88 | 89 | 90 | def encode_from_utf8(text: str) -> str: 91 | # Check if text contains any non-ASCII characters 92 | matches: list[Any] = findall(r"[^\x00-\x7F]+", text) 93 | if not matches: 94 | # Return text if it doesn't contain any non-ASCII characters 95 | return text 96 | else: 97 | # Encode and replace non-ASCII characters with UTF-8 formatted text 98 | for match in matches: 99 | encoded_text: str = b64encode(match.encode("utf-8")).decode("utf-8") 100 | text = text.replace(match, "=?UTF-8?B?" + encoded_text + "?=") 101 | return text 102 | -------------------------------------------------------------------------------- /app/utils/function_calling/callbacks/lite_browsing.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi.concurrency import run_in_threadpool 4 | 5 | from app.common.lotties import Lotties 6 | from app.models.function_calling.functions import FunctionCalls 7 | from app.shared import Shared 8 | from app.utils.chat.buffer import BufferedUserContext 9 | from app.utils.chat.managers.websocket import SendToWebsocket 10 | from app.utils.logger import ApiLogger 11 | 12 | from ..query import aget_query_to_search 13 | 14 | 15 | async def lite_web_browsing_callback( 16 | buffer: BufferedUserContext, 17 | query: str, 18 | finish: bool, 19 | wait_next_query: bool, 20 | show_result: bool = False, 21 | ) -> Optional[str]: 22 | await SendToWebsocket.message( 23 | websocket=buffer.websocket, 24 | msg=Lotties.SEARCH_WEB.format("### Browsing web", end=False), 25 | chat_room_id=buffer.current_chat_room_id, 26 | finish=False, 27 | ) 28 | try: 29 | query_to_search: str = await aget_query_to_search( 30 | buffer=buffer, 31 | query=query, 32 | function=FunctionCalls.get_function_call(FunctionCalls.web_search), 33 | ) 34 | r = await run_in_threadpool(Shared().duckduckgo.run, query_to_search) 35 | # ApiLogger("||lite_web_browsing_chain||").info(r) 36 | 37 | await SendToWebsocket.message( 38 | websocket=buffer.websocket, 39 | msg=Lotties.OK.format("### Finished browsing") 40 | + (r if show_result else ""), 41 | chat_room_id=buffer.current_chat_room_id, 42 | finish=finish, 43 | wait_next_query=wait_next_query, 44 | ) 45 | return r 46 | except Exception as e: 47 | ApiLogger("||lite_web_browsing_chain||").exception(e) 48 | await SendToWebsocket.message( 49 | msg=Lotties.FAIL.format("### Failed to browse web"), 50 | websocket=buffer.websocket, 51 | chat_room_id=buffer.current_chat_room_id, 52 | finish=finish, 53 | wait_next_query=wait_next_query, 54 | ) 55 | -------------------------------------------------------------------------------- /app/utils/function_calling/callbacks/translate.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from app.common.lotties import Lotties 4 | from app.utils.api.translate import Translator 5 | from app.utils.chat.buffer import BufferedUserContext 6 | from app.utils.chat.managers.websocket import SendToWebsocket 7 | 8 | 9 | async def translate_callback( 10 | buffer: BufferedUserContext, 11 | query: str, 12 | finish: bool, 13 | wait_next_query: Optional[bool], 14 | show_result: bool = True, 15 | show_result_prefix: Optional[str] = " # 🌐 Translation Result\n---\n\n", 16 | src_lang: str = "en", 17 | trg_lang: str = "en", 18 | ) -> Optional[str]: 19 | await SendToWebsocket.message( 20 | msg=Lotties.TRANSLATE.format("### Translating"), 21 | websocket=buffer.websocket, 22 | chat_room_id=buffer.current_chat_room_id, 23 | finish=False, 24 | ) 25 | try: 26 | r = await Translator.translate(text=query, src_lang=src_lang, trg_lang=trg_lang) 27 | r_show = show_result_prefix + r if show_result_prefix is not None else r 28 | await SendToWebsocket.message( 29 | msg=Lotties.OK.format("### Finished translation") 30 | + (r_show if show_result else ""), 31 | websocket=buffer.websocket, 32 | chat_room_id=buffer.current_chat_room_id, 33 | finish=finish, 34 | wait_next_query=wait_next_query, 35 | ) 36 | return r 37 | except Exception: 38 | await SendToWebsocket.message( 39 | msg=Lotties.FAIL.format("### Failed translation"), 40 | websocket=buffer.websocket, 41 | chat_room_id=buffer.current_chat_room_id, 42 | finish=finish, 43 | wait_next_query=wait_next_query, 44 | ) 45 | return None 46 | -------------------------------------------------------------------------------- /app/utils/function_calling/callbacks/vectorstore_search.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from app.common.config import ChatConfig, config 4 | from app.common.lotties import Lotties 5 | from app.utils.chat.buffer import BufferedUserContext 6 | from app.utils.chat.managers.vectorstore import Document, VectorStoreManager 7 | from app.utils.chat.managers.websocket import SendToWebsocket 8 | from app.utils.logger import ApiLogger 9 | 10 | 11 | async def vectorstore_search_callback( 12 | buffer: BufferedUserContext, 13 | query_to_search: str, 14 | finish: bool, 15 | wait_next_query: bool, 16 | show_result: bool = True, 17 | k: int = ChatConfig.vectorstore_n_results_limit, 18 | ) -> Optional[str]: 19 | try: 20 | found_text_and_score: list[ 21 | Tuple[Document, float] 22 | ] = await VectorStoreManager.asimilarity_search_multiple_collections_with_score( 23 | query=query_to_search, 24 | collection_names=[buffer.user_id, config.shared_vectorestore_name], 25 | k=k, 26 | ) 27 | if not found_text_and_score: 28 | raise Exception("No result found") 29 | 30 | found_text: Optional[str] = "\n\n".join( 31 | [document.page_content for document, _ in found_text_and_score] 32 | ) 33 | send_message = "### Finished searching vectorstore" 34 | if show_result: 35 | send_message += f'\n---\n{found_text}' 36 | await SendToWebsocket.message( 37 | websocket=buffer.websocket, 38 | msg=Lotties.OK.format(send_message), 39 | chat_room_id=buffer.current_chat_room_id, 40 | finish=finish, 41 | wait_next_query=wait_next_query, 42 | ) 43 | return found_text 44 | 45 | except Exception as e: 46 | ApiLogger("||vectorstore_search_callback||").exception(e) 47 | await SendToWebsocket.message( 48 | msg=Lotties.FAIL.format("### Failed to search vectorstore"), 49 | websocket=buffer.websocket, 50 | chat_room_id=buffer.current_chat_room_id, 51 | finish=finish, 52 | wait_next_query=wait_next_query, 53 | ) 54 | return None 55 | -------------------------------------------------------------------------------- /app/utils/function_calling/parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | from inspect import signature 3 | from re import compile 4 | from typing import ( 5 | Annotated, 6 | Any, 7 | Callable, 8 | Iterable, 9 | get_args, 10 | get_origin, 11 | ) 12 | 13 | from app.models.completion_models import ( 14 | FunctionCallParsed, 15 | FunctionCallUnparsed, 16 | ) 17 | from app.models.function_calling.base import ( 18 | FunctionCall, 19 | FunctionCallParameter, 20 | JsonTypes, 21 | ) 22 | from app.utils.types import get_type_and_optional 23 | 24 | 25 | def make_function_call_parsed_from_dict( 26 | unparsed_dict: dict[str, JsonTypes] | FunctionCallParsed | FunctionCallUnparsed, 27 | ) -> FunctionCallParsed: 28 | """ 29 | Parse function call response from API into a FunctionCallParsed object. 30 | This is a helper method to identify what function will be called and what 31 | arguments will be passed to it.""" 32 | 33 | if "name" not in unparsed_dict: 34 | raise ValueError("Function call name is required.") 35 | 36 | function_call_parsed: FunctionCallParsed = FunctionCallParsed( 37 | name=str(unparsed_dict["name"]) 38 | ) 39 | arguments = unparsed_dict.get("arguments", {}) 40 | if isinstance(arguments, dict): 41 | function_call_parsed["arguments"] = arguments 42 | return function_call_parsed 43 | else: 44 | try: 45 | function_call_parsed["arguments"] = json.loads(str(arguments)) 46 | except json.JSONDecodeError: 47 | pass 48 | return function_call_parsed 49 | 50 | 51 | def parse_function_call_from_function(func: Callable) -> FunctionCall: 52 | """ 53 | Parse a function into a FunctionCall object. 54 | FunctionCall objects are used to represent the specification of a function 55 | """ 56 | json_types = get_args(JsonTypes) 57 | function_call_params: list[FunctionCallParameter] = [] 58 | required: list[str] = [] 59 | for name, param in signature(func).parameters.items(): 60 | annotation = param.annotation 61 | description: str = "" 62 | enum: list[Any] = [] 63 | 64 | if get_origin(annotation) is Annotated: 65 | # If the annotation is an Annotated type, 66 | # we need to parse the metadata 67 | _param_args = get_args(param.annotation) 68 | _param_type = _param_args[0] 69 | 70 | for metadata in _param_args[1:]: 71 | if isinstance(metadata, str): 72 | # If the metadata is a string, it's the description 73 | description += metadata 74 | elif isinstance(metadata, Iterable): 75 | # If the metadata is an iterable, it's the enum 76 | enum.extend(metadata) 77 | 78 | else: 79 | _param_type = annotation 80 | param_type, optional = get_type_and_optional(_param_type) 81 | if not optional: 82 | required.append(name) 83 | if param_type not in json_types: 84 | continue 85 | function_call_params.append( 86 | FunctionCallParameter( 87 | name=name, 88 | type=param_type, 89 | description=description or None, 90 | enum=enum or None, 91 | ) 92 | ) 93 | line_break_pattern = compile(r"\n\s*") 94 | return FunctionCall( 95 | name=func.__name__, 96 | description=line_break_pattern.sub(" ", func.__doc__) if func.__doc__ else None, 97 | parameters=function_call_params, 98 | required=required or None, 99 | ) 100 | -------------------------------------------------------------------------------- /app/utils/function_calling/query.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from langchain import PromptTemplate 4 | from orjson import loads as orjson_loads 5 | 6 | from app.common.config import ChatConfig 7 | from app.common.constants import JSON_PATTERN 8 | from app.models.chat_models import ChatRoles, MessageHistory 9 | from app.models.function_calling.base import FunctionCall 10 | from app.shared import Shared 11 | from app.utils.chat.buffer import BufferedUserContext 12 | from app.utils.chat.messages.converter import ( 13 | chat_completion_api_parse_method, 14 | message_histories_to_list, 15 | ) 16 | from app.utils.chat.tokens import cutoff_message_histories 17 | 18 | from .request import request_function_call 19 | 20 | 21 | async def aget_query_to_search( 22 | buffer: BufferedUserContext, 23 | query: str, 24 | function: FunctionCall, 25 | timeout: Optional[float] = 30, 26 | ) -> str: 27 | """Get query to search from user query and current context""" 28 | ( 29 | user_message_histories, 30 | ai_message_histories, 31 | system_message_histories, 32 | ) = cutoff_message_histories( 33 | user_chat_context=buffer.current_user_chat_context, 34 | ai_message_histories=buffer.current_ai_message_histories, 35 | user_message_histories=buffer.current_user_message_histories 36 | + [ 37 | MessageHistory( 38 | role=buffer.current_user_chat_roles.user, 39 | content=query, 40 | tokens=buffer.current_user_chat_context.get_tokens_of(query) 41 | + buffer.current_llm_model.value.token_margin, 42 | actual_role=ChatRoles.USER.value, 43 | ) 44 | ], 45 | system_message_histories=[], 46 | token_limit=ChatConfig.query_context_token_limit, 47 | ) 48 | try: 49 | function_call_parsed = await request_function_call( 50 | messages=message_histories_to_list( 51 | parse_method=chat_completion_api_parse_method, 52 | user_message_histories=user_message_histories, 53 | ai_message_histories=ai_message_histories, 54 | system_message_histories=system_message_histories, 55 | ), 56 | functions=[function], 57 | function_call=function, 58 | timeout=timeout, 59 | ) 60 | if "arguments" not in function_call_parsed: 61 | raise ValueError("No arguments returned") 62 | return str(function_call_parsed["arguments"]["query_to_search"]) 63 | except Exception: 64 | return query 65 | 66 | 67 | async def aget_json( 68 | query_template: PromptTemplate, **kwargs_to_format: str 69 | ) -> Optional[Any]: 70 | """Get json from query template and kwargs to format""" 71 | try: 72 | json_query = JSON_PATTERN.search( 73 | await Shared().llm.apredict(query_template.format(**kwargs_to_format)) 74 | ) 75 | if json_query is None: 76 | raise ValueError("Result is None") 77 | else: 78 | return orjson_loads(json_query.group()) 79 | except Exception: 80 | return None 81 | -------------------------------------------------------------------------------- /app/utils/function_calling/request.py: -------------------------------------------------------------------------------- 1 | from asyncio import wait_for 2 | from typing import Any, Literal, Optional 3 | 4 | from app.common.config import OPENAI_API_KEY, ChatConfig 5 | from app.models.completion_models import FunctionCallParsed 6 | from app.models.function_calling.base import FunctionCall 7 | from app.utils.api.completion import request_chat_completion 8 | 9 | from .parser import make_function_call_parsed_from_dict 10 | 11 | 12 | async def request_function_call( 13 | messages: list[dict[str, str]], 14 | functions: list[FunctionCall], 15 | function_call: Optional[FunctionCall | Literal["auto", "none"]] = "auto", 16 | model: str = ChatConfig.global_openai_model, 17 | api_base: str = "https://api.openai.com/v1", 18 | api_key: Optional[str] = OPENAI_API_KEY, 19 | timeout: Optional[float] = None, 20 | force_arguments: bool = False, 21 | **kwargs: Any, 22 | ) -> FunctionCallParsed: 23 | coro = request_chat_completion( 24 | messages=messages, 25 | model=model, 26 | api_base=api_base, 27 | api_key=api_key, 28 | functions=functions, 29 | function_call=function_call, 30 | **kwargs, 31 | ) 32 | if timeout is not None: 33 | coro = wait_for(coro, timeout=timeout) 34 | function_call_unparsed = (await coro)["choices"][0]["message"].get( 35 | "function_call" 36 | ) 37 | if function_call_unparsed is None: 38 | raise ValueError("No function call returned") 39 | function_call_parsed = make_function_call_parsed_from_dict( 40 | function_call_unparsed 41 | ) 42 | if force_arguments and "arguments" not in function_call_parsed: 43 | raise ValueError("No arguments returned") 44 | 45 | return function_call_parsed 46 | -------------------------------------------------------------------------------- /app/utils/function_calling/token_count.py: -------------------------------------------------------------------------------- 1 | from tiktoken import encoding_for_model, get_encoding 2 | 3 | from app.models.function_calling.base import FunctionCall 4 | 5 | 6 | def get_num_tokens_from_functions( 7 | functions: list[FunctionCall], model: str = "gpt-3.5-turbo" 8 | ) -> int: 9 | """Return the number of tokens used by a list of functions.""" 10 | try: 11 | encoding = encoding_for_model(model) 12 | except KeyError: 13 | print("Warning: model not found. Using cl100k_base encoding.") 14 | encoding = get_encoding("cl100k_base") 15 | 16 | num_tokens = 0 17 | for _function in functions: 18 | function = _function.to_dict() 19 | function_tokens = len(encoding.encode(function["name"])) 20 | if "description" in function: 21 | function_tokens += len(encoding.encode(function["description"])) 22 | 23 | if "parameters" in function: 24 | parameters = function["parameters"] 25 | if "properties" in parameters: 26 | for name, property in parameters["properties"].items(): 27 | function_tokens += len(encoding.encode(name)) 28 | for field in property: 29 | if field == "type": 30 | function_tokens += 2 31 | function_tokens += len( 32 | encoding.encode(property["type"]) 33 | ) 34 | elif ( 35 | field == "description" 36 | and "description" in property 37 | ): 38 | function_tokens += 2 39 | function_tokens += len( 40 | encoding.encode(property["description"]) 41 | ) 42 | elif field == "enum" and "enum" in property: 43 | function_tokens -= 3 44 | for o in property["enum"]: 45 | function_tokens += 3 46 | function_tokens += len(encoding.encode(str(o))) 47 | else: 48 | print(f"Warning: not supported field {field}") 49 | function_tokens += 11 50 | 51 | num_tokens += function_tokens 52 | 53 | num_tokens += 12 54 | return num_tokens 55 | -------------------------------------------------------------------------------- /app/utils/js_initializer.py: -------------------------------------------------------------------------------- 1 | import re 2 | from app.common.config import config, ProdConfig 3 | 4 | 5 | def js_url_initializer(js_location: str = "app/web/main.dart.js") -> None: 6 | with open(file=js_location, mode="r") as file: 7 | filedata = file.read() 8 | 9 | for schema in ("http", "ws"): 10 | if isinstance(config, ProdConfig): 11 | from_url_local = rf"{schema}://localhost:\d+" 12 | to_url_prod = f"{schema}s://{config.host_main}" 13 | filedata = re.sub( 14 | from_url_local, 15 | to_url_prod, 16 | filedata, 17 | ) 18 | else: 19 | from_url_local = rf"{schema}://localhost:\d+" 20 | from_url_prod = rf"{schema}://{config.host_main}" 21 | to_url_local = f"{schema}://localhost:{config.port}" 22 | filedata = re.sub( 23 | from_url_local, 24 | to_url_local, 25 | filedata, 26 | ) 27 | filedata = re.sub( 28 | from_url_prod, 29 | to_url_local, 30 | filedata, 31 | ) 32 | 33 | with open(file=js_location, mode="w") as file: 34 | file.write(filedata) 35 | 36 | 37 | if __name__ == "__main__": 38 | js_url_initializer() 39 | -------------------------------------------------------------------------------- /app/utils/langchain/chat_llama_cpp.py: -------------------------------------------------------------------------------- 1 | from langchain import LlamaCpp 2 | from pydantic import Field, root_validator 3 | 4 | 5 | class CustomLlamaCpp(LlamaCpp): 6 | low_vram: bool = Field(False, alias="low_vram") 7 | 8 | @root_validator() 9 | def validate_environment(cls, values: dict) -> dict: 10 | """Validate that llama-cpp-python library is installed.""" 11 | model_path = values["model_path"] 12 | model_param_names = [ 13 | "lora_path", 14 | "lora_base", 15 | "n_ctx", 16 | "n_parts", 17 | "seed", 18 | "f16_kv", 19 | "logits_all", 20 | "vocab_only", 21 | "use_mlock", 22 | "n_threads", 23 | "n_batch", 24 | "use_mmap", 25 | "last_n_tokens_size", 26 | ] 27 | model_params = {k: values[k] for k in model_param_names} 28 | # For backwards compatibility, only include if non-null. 29 | if values["n_gpu_layers"] is not None: 30 | model_params["n_gpu_layers"] = values["n_gpu_layers"] 31 | if values["low_vram"] is not None: 32 | model_params["low_vram"] = values["low_vram"] 33 | 34 | try: 35 | from llama_cpp import Llama 36 | 37 | values["client"] = Llama(model_path, **model_params) 38 | except ImportError: 39 | raise ModuleNotFoundError( 40 | "Could not import llama-cpp-python library. " 41 | "Please install the llama-cpp-python library to " 42 | "use this embedding model: pip install llama-cpp-python" 43 | ) 44 | except Exception as e: 45 | raise ValueError( 46 | f"Could not load Llama model from path: {model_path}. " 47 | f"Received error {e}" 48 | ) 49 | 50 | return values 51 | -------------------------------------------------------------------------------- /app/utils/langchain/embeddings_api.py: -------------------------------------------------------------------------------- 1 | """Wrapper around embedding API models.""" 2 | from __future__ import annotations 3 | 4 | import logging 5 | from typing import ( 6 | Any, 7 | List, 8 | Optional, 9 | Tuple, 10 | Union, 11 | ) 12 | 13 | import requests 14 | 15 | from pydantic import BaseModel, Extra 16 | from langchain.embeddings.base import Embeddings 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class APIEmbeddings(BaseModel, Embeddings): 22 | """Wrapper around embedding models from OpenAI-style API""" 23 | 24 | client: Any #: :meta private: 25 | model: str = "intfloat/e5-large-v2" 26 | embedding_api_url: str = "http://localhost:8002/v1/embeddings" 27 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None 28 | """Timeout in seconds for the API request.""" 29 | headers: Any = None 30 | 31 | class Config: 32 | """Configuration for this pydantic object.""" 33 | 34 | extra = Extra.forbid 35 | 36 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 37 | """Call out to embedding endpoint for embedding search docs. 38 | 39 | Args: 40 | texts: The list of texts to embed. 41 | 42 | Returns: 43 | List of embeddings, one for each text. 44 | """ 45 | # NOTE: to keep things simple, we assume the list may contain texts longer 46 | # than the maximum context and use length-safe embedding function. 47 | response = requests.post( 48 | self.embedding_api_url, 49 | json={ 50 | "model": self.model, 51 | "input": texts, 52 | }, 53 | ) 54 | 55 | # Check if the request was successful 56 | if response.status_code == 200: 57 | # Parse the response 58 | response_data = response.json() 59 | 60 | # Extract the embeddings and total tokens 61 | embeddings: list[list[float]] = [ 62 | data["embedding"] for data in response_data["data"] 63 | ] 64 | total_tokens = response_data["usage"]["total_tokens"] 65 | return embeddings 66 | raise ConnectionError( 67 | f"Request to {self.embedding_api_url} failed with status code " 68 | f"{response.status_code}." 69 | ) 70 | 71 | def embed_query(self, text: str) -> List[float]: 72 | """Call out to OpenAI's embedding endpoint for embedding query text. 73 | 74 | Args: 75 | text: The text to embed. 76 | 77 | Returns: 78 | Embedding for the text. 79 | """ 80 | response = requests.post( 81 | self.embedding_api_url, 82 | json={ 83 | "model": self.model, 84 | "input": text, 85 | }, 86 | ) 87 | 88 | # Check if the request was successful 89 | if response.status_code == 200: 90 | # Parse the response 91 | response_data = response.json() 92 | 93 | # Extract the embeddings and total tokens 94 | embeddings: list[list[float]] = [ 95 | data["embedding"] for data in response_data["data"] 96 | ] 97 | total_tokens = response_data["usage"]["total_tokens"] 98 | return embeddings[0] 99 | raise ConnectionError( 100 | f"Request to {self.embedding_api_url} failed with status code " 101 | f"{response.status_code}." 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | import argparse 107 | import json 108 | 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--model", type=str, default="intfloat/e5-large-v2") 111 | parser.add_argument( 112 | "--embedding_api_url", type=str, default="http://localhost:8002/v1/embeddings" 113 | ) 114 | parser.add_argument("--request_timeout", type=float, default=None) 115 | parser.add_argument("--headers", type=str, default=None) 116 | parser.add_argument("--text", type=str, default="Hello, world!") 117 | args = parser.parse_args() 118 | 119 | print(args) 120 | 121 | # Create the API embeddings model 122 | api_embeddings = APIEmbeddings( 123 | client=None, 124 | model=args.model, 125 | embedding_api_url=args.embedding_api_url, 126 | request_timeout=args.request_timeout, 127 | headers=args.headers, 128 | ) 129 | 130 | # Embed the query 131 | query_embedding = api_embeddings.embed_query(args.text) 132 | 133 | # Print the query embedding 134 | print(json.dumps(query_embedding)) 135 | -------------------------------------------------------------------------------- /app/utils/langchain/structured_tool.py: -------------------------------------------------------------------------------- 1 | """Base implementation for tools or skills.""" 2 | from __future__ import annotations 3 | 4 | from inspect import signature, iscoroutinefunction 5 | from typing import Any, Awaitable, Callable, Optional, Type 6 | 7 | from langchain.callbacks.manager import ( 8 | AsyncCallbackManagerForToolRun, 9 | CallbackManagerForToolRun, 10 | ) 11 | from langchain.tools.base import BaseTool, create_schema_from_function 12 | from pydantic import ( 13 | BaseModel, 14 | Field, 15 | ) 16 | 17 | 18 | class StructuredTool(BaseTool): 19 | """Tool that can operate on any number of inputs.""" 20 | 21 | description: str = "" 22 | args_schema: Type[BaseModel] = Field(..., description="The tool schema.") 23 | """The input arguments' schema.""" 24 | func: Optional[Callable[..., Any]] = None 25 | """The function to run when the tool is called.""" 26 | coroutine: Optional[Callable[..., Awaitable[Any]]] = None 27 | """The asynchronous version of the function.""" 28 | 29 | @property 30 | def args(self) -> dict: 31 | """The tool's input arguments.""" 32 | return self.args_schema.schema()["properties"] 33 | 34 | def _run( 35 | self, 36 | *args: Any, 37 | run_manager: Optional[CallbackManagerForToolRun] = None, 38 | **kwargs: Any, 39 | ) -> Any: 40 | """Use the tool.""" 41 | if self.func: 42 | new_argument_supported = signature(self.func).parameters.get("callbacks") 43 | return ( 44 | self.func( 45 | *args, 46 | callbacks=run_manager.get_child() if run_manager else None, 47 | **kwargs, 48 | ) 49 | if new_argument_supported 50 | else self.func(*args, **kwargs) 51 | ) 52 | raise NotImplementedError("Tool does not support sync") 53 | 54 | async def _arun( 55 | self, 56 | *args: Any, 57 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, 58 | **kwargs: Any, 59 | ) -> str: 60 | """Use the tool asynchronously.""" 61 | if self.coroutine: 62 | new_argument_supported = signature(self.coroutine).parameters.get( 63 | "callbacks" 64 | ) 65 | return ( 66 | await self.coroutine( 67 | *args, 68 | callbacks=run_manager.get_child() if run_manager else None, 69 | **kwargs, 70 | ) 71 | if new_argument_supported 72 | else await self.coroutine(*args, **kwargs) 73 | ) 74 | return self._run(*args, **kwargs) 75 | 76 | @classmethod 77 | def from_function( 78 | cls, 79 | func: Callable, 80 | name: Optional[str] = None, 81 | description: Optional[str] = None, 82 | return_direct: bool = False, 83 | args_schema: Optional[Type[BaseModel]] = None, 84 | infer_schema: bool = True, 85 | **kwargs: Any, 86 | ) -> StructuredTool: 87 | name = name or func.__name__ 88 | description = description or func.__doc__ 89 | assert ( 90 | description is not None 91 | ), "Function must have a docstring if description not provided." 92 | 93 | # Description example: 94 | # search_api(query: str) - Searches the API for the query. 95 | description = f"{name}{signature(func)} - {description.strip()}" 96 | _args_schema = args_schema 97 | if _args_schema is None and infer_schema: 98 | _args_schema = create_schema_from_function(f"{name}Schema", func) 99 | is_coroutine_function = iscoroutinefunction(func) 100 | return cls( 101 | name=name, 102 | func=func if not is_coroutine_function else None, 103 | coroutine=func if is_coroutine_function else None, 104 | args_schema=_args_schema, 105 | description=description, 106 | return_direct=return_direct, 107 | **kwargs, 108 | ) 109 | -------------------------------------------------------------------------------- /app/utils/langchain/token_text_splitter.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import ( 3 | AbstractSet, 4 | Any, 5 | Collection, 6 | Iterable, 7 | List, 8 | Literal, 9 | Optional, 10 | Sequence, 11 | Union, 12 | ) 13 | 14 | from langchain.docstore.document import Document 15 | from langchain.text_splitter import Tokenizer, TokenTextSplitter, split_text_on_tokens 16 | 17 | 18 | class CustomTokenTextSplitter(TokenTextSplitter): 19 | """Implementation of splitting text that looks at tokens.""" 20 | 21 | def __init__( 22 | self, 23 | encoding_name: str = "gpt2", 24 | model_name: Optional[str] = None, 25 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), 26 | disallowed_special: Union[Literal["all"], Collection[str]] = "all", 27 | **kwargs: Any, 28 | ): 29 | super().__init__( 30 | encoding_name=encoding_name, 31 | model_name=model_name, 32 | allowed_special=allowed_special, 33 | disallowed_special=disallowed_special, 34 | **kwargs, 35 | ) 36 | 37 | def split_text( 38 | self, 39 | text: str, 40 | tokens_per_chunk: Optional[int] = None, 41 | chunk_overlap: Optional[int] = None, 42 | ) -> List[str]: 43 | def _encode(_text: str) -> List[int]: 44 | return self._tokenizer.encode( 45 | _text, 46 | allowed_special=self._allowed_special, # type: ignore 47 | disallowed_special=self._disallowed_special, 48 | ) 49 | 50 | tokenizer = Tokenizer( 51 | chunk_overlap=self._chunk_overlap 52 | if chunk_overlap is None 53 | else chunk_overlap, 54 | tokens_per_chunk=self._chunk_size 55 | if tokens_per_chunk is None 56 | else tokens_per_chunk, 57 | decode=self._tokenizer.decode, 58 | encode=_encode, 59 | ) 60 | 61 | return split_text_on_tokens(text=text, tokenizer=tokenizer) 62 | 63 | def create_documents( 64 | self, 65 | texts: List[str], 66 | metadatas: Optional[List[dict]] = None, 67 | tokens_per_chunk: Optional[int] = None, 68 | chunk_overlap: Optional[int] = None, 69 | ) -> List[Document]: 70 | """Create documents from a list of texts.""" 71 | _metadatas = metadatas or [{}] * len(texts) 72 | documents = [] 73 | for i, text in enumerate(texts): 74 | index = -1 75 | for chunk in self.split_text( 76 | text, 77 | tokens_per_chunk=tokens_per_chunk, 78 | chunk_overlap=chunk_overlap, 79 | ): 80 | metadata = deepcopy(_metadatas[i]) 81 | if self._add_start_index: 82 | index = text.find(chunk, index + 1) 83 | metadata["start_index"] = index 84 | new_doc = Document(page_content=chunk, metadata=metadata) 85 | documents.append(new_doc) 86 | return documents 87 | 88 | def split_documents( 89 | self, 90 | documents: Iterable[Document], 91 | tokens_per_chunk: Optional[int] = None, 92 | chunk_overlap: Optional[int] = None, 93 | ) -> List[Document]: 94 | """Split documents.""" 95 | texts, metadatas = [], [] 96 | for doc in documents: 97 | texts.append(doc.page_content) 98 | metadatas.append(doc.metadata) 99 | return self.create_documents( 100 | texts, 101 | metadatas=metadatas, 102 | tokens_per_chunk=tokens_per_chunk, 103 | chunk_overlap=chunk_overlap, 104 | ) 105 | 106 | def transform_documents( 107 | self, 108 | documents: Sequence[Document], 109 | tokens_per_chunk: Optional[int] = None, 110 | chunk_overlap: Optional[int] = None, 111 | **kwargs: Any, 112 | ) -> Sequence[Document]: 113 | """Transform sequence of documents by splitting them.""" 114 | return self.split_documents( 115 | list(documents), 116 | tokens_per_chunk=tokens_per_chunk, 117 | chunk_overlap=chunk_overlap, 118 | ) 119 | -------------------------------------------------------------------------------- /app/utils/module_reloader.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from types import ModuleType 3 | 4 | 5 | class ModuleReloader: 6 | def __init__(self, module_name: str): 7 | self._module_name = module_name 8 | 9 | def reload(self) -> ModuleType: 10 | importlib.reload(self.module) 11 | return self.module 12 | 13 | @property 14 | def module(self) -> ModuleType: 15 | return importlib.import_module(self._module_name) 16 | -------------------------------------------------------------------------------- /app/utils/params_utils.py: -------------------------------------------------------------------------------- 1 | from hmac import HMAC, new 2 | from base64 import b64encode 3 | 4 | 5 | def parse_params(params: dict) -> str: 6 | return "&".join([f"{key}={value}" for key, value in params.items()]) 7 | 8 | 9 | def hash_params(query_params: str, secret_key: str) -> str: 10 | mac: HMAC = new( 11 | key=bytes(secret_key, encoding="utf-8"), 12 | msg=bytes(query_params, encoding="utf-8"), 13 | digestmod="sha256", 14 | ) 15 | return str(b64encode(mac.digest()).decode("utf-8")) 16 | -------------------------------------------------------------------------------- /app/utils/tests/random.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | 4 | def random_user_generator(**kwargs): 5 | random_8_digits = str(hash(uuid4()))[:8] 6 | return { 7 | "email": f"{random_8_digits}@test.com", 8 | "password": "123456", 9 | "name": f"{random_8_digits}", 10 | "phone_number": f"010{random_8_digits}", 11 | } | kwargs 12 | -------------------------------------------------------------------------------- /app/utils/tests/timeit.py: -------------------------------------------------------------------------------- 1 | from asyncio import iscoroutinefunction 2 | from functools import wraps 3 | from time import time 4 | from typing import Any, AsyncGenerator, Callable, Generator 5 | 6 | from app.common.config import logging_config 7 | from app.utils.logger import CustomLogger 8 | 9 | logger = CustomLogger( 10 | name=__name__, 11 | logging_config=logging_config, 12 | ) 13 | 14 | 15 | def log_time(fn: Any, elapsed: float) -> None: 16 | formatted_time = f"{int(elapsed // 60):02d}:{int(elapsed % 60):02d}:{int((elapsed - int(elapsed)) * 1000000):06d}" 17 | logger.info(f"Function {fn.__name__} execution time: {formatted_time}") 18 | 19 | 20 | def timeit(fn: Callable) -> Callable: 21 | """ 22 | Decorator to measure execution time of a function, or generator. 23 | Supports both synchronous and asynchronous functions and generators. 24 | :param fn: function to measure execution time of 25 | :return: function wrapper 26 | 27 | Usage: 28 | >>> @timeit 29 | >>> def my_function(): 30 | >>> ... 31 | """ 32 | 33 | @wraps(fn) 34 | async def coroutine_function_wrapper(*args: Any, **kwargs: Any) -> Any: 35 | start_time: float = time() 36 | fn_result = await fn(*args, **kwargs) 37 | elapsed: float = time() - start_time 38 | log_time(fn, elapsed) 39 | return fn_result 40 | 41 | @wraps(fn) 42 | def noncoroutine_function_wrapper(*args: Any, **kwargs: Any) -> Any: 43 | def sync_generator_wrapper() -> Generator: 44 | while True: 45 | try: 46 | start_time: float = time() 47 | item: Any = next(fn_result) 48 | elapsed: float = time() - start_time 49 | log_time(fn_result, elapsed) 50 | yield item 51 | except StopIteration: 52 | break 53 | 54 | async def async_generator_wrapper() -> AsyncGenerator: 55 | while True: 56 | try: 57 | start_time: float = time() 58 | item: Any = await anext(fn_result) 59 | elapsed: float = time() - start_time 60 | log_time(fn_result, elapsed) 61 | yield item 62 | except StopAsyncIteration: 63 | break 64 | 65 | start_time: float = time() 66 | fn_result: Any = fn(*args, **kwargs) 67 | elapsed: float = time() - start_time 68 | if isinstance(fn_result, Generator): 69 | return sync_generator_wrapper() 70 | elif isinstance(fn_result, AsyncGenerator): 71 | return async_generator_wrapper() 72 | log_time(fn, elapsed) 73 | return fn_result 74 | 75 | if iscoroutinefunction(fn): 76 | return coroutine_function_wrapper 77 | return noncoroutine_function_wrapper 78 | -------------------------------------------------------------------------------- /app/viewmodels/admin.py: -------------------------------------------------------------------------------- 1 | from starlette_admin.fields import ( 2 | EnumField, 3 | BooleanField, 4 | EmailField, 5 | PhoneField, 6 | StringField, 7 | NumberField, 8 | DateTimeField, 9 | PasswordField, 10 | ) 11 | from starlette_admin.contrib.sqla.view import ModelView 12 | 13 | from app.database.schemas.auth import Users 14 | 15 | USER_STATUS_TYPES = [ 16 | ("admin", "Admin"), 17 | ("active", "Active"), 18 | ("deleted", "Deleted"), 19 | ("blocked", "Blocked"), 20 | ] 21 | API_KEY_STATUS_TYPES = [ 22 | ("active", "Active"), 23 | ("deleted", "Deleted"), 24 | ("stopped", "Stopped"), 25 | ] 26 | 27 | 28 | 29 | 30 | 31 | class UserAdminView(ModelView): 32 | fields = [ 33 | EnumField("status", choices=USER_STATUS_TYPES, select2=False), 34 | NumberField("id"), 35 | EmailField("email"), 36 | PasswordField("password"), 37 | StringField("name"), 38 | PhoneField("phone_number"), 39 | BooleanField("marketing_agree"), 40 | DateTimeField("created_at"), 41 | DateTimeField("updated_at"), 42 | StringField("ip_address"), 43 | "api_keys", 44 | ] 45 | fields_default_sort = [Users.id, ("id", False)] 46 | 47 | 48 | class ApiKeyAdminView(ModelView): 49 | fields = [ 50 | EnumField("status", choices=API_KEY_STATUS_TYPES, select2=False), 51 | NumberField("id"), 52 | StringField("access_key"), 53 | StringField("user_memo"), 54 | BooleanField("is_whitelisted"), 55 | DateTimeField("created_at"), 56 | DateTimeField("updated_at"), 57 | "users", 58 | "whitelists", 59 | ] 60 | fields_default_sort = [Users.id, ("id", False)] 61 | -------------------------------------------------------------------------------- /app/viewmodels/status.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | class UserStatus(str, enum.Enum): 5 | admin = "admin" 6 | active = "active" 7 | deleted = "deleted" 8 | blocked = "blocked" 9 | 10 | 11 | class ApiKeyStatus(str, enum.Enum): 12 | active = "active" 13 | stopped = "stopped" 14 | deleted = "deleted" 15 | -------------------------------------------------------------------------------- /app/web/.last_build_id: -------------------------------------------------------------------------------- 1 | 6ab7adef156d180eb3fbf8ed0392faf8 -------------------------------------------------------------------------------- /app/web/assets/AssetManifest.bin: -------------------------------------------------------------------------------- 1 | assets/images/ai_profile.png  assetassets/images/ai_profile.pngassets/images/favicon.ico  assetassets/images/favicon.ico assets/images/openai_profile.png  asset assets/images/openai_profile.pngassets/images/site_logo.png  assetassets/images/site_logo.pngassets/images/user_profile.png  assetassets/images/user_profile.png assets/images/vicuna_profile.jpg  asset assets/images/vicuna_profile.jpgassets/lotties/click.json  assetassets/lotties/click.jsonassets/lotties/fail.json  assetassets/lotties/fail.jsonassets/lotties/file-upload.json  assetassets/lotties/file-upload.jsonassets/lotties/go-back.json  assetassets/lotties/go-back.jsonassets/lotties/ok.json  assetassets/lotties/ok.jsonassets/lotties/read.json  assetassets/lotties/read.jsonassets/lotties/scroll-down.json  assetassets/lotties/scroll-down.jsonassets/lotties/search-doc.json  assetassets/lotties/search-doc.jsonassets/lotties/search-web.json  assetassets/lotties/search-web.jsonassets/lotties/translate.json  assetassets/lotties/translate.jsonassets/svgs/search-web.svg  assetassets/svgs/search-web.svg2packages/cupertino_icons/assets/CupertinoIcons.ttf  asset2packages/cupertino_icons/assets/CupertinoIcons.ttf -------------------------------------------------------------------------------- /app/web/assets/AssetManifest.json: -------------------------------------------------------------------------------- 1 | {"assets/images/ai_profile.png":["assets/images/ai_profile.png"],"assets/images/favicon.ico":["assets/images/favicon.ico"],"assets/images/openai_profile.png":["assets/images/openai_profile.png"],"assets/images/site_logo.png":["assets/images/site_logo.png"],"assets/images/user_profile.png":["assets/images/user_profile.png"],"assets/images/vicuna_profile.jpg":["assets/images/vicuna_profile.jpg"],"assets/lotties/click.json":["assets/lotties/click.json"],"assets/lotties/fail.json":["assets/lotties/fail.json"],"assets/lotties/file-upload.json":["assets/lotties/file-upload.json"],"assets/lotties/go-back.json":["assets/lotties/go-back.json"],"assets/lotties/ok.json":["assets/lotties/ok.json"],"assets/lotties/read.json":["assets/lotties/read.json"],"assets/lotties/scroll-down.json":["assets/lotties/scroll-down.json"],"assets/lotties/search-doc.json":["assets/lotties/search-doc.json"],"assets/lotties/search-web.json":["assets/lotties/search-web.json"],"assets/lotties/translate.json":["assets/lotties/translate.json"],"assets/svgs/search-web.svg":["assets/svgs/search-web.svg"],"packages/cupertino_icons/assets/CupertinoIcons.ttf":["packages/cupertino_icons/assets/CupertinoIcons.ttf"]} -------------------------------------------------------------------------------- /app/web/assets/FontManifest.json: -------------------------------------------------------------------------------- 1 | [{"family":"MaterialIcons","fonts":[{"asset":"fonts/MaterialIcons-Regular.otf"}]},{"family":"packages/cupertino_icons/CupertinoIcons","fonts":[{"asset":"packages/cupertino_icons/assets/CupertinoIcons.ttf"}]}] -------------------------------------------------------------------------------- /app/web/assets/assets/images/ai_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/ai_profile.png -------------------------------------------------------------------------------- /app/web/assets/assets/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/favicon.ico -------------------------------------------------------------------------------- /app/web/assets/assets/images/openai_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/openai_profile.png -------------------------------------------------------------------------------- /app/web/assets/assets/images/site_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/site_logo.png -------------------------------------------------------------------------------- /app/web/assets/assets/images/user_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/user_profile.png -------------------------------------------------------------------------------- /app/web/assets/assets/images/vicuna_profile.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/vicuna_profile.jpg -------------------------------------------------------------------------------- /app/web/assets/assets/lotties/scroll-down.json: -------------------------------------------------------------------------------- 1 | {"v":"5.5.2","fr":120,"ip":0,"op":120,"w":192,"h":192,"nm":"Comp 2","ddd":0,"assets":[],"layers":[{"ddd":0,"ind":1,"ty":4,"nm":"Arrow-Down Outlines","sr":1,"ks":{"o":{"a":0,"k":100,"ix":11},"r":{"a":0,"k":0,"ix":10},"p":{"a":1,"k":[{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":0,"s":[96,91,0],"to":[0,2.833,0],"ti":[0,3.5,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":38,"s":[96,108,0],"to":[0,-3.5,0],"ti":[0,1.333,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":60,"s":[96,70,0],"to":[0,-1.333,0],"ti":[0,-3.5,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.167,"y":0},"t":77,"s":[96,100,0],"to":[0,3.5,0],"ti":[0,1.5,0]},{"t":88,"s":[96,91,0]}],"ix":2},"a":{"a":0,"k":[12,12,0],"ix":1},"s":{"a":0,"k":[825,825,100],"ix":6}},"ao":0,"shapes":[{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[0,0],[0,0]],"o":[[0,0],[0,0]],"v":[[12,18],[12,6]],"c":false},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"tm","s":{"a":0,"k":0,"ix":1},"e":{"a":1,"k":[{"i":{"x":[0.667],"y":[1]},"o":{"x":[0.333],"y":[0]},"t":0,"s":[100]},{"i":{"x":[0.667],"y":[1]},"o":{"x":[0.333],"y":[0]},"t":38,"s":[32]},{"t":60,"s":[100]}],"ix":2},"o":{"a":0,"k":0,"ix":3},"m":1,"ix":2,"nm":"Trim Paths 1","mn":"ADBE Vector Filter - Trim","hd":false},{"ty":"st","c":{"a":0,"k":[0.6392156862745098,0.9333333333333333,0.9137254901960784,1],"ix":3},"o":{"a":0,"k":100,"ix":4},"w":{"a":0,"k":2,"ix":5},"lc":2,"lj":2,"bm":0,"nm":"Stroke 1","mn":"ADBE Vector Graphic - Stroke","hd":false},{"ty":"tr","p":{"a":1,"k":[{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":38,"s":[0,0],"to":[0,-0.167],"ti":[0,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":60,"s":[0,-1],"to":[0,0],"ti":[0,-0.167]},{"t":77,"s":[0,0]}],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[100,100],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 1","np":3,"cix":2,"bm":0,"ix":1,"mn":"ADBE Vector Group","hd":false},{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":1,"k":[{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":0,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-5,-2.5],[0,2.5],[5,-2.5]],"c":false}]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":38,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-2.612,-1.742],[-0.064,2.431],[2.484,-1.742]],"c":false}]},{"i":{"x":0.667,"y":1},"o":{"x":0.167,"y":0},"t":60,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-5.966,-2.5],[-0.05,2.5],[5.865,-2.5]],"c":false}]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":77,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-4.103,-2.5],[-0.016,2.5],[4.071,-2.5]],"c":false}]},{"t":88,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-5,-2.5],[0,2.5],[5,-2.5]],"c":false}]}],"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"st","c":{"a":0,"k":[0.6392156862745098,0.9333333333333333,0.9137254901960784,1],"ix":3},"o":{"a":0,"k":100,"ix":4},"w":{"a":0,"k":2,"ix":5},"lc":2,"lj":2,"bm":0,"nm":"Stroke 1","mn":"ADBE Vector Graphic - Stroke","hd":false},{"ty":"tr","p":{"a":0,"k":[12,16.5],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[87.834,87.834],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 2","np":2,"cix":2,"bm":0,"ix":2,"mn":"ADBE Vector Group","hd":false}],"ip":0,"op":1215,"st":0,"bm":0}],"markers":[]} -------------------------------------------------------------------------------- /app/web/assets/assets/svgs/search-web.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 14 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /app/web/assets/fonts/MaterialIcons-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/fonts/MaterialIcons-Regular.otf -------------------------------------------------------------------------------- /app/web/assets/packages/cupertino_icons/assets/CupertinoIcons.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/packages/cupertino_icons/assets/CupertinoIcons.ttf -------------------------------------------------------------------------------- /app/web/canvaskit/canvaskit.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/canvaskit/canvaskit.wasm -------------------------------------------------------------------------------- /app/web/canvaskit/chromium/canvaskit.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/canvaskit/chromium/canvaskit.wasm -------------------------------------------------------------------------------- /app/web/canvaskit/skwasm.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/canvaskit/skwasm.wasm -------------------------------------------------------------------------------- /app/web/canvaskit/skwasm.worker.js: -------------------------------------------------------------------------------- 1 | "use strict";var Module={};if(typeof process==="object"&&typeof process.versions==="object"&&typeof process.versions.node==="string"){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",function(data){onmessage({data:data})});var nodeFS=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(nodeFS.readFileSync(f,"utf8"))},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=((info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports});self.onmessage=(e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob==="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}skwasm(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.threadInfoStruct,/*isMainBrowserThread=*/0,/*isMainRuntimeThread=*/0,/*canBlock=*/1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInit();try{var result=Module["invokeEntryPoint"](e.data.start_routine,e.data.arg);if(Module["keepRuntimeAlive"]()){Module["PThread"].setExitStatus(result)}else{Module["__emscripten_thread_exit"](result)}}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processThreadQueue"){if(Module["_pthread_self"]()){Module["_emscripten_current_thread_process_queued_calls"]()}}else{err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){err("worker.js onmessage() captured an uncaught exception: "+ex);if(ex&&ex.stack)err(ex.stack);throw ex}}); 2 | -------------------------------------------------------------------------------- /app/web/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/favicon.ico -------------------------------------------------------------------------------- /app/web/icons/Icon-192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-192.png -------------------------------------------------------------------------------- /app/web/icons/Icon-512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-512.png -------------------------------------------------------------------------------- /app/web/icons/Icon-maskable-192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-maskable-192.png -------------------------------------------------------------------------------- /app/web/icons/Icon-maskable-512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-maskable-512.png -------------------------------------------------------------------------------- /app/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | LLMChat 21 | 22 | 23 | 27 | 28 | 29 | 30 | 31 | 32 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /app/web/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "LLMChat", 3 | "short_name": "LLMChat", 4 | "start_url": ".", 5 | "display": "standalone", 6 | "background_color": "#0175C2", 7 | "theme_color": "#0175C2", 8 | "description": "Seamless chat experience", 9 | "orientation": "portrait-primary", 10 | "prefer_related_applications": false, 11 | "icons": [ 12 | { 13 | "src": "icons/Icon-192.png", 14 | "sizes": "192x192", 15 | "type": "image/png" 16 | }, 17 | { 18 | "src": "icons/Icon-512.png", 19 | "sizes": "512x512", 20 | "type": "image/png" 21 | }, 22 | { 23 | "src": "icons/Icon-maskable-192.png", 24 | "sizes": "192x192", 25 | "type": "image/png", 26 | "purpose": "maskable" 27 | }, 28 | { 29 | "src": "icons/Icon-maskable-512.png", 30 | "sizes": "512x512", 31 | "type": "image/png", 32 | "purpose": "maskable" 33 | } 34 | ] 35 | } -------------------------------------------------------------------------------- /app/web/site_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/site_logo.png -------------------------------------------------------------------------------- /app/web/version.json: -------------------------------------------------------------------------------- 1 | {"app_name":"flutter_web","version":"1.0.0","build_number":"1","package_name":"flutter_web"} -------------------------------------------------------------------------------- /build-web.bat: -------------------------------------------------------------------------------- 1 | cd frontend 2 | call flutter build web --base-href /chat/ --web-renderer canvaskit 3 | 4 | cd .. 5 | IF EXIST "app\web" ( 6 | rmdir /s /q "app\web" 7 | ) 8 | IF NOT EXIST "app\web" ( 9 | mkdir "app\web" 10 | ) 11 | xcopy /s /y "frontend\build\web\*" "app\web\" 12 | -------------------------------------------------------------------------------- /docker-compose-local.yaml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | 3 | volumes: 4 | mysql: 5 | redis: 6 | qdrant: 7 | 8 | 9 | networks: 10 | reverse-proxy-public: 11 | driver: bridge 12 | ipam: 13 | driver: default 14 | config: 15 | - subnet: 172.16.0.0/24 # subnet for traefik and other services 16 | 17 | services: 18 | # No need to use traefik for local development 19 | 20 | db: 21 | image: mysql:latest 22 | restart: always 23 | ports: 24 | - "3306:3306" 25 | environment: 26 | MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD}" 27 | MYSQL_ROOT_HOST: "%" 28 | MYSQL_DATABASE: "${MYSQL_DATABASE}" 29 | MYSQL_USER: "${MYSQL_USER}" 30 | MYSQL_PASSWORD: "${MYSQL_PASSWORD}" 31 | TZ: "Asia/Seoul" 32 | volumes: 33 | - mysql:/var/lib/mysql 34 | networks: 35 | reverse-proxy-public: 36 | ipv4_address: 172.16.0.11 # static IP 37 | 38 | cache: 39 | image: redis/redis-stack-server:latest 40 | restart: always 41 | environment: 42 | - REDIS_ARGS=--requirepass ${REDIS_PASSWORD} --maxmemory 100mb --maxmemory-policy allkeys-lru --appendonly yes 43 | ports: 44 | - "6379:6379" 45 | volumes: 46 | - redis:/data 47 | networks: 48 | reverse-proxy-public: 49 | ipv4_address: 172.16.0.12 # static IP 50 | 51 | vectorstore: 52 | image: qdrant/qdrant 53 | restart: always 54 | ports: 55 | - "6333:6333" 56 | - "6334:6334" 57 | volumes: 58 | - qdrant:/qdrant/storage 59 | networks: 60 | reverse-proxy-public: 61 | ipv4_address: 172.16.0.13 # static IP 62 | 63 | api: 64 | image: cosogi/llmchat:230703 65 | restart: always 66 | env_file: 67 | - .env 68 | command: 69 | - "--host" 70 | - "0.0.0.0" 71 | - "--port" 72 | - "8000" 73 | ports: 74 | - "8000:8000" 75 | depends_on: 76 | - db 77 | - cache 78 | volumes: 79 | - .:/app 80 | networks: 81 | reverse-proxy-public: 82 | ipv4_address: 172.16.0.14 # static IP 83 | -------------------------------------------------------------------------------- /docker-compose-prod.yaml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | 3 | volumes: 4 | mysql: 5 | redis: 6 | qdrant: 7 | 8 | 9 | networks: 10 | reverse-proxy-public: 11 | driver: bridge 12 | ipam: 13 | driver: default 14 | 15 | services: 16 | proxy: 17 | image: traefik 18 | command: 19 | - "--entrypoints.web.address=:80" 20 | - "--entrypoints.websecure.address=:443" 21 | - "--entrypoints.mysql.address=:3306" 22 | - "--entrypoints.redis.address=:6379" 23 | - "--entryPoints.web.http.redirections.entryPoint.to=websecure" 24 | - "--entryPoints.web.http.redirections.entryPoint.scheme=https" 25 | - "--providers.docker" 26 | - "--providers.docker.exposedbydefault=false" 27 | - "--api.insecure=false" 28 | - "--certificatesresolvers.myresolver.acme.tlschallenge=true" 29 | - "--certificatesresolvers.myresolver.acme.email=${MY_EMAIL}" 30 | - "--certificatesresolvers.myresolver.acme.storage=/letsencrypt/acme.json" 31 | # - "--log.level=DEBUG" 32 | # - "--certificatesresolvers.myresolver.acme.caserver=https://acme-staging-v02.api.letsencrypt.org/directory" 33 | ports: 34 | - "80:80" 35 | - "443:443" 36 | - "3306:3306" 37 | - "6379:6379" 38 | volumes: 39 | - /var/run/docker.sock:/var/run/docker.sock 40 | - ./letsencrypt:/letsencrypt 41 | networks: 42 | - reverse-proxy-public 43 | 44 | db: 45 | image: mysql 46 | restart: always 47 | environment: 48 | MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD}" 49 | MYSQL_ROOT_HOST: "%" 50 | MYSQL_DATABASE: "${MYSQL_DATABASE}" 51 | MYSQL_USER: "${MYSQL_USER}" 52 | MYSQL_PASSWORD: "${MYSQL_PASSWORD}" 53 | TZ: "Asia/Seoul" 54 | volumes: 55 | - mysql:/var/lib/mysql 56 | # - ./my.cnf:/etc/mysql/conf.d/my.cnf 57 | labels: 58 | - "traefik.enable=true" 59 | - "traefik.tcp.routers.db.rule=HostSNI(`*`)" 60 | - "traefik.tcp.services.db.loadbalancer.server.port=3306" 61 | - "traefik.tcp.routers.db.entrypoints=mysql" 62 | networks: 63 | - reverse-proxy-public 64 | 65 | cache: 66 | image: redis/redis-stack-server:latest 67 | restart: always 68 | environment: 69 | - REDIS_ARGS=--requirepass ${REDIS_PASSWORD} --maxmemory 100mb --maxmemory-policy allkeys-lru --appendonly yes 70 | volumes: 71 | - redis:/data 72 | labels: 73 | - "traefik.enable=true" 74 | - "traefik.tcp.routers.cache.rule=HostSNI(`*`)" 75 | - "traefik.tcp.services.cache.loadbalancer.server.port=6379" 76 | - "traefik.tcp.routers.cache.entrypoints=redis" 77 | networks: 78 | - reverse-proxy-public 79 | 80 | api: 81 | image: cosogi/llmchat:230703 82 | restart: always 83 | env_file: 84 | - .env 85 | command: 86 | - "--host" 87 | - "0.0.0.0" 88 | - "--port" 89 | - "8000" 90 | labels: 91 | - "traefik.enable=true" 92 | - "traefik.docker.network=reverse-proxy-public" 93 | - "traefik.http.routers.api.rule=HostRegexp(`${HOST_MAIN}`, `{subdomain:[a-z]+}.${HOST_MAIN}`, `${HOST_IP}`)" 94 | - "traefik.http.routers.api.entrypoints=websecure" 95 | - "traefik.http.services.api.loadbalancer.server.scheme=http" 96 | - "traefik.http.services.api.loadbalancer.server.port=8000" 97 | - "traefik.http.routers.api.tls=true" 98 | - "traefik.http.routers.api.tls.certresolver=myresolver" 99 | - "traefik.http.routers.api.tls.domains[0].main=${HOST_MAIN}" 100 | - "traefik.http.routers.api.tls.domains[0].sans=${HOST_SUB}" 101 | depends_on: 102 | - proxy 103 | - db 104 | - cache 105 | - vectorstore 106 | volumes: 107 | - .:/app 108 | networks: 109 | - reverse-proxy-public 110 | 111 | vectorstore: 112 | image: qdrant/qdrant:latest 113 | restart: always 114 | volumes: 115 | - qdrant:/qdrant/storage 116 | networks: 117 | - reverse-proxy-public 118 | # search-api: 119 | # image: searxng/searxng:latest 120 | # restart: always 121 | 122 | # environment: 123 | # REDIS_PASSWORD: "${REDIS_PASSWORD}" 124 | # JWT_SECRET: "${JWT_SECRET}" 125 | # volumes: 126 | # - ./searxng:/etc/searxng 127 | # networks: 128 | # - api-private 129 | -------------------------------------------------------------------------------- /dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi 2 | RUN apt-get update -y 3 | RUN apt-get install poppler-utils -y 4 | RUN apt-get install tesseract-ocr -y 5 | COPY requirements.txt ./ 6 | RUN pip install --trusted-host pypi.python.org --no-cache-dir --upgrade -r requirements.txt 7 | ENTRYPOINT ["uvicorn", "main:app"] -------------------------------------------------------------------------------- /init-git.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | git init 3 | # git config core.sparseCheckout true 4 | git remote add -f origin git@github.com:c0sogi/LLMChat.git 5 | # echo "app" >> .git/info/sparse-checkout -------------------------------------------------------------------------------- /llama_models/ggml/llama_cpp_models_here.txt: -------------------------------------------------------------------------------- 1 | The LLama.cpp GGML model must be put here as a file. 2 | 3 | For example, if you downloaded a q4_0 quantized model from "https://huggingface.co/TheBloke/robin-7B-v2-GGML", 4 | The path of the model has to be "robin-7b.ggmlv3.q4_0.bin". -------------------------------------------------------------------------------- /llama_models/gptq/exllama_models_here.txt: -------------------------------------------------------------------------------- 1 | The Exllama GPTQ model must be put here as a folder. 2 | 3 | For example, if you downloaded 3 files from "https://huggingface.co/TheBloke/orca_mini_7B-GPTQ/tree/main": 4 | 5 | - orca-mini-7b-GPTQ-4bit-128g.no-act.order.safetensors 6 | - tokenizer.model 7 | - config.json 8 | 9 | Then you need to put them in a folder. 10 | The path of the model has to be the folder name. Let's say, "orca_mini_7b", which contains the 3 files. 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | from maintools import initialize_before_launch 3 | 4 | 5 | if __name__ == "__mp_main__": 6 | """Option 1: Skip section for multiprocess spawning 7 | This section will be skipped when running in multiprocessing mode""" 8 | pass 9 | 10 | elif __name__ == "__main__": 11 | # 12 | """Option 2: Debug mode 13 | Running this file directly to debug the app 14 | Run this section if you don't want to run app in docker""" 15 | 16 | if environ.get("API_ENV") != "test": 17 | environ["API_ENV"] = "local" 18 | environ["DOCKER_MODE"] = "False" 19 | 20 | import uvicorn 21 | 22 | from app.common.app_settings import create_app 23 | from app.common.config import config 24 | 25 | initialize_before_launch() 26 | app = create_app(config=config) 27 | uvicorn.run( 28 | app, 29 | host="0.0.0.0", 30 | port=config.port, 31 | ) 32 | else: 33 | """Option 3: Non-debug mode 34 | Docker will run this section as the main entrypoint 35 | This section will mostly be used. 36 | Maybe LLaMa won't work in Docker.""" 37 | from app.common.app_settings import create_app 38 | from app.common.config import config 39 | 40 | initialize_before_launch() 41 | app = create_app(config=config) 42 | -------------------------------------------------------------------------------- /maintools.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | def ensure_packages_installed(): 5 | import subprocess 6 | 7 | subprocess.call( 8 | [ 9 | "pip", 10 | "install", 11 | "--trusted-host", 12 | "pypi.python.org", 13 | "-r", 14 | "requirements.txt", 15 | ] 16 | ) 17 | 18 | 19 | def set_priority(pid: Optional[int] = None, priority: str = "high"): 20 | import platform 21 | from os import getpid 22 | import psutil 23 | 24 | """Set The Priority of a Process. Priority is a string which can be 'low', 'below_normal', 25 | 'normal', 'above_normal', 'high', 'realtime'. 'normal' is the default priority.""" 26 | 27 | if platform.system() == "Windows": 28 | priorities = { 29 | "low": psutil.IDLE_PRIORITY_CLASS, 30 | "below_normal": psutil.BELOW_NORMAL_PRIORITY_CLASS, 31 | "normal": psutil.NORMAL_PRIORITY_CLASS, 32 | "above_normal": psutil.ABOVE_NORMAL_PRIORITY_CLASS, 33 | "high": psutil.HIGH_PRIORITY_CLASS, 34 | "realtime": psutil.REALTIME_PRIORITY_CLASS, 35 | } 36 | else: # Linux and other Unix systems 37 | priorities = { 38 | "low": 19, 39 | "below_normal": 10, 40 | "normal": 0, 41 | "above_normal": -5, 42 | "high": -11, 43 | "realtime": -20, 44 | } 45 | 46 | if pid is None: 47 | pid = getpid() 48 | p = psutil.Process(pid) 49 | p.nice(priorities[priority]) 50 | 51 | 52 | def initialize_before_launch(install_packages: bool = False): 53 | """Initialize the app""" 54 | import platform 55 | 56 | if platform.system() == "Windows": 57 | set_priority(priority="high") 58 | else: 59 | set_priority(priority="normal") 60 | if install_packages: 61 | ensure_packages_installed() 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | select = ["RUF100", "E501"] 3 | line-length = 120 4 | -------------------------------------------------------------------------------- /qdrant/.lock: -------------------------------------------------------------------------------- 1 | tmp lock file -------------------------------------------------------------------------------- /qdrant/meta.json: -------------------------------------------------------------------------------- 1 | {"collections": {}, "aliases": {}} -------------------------------------------------------------------------------- /requirements-exllama.txt: -------------------------------------------------------------------------------- 1 | safetensors==0.3.1 2 | sentencepiece>=0.1.97 3 | ninja==1.11.1 4 | --find-links https://download.pytorch.org/whl/torch_stable.html 5 | torch==2.0.1+cu118 -------------------------------------------------------------------------------- /requirements-llama-cpp.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24 2 | sentencepiece==0.1.98 -------------------------------------------------------------------------------- /requirements-sentence-encoder.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | tensorflow>=2.0.0 4 | tensorflow-hub -------------------------------------------------------------------------------- /requirements-transformer-embedding-cpu.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==2.0.1 3 | -------------------------------------------------------------------------------- /requirements-transformer-embedding-gpu.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==2.0.1+cu118 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles 2 | aiomysql 3 | attrs 4 | bcrypt 5 | boto3 6 | botocore 7 | cachetools 8 | certifi 9 | cffi 10 | chardet 11 | click 12 | cssselect 13 | cssutils 14 | dnspython 15 | email-validator 16 | fastapi 17 | h11 18 | httptools 19 | httpx 20 | idna 21 | itsdangerous 22 | iniconfig 23 | jinja2 24 | jmespath 25 | langchain 26 | llama-cpp-python[server] 27 | lxml 28 | orjson 29 | openai 30 | packaging 31 | pdf2image 32 | pexpect 33 | pluggy 34 | premailer 35 | psutil 36 | py 37 | pycparser 38 | pydantic 39 | PyJWT 40 | PyMySQL 41 | pyparsing 42 | pytesseract 43 | pytest 44 | pytest-asyncio 45 | python-dateutil 46 | python-dotenv 47 | qdrant_client 48 | redis 49 | requests 50 | requests_html 51 | s3transfer 52 | sentencepiece 53 | six 54 | sqlalchemy 55 | sqlalchemy-utils 56 | starlette 57 | starlette_admin 58 | starlette-session 59 | tabulate 60 | tiktoken 61 | toml 62 | transformers 63 | unstructured 64 | urllib3 65 | uvicorn[standard] 66 | websockets 67 | yagmail 68 | cryptography -------------------------------------------------------------------------------- /run_local.bat: -------------------------------------------------------------------------------- 1 | @REM Virtual environment activation. You must have virtualenv installed. 2 | @REM If you don't have it, delete the next line and just run the last one. 3 | @REM Of course, you must have requirements.txt installed. 4 | @REM use "pip install -r requirements.txt" to install it. 5 | call .venv/Scripts/activate.bat 6 | python -m main -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/tests/__init__.py -------------------------------------------------------------------------------- /tests/__pycache__/test_completion_api.cpython-311-pytest-7.4.0.pyc.8996: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/tests/__pycache__/test_completion_api.cpython-311-pytest-7.4.0.pyc.8996 -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | from typing import AsyncGenerator, Generator 4 | from fastapi import FastAPI 5 | from fastapi.testclient import TestClient 6 | import pytest 7 | import httpx 8 | import pytest_asyncio 9 | from app.utils.auth.token import create_access_token 10 | from app.utils.tests.random import random_user_generator 11 | from app.database.schemas.auth import Users 12 | from app.common.app_settings import create_app 13 | from app.common.config import Config, LoggingConfig 14 | from app.utils.logger import CustomLogger 15 | from app.models.base_models import UserToken 16 | from app.utils.chat.managers.cache import CacheManager, cache 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def config(): 21 | return Config.get(option="test") 22 | 23 | 24 | @pytest.fixture(scope="session") 25 | def cache_manager(): 26 | cache.start(config=Config.get(option="test")) 27 | return CacheManager 28 | 29 | 30 | @pytest.fixture(scope="session") 31 | def test_logger(): 32 | return CustomLogger( 33 | name="PyTest", logging_config=LoggingConfig(file_log_name="./logs/test.log") 34 | ) 35 | 36 | 37 | @pytest.fixture(scope="session", autouse=True) 38 | def event_loop(): 39 | loop = asyncio.get_event_loop() 40 | yield loop 41 | loop.close() 42 | 43 | 44 | @pytest.fixture(scope="session") 45 | def app(config) -> FastAPI: 46 | return create_app(config) 47 | 48 | 49 | @pytest.fixture(scope="session") 50 | def base_http_url() -> str: 51 | return "http://localhost" 52 | 53 | 54 | @pytest.fixture(scope="session") 55 | def base_websocket_url() -> str: 56 | return "ws://localhost" 57 | 58 | 59 | @pytest_asyncio.fixture(scope="session") 60 | async def async_client( 61 | app: FastAPI, base_http_url: str 62 | ) -> AsyncGenerator[httpx.AsyncClient, None]: 63 | async with httpx.AsyncClient(app=app, base_url=base_http_url) as ac: 64 | yield ac 65 | 66 | 67 | @pytest.fixture(scope="session") 68 | def client(app: FastAPI) -> Generator[TestClient, None, None]: 69 | with TestClient(app=app) as tc: 70 | yield tc 71 | 72 | 73 | @pytest_asyncio.fixture(scope="session") 74 | async def login_header(random_user: dict[str, str]) -> dict[str, str]: 75 | """ 76 | 테스트 전 사용자 미리 등록 77 | """ 78 | new_user = await Users.add_one(autocommit=True, refresh=True, **random_user) # type: ignore 79 | access_token = create_access_token( 80 | data=UserToken.from_orm(new_user).dict(exclude={"password", "marketing_agree"}), 81 | expires_delta=24, 82 | ) 83 | return {"Authorization": f"Bearer {access_token}"} 84 | 85 | 86 | @pytest_asyncio.fixture(scope="session") 87 | async def api_key_dict(async_client: httpx.AsyncClient, login_header: dict[str, str]): 88 | api_key_memo: str = f"TESTING : {str(datetime.now())}" 89 | response: httpx.Response = await async_client.post( 90 | "/api/user/apikeys", 91 | json={"user_memo": api_key_memo}, 92 | headers=login_header, 93 | ) 94 | response_body = response.json() 95 | assert response.status_code == 201 96 | assert "access_key" in response_body 97 | assert "secret_key" in response_body 98 | apikey = { 99 | "access_key": response_body["access_key"], 100 | "secret_key": response_body["secret_key"], 101 | } 102 | 103 | response = await async_client.get("/api/user/apikeys", headers=login_header) 104 | response_body = response.json() 105 | assert response.status_code == 200 106 | assert api_key_memo in response_body[0]["user_memo"] 107 | return apikey 108 | 109 | 110 | @pytest.fixture(scope="session") 111 | def random_user(): 112 | return random_user_generator() 113 | 114 | 115 | def pytest_collection_modifyitems(items): 116 | app_tests = [] 117 | redis_tests = [] 118 | other_tests = [] 119 | for item in items: 120 | if "client" in item.fixturenames: 121 | app_tests.append(item) 122 | elif item.get_closest_marker("redistest") is not None: 123 | redis_tests.append(item) 124 | else: 125 | other_tests.append(item) 126 | items[:] = redis_tests + app_tests + other_tests 127 | -------------------------------------------------------------------------------- /tests/test_api_service.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | from httpx import AsyncClient, Response 3 | import pytest 4 | from app.utils.date_utils import UTC 5 | from app.utils.params_utils import hash_params, parse_params 6 | 7 | 8 | async def request_service( 9 | async_client: AsyncClient, 10 | access_key: str, 11 | secret_key: str, 12 | request_method: Literal["get", "post", "put", "delete", "options"], 13 | allowed_status_codes: tuple = (200, 201, 307), 14 | service_name: str = "", 15 | required_query_params: dict[str, Any] = {}, 16 | required_headers: dict[str, Any] = {}, 17 | stream: bool = False, 18 | logger=None, 19 | ) -> Any: 20 | all_query_params: str = parse_params( 21 | params={ 22 | "key": access_key, 23 | "timestamp": UTC.timestamp(), 24 | } 25 | | required_query_params 26 | ) 27 | method_options: dict = { 28 | "headers": { 29 | "secret": hash_params( 30 | query_params=all_query_params, 31 | secret_key=secret_key, 32 | ) 33 | } 34 | | required_headers 35 | } 36 | url: str = f"/api/services/{service_name}?{all_query_params}" 37 | 38 | if stream: 39 | response_body = "" 40 | async with async_client.stream(method=request_method.upper(), url=url, **method_options) as response: 41 | assert response.status_code in allowed_status_codes 42 | async for chunk in response.aiter_text(): 43 | response_body += chunk 44 | logger.info(f"Streamed data: {chunk}") if logger is not None else ... 45 | else: 46 | response: Response = await getattr(async_client, request_method.lower())(url=url, **method_options) 47 | response_body: Any = response.json() 48 | logger.info(f"response_body: {response_body}") if logger is not None else ... 49 | assert response.status_code in allowed_status_codes 50 | return response_body 51 | 52 | 53 | @pytest.mark.asyncio 54 | async def test_request_api(async_client: AsyncClient, api_key_dict: dict, test_logger): 55 | access_key, secret_key = api_key_dict["access_key"], api_key_dict["secret_key"] 56 | service_name: str = "" 57 | request_method: str = "get" 58 | required_query_params: dict[str, Any] = {} 59 | required_headers: dict[str, Any] = {} 60 | allowed_status_codes: tuple = (200, 201, 307) 61 | await request_service( 62 | async_client=async_client, 63 | access_key=access_key, 64 | secret_key=secret_key, 65 | request_method=request_method, 66 | allowed_status_codes=allowed_status_codes, 67 | service_name=service_name, 68 | required_query_params=required_query_params, 69 | required_headers=required_headers, 70 | stream=False, 71 | logger=test_logger, 72 | ) 73 | 74 | 75 | # @pytest.mark.asyncio 76 | # async def test_weather_api(async_client: AsyncClient, api_key_dict: dict): 77 | # access_key, secret_key = api_key_dict["access_key"], api_key_dict["secret_key"] 78 | # service_name: str = "weather" 79 | # request_method: str = "get" 80 | # required_query_params: dict[str, any] = { 81 | # "latitude": 37.0, 82 | # "longitude": 120.0, 83 | # } 84 | # required_headers: dict[str, any] = {} 85 | # allowed_status_codes: tuple[int] = (200, 307) 86 | # await request_service( 87 | # async_client=async_client, 88 | # access_key=access_key, 89 | # secret_key=secret_key, 90 | # request_method=request_method, 91 | # allowed_status_codes=allowed_status_codes, 92 | # service_name=service_name, 93 | # required_query_params=required_query_params, 94 | # required_headers=required_headers, 95 | # stream=False, 96 | # ) 97 | 98 | 99 | # @pytest.mark.asyncio 100 | # async def test_stream_api(async_client: AsyncClient, api_key_dict: dict): 101 | # access_key, secret_key = api_key_dict["access_key"], api_key_dict["secret_key"] 102 | # service_name: str = "stream" 103 | # request_method: str = "get" 104 | # required_query_params: dict[str, any] = {} 105 | # required_headers: dict[str, any] = {} 106 | # allowed_status_codes: tuple[int] = (200, 307) 107 | # await request_service( 108 | # async_client=async_client, 109 | # access_key=access_key, 110 | # secret_key=secret_key, 111 | # request_method=request_method, 112 | # allowed_status_codes=allowed_status_codes, 113 | # service_name=service_name, 114 | # required_query_params=required_query_params, 115 | # required_headers=required_headers, 116 | # stream=True, 117 | # ) 118 | -------------------------------------------------------------------------------- /tests/test_auth.py: -------------------------------------------------------------------------------- 1 | from httpx import AsyncClient 2 | import pytest 3 | from fastapi import status 4 | from app.errors.api_exceptions import Responses_400 5 | from app.utils.tests.random import random_user_generator 6 | 7 | 8 | @pytest.mark.asyncio 9 | async def test_register(async_client): 10 | valid_emails = ("ankitrai326@gmail.com", "my.ownsite@our-earth.org", "aa@a.a") 11 | invalid_emails = ("ankitrai326gmail.com", "ankitrai326.com") 12 | 13 | valid_users = [random_user_generator(email=email) for email in valid_emails] 14 | invalid_users = [random_user_generator(email=email) for email in invalid_emails] 15 | 16 | # Valid email 17 | for valid_user in valid_users: 18 | response = await async_client.post("api/auth/register/email", json=valid_user) 19 | response_body = response.json() 20 | assert response.status_code == status.HTTP_201_CREATED 21 | assert "Authorization" in response_body.keys(), response.content.decode("utf-8") 22 | 23 | # Invalid email 24 | for invalid_user in invalid_users: 25 | response = await async_client.post("api/auth/register/email", json=invalid_user) 26 | response_body = response.json() 27 | assert response.status_code != status.HTTP_201_CREATED 28 | assert "Authorization" not in response_body.keys(), response.content.decode("utf-8") 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_auth(async_client: AsyncClient): 33 | new_user = random_user_generator() 34 | 35 | # Register 36 | response = await async_client.post("api/auth/register/email", json=new_user) 37 | assert response.status_code == status.HTTP_201_CREATED, response.content.decode("utf-8") 38 | 39 | # Login 40 | response = await async_client.post("api/auth/login/email", json=new_user) 41 | response_body = response.json() 42 | assert response.status_code == status.HTTP_200_OK, "User login faiture" 43 | assert "Authorization" in response_body.keys(), response.content.decode("utf-8") 44 | 45 | # Duplicate email check 46 | response = await async_client.post("api/auth/register/email", json=new_user) 47 | response_body = response.json() 48 | assert response.status_code == Responses_400.email_already_exists.status_code, "Email duplication check failure" 49 | assert response_body["detail"] == Responses_400.email_already_exists.detail, "Email duplication check failure" 50 | -------------------------------------------------------------------------------- /tests/test_database.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import select 3 | from uuid import uuid4 4 | from app.common.config import Config 5 | from app.models.base_models import AddApiKey, UserToken 6 | from app.database.connection import db 7 | from app.database.schemas.auth import Users, ApiKeys 8 | from app.database.crud.api_keys import create_api_key, get_api_key_and_owner 9 | from app.middlewares.token_validator import Validator 10 | from app.utils.date_utils import UTC 11 | from app.utils.params_utils import hash_params, parse_params 12 | 13 | db.start(config=Config.get(option="test")) 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_apikey_idenfitication(random_user: dict[str, str]): 18 | user: Users = await Users.add_one(autocommit=True, refresh=True, **random_user) # type: ignore 19 | additional_key_info: AddApiKey = AddApiKey(user_memo="[Testing] test_apikey_query") 20 | new_api_key: ApiKeys = await create_api_key( 21 | user_id=user.id, additional_key_info=additional_key_info 22 | ) 23 | matched_api_key, matched_user = await get_api_key_and_owner( 24 | access_key=new_api_key.access_key 25 | ) 26 | assert matched_api_key.access_key == new_api_key.access_key 27 | assert random_user["email"] == matched_user.email 28 | 29 | 30 | @pytest.mark.asyncio 31 | async def test_api_key_validation(random_user: dict[str, str]): 32 | user: Users = await Users.add_one(autocommit=True, refresh=True, **random_user) # type: ignore 33 | additional_key_info: AddApiKey = AddApiKey(user_memo="[Testing] test_apikey_query") 34 | new_api_key: ApiKeys = await create_api_key( 35 | user_id=user.id, additional_key_info=additional_key_info 36 | ) 37 | timestamp: str = str(UTC.timestamp()) 38 | parsed_qs: str = parse_params( 39 | params={"key": new_api_key.access_key, "timestamp": timestamp} 40 | ) 41 | user_token: UserToken = await Validator.api_key( 42 | api_access_key=new_api_key.access_key, 43 | hashed_secret=hash_params( 44 | query_params=parsed_qs, secret_key=new_api_key.secret_key 45 | ), 46 | query_params=parsed_qs, 47 | timestamp=timestamp, 48 | ) 49 | assert user_token.id == user.id 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_crud(): 54 | total_users: int = 4 55 | users: list[dict[str, str]] = [ 56 | {"email": str(uuid4())[:18], "password": str(uuid4())[:18]} 57 | for _ in range(total_users) 58 | ] 59 | user_1, user_2, user_3, user_4 = users 60 | 61 | # C/U 62 | await Users.add_all(user_1, user_2, autocommit=True, refresh=True) 63 | await Users.add_one(autocommit=True, refresh=True, **user_3) 64 | await Users.update_filtered( 65 | Users.email == user_1["email"], 66 | Users.password == user_1["password"], 67 | updated={"email": "UPDATED", "password": "updated"}, 68 | autocommit=True, 69 | ) 70 | 71 | # R 72 | result_1_stmt = select(Users).filter( 73 | Users.email.in_([user_1["email"], "UPDATED", user_3["email"]]) 74 | ) 75 | result_1 = await db.scalars__fetchall(result_1_stmt) 76 | assert len(result_1) == 2 77 | assert result_1[0].email == "UPDATED" # type: ignore 78 | assert result_1[1].email == user_3["email"] # type: ignore 79 | result_2 = await Users.one_filtered_by(**user_2) 80 | assert result_2.email == user_2["email"] # type: ignore 81 | result_3 = await Users.fetchall_filtered_by(**user_4) 82 | assert len(result_3) == 0 83 | 84 | # D 85 | await db.delete(result_2, autocommit=True) 86 | result_4_stmt = select(Users).filter_by(**user_2) 87 | result_4 = (await db.scalars(stmt=result_4_stmt)).first() 88 | assert result_4 is None 89 | -------------------------------------------------------------------------------- /tests/test_dynamic_llm_load.py: -------------------------------------------------------------------------------- 1 | from app.models.llms import LLMModels 2 | 3 | 4 | def test_dynamic_llm_load(): 5 | print("- member_map keys -") 6 | print(LLMModels.member_map.keys()) 7 | print("- static member map keys -") 8 | print(LLMModels.static_member_map.keys()) 9 | print("- dynamic member map keys -") 10 | print(LLMModels.dynamic_member_map.keys()) 11 | 12 | LLMModels.add_member("foo", {"bar": "baz"}) 13 | 14 | print("- member_map keys -") 15 | print(LLMModels.member_map.keys()) 16 | print("- static member map keys -") 17 | print(LLMModels.static_member_map.keys()) 18 | print("- dynamic member map keys -") 19 | print(LLMModels.dynamic_member_map.keys()) 20 | 21 | print("- member_map class -") 22 | print(LLMModels.member_map) 23 | 24 | print("- is dynamic member instance of LLMModels ? -") 25 | print(isinstance(LLMModels.dynamic_member_map["foo"], LLMModels)) 26 | print(f"- {LLMModels.dynamic_member_map['foo']} -") 27 | -------------------------------------------------------------------------------- /tests/test_memory_deallocation.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import queue 3 | from collections import deque 4 | 5 | from app.utils.system import free_memory_of_first_item_from_container 6 | 7 | DEL_COUNT: int = 0 8 | 9 | 10 | class DummyObject: 11 | """Dummy object for testing.""" 12 | 13 | foo: bool 14 | bar: bool 15 | 16 | def __init__(self) -> None: 17 | """Initialize.""" 18 | self.foo = True 19 | self.bar = False 20 | 21 | def __del__(self) -> None: 22 | """Clean up resources.""" 23 | global DEL_COUNT 24 | DEL_COUNT += 1 25 | 26 | 27 | def test_deallocate_item_from_memory_by_reference(): 28 | """Tests if __del__ is called when item is removed from container.""" 29 | global DEL_COUNT 30 | 31 | # Test with a deque 32 | _deque = deque([DummyObject() for _ in range(10)]) 33 | _list = [DummyObject() for _ in range(10)] 34 | _dict = {i: DummyObject() for i in range(10)} 35 | _queue = queue.Queue() 36 | _asyncio_queue = asyncio.Queue() 37 | for _ in range(10): 38 | _queue.put(DummyObject()) 39 | _asyncio_queue.put_nowait(DummyObject()) 40 | 41 | # Test begin 42 | for container in [_deque, _list, _dict, _queue, _asyncio_queue]: 43 | print(f"- Testing {container.__class__.__name__}") 44 | DEL_COUNT = 0 45 | for _ in range(10): 46 | free_memory_of_first_item_from_container(container) 47 | print(f"- Finished testing {container.__class__.__name__}") 48 | assert ( 49 | DEL_COUNT >= 10 50 | ), f"At least {container} should have 10 items deleted" 51 | 52 | 53 | def test_dereference(): 54 | obj = DummyObject() 55 | 56 | def delete_obj(obj) -> None: 57 | del obj 58 | 59 | def delete_foo(obj) -> None: 60 | del obj.foo 61 | 62 | delete_obj(obj) 63 | assert obj is not None # Check if obj is still in memory 64 | delete_foo(obj) 65 | assert not hasattr(obj, "foo") # Check if obj.foo is deleted 66 | -------------------------------------------------------------------------------- /tests/test_translate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from app.common.config import ( 3 | GOOGLE_TRANSLATE_API_KEY, 4 | PAPAGO_CLIENT_ID, 5 | PAPAGO_CLIENT_SECRET, 6 | RAPID_API_KEY, 7 | CUSTOM_TRANSLATE_URL, 8 | ) 9 | from app.utils.api.translate import Translator 10 | 11 | 12 | @pytest.mark.asyncio 13 | async def test_custom_translate(test_logger): 14 | translated = await Translator.custom_translate_api( 15 | text="hello", 16 | src_lang="en", 17 | trg_lang="ko", 18 | api_url=str(CUSTOM_TRANSLATE_URL), 19 | ) 20 | test_logger.info(__name__ + ":" + translated) 21 | 22 | 23 | @pytest.mark.asyncio 24 | async def test_google_translate(test_logger): 25 | translated = await Translator.google( 26 | text="hello", 27 | src_lang="en", 28 | trg_lang="ko", 29 | api_key=str(GOOGLE_TRANSLATE_API_KEY), 30 | ) 31 | test_logger.info(__name__ + ":" + translated) 32 | 33 | 34 | @pytest.mark.asyncio 35 | async def test_papago(test_logger): 36 | translated = await Translator.papago( 37 | text="hello", 38 | src_lang="en", 39 | trg_lang="ko", 40 | client_id=str(PAPAGO_CLIENT_ID), 41 | client_secret=str(PAPAGO_CLIENT_SECRET), 42 | ) 43 | test_logger.info(__name__ + ":" + translated) 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_deepl(test_logger): 48 | translated = await Translator.deepl_via_rapid_api( 49 | text="hello", 50 | src_lang="en", 51 | trg_lang="ko", 52 | api_key=str(RAPID_API_KEY), 53 | ) 54 | test_logger.info(__name__ + ":" + translated) 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def test_auto_translate(test_logger): 59 | translated = await Translator.translate( 60 | text="hello", 61 | src_lang="en", 62 | trg_lang="ko", 63 | ) 64 | test_logger.info("FIRST:" + translated + str(Translator.cached_function) + str(Translator.cached_args)) 65 | translated = await Translator.translate( 66 | text="what is your name?", 67 | src_lang="en", 68 | trg_lang="ko", 69 | ) 70 | test_logger.info("SECOND:" + translated + str(Translator.cached_function) + str(Translator.cached_args)) 71 | -------------------------------------------------------------------------------- /update-git.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | git fetch 3 | git reset --hard origin/master 4 | git pull origin master --------------------------------------------------------------------------------