├── .dockerignore
├── .env-sample
├── .flake8
├── .gitattributes
├── .gitignore
├── .gitmodules
├── .idea
├── .gitignore
├── app.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── .vscode
├── launch.json
├── redis-xplorer.redis
└── settings.json
├── LICENSE
├── app
├── __init__.py
├── auth
│ └── admin.py
├── common
│ ├── app_settings.py
│ ├── app_settings_llama_cpp.py
│ ├── config.py
│ ├── constants.py
│ └── lotties.py
├── contents
│ ├── browsing_demo.gif
│ ├── browsing_demo.png
│ ├── code_demo.png
│ ├── edit_title_demo.png
│ ├── embed_file_demo.png
│ ├── favicon.ico
│ ├── llama_api.png
│ ├── model_selection_demo.png
│ ├── site_logo.png
│ ├── state_of_the_union.txt
│ ├── ui_demo.png
│ └── upload_demo.gif
├── database
│ ├── __init__.py
│ ├── connection.py
│ ├── crud
│ │ ├── api_keys.py
│ │ ├── api_whitelists.py
│ │ ├── deprecated_chatgpt.py
│ │ └── users.py
│ └── schemas
│ │ ├── __init__.py
│ │ ├── auth.py
│ │ └── deprecated_chatgpt.py
├── dependencies.py
├── deprecated_chatgpt.html
├── errors
│ ├── api_exceptions.py
│ └── chat_exceptions.py
├── middlewares
│ ├── token_validator.py
│ └── trusted_hosts.py
├── mixins
│ └── enum.py
├── models
│ ├── base_models.py
│ ├── chat_commands.py
│ ├── chat_models.py
│ ├── completion_models.py
│ ├── function_calling
│ │ ├── base.py
│ │ └── functions.py
│ ├── llm_tokenizers.py
│ └── llms.py
├── routers
│ ├── __init__.py
│ ├── auth.py
│ ├── index.py
│ ├── services.py
│ ├── user_services.py
│ ├── users.py
│ ├── v1.py
│ └── websocket.py
├── shared.py
├── utils
│ ├── api
│ │ ├── completion.py
│ │ ├── duckduckgo.py
│ │ ├── translate.py
│ │ └── weather.py
│ ├── auth
│ │ ├── api_keys.py
│ │ ├── register_validation.py
│ │ └── token.py
│ ├── chat
│ │ ├── buffer.py
│ │ ├── build_llama_shared_lib.py
│ │ ├── chat_rooms.py
│ │ ├── commands
│ │ │ ├── browsing.py
│ │ │ ├── core.py
│ │ │ ├── llm_parameter.py
│ │ │ ├── prompt.py
│ │ │ ├── server.py
│ │ │ ├── summarize.py
│ │ │ ├── testing.py
│ │ │ └── vectorstore.py
│ │ ├── embeddings
│ │ │ ├── __init__.py
│ │ │ ├── sentence_encoder.py
│ │ │ └── transformer.py
│ │ ├── file_loader.py
│ │ ├── managers
│ │ │ ├── cache.py
│ │ │ ├── message.py
│ │ │ ├── stream.py
│ │ │ ├── vectorstore.py
│ │ │ └── websocket.py
│ │ ├── messages
│ │ │ ├── converter.py
│ │ │ ├── handler.py
│ │ │ └── turn_templates.py
│ │ ├── text_generations
│ │ │ ├── __init__.py
│ │ │ ├── completion_api.py
│ │ │ ├── converter.py
│ │ │ ├── exllama.py
│ │ │ ├── llama_cpp.py
│ │ │ ├── path.py
│ │ │ └── summarization.py
│ │ └── tokens.py
│ ├── colorama.py
│ ├── date_utils.py
│ ├── encoding_utils.py
│ ├── errors.py
│ ├── function_calling
│ │ ├── callbacks
│ │ │ ├── click_link.py
│ │ │ ├── full_browsing.py
│ │ │ ├── lite_browsing.py
│ │ │ ├── translate.py
│ │ │ └── vectorstore_search.py
│ │ ├── parser.py
│ │ ├── query.py
│ │ ├── request.py
│ │ └── token_count.py
│ ├── huggingface.py
│ ├── js_initializer.py
│ ├── langchain
│ │ ├── chat_llama_cpp.py
│ │ ├── embeddings_api.py
│ │ ├── qdrant_vectorstore.py
│ │ ├── redis_vectorstore.py
│ │ ├── structured_tool.py
│ │ ├── token_text_splitter.py
│ │ └── web_search.py
│ ├── logger.py
│ ├── module_reloader.py
│ ├── params_utils.py
│ ├── system.py
│ ├── tests
│ │ ├── random.py
│ │ └── timeit.py
│ └── types.py
├── viewmodels
│ ├── admin.py
│ └── status.py
└── web
│ ├── .last_build_id
│ ├── assets
│ ├── AssetManifest.bin
│ ├── AssetManifest.json
│ ├── FontManifest.json
│ ├── NOTICES
│ ├── assets
│ │ ├── images
│ │ │ ├── ai_profile.png
│ │ │ ├── favicon.ico
│ │ │ ├── openai_profile.png
│ │ │ ├── site_logo.png
│ │ │ ├── user_profile.png
│ │ │ └── vicuna_profile.jpg
│ │ ├── lotties
│ │ │ ├── click.json
│ │ │ ├── fail.json
│ │ │ ├── file-upload.json
│ │ │ ├── go-back.json
│ │ │ ├── ok.json
│ │ │ ├── read.json
│ │ │ ├── scroll-down.json
│ │ │ ├── search-doc.json
│ │ │ ├── search-web.json
│ │ │ └── translate.json
│ │ └── svgs
│ │ │ └── search-web.svg
│ ├── fonts
│ │ └── MaterialIcons-Regular.otf
│ ├── packages
│ │ └── cupertino_icons
│ │ │ └── assets
│ │ │ └── CupertinoIcons.ttf
│ └── shaders
│ │ └── ink_sparkle.frag
│ ├── canvaskit
│ ├── canvaskit.js
│ ├── canvaskit.wasm
│ ├── chromium
│ │ ├── canvaskit.js
│ │ └── canvaskit.wasm
│ ├── skwasm.js
│ ├── skwasm.wasm
│ └── skwasm.worker.js
│ ├── favicon.ico
│ ├── flutter.js
│ ├── flutter_service_worker.js
│ ├── icons
│ ├── Icon-192.png
│ ├── Icon-512.png
│ ├── Icon-maskable-192.png
│ └── Icon-maskable-512.png
│ ├── index.html
│ ├── main.dart.js
│ ├── manifest.json
│ ├── site_logo.png
│ └── version.json
├── build-web.bat
├── docker-compose-local.yaml
├── docker-compose-prod.yaml
├── dockerfile
├── init-git.sh
├── llama_models
├── ggml
│ └── llama_cpp_models_here.txt
└── gptq
│ └── exllama_models_here.txt
├── main.py
├── maintools.py
├── notebook
├── function-calling-example.ipynb
└── langchain-summarization.ipynb
├── pyproject.toml
├── qdrant
├── .lock
└── meta.json
├── readme.md
├── requirements-exllama.txt
├── requirements-llama-cpp.txt
├── requirements-sentence-encoder.txt
├── requirements-transformer-embedding-cpu.txt
├── requirements-transformer-embedding-gpu.txt
├── requirements.txt
├── run_local.bat
├── tests
├── __init__.py
├── __pycache__
│ └── test_completion_api.cpython-311-pytest-7.4.0.pyc.8996
├── conftest.py
├── test_api_service.py
├── test_auth.py
├── test_chat.py
├── test_completion_api.py
├── test_database.py
├── test_dynamic_llm_load.py
├── test_memory_deallocation.py
├── test_translate.py
└── test_vectorstore.py
└── update-git.sh
/.dockerignore:
--------------------------------------------------------------------------------
1 | .env
2 | llama_models/*
--------------------------------------------------------------------------------
/.env-sample:
--------------------------------------------------------------------------------
1 |
2 | # API_ENV can be "local", "test", "prod"
3 | API_ENV="local"
4 |
5 | # Port for Docker. If you run this app without Docker, this will be ignored to 8001
6 | PORT=8000
7 |
8 | # Default LLM model for each chat.
9 | DEFAULT_LLM_MODEL="gpt_3_5_turbo"
10 |
11 | # Your MySQL DB info
12 | MYSQL_DATABASE="traffic"
13 | MYSQL_TEST_DATABASE="testing_db"
14 | MYSQL_ROOT_PASSWORD="YOUR_MYSQL_PASSWORD_HERE"
15 | MYSQL_USER="traffic_admin"
16 | MYSQL_PASSWORD="YOUR_DB_ADMIN_PASSWORD_HERE"
17 |
18 | # Your Redis DB info
19 | REDIS_DATABASE="0"
20 | REDIS_PASSWORD="YOUR_REDIS_PASSWORD_HERE"
21 |
22 | # Your JWT secret key
23 | JWT_SECRET="ANY_PASSWORD_FOR_JWT_TOKEN_GENERATION_HERE"
24 |
25 | # Your OpenAI API key
26 | OPENAI_API_KEY="sk-*************"
27 |
28 | # Chatbot settings
29 | # Summarize for chat: Do token summarization for message more than SUMMARIZATION_THRESHOLD
30 | SUMMARIZE_FOR_CHAT=True
31 | SUMMARIZATION_THRESHOLD=512
32 |
33 | # Embedding text will be chunked by EMBEDDING_TOKEN_CHUNK_SIZE with EMBEDDING_TOKEN_CHUNK_OVERLAP
34 | # overlap means how many tokens will be overlapped between each chunk.
35 | EMBEDDING_TOKEN_CHUNK_SIZE=512
36 | EMBEDDING_TOKEN_CHUNK_OVERLAP=128
37 |
38 | # The shared vector collection name. This will be shared for all users.
39 | QDRANT_COLLECTION="SharedCollection"
40 |
41 | # If you want to set prefix or suffix for all prompt to LLM, set these.
42 | GLOBAL_PREFIX=""
43 | GLOBAL_SUFFIX=""
44 |
45 | # If you want to use local embedding instead of OpenAI's Ada-002,
46 | # set LOCAL_EMBEDDING_MODEL as "intfloat/e5-large-v2" or other huggingface embedding model repo.
47 | # Warning: Local embedding needs a lot of computing resources!!!
48 | LOCAL_EMBEDDING_MODEL=None
49 |
50 |
51 | # Define these if you want to open production server with API_ENV="prod"
52 | HOST_IP="OPTIONAL_YOUR_IP_HERE e.g. 192.168.0.2"
53 | HOST_MAIN="OPTIONAL_YOUR_DOMAIN_HERE e.g. yourdomain.com, if you are running API_ENV as production, this will be needed for TLS certificate registration"
54 | HOST_SUB="OPTIONAL_YOUR_SUB_DOMAIN_HERE e.g. mobile.yourdomain.com"
55 | MY_EMAIL="OPTIONAL_YOUR_DOMAIN_HERE e.g. yourdomain.com, if you are running API_ENV as production, this will be needed for TLS certificate registration"
56 |
57 | # Not used.
58 | AWS_ACCESS_KEY="OPTIONAL_IF_YOU_NEED"
59 | AWS_SECRET_KEY="OPTIONAL_IF_YOU_NEED"
60 | AWS_AUTHORIZED_EMAIL="OPTIONAL_IF_YOU_NEED"
61 | SAMPLE_JWT_TOKEN="OPTIONAL_IF_YOU_NEED_FOR_TESTING e.g. Bearer XXXXX"
62 | SAMPLE_ACCESS_KEY="OPTIONAL_IF_YOU_NEED_FOR_TESTING"
63 | SAMPLE_SECRET_KEY="OPTIONAL_IF_YOU_NEED_FOR_TESTING"
64 | KAKAO_RESTAPI_TOKEN="OPTIONAL_IF_YOU_NEED e.g. Bearer XXXXX"
65 | WEATHERBIT_API_KEY="OPTIONAL_IF_YOU_NEED"
66 | NASA_API_KEY="OPTIONAL_IF_YOU_NEED"
67 |
68 | # For translation. If you don't need translation, you can ignore these.
69 | PAPAGO_CLIENT_ID="OPTIONAL_FOR_TRANSTLATION"
70 | PAPAGO_CLIENT_SECRET="OPTIONAL_FOR_TRANSTLATION"
71 | GOOGLE_CLOUD_PROJECT_ID="OPTIONAL_FOR_TRANSTLATION e.g. top-abcd-01234"
72 | GOOGLE_TRANSLATE_API_KEY ="OPTIONAL_FOR_TRANSTLATION"
73 | GOOGLE_TRANSLATE_OAUTH_ID="OPTIONAL_FOR_TRANSTLATION"
74 | GOOGLE_TRANSLATE_OAUTH_SECRET="OPTIONAL_FOR_TRANSTLATION"
75 | RAPIDAPI_KEY="OPTIONAL_FOR_TRANSLATION"
76 | CUSTOM_TRANSLATE_URL="OPTIONAL_FOR_TRANSLATION"
77 |
78 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ; ignore = W503, E501
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.js linguist-detectable=false
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | acme.json
3 | PRIVATE_*
4 | venv/
5 | .cache/
6 | *.pyc
7 | *.log
8 | llama_models/*
9 | !llama_models/ggml/llama_cpp_models_here.txt
10 | !llama_models/gptq/exllama_models_here.txt
11 | deprecated_*
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "frontend"]
2 | path = frontend
3 | url = https://github.com/c0sogi/LLMChat-frontend
4 | [submodule "repositories/llama_cpp"]
5 | path = repositories/llama_cpp
6 | url = https://github.com/abetlen/llama-cpp-python
7 | [submodule "repositories/exllama"]
8 | path = repositories/exllama
9 | url = https://github.com/turboderp/exllama
10 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/app.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.2.0",
3 | "configurations": [
4 | {
5 | "name": "Flutter",
6 | "type": "dart",
7 | "request": "launch",
8 | "program": "frontend/lib/main.dart",
9 | "deviceId": "chrome"
10 | },
11 | {
12 | "name": "Python: FastAPI",
13 | "type": "python",
14 | "request": "launch",
15 | "module": "main",
16 | "jinja": true,
17 | "justMyCode": true,
18 | "console": "integratedTerminal",
19 | "env": {
20 | "PYTHONPATH": "${workspaceFolder}"
21 | },
22 | "args": [
23 | "uvicorn",
24 | "main:app",
25 | "--reload"
26 | ]
27 | }
28 | ]
29 | }
--------------------------------------------------------------------------------
/.vscode/redis-xplorer.redis:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/.vscode/redis-xplorer.redis
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "tests",
4 | "-s",
5 | "-v",
6 | "--capture=no"
7 | ],
8 | "python.testing.unittestEnabled": false,
9 | "python.testing.pytestEnabled": true,
10 | "files.exclude": {
11 | "**/.git": true,
12 | "**/.svn": true,
13 | "**/.hg": true,
14 | "**/CVS": true,
15 | "**/.DS_Store": true,
16 | "**/Thumbs.db": true,
17 | "**/__pycache__": true,
18 | "**/.pytest_cache": true,
19 | },
20 | "python.analysis.typeCheckingMode": "basic",
21 | "python.linting.flake8Enabled": true,
22 | "python.linting.mypyEnabled": false,
23 | "python.linting.enabled": true,
24 | "files.associations": {
25 | "*.inc": "cpp"
26 | },
27 | "python.linting.pylintEnabled": false,
28 | "pylint.path": [
29 | "-"
30 | ],
31 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Seongbin Choi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/__init__.py
--------------------------------------------------------------------------------
/app/auth/admin.py:
--------------------------------------------------------------------------------
1 | from starlette.requests import Request
2 | from starlette.responses import Response
3 | from starlette_admin.auth import AdminUser, AuthProvider
4 | from starlette_admin.exceptions import FormValidationError, LoginFailed
5 | from app.common.config import config
6 |
7 |
8 | class MyAuthProvider(AuthProvider):
9 | async def login(
10 | self,
11 | username: str,
12 | password: str,
13 | remember_me: bool,
14 | request: Request,
15 | response: Response,
16 | ) -> Response:
17 | if len(username) < 3:
18 | """Form data validation"""
19 | raise FormValidationError(
20 | {"username": "Ensure username has at least 03 characters"}
21 | )
22 |
23 | if username == config.mysql_user and password == config.mysql_password:
24 | """Save `username` in session"""
25 | request.session.update({"username": username})
26 | return response
27 |
28 | raise LoginFailed("Invalid username or password")
29 |
30 | async def is_authenticated(self, request) -> bool:
31 | if request.session.get("username", None) == config.mysql_user:
32 | """
33 | Save current `user` object in the request state. Can be used later
34 | to restrict access to connected user.
35 | """
36 | return True
37 |
38 | return False
39 |
40 | def get_admin_user(self, request: Request) -> AdminUser:
41 | return AdminUser(username=config.mysql_user)
42 |
43 | async def logout(self, request: Request, response: Response) -> Response:
44 | request.session.clear()
45 | return response
46 |
--------------------------------------------------------------------------------
/app/common/app_settings_llama_cpp.py:
--------------------------------------------------------------------------------
1 | from contextlib import asynccontextmanager
2 | from multiprocessing import Process
3 | from os import kill
4 | from signal import SIGINT
5 | from threading import Event
6 | from urllib import parse
7 |
8 | import requests
9 | from fastapi import FastAPI
10 | from starlette.middleware.cors import CORSMiddleware
11 |
12 | from app.shared import Shared
13 | from app.utils.logger import ApiLogger
14 |
15 | from .config import Config
16 |
17 |
18 | def check_health(url: str) -> bool:
19 | """Check if the given url is available or not"""
20 | try:
21 | schema = parse.urlparse(url).scheme
22 | netloc = parse.urlparse(url).netloc
23 | if requests.get(f"{schema}://{netloc}/health").status_code != 200:
24 | return False
25 | return True
26 | except Exception:
27 | return False
28 |
29 |
30 | def start_llama_cpp_server(config: Config, shared: Shared):
31 | """Start Llama CPP server. if it is already running, terminate it first."""
32 |
33 | if shared.process.is_alive():
34 | ApiLogger.cwarning("Terminating existing Llama CPP server")
35 | shared.process.terminate()
36 | shared.process.join()
37 |
38 | if config.llama_server_port is None:
39 | raise NotImplementedError("Llama CPP server port is not set")
40 |
41 | ApiLogger.ccritical("Starting Llama CPP server")
42 | shared.process = Process(
43 | target=run_llama_cpp, args=(config.llama_server_port,), daemon=True
44 | )
45 | shared.process.start()
46 |
47 |
48 | def shutdown_llama_cpp_server(shared: Shared):
49 | """Shutdown Llama CPP server."""
50 | ApiLogger.ccritical("Shutting down Llama CPP server")
51 | if shared.process.is_alive() and shared.process.pid:
52 | kill(shared.process.pid, SIGINT)
53 | shared.process.join()
54 |
55 |
56 | def monitor_llama_cpp_server(
57 | config: Config,
58 | shared: Shared,
59 | ) -> None:
60 | """Monitors the Llama CPP server and handles server availability.
61 |
62 | Parameters:
63 | - `config: Config`: An object representing the server configuration.
64 | - `shared: Shared`: An object representing shared data."""
65 | thread_sigterm: Event = shared.thread_terminate_signal
66 | if not config.llama_completion_url:
67 | return
68 | while True:
69 | if not check_health(config.llama_completion_url):
70 | if thread_sigterm.is_set():
71 | break
72 | if config.is_llama_booting:
73 | continue
74 | ApiLogger.cerror("Llama CPP server is not available")
75 | config.is_llama_available = False
76 | config.is_llama_booting = True
77 | try:
78 | start_llama_cpp_server(config=config, shared=shared)
79 | except (ImportError, NotImplementedError):
80 | ApiLogger.cerror(
81 | "ImportError: Llama CPP server is not available"
82 | )
83 | return
84 | except Exception:
85 | ApiLogger.cexception(
86 | "Unknown error: Llama CPP server is not available"
87 | )
88 | config.is_llama_booting = False
89 | continue
90 | else:
91 | config.is_llama_booting = False
92 | config.is_llama_available = True
93 | shutdown_llama_cpp_server(shared)
94 |
95 |
96 | @asynccontextmanager
97 | async def lifespan_llama_cpp(app: FastAPI):
98 | ApiLogger.ccritical("🦙 Llama.cpp server is running")
99 | yield
100 | ApiLogger.ccritical("🦙 Shutting down llama.cpp server...")
101 |
102 |
103 | def create_app_llama_cpp():
104 | from app.routers import v1
105 |
106 | new_app = FastAPI(
107 | title="🦙 llama.cpp Python API",
108 | version="0.0.1",
109 | lifespan=lifespan_llama_cpp,
110 | )
111 | new_app.add_middleware(
112 | CORSMiddleware,
113 | allow_origins=["*"],
114 | allow_credentials=True,
115 | allow_methods=["*"],
116 | allow_headers=["*"],
117 | )
118 |
119 | @new_app.get("/health")
120 | async def health():
121 | return "ok"
122 |
123 | new_app.include_router(v1.router)
124 | return new_app
125 |
126 |
127 | def run_llama_cpp(port: int) -> None:
128 | from uvicorn import Config, Server
129 |
130 | from maintools import initialize_before_launch
131 |
132 | initialize_before_launch()
133 |
134 | Server(
135 | config=Config(
136 | create_app_llama_cpp(),
137 | host="0.0.0.0",
138 | port=port,
139 | log_level="warning",
140 | )
141 | ).run()
142 |
143 |
144 | if __name__ == "__main__":
145 | run_llama_cpp(port=8002)
146 |
--------------------------------------------------------------------------------
/app/common/lotties.py:
--------------------------------------------------------------------------------
1 | from app.mixins.enum import EnumMixin
2 |
3 |
4 | class Lotties(EnumMixin):
5 | CLICK = "lottie-click"
6 | READ = "lottie-read"
7 | SCROLL_DOWN = "lottie-scroll-down"
8 | GO_BACK = "lottie-go-back"
9 | SEARCH_WEB = "lottie-search-web"
10 | SEARCH_DOC = "lottie-search-doc"
11 | OK = "lottie-ok"
12 | FAIL = "lottie-fail"
13 | TRANSLATE = "lottie-translate"
14 |
15 | def format(self, contents: str, end: bool = True) -> str:
16 | return f"\n```{self.get_value(self)}\n{contents}" + (
17 | "\n```\n" if end else ""
18 | )
19 |
20 |
21 | if __name__ == "__main__":
22 | print(Lotties.CLICK.format("hello"))
23 |
--------------------------------------------------------------------------------
/app/contents/browsing_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/browsing_demo.gif
--------------------------------------------------------------------------------
/app/contents/browsing_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/browsing_demo.png
--------------------------------------------------------------------------------
/app/contents/code_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/code_demo.png
--------------------------------------------------------------------------------
/app/contents/edit_title_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/edit_title_demo.png
--------------------------------------------------------------------------------
/app/contents/embed_file_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/embed_file_demo.png
--------------------------------------------------------------------------------
/app/contents/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/favicon.ico
--------------------------------------------------------------------------------
/app/contents/llama_api.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/llama_api.png
--------------------------------------------------------------------------------
/app/contents/model_selection_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/model_selection_demo.png
--------------------------------------------------------------------------------
/app/contents/site_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/site_logo.png
--------------------------------------------------------------------------------
/app/contents/ui_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/ui_demo.png
--------------------------------------------------------------------------------
/app/contents/upload_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/contents/upload_demo.gif
--------------------------------------------------------------------------------
/app/database/__init__.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy.orm import declarative_base
2 | from sqlalchemy.orm.decl_api import DeclarativeMeta
3 | from typing import TypeVar
4 |
5 | Base: DeclarativeMeta = declarative_base()
6 | TableGeneric = TypeVar("TableGeneric")
7 |
--------------------------------------------------------------------------------
/app/database/crud/api_keys.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | from sqlalchemy import select, func, exists
3 | from app.models.base_models import AddApiKey
4 | from app.errors.api_exceptions import (
5 | Responses_400,
6 | Responses_404,
7 | Responses_500,
8 | )
9 | from app.common.config import MAX_API_KEY
10 | from app.database.connection import db
11 | from app.database.schemas.auth import ApiKeys, Users
12 | from app.utils.auth.api_keys import generate_new_api_key
13 |
14 |
15 | async def create_api_key(
16 | user_id: int,
17 | additional_key_info: AddApiKey,
18 | ) -> ApiKeys:
19 | if db.session is None:
20 | raise Responses_500.database_not_initialized
21 | async with db.session() as transaction:
22 | api_key_count_stmt = select(func.count(ApiKeys.id)).filter_by(user_id=user_id)
23 | api_key_count: int | None = await transaction.scalar(api_key_count_stmt)
24 | if api_key_count is not None and api_key_count >= MAX_API_KEY:
25 | raise Responses_400.max_key_count_exceed
26 | while True:
27 | new_api_key: ApiKeys = generate_new_api_key(
28 | user_id=user_id, additional_key_info=additional_key_info
29 | )
30 | is_api_key_duplicate_stmt = select(
31 | exists().where(ApiKeys.access_key == new_api_key.access_key)
32 | )
33 | is_api_key_duplicate: bool | None = await transaction.scalar(
34 | is_api_key_duplicate_stmt
35 | )
36 | if not is_api_key_duplicate:
37 | break
38 | transaction.add(new_api_key)
39 | await transaction.commit()
40 | await transaction.refresh(new_api_key)
41 | return new_api_key
42 |
43 |
44 | async def get_api_keys(user_id: int) -> list[ApiKeys]:
45 | return await ApiKeys.fetchall_filtered_by(user_id=user_id)
46 |
47 |
48 | async def get_api_key_owner(access_key: str) -> Users:
49 | if db.session is None:
50 | raise Responses_500.database_not_initialized
51 | async with db.session() as transaction:
52 | matched_api_key: Optional[ApiKeys] = await transaction.scalar(
53 | select(ApiKeys).filter_by(access_key=access_key)
54 | )
55 | if matched_api_key is None:
56 | raise Responses_404.not_found_access_key
57 | owner: Users = await Users.first_filtered_by(id=matched_api_key.user_id)
58 | if owner is None:
59 | raise Responses_404.not_found_user
60 | return owner
61 |
62 |
63 | async def get_api_key_and_owner(access_key: str) -> Tuple[ApiKeys, Users]:
64 | if db.session is None:
65 | raise Responses_500.database_not_initialized
66 | async with db.session() as transaction:
67 | matched_api_key: Optional[ApiKeys] = await transaction.scalar(
68 | select(ApiKeys).filter_by(access_key=access_key)
69 | )
70 | if matched_api_key is None:
71 | raise Responses_404.not_found_access_key
72 | api_key_owner: Optional[Users] = await transaction.scalar(
73 | select(Users).filter_by(id=matched_api_key.user_id)
74 | )
75 | if api_key_owner is None:
76 | raise Responses_404.not_found_user
77 | return matched_api_key, api_key_owner
78 |
79 |
80 | async def update_api_key(
81 | updated_key_info: dict,
82 | access_key_id: int,
83 | user_id: int,
84 | ) -> ApiKeys:
85 | if db.session is None:
86 | raise Responses_500.database_not_initialized
87 | async with db.session() as transaction:
88 | matched_api_key: Optional[ApiKeys] = await transaction.scalar(
89 | select(ApiKeys).filter_by(id=access_key_id, user_id=user_id)
90 | )
91 | if matched_api_key is None:
92 | raise Responses_404.not_found_api_key
93 | matched_api_key.set_values_as(**updated_key_info)
94 | transaction.add(matched_api_key)
95 | await transaction.commit()
96 | await transaction.refresh(matched_api_key)
97 | return matched_api_key
98 |
99 |
100 | async def delete_api_key(
101 | access_key_id: int,
102 | access_key: str,
103 | user_id: int,
104 | ) -> None:
105 | if db.session is None:
106 | raise Responses_500.database_not_initialized
107 | async with db.session() as transaction:
108 | matched_api_key: Optional[ApiKeys] = await transaction.scalar(
109 | select(ApiKeys).filter_by(
110 | id=access_key_id, user_id=user_id, access_key=access_key
111 | )
112 | )
113 | if matched_api_key is None:
114 | raise Responses_404.not_found_api_key
115 | await transaction.delete(matched_api_key)
116 | await transaction.commit()
117 |
--------------------------------------------------------------------------------
/app/database/crud/api_whitelists.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from sqlalchemy import select, func
3 | from app.errors.api_exceptions import (
4 | Responses_400,
5 | Responses_404,
6 | Responses_500,
7 | )
8 | from app.common.config import MAX_API_WHITELIST
9 | from app.database.connection import db
10 | from app.database.schemas.auth import (
11 | ApiKeys,
12 | ApiWhiteLists,
13 | )
14 |
15 |
16 | async def create_api_key_whitelist(
17 | ip_address: str,
18 | api_key_id: int,
19 | ) -> ApiWhiteLists:
20 | if db.session is None:
21 | raise Responses_500.database_not_initialized
22 | async with db.session() as transaction:
23 | whitelist_count_stmt = select(func.count(ApiWhiteLists.id)).filter_by(api_key_id=api_key_id)
24 | whitelist_count: int | None = await transaction.scalar(whitelist_count_stmt)
25 | if whitelist_count is not None and whitelist_count >= MAX_API_WHITELIST:
26 | raise Responses_400.max_whitekey_count_exceed
27 | ip_duplicated_whitelist_stmt = select(ApiWhiteLists).filter_by(api_key_id=api_key_id, ip_address=ip_address)
28 | ip_duplicated_whitelist = await transaction.scalar(ip_duplicated_whitelist_stmt)
29 | if ip_duplicated_whitelist is not None:
30 | return ip_duplicated_whitelist
31 | new_whitelist = ApiWhiteLists(api_key_id=api_key_id, ip_address=ip_address)
32 | transaction.add(new_whitelist)
33 | await transaction.commit()
34 | await transaction.refresh(new_whitelist)
35 | return new_whitelist
36 |
37 |
38 | async def get_api_key_whitelist(api_key_id: int) -> list[ApiWhiteLists]:
39 | return await ApiWhiteLists.fetchall_filtered_by(api_key_id=api_key_id)
40 |
41 |
42 | async def delete_api_key_whitelist(
43 | user_id: int,
44 | api_key_id: int,
45 | whitelist_id: int,
46 | ) -> None:
47 | if db.session is None:
48 | raise Responses_500.database_not_initialized
49 | async with db.session() as transaction:
50 | matched_api_key: Optional[ApiKeys] = await transaction.scalar(
51 | select(ApiKeys).filter_by(id=api_key_id, user_id=user_id)
52 | )
53 | if matched_api_key is None:
54 | raise Responses_404.not_found_api_key
55 | matched_whitelist_stmt = select(ApiWhiteLists).filter_by(id=whitelist_id, api_key_id=api_key_id)
56 | matched_whitelist = await transaction.scalar(matched_whitelist_stmt)
57 | await transaction.delete(matched_whitelist)
58 | await transaction.commit()
59 |
--------------------------------------------------------------------------------
/app/database/crud/deprecated_chatgpt.py:
--------------------------------------------------------------------------------
1 | # from typing import Optional
2 |
3 | # from sqlalchemy import select
4 | # from app.errors.api_exceptions import (
5 | # Responses_404,
6 | # Responses_500,
7 | # )
8 | # from app.database.connection import db
9 | # from app.database.schemas.auth import (
10 | # ChatMessages,
11 | # ChatRooms,
12 | # GptPresets,
13 | # )
14 |
15 |
16 | # async def create_chat_room(
17 | # chat_room_type: str,
18 | # name: str,
19 | # description: str | None,
20 | # user_id: int,
21 | # status: str = "active",
22 | # ) -> ChatRooms:
23 | # if db.session is None:
24 | # raise Responses_500.database_not_initialized
25 | # async with db.session() as transaction:
26 | # new_chat_room: ChatRooms = ChatRooms(
27 | # status=status,
28 | # chat_room_type=chat_room_type,
29 | # name=name,
30 | # description=description,
31 | # user_id=user_id,
32 | # )
33 | # transaction.add(new_chat_room)
34 | # await transaction.commit()
35 | # await transaction.refresh(new_chat_room)
36 | # return new_chat_room
37 |
38 |
39 | # async def get_chat_all_rooms(user_id: int) -> list[ChatRooms]:
40 | # return await ChatRooms.fetchall_filtered_by(user_id=user_id) # type: ignore
41 |
42 |
43 | # async def create_chat_message(
44 | # role: str,
45 | # message: str,
46 | # chat_room_id: int,
47 | # user_id: int,
48 | # status: str = "active",
49 | # ) -> ChatMessages:
50 | # if db.session is None:
51 | # raise Responses_500.database_not_initialized
52 | # async with db.session() as transaction:
53 | # new_chat_message: ChatMessages = ChatMessages(
54 | # status=status,
55 | # role=role,
56 | # message=message,
57 | # user_id=user_id,
58 | # chat_room_id=chat_room_id,
59 | # )
60 | # transaction.add(new_chat_message)
61 | # await transaction.commit()
62 | # await transaction.refresh(new_chat_message)
63 | # return new_chat_message
64 |
65 |
66 | # async def get_chat_all_messages(
67 | # chat_room_id: int,
68 | # user_id: int,
69 | # ) -> list[ChatMessages]:
70 | # return await ChatMessages.fetchall_filtered_by(
71 | # user_id=user_id,
72 | # chat_room_id=chat_room_id,
73 | # ) # type: ignore
74 |
75 |
76 | # async def create_gpt_preset(
77 | # user_id: int,
78 | # temperature: float,
79 | # top_p: float,
80 | # presence_penalty: float,
81 | # frequency_penalty: float,
82 | # ) -> GptPresets:
83 | # if db.session is None:
84 | # raise Responses_500.database_not_initialized
85 | # async with db.session() as transaction:
86 | # new_gpt_preset: GptPresets = GptPresets(
87 | # user_id=user_id,
88 | # temperature=temperature,
89 | # top_p=top_p,
90 | # presence_penalty=presence_penalty,
91 | # frequency_penalty=frequency_penalty,
92 | # )
93 | # transaction.add(new_gpt_preset)
94 | # await transaction.commit()
95 | # await transaction.refresh(new_gpt_preset)
96 | # return new_gpt_preset
97 |
98 |
99 | # async def get_gpt_presets(user_id: int) -> list[GptPresets]:
100 | # return await GptPresets.fetchall_filtered_by(user_id=user_id) # type: ignore
101 |
102 |
103 | # async def get_gpt_preset(user_id: int, preset_id: int) -> GptPresets:
104 | # return await GptPresets.fetchone_filtered_by(user_id=user_id, id=preset_id)
105 |
106 |
107 | # async def update_gpt_preset(
108 | # user_id: int,
109 | # preset_id: int,
110 | # temperature: float,
111 | # top_p: float,
112 | # presence_penalty: float,
113 | # frequency_penalty: float,
114 | # status: str = "active",
115 | # ) -> GptPresets:
116 | # if db.session is None:
117 | # raise Responses_500.database_not_initialized
118 | # async with db.session() as transaction:
119 | # matched_preset: Optional[GptPresets] = await transaction.scalar(
120 | # select(GptPresets).filter_by(user_id=user_id, id=preset_id)
121 | # )
122 | # if matched_preset is None:
123 | # raise Responses_404.not_found_preset
124 | # matched_preset.set_values_as(
125 | # temperature=temperature,
126 | # top_p=top_p,
127 | # presence_penalty=presence_penalty,
128 | # frequency_penalty=frequency_penalty,
129 | # status=status,
130 | # )
131 | # transaction.add(matched_preset)
132 | # await transaction.commit()
133 | # await transaction.refresh(matched_preset)
134 | # return matched_preset
135 |
136 |
137 | # async def delete_gpt_preset(user_id: int, preset_id: int) -> None:
138 | # if db.session is None:
139 | # raise Responses_500.database_not_initialized
140 | # async with db.session() as transaction:
141 | # matched_preset: Optional[GptPresets] = await transaction.scalar(
142 | # select(GptPresets).filter_by(user_id=user_id, id=preset_id)
143 | # )
144 | # if matched_preset is None:
145 | # raise Responses_404.not_found_preset
146 | # await transaction.delete(matched_preset)
147 | # await transaction.commit()
148 | # return
149 |
--------------------------------------------------------------------------------
/app/database/crud/users.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import select, exists
2 | from app.database.connection import db
3 | from app.database.schemas.auth import (
4 | Users,
5 | ApiKeys,
6 | )
7 |
8 |
9 | async def is_email_exist(email: str) -> bool:
10 | return True if await db.scalar(select(exists().where(Users.email == email))) else False
11 |
12 |
13 | async def get_me(user_id: int):
14 | return await Users.first_filtered_by(id=user_id)
15 |
16 |
17 | async def is_valid_api_key(access_key: str) -> bool:
18 | return True if await db.scalar(select(exists().where(ApiKeys.access_key == access_key))) else False
19 |
20 |
21 | async def register_new_user(
22 | email: str,
23 | hashed_password: str,
24 | ip_address: str | None,
25 | ) -> Users:
26 | return (
27 | await Users.add_one(
28 | autocommit=True,
29 | refresh=True,
30 | email=email,
31 | password=hashed_password,
32 | ip_address=ip_address,
33 | )
34 | if ip_address
35 | else await Users.add_one(
36 | autocommit=True,
37 | refresh=True,
38 | email=email,
39 | password=hashed_password,
40 | )
41 | )
42 |
43 |
44 | async def find_matched_user(email: str) -> Users:
45 | return await Users.first_filtered_by(email=email)
46 |
--------------------------------------------------------------------------------
/app/database/schemas/auth.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from sqlalchemy import (
3 | String,
4 | Integer,
5 | Enum,
6 | Boolean,
7 | ForeignKey,
8 | )
9 | from sqlalchemy.orm import (
10 | relationship,
11 | Mapped,
12 | mapped_column,
13 | )
14 |
15 | from app.viewmodels.status import ApiKeyStatus, UserStatus
16 | from .. import Base
17 | from . import TableMixin
18 |
19 |
20 | class Users(Base, TableMixin):
21 | __tablename__ = "users"
22 | status: Mapped[str] = mapped_column(Enum(UserStatus), default=UserStatus.active)
23 | email: Mapped[str] = mapped_column(String(length=50))
24 | password: Mapped[str | None] = mapped_column(String(length=100))
25 | name: Mapped[str | None] = mapped_column(String(length=20))
26 | phone_number: Mapped[str | None] = mapped_column(String(length=20))
27 | profile_img: Mapped[str | None] = mapped_column(String(length=100))
28 | marketing_agree: Mapped[bool] = mapped_column(Boolean, default=True)
29 | api_keys: Mapped["ApiKeys"] = relationship(
30 | back_populates="users", cascade="all, delete-orphan", lazy=True
31 | )
32 | # chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="users", cascade="all, delete-orphan", lazy=True)
33 | # chat_messages: Mapped["ChatMessages"] = relationship(
34 | # back_populates="users", cascade="all, delete-orphan", lazy=True
35 | # )
36 | # gpt_presets: Mapped["GptPresets"] = relationship(
37 | # back_populates="users", cascade="all, delete-orphan", lazy=True, uselist=False
38 | # )
39 |
40 |
41 | class ApiKeys(Base, TableMixin):
42 | __tablename__ = "api_keys"
43 | status: Mapped[str] = mapped_column(Enum(ApiKeyStatus), default=ApiKeyStatus.active)
44 | access_key: Mapped[str] = mapped_column(String(length=64), index=True, unique=True)
45 | secret_key: Mapped[str] = mapped_column(String(length=64))
46 | user_memo: Mapped[str | None] = mapped_column(String(length=40))
47 | is_whitelisted: Mapped[bool] = mapped_column(default=False)
48 | user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
49 | users: Mapped["Users"] = relationship(back_populates="api_keys")
50 | whitelists: Mapped["ApiWhiteLists"] = relationship(
51 | backref="api_keys", cascade="all, delete-orphan"
52 | )
53 |
54 |
55 | class ApiWhiteLists(Base, TableMixin):
56 | __tablename__ = "api_whitelists"
57 | api_key_id: Mapped[int] = mapped_column(
58 | Integer, ForeignKey("api_keys.id", ondelete="CASCADE")
59 | )
60 | ip_address: Mapped[str] = mapped_column(String(length=64))
61 |
62 |
63 | # class ChatRooms(Base, Mixin):
64 | # __tablename__ = "chat_rooms"
65 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True)
66 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active")
67 | # chat_room_type: Mapped[str] = mapped_column(String(length=20), index=True)
68 | # name: Mapped[str] = mapped_column(String(length=20))
69 | # description: Mapped[str | None] = mapped_column(String(length=100))
70 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
71 | # users: Mapped["Users"] = relationship(back_populates="chat_rooms")
72 | # chat_messages: Mapped["ChatMessages"] = relationship(back_populates="chat_rooms", cascade="all, delete-orphan")
73 |
74 |
75 | # class ChatMessages(Base, Mixin):
76 | # __tablename__ = "chat_messages"
77 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True)
78 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active")
79 | # role: Mapped[str] = mapped_column(String(length=20), default="user")
80 | # message: Mapped[str] = mapped_column(Text)
81 | # chat_room_id: Mapped[int] = mapped_column(ForeignKey("chat_rooms.id", ondelete="CASCADE"))
82 | # chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="chat_messages")
83 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
84 | # users: Mapped["Users"] = relationship(back_populates="chat_messages")
85 |
86 |
87 | # class GptPresets(Base, Mixin):
88 | # __tablename__ = "gpt_presets"
89 | # temperature: Mapped[float] = mapped_column(Float, default=0.9)
90 | # top_p: Mapped[float] = mapped_column(Float, default=1.0)
91 | # presence_penalty: Mapped[float] = mapped_column(Float, default=0)
92 | # frequency_penalty: Mapped[float] = mapped_column(Float, default=0)
93 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), unique=True)
94 | # users: Mapped["Users"] = relationship(back_populates="gpt_presets")
95 |
--------------------------------------------------------------------------------
/app/database/schemas/deprecated_chatgpt.py:
--------------------------------------------------------------------------------
1 | # from sqlalchemy import (
2 | # String,
3 | # Enum,
4 | # ForeignKey,
5 | # Float,
6 | # Text,
7 | # )
8 | # from sqlalchemy.orm import (
9 | # relationship,
10 | # Mapped,
11 | # mapped_column,
12 | # )
13 | # from app.database.schemas.auth import Users
14 | # from .. import Base
15 | # from . import Mixin
16 |
17 |
18 | # class ChatRooms(Base, Mixin):
19 | # __tablename__ = "chat_rooms"
20 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True)
21 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active")
22 | # chat_room_type: Mapped[str] = mapped_column(String(length=20), index=True)
23 | # name: Mapped[str] = mapped_column(String(length=20))
24 | # description: Mapped[str | None] = mapped_column(String(length=100))
25 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
26 | # users: Mapped["Users"] = relationship(back_populates="chat_rooms")
27 | # chat_messages: Mapped["ChatMessages"] = relationship(back_populates="chat_rooms", cascade="all, delete-orphan")
28 |
29 |
30 | # class ChatMessages(Base, Mixin):
31 | # __tablename__ = "chat_messages"
32 | # uuid: Mapped[str] = mapped_column(String(length=36), index=True, unique=True)
33 | # status: Mapped[str] = mapped_column(Enum("active", "deleted", "blocked"), default="active")
34 | # role: Mapped[str] = mapped_column(String(length=20), default="user")
35 | # message: Mapped[str] = mapped_column(Text)
36 | # chat_room_id: Mapped[int] = mapped_column(ForeignKey("chat_rooms.id"))
37 | # chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="chat_messages")
38 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
39 | # users: Mapped["Users"] = relationship(back_populates="chat_messages")
40 |
41 |
42 | # class GptPresets(Base, Mixin):
43 | # __tablename__ = "gpt_presets"
44 | # temperature: Mapped[float] = mapped_column(Float, default=0.9)
45 | # top_p: Mapped[float] = mapped_column(Float, default=1.0)
46 | # presence_penalty: Mapped[float] = mapped_column(Float, default=0)
47 | # frequency_penalty: Mapped[float] = mapped_column(Float, default=0)
48 | # user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), unique=True)
49 | # users: Mapped["Users"] = relationship(back_populates="gpt_presets")
50 |
--------------------------------------------------------------------------------
/app/dependencies.py:
--------------------------------------------------------------------------------
1 | from fastapi import Header, Query
2 | from fastapi.security import APIKeyHeader
3 |
4 |
5 | def api_service_dependency(secret: str = Header(...), key: str = Query(...), timestamp: str = Query(...)):
6 | ... # do some validation or processing with the headers
7 |
8 |
9 | USER_DEPENDENCY = APIKeyHeader(name="Authorization", auto_error=False)
10 |
--------------------------------------------------------------------------------
/app/errors/chat_exceptions.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | class ChatException(Exception): # Base exception for chat
5 | def __init__(self, *, msg: str | None = None) -> None:
6 | self.msg = msg
7 | super().__init__()
8 |
9 |
10 | class ChatConnectionException(ChatException):
11 | def __init__(self, *, msg: str | None = None) -> None:
12 | self.msg = msg
13 | super().__init__(msg=msg)
14 |
15 |
16 | class ChatLengthException(ChatException):
17 | def __init__(self, *, msg: str | None = None) -> None:
18 | self.msg = msg
19 | super().__init__(msg=msg)
20 |
21 |
22 | class ChatContentFilterException(ChatException):
23 | def __init__(self, *, msg: str | None = None) -> None:
24 | self.msg = msg
25 | super().__init__(msg=msg)
26 |
27 |
28 | class ChatTooMuchTokenException(ChatException):
29 | def __init__(self, *, msg: str | None = None) -> None:
30 | self.msg = msg
31 | super().__init__(msg=msg)
32 |
33 |
34 | class ChatTextGenerationException(ChatException):
35 | def __init__(self, *, msg: str | None = None) -> None:
36 | self.msg = msg
37 | super().__init__(msg=msg)
38 |
39 |
40 | class ChatOtherException(ChatException):
41 | def __init__(self, *, msg: str | None = None) -> None:
42 | self.msg = msg
43 | super().__init__(msg=msg)
44 |
45 |
46 | class ChatModelNotImplementedException(ChatException):
47 | def __init__(self, *, msg: str | None = None) -> None:
48 | self.msg = msg
49 | super().__init__(msg=msg)
50 |
51 |
52 | class ChatBreakException(ChatException):
53 | def __init__(self, *, msg: str | None = None) -> None:
54 | self.msg = msg
55 | super().__init__(msg=msg)
56 |
57 |
58 | class ChatContinueException(ChatException):
59 | def __init__(self, *, msg: str | None = None) -> None:
60 | self.msg = msg
61 | super().__init__(msg=msg)
62 |
63 |
64 | class ChatInterruptedException(ChatException):
65 | def __init__(self, *, msg: str | None = None) -> None:
66 | self.msg = msg
67 | super().__init__(msg=msg)
68 |
69 |
70 | class ChatFunctionCallException(ChatException):
71 | """Raised when function is called."""
72 |
73 | def __init__(self, *, func_name: str, func_kwargs: dict[str, Any]) -> None:
74 | self.func_name = func_name
75 | self.func_kwargs = func_kwargs
76 | super().__init__(msg=f"Function {func_name}({func_kwargs}) is called.")
77 |
--------------------------------------------------------------------------------
/app/middlewares/trusted_hosts.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from starlette.datastructures import URL, Headers
4 | from starlette.responses import PlainTextResponse, RedirectResponse
5 | from starlette.types import ASGIApp, Receive, Scope, Send
6 |
7 | from app.errors.api_exceptions import Responses_500
8 |
9 |
10 | class TrustedHostMiddleware:
11 | def __init__(
12 | self,
13 | app: ASGIApp,
14 | allowed_hosts: Sequence[str] | None = None,
15 | except_path: Sequence[str] | None = None,
16 | www_redirect: bool = True,
17 | ):
18 | self.app: ASGIApp = app
19 | self.allowed_hosts: list = (
20 | ["*"] if allowed_hosts is None else list(allowed_hosts)
21 | )
22 | self.allow_any: bool = (
23 | "*" in allowed_hosts if allowed_hosts is not None else True
24 | )
25 | self.www_redirect: bool = www_redirect
26 | self.except_path: list = (
27 | [] if except_path is None else list(except_path)
28 | )
29 | if allowed_hosts is not None:
30 | for allowed_host in allowed_hosts:
31 | if "*" in allowed_host[1:]:
32 | raise Responses_500.middleware_exception
33 | if (
34 | allowed_host.startswith("*") and allowed_host != "*"
35 | ) and not allowed_host.startswith("*."):
36 | raise Responses_500.middleware_exception
37 |
38 | async def __call__(
39 | self, scope: Scope, receive: Receive, send: Send
40 | ) -> None:
41 | if self.allow_any or scope["type"] not in (
42 | "http",
43 | "websocket",
44 | ): # pragma: no cover
45 | await self.app(scope, receive, send)
46 | return
47 |
48 | headers = Headers(scope=scope)
49 | host = headers.get("host", "").split(":")[0]
50 | is_valid_host = False
51 | found_www_redirect = False
52 | for pattern in self.allowed_hosts:
53 | if (
54 | host == pattern
55 | or pattern.startswith("*")
56 | and host.endswith(pattern[1:])
57 | or URL(scope=scope).path in self.except_path
58 | ):
59 | is_valid_host = True
60 | break
61 | elif "www." + host == pattern:
62 | found_www_redirect = True
63 |
64 | if is_valid_host:
65 | await self.app(scope, receive, send)
66 | else:
67 | if found_www_redirect and self.www_redirect:
68 | url = URL(scope=scope)
69 | redirect_url = url.replace(netloc="www." + url.netloc)
70 | response = RedirectResponse(url=str(redirect_url))
71 | else:
72 | response = PlainTextResponse(
73 | "Invalid host header", status_code=400
74 | )
75 | await response(scope, receive, send)
76 |
--------------------------------------------------------------------------------
/app/models/chat_commands.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from app.utils.chat.commands.browsing import BrowsingCommands
4 | from app.utils.chat.commands.core import CoreCommands
5 | from app.utils.chat.commands.llm_parameter import LLMParameterCommands
6 | from app.utils.chat.commands.prompt import PromptCommands
7 | from app.utils.chat.commands.server import ServerCommands
8 | from app.utils.chat.commands.summarize import SummarizeCommands
9 | from app.utils.chat.commands.testing import TestingCommands
10 | from app.utils.chat.commands.vectorstore import VectorstoreCommands
11 |
12 | from .chat_models import command_response
13 |
14 |
15 | class ChatCommandsMetaClass(type):
16 | """Metaclass for ChatCommands class.
17 | It is used to automatically create list of special commands.
18 | """
19 |
20 | special_commands: list[str]
21 |
22 | def __init__(cls, name, bases, dct):
23 | super().__init__(name, bases, dct)
24 | cls.special_commands = [
25 | callback_name
26 | for callback_name in dir(CoreCommands)
27 | if not callback_name.startswith("_")
28 | ]
29 |
30 |
31 | class ChatCommands(
32 | CoreCommands,
33 | VectorstoreCommands,
34 | PromptCommands,
35 | BrowsingCommands,
36 | LLMParameterCommands,
37 | ServerCommands,
38 | TestingCommands,
39 | SummarizeCommands,
40 | metaclass=ChatCommandsMetaClass,
41 | ):
42 | @classmethod
43 | def find_callback_with_command(cls, command: str) -> Callable:
44 | found_callback: Callable = getattr(cls, command)
45 | if found_callback is None or not callable(found_callback):
46 | found_callback = cls.not_existing_callback
47 | return found_callback
48 |
49 | @staticmethod
50 | @command_response.send_message_and_stop
51 | def not_existing_callback() -> str: # callback for not existing command
52 | return "Sorry, I don't know what you mean by..."
53 |
--------------------------------------------------------------------------------
/app/models/completion_models.py:
--------------------------------------------------------------------------------
1 | from sys import version_info
2 | from typing import Dict, List, Optional, Literal, TypedDict
3 |
4 | # If python version >= 3.11, use the built-in NotRequired type.
5 | # Otherwise, import it from typing_extensi
6 | if version_info >= (3, 11):
7 | from typing import NotRequired # type: ignore
8 | else:
9 | from typing_extensions import NotRequired
10 |
11 | from .function_calling.base import JsonTypes
12 |
13 |
14 | class FunctionCallParsed(TypedDict):
15 | name: str
16 | arguments: NotRequired[Dict[str, JsonTypes]]
17 |
18 |
19 | class FunctionCallUnparsed(TypedDict):
20 | name: NotRequired[str]
21 | arguments: NotRequired[str]
22 |
23 |
24 | class EmbeddingUsage(TypedDict):
25 | prompt_tokens: int
26 | total_tokens: int
27 |
28 |
29 | class EmbeddingData(TypedDict):
30 | index: int
31 | object: str
32 | embedding: List[float]
33 |
34 |
35 | class Embedding(TypedDict):
36 | object: Literal["list"]
37 | model: str
38 | data: List[EmbeddingData]
39 | usage: EmbeddingUsage
40 |
41 |
42 | class CompletionLogprobs(TypedDict):
43 | text_offset: List[int]
44 | token_logprobs: List[Optional[float]]
45 | tokens: List[str]
46 | top_logprobs: List[Optional[Dict[str, float]]]
47 |
48 |
49 | class CompletionChoice(TypedDict):
50 | text: str
51 | index: int
52 | logprobs: Optional[CompletionLogprobs]
53 | finish_reason: Optional[str]
54 |
55 |
56 | class CompletionUsage(TypedDict):
57 | prompt_tokens: int
58 | completion_tokens: int
59 | total_tokens: int
60 |
61 |
62 | class CompletionChunk(TypedDict):
63 | id: str
64 | object: Literal["text_completion"]
65 | created: int
66 | model: str
67 | choices: List[CompletionChoice]
68 |
69 |
70 | class Completion(TypedDict):
71 | id: str
72 | object: Literal["text_completion"]
73 | created: int
74 | model: str
75 | choices: List[CompletionChoice]
76 | usage: CompletionUsage
77 |
78 |
79 | class ChatCompletionMessage(TypedDict):
80 | role: Literal["assistant", "user", "system"]
81 | content: str
82 | user: NotRequired[str]
83 | function_call: NotRequired[FunctionCallUnparsed]
84 |
85 |
86 | class ChatCompletionChoice(TypedDict):
87 | index: int
88 | message: ChatCompletionMessage
89 | finish_reason: Optional[str]
90 |
91 |
92 | class ChatCompletion(TypedDict):
93 | id: str
94 | object: Literal["chat.completion"]
95 | created: int
96 | model: str
97 | choices: List[ChatCompletionChoice]
98 | usage: CompletionUsage
99 |
100 |
101 | class ChatCompletionChunkDelta(TypedDict):
102 | role: NotRequired[Literal["assistant"]]
103 | content: NotRequired[str]
104 | function_call: NotRequired[FunctionCallUnparsed]
105 |
106 |
107 | class ChatCompletionChunkChoice(TypedDict):
108 | index: int
109 | delta: ChatCompletionChunkDelta
110 | finish_reason: Optional[str]
111 |
112 |
113 | class ChatCompletionChunk(TypedDict):
114 | id: str
115 | model: str
116 | object: Literal["chat.completion.chunk"]
117 | created: int
118 | choices: List[ChatCompletionChunkChoice]
119 |
120 |
121 | class ModelData(TypedDict):
122 | id: str
123 | object: Literal["model"]
124 | owned_by: str
125 | permissions: List[str]
126 |
127 |
128 | class ModelList(TypedDict):
129 | object: Literal["list"]
130 | data: List[ModelData]
131 |
--------------------------------------------------------------------------------
/app/models/function_calling/base.py:
--------------------------------------------------------------------------------
1 | """Helper classes for wrapping functions in OpenAI's API"""
2 |
3 | from dataclasses import dataclass
4 | from sys import version_info
5 | from types import NoneType
6 | from typing import (
7 | Any,
8 | Generic,
9 | Literal,
10 | Optional,
11 | Type,
12 | TypeVar,
13 | TypedDict,
14 | Union,
15 | )
16 |
17 | # If python version >= 3.11, use the built-in NotRequired type.
18 | # Otherwise, import it from typing_extensi
19 | if version_info >= (3, 11):
20 | from typing import NotRequired # type: ignore
21 | else:
22 | from typing_extensions import NotRequired
23 |
24 | # The types that can be used in JSON
25 | JsonTypes = Union[int, float, str, bool, dict, list, None]
26 |
27 | ParamType = TypeVar("ParamType", bound=JsonTypes)
28 | ReturnType = TypeVar("ReturnType")
29 |
30 |
31 | class ParameterProperty(TypedDict):
32 | type: str
33 | description: NotRequired[str]
34 | enum: NotRequired[list[JsonTypes]]
35 |
36 |
37 | class ParameterDefinition(TypedDict):
38 | type: Literal["object"]
39 | properties: dict[str, ParameterProperty]
40 | required: NotRequired[list[str]]
41 |
42 |
43 | class FunctionProperty(TypedDict):
44 | name: str
45 | description: NotRequired[str]
46 | parameters: NotRequired[ParameterDefinition]
47 |
48 |
49 | @dataclass
50 | class FunctionCallParameter(Generic[ParamType]):
51 | """A class for wrapping function parameters in OpenAI's API"""
52 |
53 | name: str
54 | type: Type[ParamType]
55 | description: Optional[str] = None
56 | enum: Optional[list[ParamType]] = None
57 |
58 | def to_dict(self) -> dict[str, ParameterProperty]:
59 | """Returns a dictionary representation of the parameter"""
60 | parameter_property: ParameterProperty = {
61 | "type": self._get_json_type(self.type)
62 | } # type: ignore
63 | if self.description:
64 | parameter_property["description"] = self.description
65 | if self.enum:
66 | parameter_property["enum"] = self.enum # type: ignore
67 | return {self.name: parameter_property}
68 |
69 | @staticmethod
70 | def _get_json_type(python_type: Type[JsonTypes]) -> str:
71 | """Returns the JSON type for a given python type"""
72 | if python_type is int:
73 | return "integer"
74 | elif python_type is float:
75 | return "number"
76 | elif python_type is str:
77 | return "string"
78 | elif python_type is bool:
79 | return "boolean"
80 | elif python_type is dict:
81 | return "object"
82 | elif python_type is list:
83 | return "array"
84 | elif python_type is NoneType or python_type is None:
85 | return "null"
86 | else:
87 | raise ValueError(
88 | f"Invalid type {python_type} for JSON. "
89 | f"Permitted types are {JsonTypes}"
90 | )
91 |
92 |
93 | @dataclass
94 | class FunctionCall:
95 | """A class for wrapping functions in OpenAI's API"""
96 |
97 | name: str
98 | description: Optional[str] = None
99 | parameters: Optional[list[FunctionCallParameter[Any]]] = None
100 | required: Optional[list[str]] = None
101 |
102 | def to_dict(self) -> FunctionProperty:
103 | """Returns a dictionary representation of the function"""
104 | function_property: FunctionProperty = FunctionProperty(name=self.name) # type: ignore
105 | if self.description:
106 | function_property["description"] = self.description
107 | if self.parameters:
108 | function_property["parameters"] = {
109 | "type": "object",
110 | "properties": {
111 | param.name: param.to_dict()[param.name]
112 | for param in self.parameters
113 | },
114 | "required": [
115 | param.name
116 | for param in self.parameters
117 | if param.name in (self.required or [])
118 | ],
119 | }
120 | return function_property
121 |
122 |
123 | if __name__ == "__main__":
124 |
125 | def test_callback(test: str) -> int:
126 | return len(test)
127 |
128 | param_1: FunctionCallParameter[str] = FunctionCallParameter(
129 | name="param_1",
130 | type=str,
131 | description="This is a test1",
132 | enum=["a", "b", "c", "d"],
133 | ) # `str` is the type of the parameter
134 |
135 | param_2: FunctionCallParameter[bool] = FunctionCallParameter(
136 | name="param_2",
137 | type=bool,
138 | description="This is a test2",
139 | enum=[True, False],
140 | ) # `bool` is the type of the parameter
141 |
142 | func: FunctionCall = FunctionCall(
143 | name="test_function",
144 | description="This is a test function",
145 | parameters=[param_1, param_2],
146 | required=["test"],
147 | ) # There's no type for the function
148 |
--------------------------------------------------------------------------------
/app/routers/__init__.py:
--------------------------------------------------------------------------------
1 | from app.routers import auth, index, services, user_services, users, websocket
2 |
3 | __all__ = [
4 | "auth",
5 | "index",
6 | "services",
7 | "user_services",
8 | "users",
9 | "websocket",
10 | ]
11 |
--------------------------------------------------------------------------------
/app/routers/auth.py:
--------------------------------------------------------------------------------
1 | import bcrypt
2 | from fastapi import APIRouter, Response, Security, status
3 | from fastapi.requests import Request
4 |
5 | from app.common.config import TOKEN_EXPIRE_HOURS
6 | from app.database.crud.users import is_email_exist, register_new_user
7 | from app.database.schemas.auth import Users
8 | from app.dependencies import USER_DEPENDENCY
9 | from app.errors.api_exceptions import Responses_400, Responses_404
10 | from app.models.base_models import SnsType, Token, UserRegister, UserToken
11 | from app.utils.auth.register_validation import (
12 | is_email_length_in_range,
13 | is_email_valid_format,
14 | is_password_length_in_range,
15 | )
16 | from app.utils.auth.token import create_access_token, token_decode
17 | from app.utils.chat.managers.cache import CacheManager
18 |
19 | router = APIRouter(prefix="/auth")
20 |
21 |
22 | @router.post("/register/{sns_type}", status_code=201, response_model=Token)
23 | async def register(
24 | request: Request,
25 | response: Response,
26 | sns_type: SnsType,
27 | reg_info: UserRegister,
28 | ) -> Token:
29 | if sns_type == SnsType.EMAIL:
30 | if not (reg_info.email and reg_info.password):
31 | raise Responses_400.no_email_or_password
32 |
33 | if is_email_length_in_range(email=reg_info.email) is False:
34 | raise Responses_400.email_length_not_in_range
35 |
36 | if is_password_length_in_range(password=reg_info.password) is False:
37 | raise Responses_400.password_length_not_in_range
38 |
39 | if is_email_valid_format(email=reg_info.email) is False:
40 | raise Responses_400.invalid_email_format
41 |
42 | if await is_email_exist(email=reg_info.email):
43 | raise Responses_400.email_already_exists
44 |
45 | hashed_password: str = bcrypt.hashpw(
46 | password=reg_info.password.encode("utf-8"),
47 | salt=bcrypt.gensalt(),
48 | ).decode("utf-8")
49 | new_user: Users = await register_new_user(
50 | email=reg_info.email,
51 | hashed_password=hashed_password,
52 | ip_address=request.state.ip,
53 | )
54 | data_to_be_tokenized: dict = UserToken.from_orm(new_user).dict(
55 | exclude={"password", "marketing_agree"}
56 | )
57 | token: str = create_access_token(
58 | data=data_to_be_tokenized, expires_delta=TOKEN_EXPIRE_HOURS
59 | )
60 | response.set_cookie(
61 | key="Authorization",
62 | value=f"Bearer {token}",
63 | max_age=TOKEN_EXPIRE_HOURS * 3600,
64 | secure=True,
65 | httponly=True,
66 | )
67 | return Token(Authorization=f"Bearer {token}")
68 | raise Responses_400.not_supported_feature
69 |
70 |
71 | @router.delete("/register", status_code=status.HTTP_204_NO_CONTENT)
72 | async def unregister(
73 | authorization: str = Security(USER_DEPENDENCY),
74 | ):
75 | registered_user: UserToken = UserToken(**token_decode(authorization))
76 | if registered_user.email is not None:
77 | await CacheManager.delete_user(user_id=registered_user.email)
78 | await Users.delete_filtered(
79 | Users.email == registered_user.email, autocommit=True
80 | )
81 |
82 |
83 | @router.post("/login/{sns_type}", status_code=200, response_model=Token)
84 | async def login(
85 | response: Response,
86 | sns_type: SnsType,
87 | user_info: UserRegister,
88 | ) -> Token:
89 | if sns_type == SnsType.EMAIL:
90 | if not (user_info.email and user_info.password):
91 | raise Responses_400.no_email_or_password
92 | matched_user: Users = await Users.first_filtered_by(
93 | email=user_info.email
94 | )
95 | if matched_user is None or matched_user.password is None:
96 | raise Responses_404.not_found_user
97 | if not bcrypt.checkpw(
98 | password=user_info.password.encode("utf-8"),
99 | hashed_password=matched_user.password.encode("utf-8"),
100 | ):
101 | raise Responses_404.not_found_user
102 | data_to_be_tokenized: dict = UserToken.from_orm(matched_user).dict(
103 | exclude={"password", "marketing_agree"}
104 | )
105 | token: str = create_access_token(
106 | data=data_to_be_tokenized, expires_delta=TOKEN_EXPIRE_HOURS
107 | )
108 | response.set_cookie(
109 | key="Authorization",
110 | value=f"Bearer {token}",
111 | max_age=TOKEN_EXPIRE_HOURS * 3600,
112 | secure=True,
113 | httponly=True,
114 | )
115 | return Token(Authorization=f"Bearer {token}")
116 | else:
117 | raise Responses_400.not_supported_feature
118 |
--------------------------------------------------------------------------------
/app/routers/index.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Request
2 | from fastapi.responses import FileResponse
3 |
4 | router = APIRouter()
5 |
6 |
7 | @router.get("/")
8 | async def index():
9 | return FileResponse("app/web/index.html")
10 |
11 |
12 | @router.get("/favicon.ico", include_in_schema=False)
13 | async def favicon():
14 | return FileResponse("app/contents/favicon.ico")
15 |
16 |
17 | @router.get("/test")
18 | async def test(request: Request):
19 | return {"username": request.session.get("username", None)}
20 |
--------------------------------------------------------------------------------
/app/routers/user_services.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import APIRouter
4 |
5 | from app.models.base_models import MessageOk
6 |
7 | router = APIRouter(prefix="/user-services")
8 |
9 |
10 | @router.get("/", status_code=200)
11 | def test(query: Optional[str] = None):
12 | return MessageOk(message="Hello, World!" if query is None else query)
13 |
--------------------------------------------------------------------------------
/app/routers/users.py:
--------------------------------------------------------------------------------
1 | import ipaddress
2 |
3 | from fastapi import APIRouter
4 | from starlette.requests import Request
5 |
6 | from app.database.crud import api_keys, api_whitelists, users
7 | from app.errors.api_exceptions import Responses_400
8 | from app.models.base_models import (
9 | AddApiKey,
10 | CreateApiWhiteList,
11 | GetApiKey,
12 | GetApiKeyFirstTime,
13 | GetApiWhiteList,
14 | MessageOk,
15 | UserMe,
16 | )
17 |
18 | router = APIRouter(prefix="/user")
19 |
20 |
21 | @router.get("/me", response_model=UserMe)
22 | async def get_me(request: Request):
23 | return await users.get_me(user_id=request.state.user.id)
24 |
25 |
26 | @router.put("/me")
27 | async def put_me(request: Request):
28 | ...
29 |
30 |
31 | @router.delete("/me")
32 | async def delete_me(request: Request):
33 | ...
34 |
35 |
36 | @router.get("/apikeys", response_model=list[GetApiKey])
37 | async def get_api_keys(request: Request):
38 | return await api_keys.get_api_keys(user_id=request.state.user.id)
39 |
40 |
41 | @router.post("/apikeys", response_model=GetApiKeyFirstTime, status_code=201)
42 | async def create_api_key(request: Request, api_key_info: AddApiKey):
43 | return await api_keys.create_api_key(
44 | user_id=request.state.user.id, additional_key_info=api_key_info
45 | )
46 |
47 |
48 | @router.put("/apikeys/{key_id}", response_model=GetApiKey)
49 | async def update_api_key(
50 | request: Request,
51 | api_key_id: int,
52 | api_key_info: AddApiKey,
53 | ):
54 | """
55 | API KEY User Memo Update
56 | """
57 | return await api_keys.update_api_key(
58 | updated_key_info=api_key_info.dict(),
59 | access_key_id=api_key_id,
60 | user_id=request.state.user.id,
61 | )
62 |
63 |
64 | @router.delete("/apikeys/{key_id}")
65 | async def delete_api_key(
66 | request: Request,
67 | key_id: int,
68 | access_key: str,
69 | ):
70 | await api_keys.delete_api_key(
71 | access_key_id=key_id,
72 | access_key=access_key,
73 | user_id=request.state.user.id,
74 | )
75 | return MessageOk()
76 |
77 |
78 | @router.get(
79 | "/apikeys/{key_id}/whitelists", response_model=list[GetApiWhiteList]
80 | )
81 | async def get_api_keys_whitelist(api_key_id: int):
82 | return await api_whitelists.get_api_key_whitelist(api_key_id=api_key_id)
83 |
84 |
85 | @router.post("/apikeys/{key_id}/whitelists", response_model=GetApiWhiteList)
86 | async def create_api_keys_whitelist(
87 | api_key_id: int,
88 | ip: CreateApiWhiteList,
89 | ):
90 | ip_address: str = ip.ip_address
91 | try:
92 | ipaddress.ip_address(ip_address)
93 | except Exception as exception:
94 | raise Responses_400.invalid_ip(
95 | lazy_format={"ip": ip_address}, ex=exception
96 | )
97 | return await api_whitelists.create_api_key_whitelist(
98 | ip_address=ip_address, api_key_id=api_key_id
99 | )
100 |
101 |
102 | @router.delete("/apikeys/{key_id}/whitelists/{list_id}")
103 | async def delete_api_keys_whitelist(
104 | request: Request,
105 | api_key_id: int,
106 | whitelist_id: int,
107 | ):
108 | await api_whitelists.delete_api_key_whitelist(
109 | user_id=request.state.user.id,
110 | api_key_id=api_key_id,
111 | whitelist_id=whitelist_id,
112 | )
113 | return MessageOk()
114 |
--------------------------------------------------------------------------------
/app/routers/websocket.py:
--------------------------------------------------------------------------------
1 | from asyncio import sleep
2 |
3 | from fastapi import APIRouter, WebSocket
4 |
5 | from app.common.config import API_ENV, HOST_MAIN, OPENAI_API_KEY
6 | from app.database.crud import api_keys
7 | from app.database.schemas.auth import Users
8 | from app.errors.api_exceptions import Responses_400, Responses_401
9 | from app.utils.chat.managers.stream import ChatStreamManager
10 | from app.utils.chat.managers.websocket import SendToWebsocket
11 | from app.viewmodels.status import ApiKeyStatus, UserStatus
12 |
13 | router = APIRouter()
14 |
15 |
16 | @router.websocket("/chat/{api_key}")
17 | async def ws_chat(websocket: WebSocket, api_key: str):
18 | if OPENAI_API_KEY is None:
19 | raise Responses_400.not_supported_feature
20 | await websocket.accept() # accept websocket
21 | if api_key != OPENAI_API_KEY and not API_ENV == "test":
22 | _api_key, _user = await api_keys.get_api_key_and_owner(
23 | access_key=api_key
24 | )
25 | if _user.status not in (UserStatus.active, UserStatus.admin):
26 | await SendToWebsocket.message(
27 | websocket=websocket,
28 | msg="Your account is not active",
29 | chat_room_id=" ",
30 | )
31 | await sleep(60)
32 | raise Responses_401.not_authorized
33 | if _api_key.status is not ApiKeyStatus.active:
34 | await SendToWebsocket.message(
35 | websocket=websocket,
36 | msg="Your api key is not active",
37 | chat_room_id=" ",
38 | )
39 | await sleep(60)
40 | raise Responses_401.not_authorized
41 | else:
42 | _user = Users(email=f"testaccount@{HOST_MAIN}")
43 | await ChatStreamManager.begin_chat(
44 | websocket=websocket,
45 | user=_user,
46 | )
47 |
--------------------------------------------------------------------------------
/app/utils/api/weather.py:
--------------------------------------------------------------------------------
1 | import httpx
2 | import orjson
3 | from typing import Any, Literal
4 | from fastapi import HTTPException
5 |
6 |
7 | async def fetch_weather_data(
8 | lat: float,
9 | lon: float,
10 | api_key: str,
11 | source: Literal["openweathermap", "weatherbit", "climacell"],
12 | ) -> Any:
13 | base_url = {
14 | "openweathermap": "https://api.openweathermap.org/data/2.5/weather",
15 | "weatherbit": "https://api.weatherbit.io/v2.0/current",
16 | "climacell": "https://api.climacell.co/v3/weather/realtime",
17 | }[source]
18 | query_params = {
19 | "openweathermap": f"lat={lat}&lon={lon}&appid={api_key}",
20 | "weatherbit": f"lat={lat}&lon={lon}&key={api_key}",
21 | "climacell": f"lat={lat}&lon={lon}&unit_system=metric&apikey={api_key}",
22 | }[source]
23 |
24 | async with httpx.AsyncClient() as client:
25 | response = await client.get(base_url + "?" + query_params)
26 | if response.status_code == 200:
27 | weather_data = orjson.loads(response.content)
28 | print("weather_data:", weather_data)
29 | return weather_data
30 | else:
31 | raise HTTPException(
32 | status_code=response.status_code,
33 | detail=f"Error fetching data from {source}",
34 | )
35 |
36 |
37 | def get_temperature(
38 | weather_data: dict,
39 | source: Literal["openweathermap", "weatherbit", "climacell"],
40 | ):
41 | if source == "openweathermap":
42 | temp = weather_data["main"]["temp"] - 273.15 # Convert from Kelvin to Celsius
43 | elif source == "weatherbit":
44 | temp = weather_data["data"][0]["temp"]
45 | elif source == "climacell":
46 | temp = weather_data["temp"]["value"]
47 | else:
48 | temp = None
49 |
50 | return temp
51 |
--------------------------------------------------------------------------------
/app/utils/auth/api_keys.py:
--------------------------------------------------------------------------------
1 | from secrets import choice
2 | from string import ascii_letters, digits
3 | from uuid import uuid4
4 |
5 | from app.database.schemas.auth import ApiKeys
6 | from app.models.base_models import AddApiKey
7 |
8 |
9 | def generate_new_api_key(
10 | user_id: int, additional_key_info: AddApiKey
11 | ) -> ApiKeys:
12 | alnums = ascii_letters + digits
13 | secret_key = "".join(choice(alnums) for _ in range(40))
14 | uid = f"{str(uuid4())[:-12]}{str(uuid4())}"
15 | new_api_key = ApiKeys(
16 | secret_key=secret_key,
17 | user_id=user_id,
18 | access_key=uid,
19 | **additional_key_info.dict(),
20 | )
21 | return new_api_key
22 |
--------------------------------------------------------------------------------
/app/utils/auth/register_validation.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | # Make a regular expression
4 | # for validating an Email
5 | EMAIL_REGEX: re.Pattern = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{1,7}\b")
6 |
7 |
8 | # Define a function for
9 | # for validating an Email
10 | def is_email_valid_format(email: str) -> bool:
11 | return True if EMAIL_REGEX.fullmatch(email) is not None else False
12 |
13 |
14 | def is_email_length_in_range(email: str) -> bool:
15 | return True if 6 <= len(email) <= 50 else False
16 |
17 |
18 | def is_password_length_in_range(password: str) -> bool:
19 | return True if 6 <= len(password) <= 100 else False
20 |
21 |
22 | # Driver Code
23 | if __name__ == "__main__":
24 | for email in ("ankitrai326@gmail.com", "my.ownsite@our-earth.org", "ankitrai326.com", "aa@a.a"):
25 | # calling run function
26 | print(is_email_valid_format(email))
27 |
--------------------------------------------------------------------------------
/app/utils/auth/token.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta
2 |
3 | from jwt import decode as jwt_decode
4 | from jwt import encode as jwt_encode
5 | from jwt.exceptions import DecodeError, ExpiredSignatureError
6 |
7 | from app.common.config import JWT_ALGORITHM, JWT_SECRET
8 | from app.errors.api_exceptions import Responses_401
9 |
10 |
11 | def create_access_token(
12 | *, data: dict, expires_delta: int | None = None
13 | ) -> str:
14 | to_encode: dict = data.copy()
15 | if expires_delta is not None and expires_delta != 0:
16 | to_encode.update(
17 | {"exp": datetime.utcnow() + timedelta(hours=expires_delta)}
18 | )
19 | return jwt_encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
20 |
21 |
22 | def token_decode(authorization: str) -> dict:
23 | if authorization is None:
24 | raise Responses_401.token_decode_failure
25 | try:
26 | authorization = authorization.replace("Bearer ", "")
27 | payload = jwt_decode(
28 | authorization, key=JWT_SECRET, algorithms=[JWT_ALGORITHM]
29 | )
30 | except ExpiredSignatureError:
31 | raise Responses_401.token_expired
32 | except DecodeError:
33 | raise Responses_401.token_decode_failure
34 | return payload
35 |
--------------------------------------------------------------------------------
/app/utils/chat/chat_rooms.py:
--------------------------------------------------------------------------------
1 | from uuid import uuid4
2 |
3 | from app.models.chat_models import UserChatContext
4 |
5 | from .buffer import BufferedUserContext
6 | from .managers.cache import CacheManager
7 |
8 |
9 | async def create_new_chat_room(
10 | user_id: str,
11 | new_chat_room_id: str | None = None,
12 | buffer: BufferedUserContext | None = None,
13 | ) -> UserChatContext:
14 | if buffer is not None:
15 | default: UserChatContext = UserChatContext.construct_default(
16 | user_id=user_id,
17 | chat_room_id=new_chat_room_id if new_chat_room_id else uuid4().hex,
18 | llm_model=buffer.current_llm_model,
19 | )
20 | else:
21 | default: UserChatContext = UserChatContext.construct_default(
22 | user_id=user_id,
23 | chat_room_id=new_chat_room_id if new_chat_room_id else uuid4().hex,
24 | )
25 | await CacheManager.create_context(user_chat_context=default)
26 | if buffer is not None:
27 | buffer.insert_context(user_chat_context=default)
28 | await buffer.change_context_to(index=0)
29 | return default
30 |
31 |
32 | async def delete_chat_room(
33 | chat_room_id_to_delete: str,
34 | buffer: BufferedUserContext,
35 | ) -> bool:
36 | await CacheManager.delete_chat_room(
37 | user_id=buffer.user_id, chat_room_id=chat_room_id_to_delete
38 | )
39 | index: int | None = buffer.find_index_of_chatroom(
40 | chat_room_id=chat_room_id_to_delete
41 | )
42 | if index is None:
43 | return False
44 | buffer.delete_context(index=index)
45 | if not buffer:
46 | await create_new_chat_room(
47 | user_id=buffer.user_id,
48 | buffer=buffer,
49 | )
50 | if buffer.current_chat_room_id == chat_room_id_to_delete:
51 | await buffer.change_context_to(index=0)
52 | return True
53 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/browsing.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from app.models.chat_models import ResponseType
4 | from app.models.function_calling.functions import FunctionCalls
5 | from app.models.llms import OpenAIModel
6 | from app.utils.chat.buffer import BufferedUserContext
7 | from app.utils.chat.messages.handler import MessageHandler
8 | from app.utils.function_calling.query import aget_query_to_search
9 |
10 |
11 | class BrowsingCommands:
12 | @staticmethod
13 | async def browse(
14 | user_query: str, /, buffer: BufferedUserContext, **kwargs
15 | ) -> tuple[Optional[str], ResponseType]:
16 | """Query LLM with duckduckgo browse results\n
17 | /browse """
18 | if user_query.startswith("/"):
19 | # User is trying to invoke another command.
20 | # Give control back to the command handler,
21 | # and let it handle the command.
22 | # e.g. `/browse /help` will invoke `/help` command
23 | return user_query, ResponseType.REPEAT_COMMAND
24 |
25 | # Save user query to buffer and database
26 | await MessageHandler.user(msg=user_query, buffer=buffer)
27 |
28 | if not isinstance(buffer.current_llm_model.value, OpenAIModel):
29 | # Non-OpenAI models can't invoke function call,
30 | # so we force function calling here
31 | query_to_search: str = await aget_query_to_search(
32 | buffer=buffer,
33 | query=user_query,
34 | function=FunctionCalls.get_function_call(
35 | FunctionCalls.web_search
36 | ),
37 | )
38 | await MessageHandler.function_call(
39 | callback_name=FunctionCalls.web_search.__name__,
40 | callback_kwargs={"query_to_search": query_to_search},
41 | buffer=buffer,
42 | )
43 | else:
44 | # OpenAI models can invoke function call,
45 | # so let the AI decide whether to invoke function call
46 | function = FunctionCalls.get_function_call(
47 | FunctionCalls.web_search
48 | )
49 | buffer.optional_info["functions"] = [function]
50 | buffer.optional_info["function_call"] = function
51 | await MessageHandler.ai(buffer=buffer)
52 |
53 | # End of command
54 | return None, ResponseType.DO_NOTHING
55 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/llm_parameter.py:
--------------------------------------------------------------------------------
1 | from app.models.chat_models import UserChatContext, command_response
2 | from app.utils.chat.managers.cache import CacheManager
3 |
4 |
5 | class LLMParameterCommands:
6 | @staticmethod
7 | @command_response.send_message_and_stop
8 | async def temp(
9 | temp_to_change: float, user_chat_context: UserChatContext
10 | ) -> str: # set temperature of ai
11 | """Set temperature of ai\n
12 | /temp """
13 | try:
14 | assert (
15 | 0 <= temp_to_change <= 2
16 | ) # assert temperature is between 0 and 2
17 | except AssertionError: # if temperature is not between 0 and 2
18 | return "Temperature must be between 0 and 2"
19 | else:
20 | previous_temperature: str = str(
21 | user_chat_context.user_chat_profile.temperature
22 | )
23 | user_chat_context.user_chat_profile.temperature = temp_to_change
24 | await CacheManager.update_profile_and_model(
25 | user_chat_context
26 | ) # update user_chat_context
27 | return f"I've changed temperature from {previous_temperature} to {temp_to_change}." # return success msg
28 |
29 | @staticmethod
30 | @command_response.send_message_and_stop
31 | async def topp(
32 | top_p_to_change: float, user_chat_context: UserChatContext
33 | ) -> str: # set top_p of ai
34 | """Set top_p of ai\n
35 | /topp """
36 | try:
37 | assert 0 <= top_p_to_change <= 1 # assert top_p is between 0 and 1
38 | except AssertionError: # if top_p is not between 0 and 1
39 | return "Top_p must be between 0 and 1." # return fail message
40 | else:
41 | previous_top_p: str = str(
42 | user_chat_context.user_chat_profile.top_p
43 | )
44 | user_chat_context.user_chat_profile.top_p = (
45 | top_p_to_change # set top_p
46 | )
47 | await CacheManager.update_profile_and_model(
48 | user_chat_context
49 | ) # update user_chat_context
50 | return f"I've changed top_p from {previous_top_p} to {top_p_to_change}." # return success message
51 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/prompt.py:
--------------------------------------------------------------------------------
1 | from app.common.constants import SystemPrompts
2 | from app.models.base_models import MessageHistory
3 | from app.models.chat_models import ChatRoles, UserChatContext, command_response
4 | from app.utils.chat.managers.message import MessageManager
5 |
6 |
7 | class PromptCommands:
8 | @staticmethod
9 | @command_response.send_message_and_stop
10 | async def codex(user_chat_context: UserChatContext) -> str:
11 | """Let GPT act as CODEX("COding DEsign eXpert")\n
12 | /codex"""
13 | system_message = SystemPrompts.CODEX
14 | await MessageManager.clear_message_history_safely(
15 | user_chat_context=user_chat_context, role=ChatRoles.SYSTEM
16 | )
17 | await MessageManager.add_message_history_safely(
18 | user_chat_context=user_chat_context,
19 | role=ChatRoles.SYSTEM,
20 | content=system_message,
21 | )
22 | return "CODEX mode ON"
23 |
24 | @staticmethod
25 | @command_response.send_message_and_stop
26 | async def redx(user_chat_context: UserChatContext) -> str:
27 | """Let GPT reduce your message as much as possible\n
28 | /redx"""
29 | system_message = SystemPrompts.REDEX
30 | await MessageManager.clear_message_history_safely(
31 | user_chat_context=user_chat_context, role=ChatRoles.SYSTEM
32 | )
33 | await MessageManager.add_message_history_safely(
34 | user_chat_context=user_chat_context,
35 | role=ChatRoles.SYSTEM,
36 | content=system_message,
37 | )
38 | return "REDX mode ON"
39 |
40 | @staticmethod
41 | @command_response.send_message_and_stop
42 | async def system(
43 | system_message: str, /, user_chat_context: UserChatContext
44 | ) -> str: # add system message
45 | """Add system message\n
46 | /system """
47 | await MessageManager.add_message_history_safely(
48 | user_chat_context=user_chat_context,
49 | content=system_message,
50 | role=ChatRoles.SYSTEM,
51 | )
52 | return (
53 | f"Added system message: {system_message}" # return success message
54 | )
55 |
56 | @staticmethod
57 | @command_response.send_message_and_stop
58 | async def pop(role: str, user_chat_context: UserChatContext) -> str:
59 | """Pop last message (user or system or ai)\n
60 | /pop """
61 | try:
62 | actual_role: ChatRoles = ChatRoles.get_static_member(role)
63 | except ValueError:
64 | return (
65 | "Role must be one of user, system, ai" # return fail message
66 | )
67 | last_message_history: MessageHistory | None = (
68 | await MessageManager.pop_message_history_safely(
69 | user_chat_context=user_chat_context,
70 | role=actual_role,
71 | )
72 | ) # type: ignore
73 | if last_message_history is None: # if last_message_history is None
74 | return f"There is no {role} message to pop." # return fail message
75 | return f"Pop {role} message: {last_message_history.content}" # return success message
76 |
77 | @staticmethod
78 | @command_response.send_message_and_stop
79 | async def set(
80 | role, new_message: str, /, user_chat_context: UserChatContext
81 | ) -> str:
82 | """Set last message (user or system or ai)\n
83 | /set """
84 | try:
85 | actual_role: ChatRoles = ChatRoles.get_static_member(role)
86 | except ValueError:
87 | return (
88 | "Role must be one of user, system, ai" # return fail message
89 | )
90 | if (
91 | await MessageManager.set_message_history_safely(
92 | user_chat_context=user_chat_context,
93 | role=actual_role,
94 | index=-1,
95 | new_content=new_message,
96 | )
97 | is None
98 | ): # if set message history failed
99 | return f"There is no {role} message to set."
100 | return f"Set {role} message: {new_message}" # return success message
101 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/server.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ProcessPoolExecutor
2 |
3 | from app.common.app_settings_llama_cpp import (
4 | shutdown_llama_cpp_server,
5 | start_llama_cpp_server,
6 | )
7 | from app.common.config import config
8 | from app.models.chat_models import command_response
9 | from app.shared import Shared
10 |
11 |
12 | class ServerCommands:
13 | @staticmethod
14 | @command_response.send_message_and_stop
15 | async def free() -> str:
16 | """Free the process pool executor\n
17 | /free"""
18 | shared = Shared()
19 | shared.process_pool_executor.shutdown(wait=True)
20 | shared.process_pool_executor = ProcessPoolExecutor()
21 | return "Process pool executor freed!"
22 |
23 | @staticmethod
24 | @command_response.send_message_and_stop
25 | async def shutdownllamacpp() -> str:
26 | """Shutdown the llama cpp server\n
27 | /shutdownllamacpp"""
28 | shutdown_llama_cpp_server(shared=Shared())
29 | return "Shutdown llama cpp server!"
30 |
31 | @staticmethod
32 | @command_response.send_message_and_stop
33 | async def startllamacpp() -> str:
34 | """Start the llama cpp server\n
35 | /startllamacpp"""
36 | start_llama_cpp_server(shared=Shared(), config=config)
37 | return "Started llama cpp server!"
38 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/summarize.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | from typing import Optional
3 |
4 | from fastapi.concurrency import run_in_threadpool
5 |
6 | from app.models.chat_models import command_response
7 | from app.shared import Shared
8 | from app.utils.chat.buffer import BufferedUserContext
9 | from app.utils.chat.messages.converter import message_histories_to_str
10 | from app.utils.chat.text_generations.summarization import get_summarization
11 |
12 |
13 | class SummarizeCommands:
14 | @staticmethod
15 | @command_response.send_message_and_stop
16 | async def summarize(
17 | to_summarize: Optional[str], /, buffer: BufferedUserContext
18 | ) -> str:
19 | shared = Shared()
20 | if to_summarize is None:
21 | to_summarize = message_histories_to_str(
22 | user_chat_roles=buffer.current_user_chat_roles,
23 | user_message_histories=buffer.current_user_message_histories,
24 | ai_message_histories=buffer.current_ai_message_histories,
25 | system_message_histories=buffer.current_system_message_histories,
26 | )
27 | to_summarize_tokens = buffer.current_user_chat_context.total_tokens
28 | start: float = time()
29 | summarized = await run_in_threadpool(
30 | get_summarization,
31 | to_summarize=to_summarize,
32 | to_summarize_tokens=to_summarize_tokens,
33 | )
34 | original_tokens: int = len(
35 | shared.token_text_splitter._tokenizer.encode(to_summarize)
36 | )
37 | summarized_tokens: int = len(
38 | shared.token_text_splitter._tokenizer.encode(summarized)
39 | )
40 | end: float = time()
41 | return "\n".join(
42 | (
43 | "# Summarization complete!",
44 | f"- original tokens: {original_tokens}",
45 | f"- summarized tokens: {summarized_tokens}",
46 | f"- summarization time: {end-start}s",
47 | "```",
48 | summarized,
49 | "```",
50 | )
51 | )
52 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/testing.py:
--------------------------------------------------------------------------------
1 | from fastapi import WebSocket
2 |
3 | from app.models.chat_models import (
4 | ResponseType,
5 | UserChatContext,
6 | command_response,
7 | )
8 | from app.models.llms import LLMModels
9 | from app.utils.chat.buffer import BufferedUserContext
10 | from app.utils.chat.managers.websocket import SendToWebsocket
11 |
12 |
13 | class TestingCommands:
14 | @staticmethod
15 | @command_response.send_message_and_stop
16 | def test(
17 | user_chat_context: UserChatContext,
18 | ) -> str: # test command showing user_chat_context
19 | """Test command showing user_chat_context\n
20 | /test"""
21 | return str(user_chat_context)
22 |
23 | @staticmethod
24 | @command_response.send_message_and_stop
25 | def ping() -> str:
26 | """Ping! Pong!\n
27 | /ping"""
28 | return "pong"
29 |
30 | @staticmethod
31 | @command_response.send_message_and_stop
32 | def echo(msg: str, /) -> str:
33 | """Echo your message\n
34 | /echo """
35 | return msg
36 |
37 | @staticmethod
38 | @command_response.do_nothing
39 | async def sendtowebsocket(
40 | msg: str, /, websocket: WebSocket, user_chat_context: UserChatContext
41 | ) -> None:
42 | """Send all messages to websocket\n
43 | /sendtowebsocket"""
44 | await SendToWebsocket.message(
45 | websocket=websocket,
46 | msg=msg,
47 | chat_room_id=user_chat_context.chat_room_id,
48 | )
49 |
50 | @staticmethod
51 | async def testchaining(
52 | chain_size: int, buffer: BufferedUserContext
53 | ) -> tuple[str, ResponseType]:
54 | """Test chains of commands\n
55 | /testchaining """
56 | if chain_size <= 0:
57 | return "Chaining Complete!", ResponseType.SEND_MESSAGE_AND_STOP
58 | await SendToWebsocket.message(
59 | websocket=buffer.websocket,
60 | msg=f"Current Chaining: {chain_size}",
61 | chat_room_id=buffer.current_chat_room_id,
62 | )
63 | return f"/testchaining {chain_size-1}", ResponseType.REPEAT_COMMAND
64 |
65 | @staticmethod
66 | @command_response.send_message_and_stop
67 | def codeblock(language, codes: str, /) -> str:
68 | """Send codeblock\n
69 | /codeblock """
70 | return f"\n```{language.lower()}\n" + codes + "\n```\n"
71 |
72 | @staticmethod
73 | @command_response.send_message_and_stop
74 | def addmember(name: str, /, user_chat_context: UserChatContext) -> str:
75 | LLMModels.add_member(name, None)
76 | return f"Added {name} to members"
77 |
--------------------------------------------------------------------------------
/app/utils/chat/commands/vectorstore.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from app.common.config import config
4 | from app.common.lotties import Lotties
5 | from app.models.chat_models import ResponseType, command_response
6 | from app.models.function_calling.functions import FunctionCalls
7 | from app.models.llms import OpenAIModel
8 | from app.utils.chat.buffer import BufferedUserContext
9 | from app.utils.chat.managers.vectorstore import VectorStoreManager
10 | from app.utils.chat.messages.handler import MessageHandler
11 | from app.utils.function_calling.query import aget_query_to_search
12 | from app.viewmodels.status import UserStatus
13 |
14 |
15 | class VectorstoreCommands:
16 | @staticmethod
17 | async def query(
18 | user_query: str, /, buffer: BufferedUserContext, **kwargs
19 | ) -> tuple[Optional[str], ResponseType]:
20 | """Query from redis vectorstore\n
21 | /query """
22 | if user_query.startswith("/"):
23 | # User is trying to invoke another command.
24 | # Give control back to the command handler,
25 | # and let it handle the command.
26 | # e.g. `/query /help` will invoke `/help` command
27 | return user_query, ResponseType.REPEAT_COMMAND
28 |
29 | # Save user query to buffer and database
30 | await MessageHandler.user(msg=user_query, buffer=buffer)
31 |
32 | if not isinstance(buffer.current_llm_model.value, OpenAIModel):
33 | # Non-OpenAI models can't invoke function call,
34 | # so we force function calling here
35 | query_to_search: str = await aget_query_to_search(
36 | buffer=buffer,
37 | query=user_query,
38 | function=FunctionCalls.get_function_call(
39 | FunctionCalls.vectorstore_search
40 | ),
41 | )
42 | await MessageHandler.function_call(
43 | callback_name=FunctionCalls.vectorstore_search.__name__,
44 | callback_kwargs={"query_to_search": query_to_search},
45 | buffer=buffer,
46 | )
47 | else:
48 | # OpenAI models can invoke function call,
49 | # so let the AI decide whether to invoke function call
50 | function = FunctionCalls.get_function_call(
51 | FunctionCalls.vectorstore_search
52 | )
53 | buffer.optional_info["functions"] = [function]
54 | buffer.optional_info["function_call"] = function
55 | await MessageHandler.ai(buffer=buffer)
56 |
57 | # End of command
58 | return None, ResponseType.DO_NOTHING
59 |
60 | @staticmethod
61 | @command_response.send_message_and_stop
62 | async def embed(text_to_embed: str, /, buffer: BufferedUserContext) -> str:
63 | """Embed the text and save its vectors in the redis vectorstore.\n
64 | /embed """
65 | await VectorStoreManager.create_documents(
66 | text=text_to_embed, collection_name=buffer.user_id
67 | )
68 | return Lotties.OK.format("Embedding successful!")
69 |
70 | @staticmethod
71 | @command_response.send_message_and_stop
72 | async def share(text_to_embed: str, /) -> str:
73 | """Embed the text and save its vectors in the redis vectorstore. This index is shared for everyone.\n
74 | /share """
75 | await VectorStoreManager.create_documents(
76 | text=text_to_embed, collection_name=config.shared_vectorestore_name
77 | )
78 | return Lotties.OK.format(
79 | "Embedding successful!\nThis data will be shared for everyone."
80 | )
81 |
82 | @staticmethod
83 | @command_response.send_message_and_stop
84 | async def drop(buffer: BufferedUserContext) -> str:
85 | """Drop the index from the redis vectorstore.\n
86 | /drop"""
87 | dropped_index: list[str] = []
88 | if await VectorStoreManager.delete_collection(
89 | collection_name=buffer.user_id
90 | ):
91 | dropped_index.append(buffer.user_id)
92 | if (
93 | buffer.user.status is UserStatus.admin
94 | and await VectorStoreManager.delete_collection(
95 | collection_name=config.shared_vectorestore_name,
96 | )
97 | ):
98 | dropped_index.append(config.shared_vectorestore_name)
99 | if not dropped_index:
100 | return "No index dropped."
101 | return f"Index dropped: {', '.join(dropped_index)}"
102 |
--------------------------------------------------------------------------------
/app/utils/chat/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 |
5 | class BaseEmbeddingGenerator(ABC):
6 | @abstractmethod
7 | def __del__(self):
8 | """Clean up resources."""
9 | ...
10 |
11 | @classmethod
12 | @abstractmethod
13 | def from_pretrained(cls, model_name: str) -> "BaseEmbeddingGenerator":
14 | """Load a pretrained model into RAM."""
15 | return cls
16 |
17 | @abstractmethod
18 | def generate_embeddings(
19 | self,
20 | texts: list[str],
21 | **kwargs: Any,
22 | ) -> list[list[float]]:
23 | """Generate embeddings for a list of texts."""
24 | ...
25 |
26 | @property
27 | @abstractmethod
28 | def model_name(self) -> str:
29 | """Identifier for the model used by this generator."""
30 | ...
31 |
--------------------------------------------------------------------------------
/app/utils/chat/embeddings/sentence_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Callable, Optional
2 |
3 | import numpy as np
4 | import tensorflow_hub as hub
5 |
6 | from app.utils.colorama import Fore
7 | from app.utils.logger import ApiLogger
8 |
9 | from . import BaseEmbeddingGenerator
10 |
11 | if TYPE_CHECKING:
12 | from tensorflow.python.framework.ops import Tensor
13 |
14 |
15 | class SentenceEncoderEmbeddingGenerator(BaseEmbeddingGenerator):
16 | """Generate embeddings using a sentence encoder model,
17 | automatically downloading the model from https://tfhub.dev/"""
18 |
19 | base_url: str = "https://tfhub.dev/google/"
20 | model: Optional[Callable[[list[str]], "Tensor"]] = None
21 | _model_name: Optional[str] = None
22 |
23 | def __del__(self) -> None:
24 | if self.model is not None:
25 | getattr(self.model, "__del__", lambda: None)()
26 | del self.model
27 | self.model = None
28 | print("🗑️ SentenceEncoderEmbedding deleted!")
29 |
30 | @classmethod
31 | def from_pretrained(
32 | cls, model_name: str
33 | ) -> "SentenceEncoderEmbeddingGenerator":
34 | self = cls()
35 | self._model_name = model_name
36 | url = f"{self.base_url.rstrip('/')}/{model_name.lstrip('/')}"
37 | self.model = hub.load(url) # type: ignore
38 | ApiLogger("||embeddings||").info(
39 | f"🤖 TFHub {Fore.WHITE}{model_name}{Fore.GREEN} loaded!",
40 | )
41 | return self
42 |
43 | def generate_embeddings(
44 | self,
45 | texts: list[str],
46 | batch_size: int = 100,
47 | **kwargs,
48 | ) -> list[list[float]]:
49 | assert self.model is not None, "Please load the model first."
50 | embeddings: list["Tensor"] = []
51 | for batch_idx in range(0, len(texts), batch_size):
52 | batch_texts = texts[batch_idx : (batch_idx + batch_size)]
53 | embeddings.append(self.model(batch_texts))
54 | return np.vstack(embeddings).tolist()
55 |
56 | @property
57 | def model_name(self) -> str:
58 | return self._model_name or self.__class__.__name__
59 |
60 |
61 | # class SemanticSearch:
62 | # use: Callable[[list[str]], "Tensor"]
63 | # nn: Optional[NearestNeighbors] = None
64 | # data: Optional[list[str]] = None
65 | # embeddings: Optional[np.ndarray] = None
66 |
67 | # def __init__(self):
68 | # self.use = hub.load(
69 | # "https://tfhub.dev/google/universal-sentence-encoder/4",
70 | # ) # type: ignore
71 |
72 | # def __call__(self, text: str, return_data: bool = True):
73 | # assert (
74 | # self.nn is not None and self.data is not None
75 | # ), "Please fit the model first."
76 | # query_embedding: "Tensor" = self.use([text])
77 | # neighbors: np.ndarray = self.nn.kneighbors(
78 | # query_embedding,
79 | # return_distance=False,
80 | # )[0]
81 |
82 | # if return_data:
83 | # return [self.data[i] for i in neighbors]
84 | # else:
85 | # return neighbors
86 |
87 | # def query(
88 | # self,
89 | # text: str,
90 | # return_data: bool = True,
91 | # ):
92 | # return self(text, return_data=return_data)
93 |
94 | # def fit(self, data: list[str], batch: int = 1000, n_neighbors: int = 5):
95 | # self.data = data
96 | # self.embeddings = self._get_text_embedding(data, batch=batch)
97 | # self.nn = NearestNeighbors(
98 | # n_neighbors=min(
99 | # n_neighbors,
100 | # len(self.embeddings),
101 | # )
102 | # )
103 | # self.nn.fit(self.embeddings)
104 |
105 | # def _get_text_embedding(
106 | # self,
107 | # texts: list[str],
108 | # batch: int = 1000,
109 | # ) -> np.ndarray:
110 | # embeddings: list["Tensor"] = []
111 | # for batch_idx in range(0, len(texts), batch):
112 | # text_batch = texts[batch_idx : (batch_idx + batch)]
113 | # embeddings.append(self.use(text_batch))
114 | # return np.vstack(embeddings)
115 |
--------------------------------------------------------------------------------
/app/utils/chat/embeddings/transformer.py:
--------------------------------------------------------------------------------
1 | from gc import collect
2 | from typing import Optional
3 |
4 | from torch import Tensor, cuda
5 | from transformers.modeling_outputs import (
6 | BaseModelOutputWithPoolingAndCrossAttentions,
7 | )
8 | from transformers.modeling_utils import PreTrainedModel
9 | from transformers.models.auto.modeling_auto import AutoModel
10 | from transformers.models.auto.tokenization_auto import AutoTokenizer
11 | from transformers.models.t5.modeling_t5 import T5Model
12 | from transformers.tokenization_utils import PreTrainedTokenizer
13 | from transformers.tokenization_utils_base import BatchEncoding
14 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
15 |
16 | from app.utils.colorama import Fore
17 | from app.utils.logger import ApiLogger
18 |
19 | from . import BaseEmbeddingGenerator
20 |
21 | device = "cuda" if cuda.is_available() else "cpu"
22 |
23 |
24 | class TransformerEmbeddingGenerator(BaseEmbeddingGenerator):
25 | """Generate embeddings using a transformer model,
26 | automatically downloading the model from https://huggingface.co/"""
27 |
28 | model: Optional[PreTrainedModel] = None
29 | tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast] = None
30 | encoder: Optional[PreTrainedModel] = None
31 | _model_name: Optional[str] = None
32 |
33 | def __del__(self) -> None:
34 | if self.model is not None:
35 | getattr(self.model, "__del__", lambda: None)()
36 | self.model = None
37 | print("🗑️ TransformerEmbedding model deleted!")
38 | if self.tokenizer is not None:
39 | getattr(self.tokenizer, "__del__", lambda: None)()
40 | self.tokenizer = None
41 | print("🗑️ TransformerEmbedding tokenizer deleted!")
42 | if self.encoder is not None:
43 | getattr(self.encoder, "__del__", lambda: None)()
44 | self.encoder = None
45 | print("🗑️ TransformerEmbedding encoder deleted!")
46 |
47 | @classmethod
48 | def from_pretrained(
49 | cls, model_name: str
50 | ) -> "TransformerEmbeddingGenerator":
51 | self = cls()
52 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
53 | self._model_name = model_name
54 | ApiLogger("||embeddings||").info(
55 | f"🤖 Huggingface tokenizer {Fore.WHITE}{model_name}{Fore.GREEN} loaded!",
56 | )
57 |
58 | self.model = AutoModel.from_pretrained(model_name)
59 | ApiLogger("||embeddings||").info(
60 | f"🤖 Huggingface model {Fore.WHITE}{model_name}{Fore.GREEN} loaded!",
61 | )
62 | return self
63 |
64 | def generate_embeddings(
65 | self,
66 | texts: list[str],
67 | context_length: int = 512,
68 | batch_size: int = 3,
69 | **kwargs,
70 | ) -> list[list[float]]:
71 | embeddings: list[list[float]] = []
72 | for batch_idx in range(0, len(texts), batch_size):
73 | batch_texts = texts[batch_idx : (batch_idx + batch_size)]
74 | batch_embeddings, _ = self._generate_embeddings_and_n_tokens(
75 | texts=batch_texts, context_length=context_length
76 | )
77 | embeddings.extend(batch_embeddings)
78 | return embeddings
79 |
80 | def _generate_embeddings_and_n_tokens(
81 | self,
82 | texts: list[str],
83 | context_length: int = 512,
84 | ) -> tuple[list[list[float]], int]:
85 | assert self.model is not None and self.tokenizer is not None
86 |
87 | def average_pool(
88 | last_hidden_states: Tensor, attention_mask: Tensor
89 | ) -> Tensor:
90 | last_hidden = last_hidden_states.masked_fill(
91 | ~attention_mask[..., None].bool(), 0.0
92 | )
93 | return (
94 | last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
95 | )
96 |
97 | # Tokenize the input texts
98 | batch_dict: BatchEncoding = self.tokenizer(
99 | texts,
100 | max_length=context_length,
101 | padding="longest",
102 | truncation=True,
103 | return_tensors="pt",
104 | )
105 | if self.encoder is None:
106 | # Get the encoder from the model
107 | if isinstance(self.model, T5Model):
108 | self.encoder = self.model.get_encoder()
109 | else:
110 | self.encoder = self.model
111 |
112 | if device == "cuda":
113 | # Load the encoder into VRAM
114 | self.encoder = self.encoder.to(device)
115 | batch_dict = batch_dict.to(device)
116 | outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.encoder(
117 | **batch_dict
118 | )
119 | embeddings, tokens = (
120 | average_pool(
121 | last_hidden_states=outputs.last_hidden_state,
122 | attention_mask=batch_dict["attention_mask"], # type: ignore
123 | ).tolist(),
124 | sum([len(encoding) for encoding in batch_dict["input_ids"]]), # type: ignore
125 | )
126 | del batch_dict
127 | del outputs
128 | if device == "cuda":
129 | # Deallocate output tensors from VRAM
130 | cuda.empty_cache()
131 | collect()
132 | return embeddings, tokens
133 |
134 | @property
135 | def model_name(self) -> str:
136 | return self._model_name or self.__class__.__name__
137 |
--------------------------------------------------------------------------------
/app/utils/chat/file_loader.py:
--------------------------------------------------------------------------------
1 | import io
2 | from typing import IO, Any
3 |
4 | from langchain.document_loaders.unstructured import UnstructuredBaseLoader
5 | from langchain.docstore.document import Document
6 | from unstructured.partition.auto import partition
7 |
8 |
9 | class UnstructuredFileIOLoader(UnstructuredBaseLoader):
10 | """Loader that uses unstructured to load file IO objects."""
11 |
12 | def __init__(self, file: IO, filename: str, mode: str = "single", **unstructured_kwargs: Any):
13 | """Initialize with file path."""
14 | self.file = file
15 | self.filename = filename
16 | super().__init__(mode=mode, **unstructured_kwargs)
17 |
18 | def _get_elements(self) -> list:
19 | return partition(file=self.file, file_filename=self.filename, **self.unstructured_kwargs)
20 |
21 | def _get_metadata(self) -> dict:
22 | return {}
23 |
24 |
25 | def read_bytes_to_documents(file: bytes, filename: str) -> list[Document]:
26 | return UnstructuredFileIOLoader(file=io.BytesIO(file), strategy="fast", filename=filename).load()
27 |
28 |
29 | def read_bytes_to_text(file: bytes, filename: str) -> str:
30 | return "\n\n".join([doc.page_content for doc in read_bytes_to_documents(file=file, filename=filename)])
31 |
32 |
33 | if __name__ == "__main__":
34 | with open(r"test.pdf", "rb") as f:
35 | file = f.read()
36 | text = read_bytes_to_text(file, "test.pdf")
37 | print(text)
38 |
--------------------------------------------------------------------------------
/app/utils/chat/messages/turn_templates.py:
--------------------------------------------------------------------------------
1 | """This module contains functions to extract information from chat turn templates."""
2 |
3 | from re import DOTALL, compile
4 | from typing import TYPE_CHECKING, Optional
5 |
6 | if TYPE_CHECKING:
7 | from langchain import PromptTemplate
8 |
9 |
10 | def shatter_chat_turn_prompt(
11 | *keys: str, chat_turn_prompt: "PromptTemplate"
12 | ) -> tuple[str, ...]:
13 | """Identify the chat turn template and return the shatter result.
14 | e.g. If template of chat_turn_prompt is "### {role}: {content} "
15 | and keys are "role" and "content",
16 | then the result will be ('### ', "{role}", ': ', "{content}", ' ')."""
17 | pattern: str = "(.*)"
18 | kwargs: dict[str, str] = {}
19 | for key in keys:
20 | kwargs[key] = "{" + key + "}"
21 | pattern += f"({kwargs[key]})(.*)"
22 | search_result = compile(pattern, flags=DOTALL).match(
23 | chat_turn_prompt.format(**kwargs)
24 | )
25 | if search_result is None:
26 | raise ValueError(
27 | f"Invalid chat turn prompt: {chat_turn_prompt.format(**kwargs)}"
28 | )
29 | return search_result.groups()
30 |
31 |
32 | def identify_end_of_string(*keys, chat_turn_prompt: "PromptTemplate") -> Optional[str]:
33 | """Identify the end of string in the chat turn prompt.
34 | e.g. If template of chat_turn_prompt is "### {role}: {content} "
35 | then the result will be "".
36 | If there is no end of string, then the result will be None."""
37 | return shatter_chat_turn_prompt(*keys, chat_turn_prompt=chat_turn_prompt)[
38 | -1
39 | ].strip()
40 |
41 |
42 | if __name__ == "__main__":
43 | from langchain import PromptTemplate
44 |
45 | input_variables = ["role", "content"]
46 | for template, template_format in (
47 | ("### {role}: {content} ", "f-string"),
48 | ("### {{role}}: {{content}} ", "jinja2"),
49 | ):
50 | chat_turn_prompt = PromptTemplate(
51 | template=template,
52 | input_variables=input_variables,
53 | template_format=template_format,
54 | )
55 | print(
56 | "Shattered String:",
57 | shatter_chat_turn_prompt(
58 | *input_variables, chat_turn_prompt=chat_turn_prompt
59 | ),
60 | ) # ('### ', '{role}', ': ', '{content}', ' ')
61 | print(
62 | "End-of-String:",
63 | identify_end_of_string(*input_variables, chat_turn_prompt=chat_turn_prompt),
64 | ) #
65 |
--------------------------------------------------------------------------------
/app/utils/chat/text_generations/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import TYPE_CHECKING, Iterator
3 |
4 | from app.models.base_models import APIChatMessage, TextGenerationSettings
5 | from app.models.completion_models import (
6 | ChatCompletion,
7 | ChatCompletionChunk,
8 | Completion,
9 | CompletionChunk,
10 | )
11 |
12 | if TYPE_CHECKING:
13 | from app.models.llms import LLMModel
14 |
15 |
16 | class BaseCompletionGenerator(ABC):
17 | """Base class for all completion generators."""
18 |
19 | user_role: str = "user"
20 | system_role: str = "system"
21 |
22 | user_input_role: str = "User"
23 | system_input_role: str = "System"
24 |
25 | ai_fallback_input_role: str = "Assistant"
26 |
27 | @abstractmethod
28 | def __del__(self):
29 | """Clean up resources."""
30 | ...
31 |
32 | @classmethod
33 | @abstractmethod
34 | def from_pretrained(
35 | cls, llm_model: "LLMModel"
36 | ) -> "BaseCompletionGenerator":
37 | """Load a pretrained model into RAM."""
38 | ...
39 |
40 | @abstractmethod
41 | def generate_completion(
42 | self, prompt: str, settings: TextGenerationSettings
43 | ) -> Completion:
44 | """Generate a completion for a given prompt."""
45 | ...
46 |
47 | @abstractmethod
48 | def generate_completion_with_streaming(
49 | self, prompt: str, settings: TextGenerationSettings
50 | ) -> Iterator[CompletionChunk]:
51 | """Generate a completion for a given prompt, yielding chunks of text as they are generated."""
52 | ...
53 |
54 | @abstractmethod
55 | def generate_chat_completion(
56 | self, messages: list[APIChatMessage], settings: TextGenerationSettings
57 | ) -> ChatCompletion:
58 | """Generate a completion for a given prompt."""
59 | ...
60 |
61 | @abstractmethod
62 | def generate_chat_completion_with_streaming(
63 | self, messages: list[APIChatMessage], settings: TextGenerationSettings
64 | ) -> Iterator[ChatCompletionChunk]:
65 | """Generate a completion for a given prompt, yielding chunks of text as they are generated."""
66 | ...
67 |
68 | @staticmethod
69 | def get_stop_strings(*roles: str) -> list[str]:
70 | """A helper method to generate stop strings for a given set of roles.
71 | Stop strings are required to stop text completion API from generating
72 | text that does not belong to the current chat turn.
73 | e.g. The common stop string is "### USER:", which can prevent ai from generating
74 | user's message itself."""
75 |
76 | prompt_stop = set()
77 | for role in roles:
78 | avoids = (
79 | f"{role}:",
80 | f"### {role}:",
81 | f"###{role}:",
82 | )
83 | prompt_stop.update(
84 | avoids,
85 | map(str.capitalize, avoids),
86 | map(str.upper, avoids),
87 | map(str.lower, avoids),
88 | )
89 | return list(prompt_stop)
90 |
91 | @classmethod
92 | def convert_messages_into_prompt(
93 | cls, messages: list[APIChatMessage], settings: TextGenerationSettings
94 | ) -> str:
95 | """A helper method to convert list of messages into one text prompt."""
96 | ai_input_role: str = cls.ai_fallback_input_role
97 | chat_history: str = ""
98 | for message in messages:
99 | if message.role.lower() == cls.user_role:
100 | input_role = cls.user_input_role
101 | elif message.role.lower() == cls.system_role:
102 | input_role = cls.system_input_role
103 | else:
104 | input_role = ai_input_role = message.role
105 | chat_history += f"### {input_role}:{message.content}"
106 |
107 | prompt_stop: list[str] = cls.get_stop_strings(
108 | cls.user_input_role, cls.system_input_role, ai_input_role
109 | )
110 | if isinstance(settings.stop, str):
111 | settings.stop = prompt_stop + [settings.stop]
112 | elif isinstance(settings.stop, list):
113 | settings.stop = prompt_stop + settings.stop
114 | else:
115 | settings.stop = prompt_stop
116 | return chat_history + f"### {ai_input_role}:"
117 |
118 | @staticmethod
119 | def is_possible_to_generate_stops(
120 | decoded_text: str, stops: list[str]
121 | ) -> bool:
122 | """A helper method to check if the decoded text contains any of the stop tokens."""
123 | for stop in stops:
124 | if stop in decoded_text or any(
125 | [
126 | decoded_text.endswith(stop[: i + 1])
127 | for i in range(len(stop))
128 | ]
129 | ):
130 | return True
131 | return False
132 |
133 | @property
134 | @abstractmethod
135 | def llm_model(self) -> "LLMModel":
136 | """The LLM model used by this generator."""
137 | ...
138 |
--------------------------------------------------------------------------------
/app/utils/chat/text_generations/path.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 |
4 | from app.common.config import BASE_DIR
5 |
6 |
7 | def resolve_model_path_to_posix(
8 | model_path: str, default_relative_directory: Optional[str] = None
9 | ):
10 | """Resolve a model path to a POSIX path, relative to the BASE_DIR."""
11 | path = Path(model_path)
12 | parent_directory: Path = (
13 | Path(BASE_DIR) / Path(default_relative_directory)
14 | if default_relative_directory is not None
15 | else Path(BASE_DIR)
16 | if Path.cwd() == path.parent.resolve()
17 | else path.parent.resolve()
18 | )
19 | filename: str = path.name
20 | return (parent_directory / filename).as_posix()
21 |
--------------------------------------------------------------------------------
/app/utils/chat/text_generations/summarization.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from app.common.config import ChatConfig
4 | from app.shared import Shared
5 | from app.utils.logger import ApiLogger
6 |
7 |
8 | def get_summarization(
9 | to_summarize: str,
10 | to_summarize_tokens: Optional[int] = None,
11 | ) -> str:
12 | shared = Shared()
13 | if to_summarize_tokens is None:
14 | to_summarize_tokens = len(
15 | shared.token_text_splitter._tokenizer.encode(to_summarize)
16 | )
17 |
18 | if to_summarize_tokens < ChatConfig.summarization_token_limit:
19 | summarize_chain = shared.stuff_summarize_chain
20 | else:
21 | summarize_chain = shared.map_reduce_summarize_chain
22 | result: str = summarize_chain.run(
23 | shared.token_text_splitter.create_documents(
24 | [to_summarize],
25 | tokens_per_chunk=ChatConfig.summarization_chunk_size,
26 | chunk_overlap=ChatConfig.summarization_token_overlap,
27 | )
28 | )
29 | ApiLogger("||get_summarization||").info(result)
30 | return result
31 |
--------------------------------------------------------------------------------
/app/utils/colorama.py:
--------------------------------------------------------------------------------
1 | # Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
2 | """
3 | This module generates ANSI character codes to printing colors to terminals.
4 | See: http://en.wikipedia.org/wiki/ANSI_escape_code
5 | """
6 |
7 | CSI = "\033["
8 | OSC = "\033]"
9 | BEL = "\a"
10 |
11 |
12 | def code_to_chars(code):
13 | return CSI + str(code) + "m"
14 |
15 |
16 | def set_title(title):
17 | return OSC + "2;" + title + BEL
18 |
19 |
20 | def clear_screen(mode=2):
21 | return CSI + str(mode) + "J"
22 |
23 |
24 | def clear_line(mode=2):
25 | return CSI + str(mode) + "K"
26 |
27 |
28 | class AnsiCodes(object):
29 | def __init__(self):
30 | # the subclasses declare class attributes which are numbers.
31 | # Upon instantiation we define instance attributes, which are the same
32 | # as the class attributes but wrapped with the ANSI escape sequence
33 | for name in dir(self):
34 | if not name.startswith("_"):
35 | value = getattr(self, name)
36 | setattr(self, name, code_to_chars(value))
37 |
38 |
39 | class AnsiCursor(object):
40 | def UP(self, n=1):
41 | return CSI + str(n) + "A"
42 |
43 | def DOWN(self, n=1):
44 | return CSI + str(n) + "B"
45 |
46 | def FORWARD(self, n=1):
47 | return CSI + str(n) + "C"
48 |
49 | def BACK(self, n=1):
50 | return CSI + str(n) + "D"
51 |
52 | def POS(self, x=1, y=1):
53 | return CSI + str(y) + ";" + str(x) + "H"
54 |
55 |
56 | class AnsiFore(AnsiCodes):
57 | BLACK = 30
58 | RED = 31
59 | GREEN = 32
60 | YELLOW = 33
61 | BLUE = 34
62 | MAGENTA = 35
63 | CYAN = 36
64 | WHITE = 37
65 | RESET = 39
66 |
67 | # These are fairly well supported, but not part of the standard.
68 | LIGHTBLACK_EX = 90
69 | LIGHTRED_EX = 91
70 | LIGHTGREEN_EX = 92
71 | LIGHTYELLOW_EX = 93
72 | LIGHTBLUE_EX = 94
73 | LIGHTMAGENTA_EX = 95
74 | LIGHTCYAN_EX = 96
75 | LIGHTWHITE_EX = 97
76 |
77 |
78 | class AnsiBack(AnsiCodes):
79 | BLACK = 40
80 | RED = 41
81 | GREEN = 42
82 | YELLOW = 43
83 | BLUE = 44
84 | MAGENTA = 45
85 | CYAN = 46
86 | WHITE = 47
87 | RESET = 49
88 |
89 | # These are fairly well supported, but not part of the standard.
90 | LIGHTBLACK_EX = 100
91 | LIGHTRED_EX = 101
92 | LIGHTGREEN_EX = 102
93 | LIGHTYELLOW_EX = 103
94 | LIGHTBLUE_EX = 104
95 | LIGHTMAGENTA_EX = 105
96 | LIGHTCYAN_EX = 106
97 | LIGHTWHITE_EX = 107
98 |
99 |
100 | class AnsiStyle(AnsiCodes):
101 | BRIGHT = 1
102 | DIM = 2
103 | NORMAL = 22
104 | RESET_ALL = 0
105 |
106 |
107 | Fore = AnsiFore()
108 | Back = AnsiBack()
109 | Style = AnsiStyle()
110 | Cursor = AnsiCursor()
111 |
--------------------------------------------------------------------------------
/app/utils/date_utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, date, timedelta
2 | from re import match
3 |
4 |
5 | class UTC:
6 | def __init__(self):
7 | self.utc_now = datetime.utcnow()
8 |
9 | @classmethod
10 | def now(cls, hour_diff: int = 0) -> datetime:
11 | return cls().utc_now + timedelta(hours=hour_diff)
12 |
13 | @classmethod
14 | def now_isoformat(cls, hour_diff: int = 0) -> str:
15 | return (cls().utc_now + timedelta(hours=hour_diff)).isoformat() + "Z"
16 |
17 | @classmethod
18 | def date(cls, hour_diff: int = 0) -> date:
19 | return cls.now(hour_diff=hour_diff).date()
20 |
21 | @classmethod
22 | def timestamp(cls, hour_diff: int = 0) -> int:
23 | return int(cls.now(hour_diff=hour_diff).strftime("%Y%m%d%H%M%S"))
24 |
25 | @classmethod
26 | def timestamp_to_datetime(cls, timestamp: int, hour_diff: int = 0) -> datetime:
27 | return datetime.strptime(str(timestamp), "%Y%m%d%H%M%S") + timedelta(hours=hour_diff)
28 |
29 | @classmethod
30 | def date_code(cls, hour_diff: int = 0) -> int:
31 | return int(cls.date(hour_diff=hour_diff).strftime("%Y%m%d"))
32 |
33 | @staticmethod
34 | def check_string_valid(string: str) -> bool:
35 | """Check if a string is in ISO 8601 UTC datetime format.
36 | e.g. 2023-05-22T05:08:29.087279Z" -> True
37 | 2023-05-22T05:08:29Z -> True
38 | 2023-05-22T05:08:29 -> False
39 | 2023/05/22T05:08:29.087279Z -> False"""
40 | regex = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z$"
41 | return True if match(regex, string) else False
42 |
43 |
44 | if __name__ == "__main__":
45 | # 예제 사용 예
46 | print(UTC.check_string_valid("2023-05-22T05:08:29.087279Z")) # True
47 | print(UTC.check_string_valid("2023-05-22T05:08:29Z")) # True
48 | print(UTC.check_string_valid("2023-05-22T05:08:29")) # False
49 | print(UTC.check_string_valid("2023/05/22T05:08:29.087279Z")) # False
50 | print(UTC.timestamp_to_datetime(10000101000000)) # Prefix timestamp
51 | print(UTC.timestamp_to_datetime(99991231235959)) # Suffix timestamp
52 |
--------------------------------------------------------------------------------
/app/utils/encoding_utils.py:
--------------------------------------------------------------------------------
1 | from base64 import b64encode, urlsafe_b64encode
2 | from os import urandom
3 | from os.path import exists
4 | from re import findall
5 | from typing import Any
6 |
7 | import json
8 | from cryptography.fernet import Fernet, InvalidToken
9 | from cryptography.hazmat.primitives.hashes import SHA256, HashAlgorithm
10 | from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
11 |
12 |
13 | class SecretConfigSetup:
14 | """
15 | Setup for secret configs encryption and decryption in JSON format.
16 |
17 | - Example:
18 | password_from_environ = environ.get("SECRET_CONFIGS_PASSWORD", None)
19 | secret_config_setup = SecretConfigSetup(
20 | password=password_from_environ
21 | if password_from_environ is not None
22 | else input("Enter Passwords:"),
23 | json_file_name="secret_configs.json",
24 | )
25 |
26 |
27 | @dataclass(frozen=True)
28 | class SecretConfig(metaclass=SingletonMetaClass):
29 | secret_config: dict = field(default_factory=secret_config_setup.initialize)
30 | """
31 |
32 | algorithm: HashAlgorithm = SHA256()
33 | length: int = 32
34 | iterations: int = 100000
35 |
36 | def __init__(self, password: str, json_file_name: str) -> None:
37 | self.password = password
38 | self.json_file_name = json_file_name
39 |
40 | @classmethod
41 | def _get_kdf(cls, salt: bytes) -> PBKDF2HMAC:
42 | return PBKDF2HMAC(
43 | salt=salt,
44 | algorithm=cls.algorithm,
45 | length=cls.length,
46 | iterations=cls.iterations,
47 | )
48 |
49 | def encrypt(self) -> None:
50 | # Derive an encryption key from the password and salt using PBKDF2
51 | salt = urandom(16)
52 | fernet = Fernet(
53 | urlsafe_b64encode(SecretConfigSetup._get_kdf(salt=salt).derive(key_material=self.password.encode()))
54 | )
55 | with open(self.json_file_name, "r") as f:
56 | plain_data: Any = json.load(f)
57 |
58 | # Encrypt the string using the encryption key
59 | encrypted_data: bytes = fernet.encrypt(json.dumps(plain_data).encode())
60 |
61 | # Save the encrypted data and salt to file
62 | with open(f"{self.json_file_name}.enc", "wb") as f:
63 | f.write(salt)
64 | f.write(encrypted_data)
65 |
66 | def decrypt(self) -> Any:
67 | # Load the encrypted data and salt from file
68 | with open(f"{self.json_file_name}.enc", "rb") as f:
69 | salt = f.read(16)
70 | fernet = Fernet(
71 | urlsafe_b64encode(SecretConfigSetup._get_kdf(salt=salt).derive(key_material=self.password.encode()))
72 | )
73 | encrypted_data: bytes = f.read()
74 | decrypted_data: bytes = fernet.decrypt(encrypted_data)
75 | return json.loads(decrypted_data)
76 |
77 | def initialize(self) -> Any:
78 | if not exists(f"{self.json_file_name}.enc"):
79 | self.encrypt()
80 |
81 | while True:
82 | try:
83 | secret_config: Any = self.decrypt()
84 | except InvalidToken:
85 | self.password = input("Wrong password! Enter password again: ")
86 | else:
87 | return secret_config
88 |
89 |
90 | def encode_from_utf8(text: str) -> str:
91 | # Check if text contains any non-ASCII characters
92 | matches: list[Any] = findall(r"[^\x00-\x7F]+", text)
93 | if not matches:
94 | # Return text if it doesn't contain any non-ASCII characters
95 | return text
96 | else:
97 | # Encode and replace non-ASCII characters with UTF-8 formatted text
98 | for match in matches:
99 | encoded_text: str = b64encode(match.encode("utf-8")).decode("utf-8")
100 | text = text.replace(match, "=?UTF-8?B?" + encoded_text + "?=")
101 | return text
102 |
--------------------------------------------------------------------------------
/app/utils/function_calling/callbacks/lite_browsing.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi.concurrency import run_in_threadpool
4 |
5 | from app.common.lotties import Lotties
6 | from app.models.function_calling.functions import FunctionCalls
7 | from app.shared import Shared
8 | from app.utils.chat.buffer import BufferedUserContext
9 | from app.utils.chat.managers.websocket import SendToWebsocket
10 | from app.utils.logger import ApiLogger
11 |
12 | from ..query import aget_query_to_search
13 |
14 |
15 | async def lite_web_browsing_callback(
16 | buffer: BufferedUserContext,
17 | query: str,
18 | finish: bool,
19 | wait_next_query: bool,
20 | show_result: bool = False,
21 | ) -> Optional[str]:
22 | await SendToWebsocket.message(
23 | websocket=buffer.websocket,
24 | msg=Lotties.SEARCH_WEB.format("### Browsing web", end=False),
25 | chat_room_id=buffer.current_chat_room_id,
26 | finish=False,
27 | )
28 | try:
29 | query_to_search: str = await aget_query_to_search(
30 | buffer=buffer,
31 | query=query,
32 | function=FunctionCalls.get_function_call(FunctionCalls.web_search),
33 | )
34 | r = await run_in_threadpool(Shared().duckduckgo.run, query_to_search)
35 | # ApiLogger("||lite_web_browsing_chain||").info(r)
36 |
37 | await SendToWebsocket.message(
38 | websocket=buffer.websocket,
39 | msg=Lotties.OK.format("### Finished browsing")
40 | + (r if show_result else ""),
41 | chat_room_id=buffer.current_chat_room_id,
42 | finish=finish,
43 | wait_next_query=wait_next_query,
44 | )
45 | return r
46 | except Exception as e:
47 | ApiLogger("||lite_web_browsing_chain||").exception(e)
48 | await SendToWebsocket.message(
49 | msg=Lotties.FAIL.format("### Failed to browse web"),
50 | websocket=buffer.websocket,
51 | chat_room_id=buffer.current_chat_room_id,
52 | finish=finish,
53 | wait_next_query=wait_next_query,
54 | )
55 |
--------------------------------------------------------------------------------
/app/utils/function_calling/callbacks/translate.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from app.common.lotties import Lotties
4 | from app.utils.api.translate import Translator
5 | from app.utils.chat.buffer import BufferedUserContext
6 | from app.utils.chat.managers.websocket import SendToWebsocket
7 |
8 |
9 | async def translate_callback(
10 | buffer: BufferedUserContext,
11 | query: str,
12 | finish: bool,
13 | wait_next_query: Optional[bool],
14 | show_result: bool = True,
15 | show_result_prefix: Optional[str] = " # 🌐 Translation Result\n---\n\n",
16 | src_lang: str = "en",
17 | trg_lang: str = "en",
18 | ) -> Optional[str]:
19 | await SendToWebsocket.message(
20 | msg=Lotties.TRANSLATE.format("### Translating"),
21 | websocket=buffer.websocket,
22 | chat_room_id=buffer.current_chat_room_id,
23 | finish=False,
24 | )
25 | try:
26 | r = await Translator.translate(text=query, src_lang=src_lang, trg_lang=trg_lang)
27 | r_show = show_result_prefix + r if show_result_prefix is not None else r
28 | await SendToWebsocket.message(
29 | msg=Lotties.OK.format("### Finished translation")
30 | + (r_show if show_result else ""),
31 | websocket=buffer.websocket,
32 | chat_room_id=buffer.current_chat_room_id,
33 | finish=finish,
34 | wait_next_query=wait_next_query,
35 | )
36 | return r
37 | except Exception:
38 | await SendToWebsocket.message(
39 | msg=Lotties.FAIL.format("### Failed translation"),
40 | websocket=buffer.websocket,
41 | chat_room_id=buffer.current_chat_room_id,
42 | finish=finish,
43 | wait_next_query=wait_next_query,
44 | )
45 | return None
46 |
--------------------------------------------------------------------------------
/app/utils/function_calling/callbacks/vectorstore_search.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | from app.common.config import ChatConfig, config
4 | from app.common.lotties import Lotties
5 | from app.utils.chat.buffer import BufferedUserContext
6 | from app.utils.chat.managers.vectorstore import Document, VectorStoreManager
7 | from app.utils.chat.managers.websocket import SendToWebsocket
8 | from app.utils.logger import ApiLogger
9 |
10 |
11 | async def vectorstore_search_callback(
12 | buffer: BufferedUserContext,
13 | query_to_search: str,
14 | finish: bool,
15 | wait_next_query: bool,
16 | show_result: bool = True,
17 | k: int = ChatConfig.vectorstore_n_results_limit,
18 | ) -> Optional[str]:
19 | try:
20 | found_text_and_score: list[
21 | Tuple[Document, float]
22 | ] = await VectorStoreManager.asimilarity_search_multiple_collections_with_score(
23 | query=query_to_search,
24 | collection_names=[buffer.user_id, config.shared_vectorestore_name],
25 | k=k,
26 | )
27 | if not found_text_and_score:
28 | raise Exception("No result found")
29 |
30 | found_text: Optional[str] = "\n\n".join(
31 | [document.page_content for document, _ in found_text_and_score]
32 | )
33 | send_message = "### Finished searching vectorstore"
34 | if show_result:
35 | send_message += f'\n---\n{found_text}'
36 | await SendToWebsocket.message(
37 | websocket=buffer.websocket,
38 | msg=Lotties.OK.format(send_message),
39 | chat_room_id=buffer.current_chat_room_id,
40 | finish=finish,
41 | wait_next_query=wait_next_query,
42 | )
43 | return found_text
44 |
45 | except Exception as e:
46 | ApiLogger("||vectorstore_search_callback||").exception(e)
47 | await SendToWebsocket.message(
48 | msg=Lotties.FAIL.format("### Failed to search vectorstore"),
49 | websocket=buffer.websocket,
50 | chat_room_id=buffer.current_chat_room_id,
51 | finish=finish,
52 | wait_next_query=wait_next_query,
53 | )
54 | return None
55 |
--------------------------------------------------------------------------------
/app/utils/function_calling/parser.py:
--------------------------------------------------------------------------------
1 | import json
2 | from inspect import signature
3 | from re import compile
4 | from typing import (
5 | Annotated,
6 | Any,
7 | Callable,
8 | Iterable,
9 | get_args,
10 | get_origin,
11 | )
12 |
13 | from app.models.completion_models import (
14 | FunctionCallParsed,
15 | FunctionCallUnparsed,
16 | )
17 | from app.models.function_calling.base import (
18 | FunctionCall,
19 | FunctionCallParameter,
20 | JsonTypes,
21 | )
22 | from app.utils.types import get_type_and_optional
23 |
24 |
25 | def make_function_call_parsed_from_dict(
26 | unparsed_dict: dict[str, JsonTypes] | FunctionCallParsed | FunctionCallUnparsed,
27 | ) -> FunctionCallParsed:
28 | """
29 | Parse function call response from API into a FunctionCallParsed object.
30 | This is a helper method to identify what function will be called and what
31 | arguments will be passed to it."""
32 |
33 | if "name" not in unparsed_dict:
34 | raise ValueError("Function call name is required.")
35 |
36 | function_call_parsed: FunctionCallParsed = FunctionCallParsed(
37 | name=str(unparsed_dict["name"])
38 | )
39 | arguments = unparsed_dict.get("arguments", {})
40 | if isinstance(arguments, dict):
41 | function_call_parsed["arguments"] = arguments
42 | return function_call_parsed
43 | else:
44 | try:
45 | function_call_parsed["arguments"] = json.loads(str(arguments))
46 | except json.JSONDecodeError:
47 | pass
48 | return function_call_parsed
49 |
50 |
51 | def parse_function_call_from_function(func: Callable) -> FunctionCall:
52 | """
53 | Parse a function into a FunctionCall object.
54 | FunctionCall objects are used to represent the specification of a function
55 | """
56 | json_types = get_args(JsonTypes)
57 | function_call_params: list[FunctionCallParameter] = []
58 | required: list[str] = []
59 | for name, param in signature(func).parameters.items():
60 | annotation = param.annotation
61 | description: str = ""
62 | enum: list[Any] = []
63 |
64 | if get_origin(annotation) is Annotated:
65 | # If the annotation is an Annotated type,
66 | # we need to parse the metadata
67 | _param_args = get_args(param.annotation)
68 | _param_type = _param_args[0]
69 |
70 | for metadata in _param_args[1:]:
71 | if isinstance(metadata, str):
72 | # If the metadata is a string, it's the description
73 | description += metadata
74 | elif isinstance(metadata, Iterable):
75 | # If the metadata is an iterable, it's the enum
76 | enum.extend(metadata)
77 |
78 | else:
79 | _param_type = annotation
80 | param_type, optional = get_type_and_optional(_param_type)
81 | if not optional:
82 | required.append(name)
83 | if param_type not in json_types:
84 | continue
85 | function_call_params.append(
86 | FunctionCallParameter(
87 | name=name,
88 | type=param_type,
89 | description=description or None,
90 | enum=enum or None,
91 | )
92 | )
93 | line_break_pattern = compile(r"\n\s*")
94 | return FunctionCall(
95 | name=func.__name__,
96 | description=line_break_pattern.sub(" ", func.__doc__) if func.__doc__ else None,
97 | parameters=function_call_params,
98 | required=required or None,
99 | )
100 |
--------------------------------------------------------------------------------
/app/utils/function_calling/query.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | from langchain import PromptTemplate
4 | from orjson import loads as orjson_loads
5 |
6 | from app.common.config import ChatConfig
7 | from app.common.constants import JSON_PATTERN
8 | from app.models.chat_models import ChatRoles, MessageHistory
9 | from app.models.function_calling.base import FunctionCall
10 | from app.shared import Shared
11 | from app.utils.chat.buffer import BufferedUserContext
12 | from app.utils.chat.messages.converter import (
13 | chat_completion_api_parse_method,
14 | message_histories_to_list,
15 | )
16 | from app.utils.chat.tokens import cutoff_message_histories
17 |
18 | from .request import request_function_call
19 |
20 |
21 | async def aget_query_to_search(
22 | buffer: BufferedUserContext,
23 | query: str,
24 | function: FunctionCall,
25 | timeout: Optional[float] = 30,
26 | ) -> str:
27 | """Get query to search from user query and current context"""
28 | (
29 | user_message_histories,
30 | ai_message_histories,
31 | system_message_histories,
32 | ) = cutoff_message_histories(
33 | user_chat_context=buffer.current_user_chat_context,
34 | ai_message_histories=buffer.current_ai_message_histories,
35 | user_message_histories=buffer.current_user_message_histories
36 | + [
37 | MessageHistory(
38 | role=buffer.current_user_chat_roles.user,
39 | content=query,
40 | tokens=buffer.current_user_chat_context.get_tokens_of(query)
41 | + buffer.current_llm_model.value.token_margin,
42 | actual_role=ChatRoles.USER.value,
43 | )
44 | ],
45 | system_message_histories=[],
46 | token_limit=ChatConfig.query_context_token_limit,
47 | )
48 | try:
49 | function_call_parsed = await request_function_call(
50 | messages=message_histories_to_list(
51 | parse_method=chat_completion_api_parse_method,
52 | user_message_histories=user_message_histories,
53 | ai_message_histories=ai_message_histories,
54 | system_message_histories=system_message_histories,
55 | ),
56 | functions=[function],
57 | function_call=function,
58 | timeout=timeout,
59 | )
60 | if "arguments" not in function_call_parsed:
61 | raise ValueError("No arguments returned")
62 | return str(function_call_parsed["arguments"]["query_to_search"])
63 | except Exception:
64 | return query
65 |
66 |
67 | async def aget_json(
68 | query_template: PromptTemplate, **kwargs_to_format: str
69 | ) -> Optional[Any]:
70 | """Get json from query template and kwargs to format"""
71 | try:
72 | json_query = JSON_PATTERN.search(
73 | await Shared().llm.apredict(query_template.format(**kwargs_to_format))
74 | )
75 | if json_query is None:
76 | raise ValueError("Result is None")
77 | else:
78 | return orjson_loads(json_query.group())
79 | except Exception:
80 | return None
81 |
--------------------------------------------------------------------------------
/app/utils/function_calling/request.py:
--------------------------------------------------------------------------------
1 | from asyncio import wait_for
2 | from typing import Any, Literal, Optional
3 |
4 | from app.common.config import OPENAI_API_KEY, ChatConfig
5 | from app.models.completion_models import FunctionCallParsed
6 | from app.models.function_calling.base import FunctionCall
7 | from app.utils.api.completion import request_chat_completion
8 |
9 | from .parser import make_function_call_parsed_from_dict
10 |
11 |
12 | async def request_function_call(
13 | messages: list[dict[str, str]],
14 | functions: list[FunctionCall],
15 | function_call: Optional[FunctionCall | Literal["auto", "none"]] = "auto",
16 | model: str = ChatConfig.global_openai_model,
17 | api_base: str = "https://api.openai.com/v1",
18 | api_key: Optional[str] = OPENAI_API_KEY,
19 | timeout: Optional[float] = None,
20 | force_arguments: bool = False,
21 | **kwargs: Any,
22 | ) -> FunctionCallParsed:
23 | coro = request_chat_completion(
24 | messages=messages,
25 | model=model,
26 | api_base=api_base,
27 | api_key=api_key,
28 | functions=functions,
29 | function_call=function_call,
30 | **kwargs,
31 | )
32 | if timeout is not None:
33 | coro = wait_for(coro, timeout=timeout)
34 | function_call_unparsed = (await coro)["choices"][0]["message"].get(
35 | "function_call"
36 | )
37 | if function_call_unparsed is None:
38 | raise ValueError("No function call returned")
39 | function_call_parsed = make_function_call_parsed_from_dict(
40 | function_call_unparsed
41 | )
42 | if force_arguments and "arguments" not in function_call_parsed:
43 | raise ValueError("No arguments returned")
44 |
45 | return function_call_parsed
46 |
--------------------------------------------------------------------------------
/app/utils/function_calling/token_count.py:
--------------------------------------------------------------------------------
1 | from tiktoken import encoding_for_model, get_encoding
2 |
3 | from app.models.function_calling.base import FunctionCall
4 |
5 |
6 | def get_num_tokens_from_functions(
7 | functions: list[FunctionCall], model: str = "gpt-3.5-turbo"
8 | ) -> int:
9 | """Return the number of tokens used by a list of functions."""
10 | try:
11 | encoding = encoding_for_model(model)
12 | except KeyError:
13 | print("Warning: model not found. Using cl100k_base encoding.")
14 | encoding = get_encoding("cl100k_base")
15 |
16 | num_tokens = 0
17 | for _function in functions:
18 | function = _function.to_dict()
19 | function_tokens = len(encoding.encode(function["name"]))
20 | if "description" in function:
21 | function_tokens += len(encoding.encode(function["description"]))
22 |
23 | if "parameters" in function:
24 | parameters = function["parameters"]
25 | if "properties" in parameters:
26 | for name, property in parameters["properties"].items():
27 | function_tokens += len(encoding.encode(name))
28 | for field in property:
29 | if field == "type":
30 | function_tokens += 2
31 | function_tokens += len(
32 | encoding.encode(property["type"])
33 | )
34 | elif (
35 | field == "description"
36 | and "description" in property
37 | ):
38 | function_tokens += 2
39 | function_tokens += len(
40 | encoding.encode(property["description"])
41 | )
42 | elif field == "enum" and "enum" in property:
43 | function_tokens -= 3
44 | for o in property["enum"]:
45 | function_tokens += 3
46 | function_tokens += len(encoding.encode(str(o)))
47 | else:
48 | print(f"Warning: not supported field {field}")
49 | function_tokens += 11
50 |
51 | num_tokens += function_tokens
52 |
53 | num_tokens += 12
54 | return num_tokens
55 |
--------------------------------------------------------------------------------
/app/utils/js_initializer.py:
--------------------------------------------------------------------------------
1 | import re
2 | from app.common.config import config, ProdConfig
3 |
4 |
5 | def js_url_initializer(js_location: str = "app/web/main.dart.js") -> None:
6 | with open(file=js_location, mode="r") as file:
7 | filedata = file.read()
8 |
9 | for schema in ("http", "ws"):
10 | if isinstance(config, ProdConfig):
11 | from_url_local = rf"{schema}://localhost:\d+"
12 | to_url_prod = f"{schema}s://{config.host_main}"
13 | filedata = re.sub(
14 | from_url_local,
15 | to_url_prod,
16 | filedata,
17 | )
18 | else:
19 | from_url_local = rf"{schema}://localhost:\d+"
20 | from_url_prod = rf"{schema}://{config.host_main}"
21 | to_url_local = f"{schema}://localhost:{config.port}"
22 | filedata = re.sub(
23 | from_url_local,
24 | to_url_local,
25 | filedata,
26 | )
27 | filedata = re.sub(
28 | from_url_prod,
29 | to_url_local,
30 | filedata,
31 | )
32 |
33 | with open(file=js_location, mode="w") as file:
34 | file.write(filedata)
35 |
36 |
37 | if __name__ == "__main__":
38 | js_url_initializer()
39 |
--------------------------------------------------------------------------------
/app/utils/langchain/chat_llama_cpp.py:
--------------------------------------------------------------------------------
1 | from langchain import LlamaCpp
2 | from pydantic import Field, root_validator
3 |
4 |
5 | class CustomLlamaCpp(LlamaCpp):
6 | low_vram: bool = Field(False, alias="low_vram")
7 |
8 | @root_validator()
9 | def validate_environment(cls, values: dict) -> dict:
10 | """Validate that llama-cpp-python library is installed."""
11 | model_path = values["model_path"]
12 | model_param_names = [
13 | "lora_path",
14 | "lora_base",
15 | "n_ctx",
16 | "n_parts",
17 | "seed",
18 | "f16_kv",
19 | "logits_all",
20 | "vocab_only",
21 | "use_mlock",
22 | "n_threads",
23 | "n_batch",
24 | "use_mmap",
25 | "last_n_tokens_size",
26 | ]
27 | model_params = {k: values[k] for k in model_param_names}
28 | # For backwards compatibility, only include if non-null.
29 | if values["n_gpu_layers"] is not None:
30 | model_params["n_gpu_layers"] = values["n_gpu_layers"]
31 | if values["low_vram"] is not None:
32 | model_params["low_vram"] = values["low_vram"]
33 |
34 | try:
35 | from llama_cpp import Llama
36 |
37 | values["client"] = Llama(model_path, **model_params)
38 | except ImportError:
39 | raise ModuleNotFoundError(
40 | "Could not import llama-cpp-python library. "
41 | "Please install the llama-cpp-python library to "
42 | "use this embedding model: pip install llama-cpp-python"
43 | )
44 | except Exception as e:
45 | raise ValueError(
46 | f"Could not load Llama model from path: {model_path}. "
47 | f"Received error {e}"
48 | )
49 |
50 | return values
51 |
--------------------------------------------------------------------------------
/app/utils/langchain/embeddings_api.py:
--------------------------------------------------------------------------------
1 | """Wrapper around embedding API models."""
2 | from __future__ import annotations
3 |
4 | import logging
5 | from typing import (
6 | Any,
7 | List,
8 | Optional,
9 | Tuple,
10 | Union,
11 | )
12 |
13 | import requests
14 |
15 | from pydantic import BaseModel, Extra
16 | from langchain.embeddings.base import Embeddings
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class APIEmbeddings(BaseModel, Embeddings):
22 | """Wrapper around embedding models from OpenAI-style API"""
23 |
24 | client: Any #: :meta private:
25 | model: str = "intfloat/e5-large-v2"
26 | embedding_api_url: str = "http://localhost:8002/v1/embeddings"
27 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None
28 | """Timeout in seconds for the API request."""
29 | headers: Any = None
30 |
31 | class Config:
32 | """Configuration for this pydantic object."""
33 |
34 | extra = Extra.forbid
35 |
36 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
37 | """Call out to embedding endpoint for embedding search docs.
38 |
39 | Args:
40 | texts: The list of texts to embed.
41 |
42 | Returns:
43 | List of embeddings, one for each text.
44 | """
45 | # NOTE: to keep things simple, we assume the list may contain texts longer
46 | # than the maximum context and use length-safe embedding function.
47 | response = requests.post(
48 | self.embedding_api_url,
49 | json={
50 | "model": self.model,
51 | "input": texts,
52 | },
53 | )
54 |
55 | # Check if the request was successful
56 | if response.status_code == 200:
57 | # Parse the response
58 | response_data = response.json()
59 |
60 | # Extract the embeddings and total tokens
61 | embeddings: list[list[float]] = [
62 | data["embedding"] for data in response_data["data"]
63 | ]
64 | total_tokens = response_data["usage"]["total_tokens"]
65 | return embeddings
66 | raise ConnectionError(
67 | f"Request to {self.embedding_api_url} failed with status code "
68 | f"{response.status_code}."
69 | )
70 |
71 | def embed_query(self, text: str) -> List[float]:
72 | """Call out to OpenAI's embedding endpoint for embedding query text.
73 |
74 | Args:
75 | text: The text to embed.
76 |
77 | Returns:
78 | Embedding for the text.
79 | """
80 | response = requests.post(
81 | self.embedding_api_url,
82 | json={
83 | "model": self.model,
84 | "input": text,
85 | },
86 | )
87 |
88 | # Check if the request was successful
89 | if response.status_code == 200:
90 | # Parse the response
91 | response_data = response.json()
92 |
93 | # Extract the embeddings and total tokens
94 | embeddings: list[list[float]] = [
95 | data["embedding"] for data in response_data["data"]
96 | ]
97 | total_tokens = response_data["usage"]["total_tokens"]
98 | return embeddings[0]
99 | raise ConnectionError(
100 | f"Request to {self.embedding_api_url} failed with status code "
101 | f"{response.status_code}."
102 | )
103 |
104 |
105 | if __name__ == "__main__":
106 | import argparse
107 | import json
108 |
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument("--model", type=str, default="intfloat/e5-large-v2")
111 | parser.add_argument(
112 | "--embedding_api_url", type=str, default="http://localhost:8002/v1/embeddings"
113 | )
114 | parser.add_argument("--request_timeout", type=float, default=None)
115 | parser.add_argument("--headers", type=str, default=None)
116 | parser.add_argument("--text", type=str, default="Hello, world!")
117 | args = parser.parse_args()
118 |
119 | print(args)
120 |
121 | # Create the API embeddings model
122 | api_embeddings = APIEmbeddings(
123 | client=None,
124 | model=args.model,
125 | embedding_api_url=args.embedding_api_url,
126 | request_timeout=args.request_timeout,
127 | headers=args.headers,
128 | )
129 |
130 | # Embed the query
131 | query_embedding = api_embeddings.embed_query(args.text)
132 |
133 | # Print the query embedding
134 | print(json.dumps(query_embedding))
135 |
--------------------------------------------------------------------------------
/app/utils/langchain/structured_tool.py:
--------------------------------------------------------------------------------
1 | """Base implementation for tools or skills."""
2 | from __future__ import annotations
3 |
4 | from inspect import signature, iscoroutinefunction
5 | from typing import Any, Awaitable, Callable, Optional, Type
6 |
7 | from langchain.callbacks.manager import (
8 | AsyncCallbackManagerForToolRun,
9 | CallbackManagerForToolRun,
10 | )
11 | from langchain.tools.base import BaseTool, create_schema_from_function
12 | from pydantic import (
13 | BaseModel,
14 | Field,
15 | )
16 |
17 |
18 | class StructuredTool(BaseTool):
19 | """Tool that can operate on any number of inputs."""
20 |
21 | description: str = ""
22 | args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
23 | """The input arguments' schema."""
24 | func: Optional[Callable[..., Any]] = None
25 | """The function to run when the tool is called."""
26 | coroutine: Optional[Callable[..., Awaitable[Any]]] = None
27 | """The asynchronous version of the function."""
28 |
29 | @property
30 | def args(self) -> dict:
31 | """The tool's input arguments."""
32 | return self.args_schema.schema()["properties"]
33 |
34 | def _run(
35 | self,
36 | *args: Any,
37 | run_manager: Optional[CallbackManagerForToolRun] = None,
38 | **kwargs: Any,
39 | ) -> Any:
40 | """Use the tool."""
41 | if self.func:
42 | new_argument_supported = signature(self.func).parameters.get("callbacks")
43 | return (
44 | self.func(
45 | *args,
46 | callbacks=run_manager.get_child() if run_manager else None,
47 | **kwargs,
48 | )
49 | if new_argument_supported
50 | else self.func(*args, **kwargs)
51 | )
52 | raise NotImplementedError("Tool does not support sync")
53 |
54 | async def _arun(
55 | self,
56 | *args: Any,
57 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
58 | **kwargs: Any,
59 | ) -> str:
60 | """Use the tool asynchronously."""
61 | if self.coroutine:
62 | new_argument_supported = signature(self.coroutine).parameters.get(
63 | "callbacks"
64 | )
65 | return (
66 | await self.coroutine(
67 | *args,
68 | callbacks=run_manager.get_child() if run_manager else None,
69 | **kwargs,
70 | )
71 | if new_argument_supported
72 | else await self.coroutine(*args, **kwargs)
73 | )
74 | return self._run(*args, **kwargs)
75 |
76 | @classmethod
77 | def from_function(
78 | cls,
79 | func: Callable,
80 | name: Optional[str] = None,
81 | description: Optional[str] = None,
82 | return_direct: bool = False,
83 | args_schema: Optional[Type[BaseModel]] = None,
84 | infer_schema: bool = True,
85 | **kwargs: Any,
86 | ) -> StructuredTool:
87 | name = name or func.__name__
88 | description = description or func.__doc__
89 | assert (
90 | description is not None
91 | ), "Function must have a docstring if description not provided."
92 |
93 | # Description example:
94 | # search_api(query: str) - Searches the API for the query.
95 | description = f"{name}{signature(func)} - {description.strip()}"
96 | _args_schema = args_schema
97 | if _args_schema is None and infer_schema:
98 | _args_schema = create_schema_from_function(f"{name}Schema", func)
99 | is_coroutine_function = iscoroutinefunction(func)
100 | return cls(
101 | name=name,
102 | func=func if not is_coroutine_function else None,
103 | coroutine=func if is_coroutine_function else None,
104 | args_schema=_args_schema,
105 | description=description,
106 | return_direct=return_direct,
107 | **kwargs,
108 | )
109 |
--------------------------------------------------------------------------------
/app/utils/langchain/token_text_splitter.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import (
3 | AbstractSet,
4 | Any,
5 | Collection,
6 | Iterable,
7 | List,
8 | Literal,
9 | Optional,
10 | Sequence,
11 | Union,
12 | )
13 |
14 | from langchain.docstore.document import Document
15 | from langchain.text_splitter import Tokenizer, TokenTextSplitter, split_text_on_tokens
16 |
17 |
18 | class CustomTokenTextSplitter(TokenTextSplitter):
19 | """Implementation of splitting text that looks at tokens."""
20 |
21 | def __init__(
22 | self,
23 | encoding_name: str = "gpt2",
24 | model_name: Optional[str] = None,
25 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
26 | disallowed_special: Union[Literal["all"], Collection[str]] = "all",
27 | **kwargs: Any,
28 | ):
29 | super().__init__(
30 | encoding_name=encoding_name,
31 | model_name=model_name,
32 | allowed_special=allowed_special,
33 | disallowed_special=disallowed_special,
34 | **kwargs,
35 | )
36 |
37 | def split_text(
38 | self,
39 | text: str,
40 | tokens_per_chunk: Optional[int] = None,
41 | chunk_overlap: Optional[int] = None,
42 | ) -> List[str]:
43 | def _encode(_text: str) -> List[int]:
44 | return self._tokenizer.encode(
45 | _text,
46 | allowed_special=self._allowed_special, # type: ignore
47 | disallowed_special=self._disallowed_special,
48 | )
49 |
50 | tokenizer = Tokenizer(
51 | chunk_overlap=self._chunk_overlap
52 | if chunk_overlap is None
53 | else chunk_overlap,
54 | tokens_per_chunk=self._chunk_size
55 | if tokens_per_chunk is None
56 | else tokens_per_chunk,
57 | decode=self._tokenizer.decode,
58 | encode=_encode,
59 | )
60 |
61 | return split_text_on_tokens(text=text, tokenizer=tokenizer)
62 |
63 | def create_documents(
64 | self,
65 | texts: List[str],
66 | metadatas: Optional[List[dict]] = None,
67 | tokens_per_chunk: Optional[int] = None,
68 | chunk_overlap: Optional[int] = None,
69 | ) -> List[Document]:
70 | """Create documents from a list of texts."""
71 | _metadatas = metadatas or [{}] * len(texts)
72 | documents = []
73 | for i, text in enumerate(texts):
74 | index = -1
75 | for chunk in self.split_text(
76 | text,
77 | tokens_per_chunk=tokens_per_chunk,
78 | chunk_overlap=chunk_overlap,
79 | ):
80 | metadata = deepcopy(_metadatas[i])
81 | if self._add_start_index:
82 | index = text.find(chunk, index + 1)
83 | metadata["start_index"] = index
84 | new_doc = Document(page_content=chunk, metadata=metadata)
85 | documents.append(new_doc)
86 | return documents
87 |
88 | def split_documents(
89 | self,
90 | documents: Iterable[Document],
91 | tokens_per_chunk: Optional[int] = None,
92 | chunk_overlap: Optional[int] = None,
93 | ) -> List[Document]:
94 | """Split documents."""
95 | texts, metadatas = [], []
96 | for doc in documents:
97 | texts.append(doc.page_content)
98 | metadatas.append(doc.metadata)
99 | return self.create_documents(
100 | texts,
101 | metadatas=metadatas,
102 | tokens_per_chunk=tokens_per_chunk,
103 | chunk_overlap=chunk_overlap,
104 | )
105 |
106 | def transform_documents(
107 | self,
108 | documents: Sequence[Document],
109 | tokens_per_chunk: Optional[int] = None,
110 | chunk_overlap: Optional[int] = None,
111 | **kwargs: Any,
112 | ) -> Sequence[Document]:
113 | """Transform sequence of documents by splitting them."""
114 | return self.split_documents(
115 | list(documents),
116 | tokens_per_chunk=tokens_per_chunk,
117 | chunk_overlap=chunk_overlap,
118 | )
119 |
--------------------------------------------------------------------------------
/app/utils/module_reloader.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from types import ModuleType
3 |
4 |
5 | class ModuleReloader:
6 | def __init__(self, module_name: str):
7 | self._module_name = module_name
8 |
9 | def reload(self) -> ModuleType:
10 | importlib.reload(self.module)
11 | return self.module
12 |
13 | @property
14 | def module(self) -> ModuleType:
15 | return importlib.import_module(self._module_name)
16 |
--------------------------------------------------------------------------------
/app/utils/params_utils.py:
--------------------------------------------------------------------------------
1 | from hmac import HMAC, new
2 | from base64 import b64encode
3 |
4 |
5 | def parse_params(params: dict) -> str:
6 | return "&".join([f"{key}={value}" for key, value in params.items()])
7 |
8 |
9 | def hash_params(query_params: str, secret_key: str) -> str:
10 | mac: HMAC = new(
11 | key=bytes(secret_key, encoding="utf-8"),
12 | msg=bytes(query_params, encoding="utf-8"),
13 | digestmod="sha256",
14 | )
15 | return str(b64encode(mac.digest()).decode("utf-8"))
16 |
--------------------------------------------------------------------------------
/app/utils/tests/random.py:
--------------------------------------------------------------------------------
1 | from uuid import uuid4
2 |
3 |
4 | def random_user_generator(**kwargs):
5 | random_8_digits = str(hash(uuid4()))[:8]
6 | return {
7 | "email": f"{random_8_digits}@test.com",
8 | "password": "123456",
9 | "name": f"{random_8_digits}",
10 | "phone_number": f"010{random_8_digits}",
11 | } | kwargs
12 |
--------------------------------------------------------------------------------
/app/utils/tests/timeit.py:
--------------------------------------------------------------------------------
1 | from asyncio import iscoroutinefunction
2 | from functools import wraps
3 | from time import time
4 | from typing import Any, AsyncGenerator, Callable, Generator
5 |
6 | from app.common.config import logging_config
7 | from app.utils.logger import CustomLogger
8 |
9 | logger = CustomLogger(
10 | name=__name__,
11 | logging_config=logging_config,
12 | )
13 |
14 |
15 | def log_time(fn: Any, elapsed: float) -> None:
16 | formatted_time = f"{int(elapsed // 60):02d}:{int(elapsed % 60):02d}:{int((elapsed - int(elapsed)) * 1000000):06d}"
17 | logger.info(f"Function {fn.__name__} execution time: {formatted_time}")
18 |
19 |
20 | def timeit(fn: Callable) -> Callable:
21 | """
22 | Decorator to measure execution time of a function, or generator.
23 | Supports both synchronous and asynchronous functions and generators.
24 | :param fn: function to measure execution time of
25 | :return: function wrapper
26 |
27 | Usage:
28 | >>> @timeit
29 | >>> def my_function():
30 | >>> ...
31 | """
32 |
33 | @wraps(fn)
34 | async def coroutine_function_wrapper(*args: Any, **kwargs: Any) -> Any:
35 | start_time: float = time()
36 | fn_result = await fn(*args, **kwargs)
37 | elapsed: float = time() - start_time
38 | log_time(fn, elapsed)
39 | return fn_result
40 |
41 | @wraps(fn)
42 | def noncoroutine_function_wrapper(*args: Any, **kwargs: Any) -> Any:
43 | def sync_generator_wrapper() -> Generator:
44 | while True:
45 | try:
46 | start_time: float = time()
47 | item: Any = next(fn_result)
48 | elapsed: float = time() - start_time
49 | log_time(fn_result, elapsed)
50 | yield item
51 | except StopIteration:
52 | break
53 |
54 | async def async_generator_wrapper() -> AsyncGenerator:
55 | while True:
56 | try:
57 | start_time: float = time()
58 | item: Any = await anext(fn_result)
59 | elapsed: float = time() - start_time
60 | log_time(fn_result, elapsed)
61 | yield item
62 | except StopAsyncIteration:
63 | break
64 |
65 | start_time: float = time()
66 | fn_result: Any = fn(*args, **kwargs)
67 | elapsed: float = time() - start_time
68 | if isinstance(fn_result, Generator):
69 | return sync_generator_wrapper()
70 | elif isinstance(fn_result, AsyncGenerator):
71 | return async_generator_wrapper()
72 | log_time(fn, elapsed)
73 | return fn_result
74 |
75 | if iscoroutinefunction(fn):
76 | return coroutine_function_wrapper
77 | return noncoroutine_function_wrapper
78 |
--------------------------------------------------------------------------------
/app/viewmodels/admin.py:
--------------------------------------------------------------------------------
1 | from starlette_admin.fields import (
2 | EnumField,
3 | BooleanField,
4 | EmailField,
5 | PhoneField,
6 | StringField,
7 | NumberField,
8 | DateTimeField,
9 | PasswordField,
10 | )
11 | from starlette_admin.contrib.sqla.view import ModelView
12 |
13 | from app.database.schemas.auth import Users
14 |
15 | USER_STATUS_TYPES = [
16 | ("admin", "Admin"),
17 | ("active", "Active"),
18 | ("deleted", "Deleted"),
19 | ("blocked", "Blocked"),
20 | ]
21 | API_KEY_STATUS_TYPES = [
22 | ("active", "Active"),
23 | ("deleted", "Deleted"),
24 | ("stopped", "Stopped"),
25 | ]
26 |
27 |
28 |
29 |
30 |
31 | class UserAdminView(ModelView):
32 | fields = [
33 | EnumField("status", choices=USER_STATUS_TYPES, select2=False),
34 | NumberField("id"),
35 | EmailField("email"),
36 | PasswordField("password"),
37 | StringField("name"),
38 | PhoneField("phone_number"),
39 | BooleanField("marketing_agree"),
40 | DateTimeField("created_at"),
41 | DateTimeField("updated_at"),
42 | StringField("ip_address"),
43 | "api_keys",
44 | ]
45 | fields_default_sort = [Users.id, ("id", False)]
46 |
47 |
48 | class ApiKeyAdminView(ModelView):
49 | fields = [
50 | EnumField("status", choices=API_KEY_STATUS_TYPES, select2=False),
51 | NumberField("id"),
52 | StringField("access_key"),
53 | StringField("user_memo"),
54 | BooleanField("is_whitelisted"),
55 | DateTimeField("created_at"),
56 | DateTimeField("updated_at"),
57 | "users",
58 | "whitelists",
59 | ]
60 | fields_default_sort = [Users.id, ("id", False)]
61 |
--------------------------------------------------------------------------------
/app/viewmodels/status.py:
--------------------------------------------------------------------------------
1 | import enum
2 |
3 |
4 | class UserStatus(str, enum.Enum):
5 | admin = "admin"
6 | active = "active"
7 | deleted = "deleted"
8 | blocked = "blocked"
9 |
10 |
11 | class ApiKeyStatus(str, enum.Enum):
12 | active = "active"
13 | stopped = "stopped"
14 | deleted = "deleted"
15 |
--------------------------------------------------------------------------------
/app/web/.last_build_id:
--------------------------------------------------------------------------------
1 | 6ab7adef156d180eb3fbf8ed0392faf8
--------------------------------------------------------------------------------
/app/web/assets/AssetManifest.bin:
--------------------------------------------------------------------------------
1 |
assets/images/ai_profile.png
assetassets/images/ai_profile.pngassets/images/favicon.ico
assetassets/images/favicon.ico assets/images/openai_profile.png
asset assets/images/openai_profile.pngassets/images/site_logo.png
assetassets/images/site_logo.pngassets/images/user_profile.png
assetassets/images/user_profile.png assets/images/vicuna_profile.jpg
asset assets/images/vicuna_profile.jpgassets/lotties/click.json
assetassets/lotties/click.jsonassets/lotties/fail.json
assetassets/lotties/fail.jsonassets/lotties/file-upload.json
assetassets/lotties/file-upload.jsonassets/lotties/go-back.json
assetassets/lotties/go-back.jsonassets/lotties/ok.json
assetassets/lotties/ok.jsonassets/lotties/read.json
assetassets/lotties/read.jsonassets/lotties/scroll-down.json
assetassets/lotties/scroll-down.jsonassets/lotties/search-doc.json
assetassets/lotties/search-doc.jsonassets/lotties/search-web.json
assetassets/lotties/search-web.jsonassets/lotties/translate.json
assetassets/lotties/translate.jsonassets/svgs/search-web.svg
assetassets/svgs/search-web.svg2packages/cupertino_icons/assets/CupertinoIcons.ttf
asset2packages/cupertino_icons/assets/CupertinoIcons.ttf
--------------------------------------------------------------------------------
/app/web/assets/AssetManifest.json:
--------------------------------------------------------------------------------
1 | {"assets/images/ai_profile.png":["assets/images/ai_profile.png"],"assets/images/favicon.ico":["assets/images/favicon.ico"],"assets/images/openai_profile.png":["assets/images/openai_profile.png"],"assets/images/site_logo.png":["assets/images/site_logo.png"],"assets/images/user_profile.png":["assets/images/user_profile.png"],"assets/images/vicuna_profile.jpg":["assets/images/vicuna_profile.jpg"],"assets/lotties/click.json":["assets/lotties/click.json"],"assets/lotties/fail.json":["assets/lotties/fail.json"],"assets/lotties/file-upload.json":["assets/lotties/file-upload.json"],"assets/lotties/go-back.json":["assets/lotties/go-back.json"],"assets/lotties/ok.json":["assets/lotties/ok.json"],"assets/lotties/read.json":["assets/lotties/read.json"],"assets/lotties/scroll-down.json":["assets/lotties/scroll-down.json"],"assets/lotties/search-doc.json":["assets/lotties/search-doc.json"],"assets/lotties/search-web.json":["assets/lotties/search-web.json"],"assets/lotties/translate.json":["assets/lotties/translate.json"],"assets/svgs/search-web.svg":["assets/svgs/search-web.svg"],"packages/cupertino_icons/assets/CupertinoIcons.ttf":["packages/cupertino_icons/assets/CupertinoIcons.ttf"]}
--------------------------------------------------------------------------------
/app/web/assets/FontManifest.json:
--------------------------------------------------------------------------------
1 | [{"family":"MaterialIcons","fonts":[{"asset":"fonts/MaterialIcons-Regular.otf"}]},{"family":"packages/cupertino_icons/CupertinoIcons","fonts":[{"asset":"packages/cupertino_icons/assets/CupertinoIcons.ttf"}]}]
--------------------------------------------------------------------------------
/app/web/assets/assets/images/ai_profile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/ai_profile.png
--------------------------------------------------------------------------------
/app/web/assets/assets/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/favicon.ico
--------------------------------------------------------------------------------
/app/web/assets/assets/images/openai_profile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/openai_profile.png
--------------------------------------------------------------------------------
/app/web/assets/assets/images/site_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/site_logo.png
--------------------------------------------------------------------------------
/app/web/assets/assets/images/user_profile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/user_profile.png
--------------------------------------------------------------------------------
/app/web/assets/assets/images/vicuna_profile.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/assets/images/vicuna_profile.jpg
--------------------------------------------------------------------------------
/app/web/assets/assets/lotties/scroll-down.json:
--------------------------------------------------------------------------------
1 | {"v":"5.5.2","fr":120,"ip":0,"op":120,"w":192,"h":192,"nm":"Comp 2","ddd":0,"assets":[],"layers":[{"ddd":0,"ind":1,"ty":4,"nm":"Arrow-Down Outlines","sr":1,"ks":{"o":{"a":0,"k":100,"ix":11},"r":{"a":0,"k":0,"ix":10},"p":{"a":1,"k":[{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":0,"s":[96,91,0],"to":[0,2.833,0],"ti":[0,3.5,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":38,"s":[96,108,0],"to":[0,-3.5,0],"ti":[0,1.333,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":60,"s":[96,70,0],"to":[0,-1.333,0],"ti":[0,-3.5,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.167,"y":0},"t":77,"s":[96,100,0],"to":[0,3.5,0],"ti":[0,1.5,0]},{"t":88,"s":[96,91,0]}],"ix":2},"a":{"a":0,"k":[12,12,0],"ix":1},"s":{"a":0,"k":[825,825,100],"ix":6}},"ao":0,"shapes":[{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[0,0],[0,0]],"o":[[0,0],[0,0]],"v":[[12,18],[12,6]],"c":false},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"tm","s":{"a":0,"k":0,"ix":1},"e":{"a":1,"k":[{"i":{"x":[0.667],"y":[1]},"o":{"x":[0.333],"y":[0]},"t":0,"s":[100]},{"i":{"x":[0.667],"y":[1]},"o":{"x":[0.333],"y":[0]},"t":38,"s":[32]},{"t":60,"s":[100]}],"ix":2},"o":{"a":0,"k":0,"ix":3},"m":1,"ix":2,"nm":"Trim Paths 1","mn":"ADBE Vector Filter - Trim","hd":false},{"ty":"st","c":{"a":0,"k":[0.6392156862745098,0.9333333333333333,0.9137254901960784,1],"ix":3},"o":{"a":0,"k":100,"ix":4},"w":{"a":0,"k":2,"ix":5},"lc":2,"lj":2,"bm":0,"nm":"Stroke 1","mn":"ADBE Vector Graphic - Stroke","hd":false},{"ty":"tr","p":{"a":1,"k":[{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":38,"s":[0,0],"to":[0,-0.167],"ti":[0,0]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":60,"s":[0,-1],"to":[0,0],"ti":[0,-0.167]},{"t":77,"s":[0,0]}],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[100,100],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 1","np":3,"cix":2,"bm":0,"ix":1,"mn":"ADBE Vector Group","hd":false},{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":1,"k":[{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":0,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-5,-2.5],[0,2.5],[5,-2.5]],"c":false}]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":38,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-2.612,-1.742],[-0.064,2.431],[2.484,-1.742]],"c":false}]},{"i":{"x":0.667,"y":1},"o":{"x":0.167,"y":0},"t":60,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-5.966,-2.5],[-0.05,2.5],[5.865,-2.5]],"c":false}]},{"i":{"x":0.667,"y":1},"o":{"x":0.333,"y":0},"t":77,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-4.103,-2.5],[-0.016,2.5],[4.071,-2.5]],"c":false}]},{"t":88,"s":[{"i":[[0,0],[0,0],[0,0]],"o":[[0,0],[0,0],[0,0]],"v":[[-5,-2.5],[0,2.5],[5,-2.5]],"c":false}]}],"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"st","c":{"a":0,"k":[0.6392156862745098,0.9333333333333333,0.9137254901960784,1],"ix":3},"o":{"a":0,"k":100,"ix":4},"w":{"a":0,"k":2,"ix":5},"lc":2,"lj":2,"bm":0,"nm":"Stroke 1","mn":"ADBE Vector Graphic - Stroke","hd":false},{"ty":"tr","p":{"a":0,"k":[12,16.5],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[87.834,87.834],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 2","np":2,"cix":2,"bm":0,"ix":2,"mn":"ADBE Vector Group","hd":false}],"ip":0,"op":1215,"st":0,"bm":0}],"markers":[]}
--------------------------------------------------------------------------------
/app/web/assets/assets/svgs/search-web.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/app/web/assets/fonts/MaterialIcons-Regular.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/fonts/MaterialIcons-Regular.otf
--------------------------------------------------------------------------------
/app/web/assets/packages/cupertino_icons/assets/CupertinoIcons.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/assets/packages/cupertino_icons/assets/CupertinoIcons.ttf
--------------------------------------------------------------------------------
/app/web/canvaskit/canvaskit.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/canvaskit/canvaskit.wasm
--------------------------------------------------------------------------------
/app/web/canvaskit/chromium/canvaskit.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/canvaskit/chromium/canvaskit.wasm
--------------------------------------------------------------------------------
/app/web/canvaskit/skwasm.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/canvaskit/skwasm.wasm
--------------------------------------------------------------------------------
/app/web/canvaskit/skwasm.worker.js:
--------------------------------------------------------------------------------
1 | "use strict";var Module={};if(typeof process==="object"&&typeof process.versions==="object"&&typeof process.versions.node==="string"){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",function(data){onmessage({data:data})});var nodeFS=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(nodeFS.readFileSync(f,"utf8"))},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=((info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports});self.onmessage=(e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob==="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}skwasm(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.threadInfoStruct,/*isMainBrowserThread=*/0,/*isMainRuntimeThread=*/0,/*canBlock=*/1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInit();try{var result=Module["invokeEntryPoint"](e.data.start_routine,e.data.arg);if(Module["keepRuntimeAlive"]()){Module["PThread"].setExitStatus(result)}else{Module["__emscripten_thread_exit"](result)}}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processThreadQueue"){if(Module["_pthread_self"]()){Module["_emscripten_current_thread_process_queued_calls"]()}}else{err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){err("worker.js onmessage() captured an uncaught exception: "+ex);if(ex&&ex.stack)err(ex.stack);throw ex}});
2 |
--------------------------------------------------------------------------------
/app/web/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/favicon.ico
--------------------------------------------------------------------------------
/app/web/icons/Icon-192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-192.png
--------------------------------------------------------------------------------
/app/web/icons/Icon-512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-512.png
--------------------------------------------------------------------------------
/app/web/icons/Icon-maskable-192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-maskable-192.png
--------------------------------------------------------------------------------
/app/web/icons/Icon-maskable-512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/icons/Icon-maskable-512.png
--------------------------------------------------------------------------------
/app/web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | LLMChat
21 |
22 |
23 |
27 |
28 |
29 |
30 |
31 |
32 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/app/web/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "LLMChat",
3 | "short_name": "LLMChat",
4 | "start_url": ".",
5 | "display": "standalone",
6 | "background_color": "#0175C2",
7 | "theme_color": "#0175C2",
8 | "description": "Seamless chat experience",
9 | "orientation": "portrait-primary",
10 | "prefer_related_applications": false,
11 | "icons": [
12 | {
13 | "src": "icons/Icon-192.png",
14 | "sizes": "192x192",
15 | "type": "image/png"
16 | },
17 | {
18 | "src": "icons/Icon-512.png",
19 | "sizes": "512x512",
20 | "type": "image/png"
21 | },
22 | {
23 | "src": "icons/Icon-maskable-192.png",
24 | "sizes": "192x192",
25 | "type": "image/png",
26 | "purpose": "maskable"
27 | },
28 | {
29 | "src": "icons/Icon-maskable-512.png",
30 | "sizes": "512x512",
31 | "type": "image/png",
32 | "purpose": "maskable"
33 | }
34 | ]
35 | }
--------------------------------------------------------------------------------
/app/web/site_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/app/web/site_logo.png
--------------------------------------------------------------------------------
/app/web/version.json:
--------------------------------------------------------------------------------
1 | {"app_name":"flutter_web","version":"1.0.0","build_number":"1","package_name":"flutter_web"}
--------------------------------------------------------------------------------
/build-web.bat:
--------------------------------------------------------------------------------
1 | cd frontend
2 | call flutter build web --base-href /chat/ --web-renderer canvaskit
3 |
4 | cd ..
5 | IF EXIST "app\web" (
6 | rmdir /s /q "app\web"
7 | )
8 | IF NOT EXIST "app\web" (
9 | mkdir "app\web"
10 | )
11 | xcopy /s /y "frontend\build\web\*" "app\web\"
12 |
--------------------------------------------------------------------------------
/docker-compose-local.yaml:
--------------------------------------------------------------------------------
1 | version: '3.9'
2 |
3 | volumes:
4 | mysql:
5 | redis:
6 | qdrant:
7 |
8 |
9 | networks:
10 | reverse-proxy-public:
11 | driver: bridge
12 | ipam:
13 | driver: default
14 | config:
15 | - subnet: 172.16.0.0/24 # subnet for traefik and other services
16 |
17 | services:
18 | # No need to use traefik for local development
19 |
20 | db:
21 | image: mysql:latest
22 | restart: always
23 | ports:
24 | - "3306:3306"
25 | environment:
26 | MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD}"
27 | MYSQL_ROOT_HOST: "%"
28 | MYSQL_DATABASE: "${MYSQL_DATABASE}"
29 | MYSQL_USER: "${MYSQL_USER}"
30 | MYSQL_PASSWORD: "${MYSQL_PASSWORD}"
31 | TZ: "Asia/Seoul"
32 | volumes:
33 | - mysql:/var/lib/mysql
34 | networks:
35 | reverse-proxy-public:
36 | ipv4_address: 172.16.0.11 # static IP
37 |
38 | cache:
39 | image: redis/redis-stack-server:latest
40 | restart: always
41 | environment:
42 | - REDIS_ARGS=--requirepass ${REDIS_PASSWORD} --maxmemory 100mb --maxmemory-policy allkeys-lru --appendonly yes
43 | ports:
44 | - "6379:6379"
45 | volumes:
46 | - redis:/data
47 | networks:
48 | reverse-proxy-public:
49 | ipv4_address: 172.16.0.12 # static IP
50 |
51 | vectorstore:
52 | image: qdrant/qdrant
53 | restart: always
54 | ports:
55 | - "6333:6333"
56 | - "6334:6334"
57 | volumes:
58 | - qdrant:/qdrant/storage
59 | networks:
60 | reverse-proxy-public:
61 | ipv4_address: 172.16.0.13 # static IP
62 |
63 | api:
64 | image: cosogi/llmchat:230703
65 | restart: always
66 | env_file:
67 | - .env
68 | command:
69 | - "--host"
70 | - "0.0.0.0"
71 | - "--port"
72 | - "8000"
73 | ports:
74 | - "8000:8000"
75 | depends_on:
76 | - db
77 | - cache
78 | volumes:
79 | - .:/app
80 | networks:
81 | reverse-proxy-public:
82 | ipv4_address: 172.16.0.14 # static IP
83 |
--------------------------------------------------------------------------------
/docker-compose-prod.yaml:
--------------------------------------------------------------------------------
1 | version: '3.9'
2 |
3 | volumes:
4 | mysql:
5 | redis:
6 | qdrant:
7 |
8 |
9 | networks:
10 | reverse-proxy-public:
11 | driver: bridge
12 | ipam:
13 | driver: default
14 |
15 | services:
16 | proxy:
17 | image: traefik
18 | command:
19 | - "--entrypoints.web.address=:80"
20 | - "--entrypoints.websecure.address=:443"
21 | - "--entrypoints.mysql.address=:3306"
22 | - "--entrypoints.redis.address=:6379"
23 | - "--entryPoints.web.http.redirections.entryPoint.to=websecure"
24 | - "--entryPoints.web.http.redirections.entryPoint.scheme=https"
25 | - "--providers.docker"
26 | - "--providers.docker.exposedbydefault=false"
27 | - "--api.insecure=false"
28 | - "--certificatesresolvers.myresolver.acme.tlschallenge=true"
29 | - "--certificatesresolvers.myresolver.acme.email=${MY_EMAIL}"
30 | - "--certificatesresolvers.myresolver.acme.storage=/letsencrypt/acme.json"
31 | # - "--log.level=DEBUG"
32 | # - "--certificatesresolvers.myresolver.acme.caserver=https://acme-staging-v02.api.letsencrypt.org/directory"
33 | ports:
34 | - "80:80"
35 | - "443:443"
36 | - "3306:3306"
37 | - "6379:6379"
38 | volumes:
39 | - /var/run/docker.sock:/var/run/docker.sock
40 | - ./letsencrypt:/letsencrypt
41 | networks:
42 | - reverse-proxy-public
43 |
44 | db:
45 | image: mysql
46 | restart: always
47 | environment:
48 | MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD}"
49 | MYSQL_ROOT_HOST: "%"
50 | MYSQL_DATABASE: "${MYSQL_DATABASE}"
51 | MYSQL_USER: "${MYSQL_USER}"
52 | MYSQL_PASSWORD: "${MYSQL_PASSWORD}"
53 | TZ: "Asia/Seoul"
54 | volumes:
55 | - mysql:/var/lib/mysql
56 | # - ./my.cnf:/etc/mysql/conf.d/my.cnf
57 | labels:
58 | - "traefik.enable=true"
59 | - "traefik.tcp.routers.db.rule=HostSNI(`*`)"
60 | - "traefik.tcp.services.db.loadbalancer.server.port=3306"
61 | - "traefik.tcp.routers.db.entrypoints=mysql"
62 | networks:
63 | - reverse-proxy-public
64 |
65 | cache:
66 | image: redis/redis-stack-server:latest
67 | restart: always
68 | environment:
69 | - REDIS_ARGS=--requirepass ${REDIS_PASSWORD} --maxmemory 100mb --maxmemory-policy allkeys-lru --appendonly yes
70 | volumes:
71 | - redis:/data
72 | labels:
73 | - "traefik.enable=true"
74 | - "traefik.tcp.routers.cache.rule=HostSNI(`*`)"
75 | - "traefik.tcp.services.cache.loadbalancer.server.port=6379"
76 | - "traefik.tcp.routers.cache.entrypoints=redis"
77 | networks:
78 | - reverse-proxy-public
79 |
80 | api:
81 | image: cosogi/llmchat:230703
82 | restart: always
83 | env_file:
84 | - .env
85 | command:
86 | - "--host"
87 | - "0.0.0.0"
88 | - "--port"
89 | - "8000"
90 | labels:
91 | - "traefik.enable=true"
92 | - "traefik.docker.network=reverse-proxy-public"
93 | - "traefik.http.routers.api.rule=HostRegexp(`${HOST_MAIN}`, `{subdomain:[a-z]+}.${HOST_MAIN}`, `${HOST_IP}`)"
94 | - "traefik.http.routers.api.entrypoints=websecure"
95 | - "traefik.http.services.api.loadbalancer.server.scheme=http"
96 | - "traefik.http.services.api.loadbalancer.server.port=8000"
97 | - "traefik.http.routers.api.tls=true"
98 | - "traefik.http.routers.api.tls.certresolver=myresolver"
99 | - "traefik.http.routers.api.tls.domains[0].main=${HOST_MAIN}"
100 | - "traefik.http.routers.api.tls.domains[0].sans=${HOST_SUB}"
101 | depends_on:
102 | - proxy
103 | - db
104 | - cache
105 | - vectorstore
106 | volumes:
107 | - .:/app
108 | networks:
109 | - reverse-proxy-public
110 |
111 | vectorstore:
112 | image: qdrant/qdrant:latest
113 | restart: always
114 | volumes:
115 | - qdrant:/qdrant/storage
116 | networks:
117 | - reverse-proxy-public
118 | # search-api:
119 | # image: searxng/searxng:latest
120 | # restart: always
121 |
122 | # environment:
123 | # REDIS_PASSWORD: "${REDIS_PASSWORD}"
124 | # JWT_SECRET: "${JWT_SECRET}"
125 | # volumes:
126 | # - ./searxng:/etc/searxng
127 | # networks:
128 | # - api-private
129 |
--------------------------------------------------------------------------------
/dockerfile:
--------------------------------------------------------------------------------
1 | FROM tiangolo/uvicorn-gunicorn-fastapi
2 | RUN apt-get update -y
3 | RUN apt-get install poppler-utils -y
4 | RUN apt-get install tesseract-ocr -y
5 | COPY requirements.txt ./
6 | RUN pip install --trusted-host pypi.python.org --no-cache-dir --upgrade -r requirements.txt
7 | ENTRYPOINT ["uvicorn", "main:app"]
--------------------------------------------------------------------------------
/init-git.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | git init
3 | # git config core.sparseCheckout true
4 | git remote add -f origin git@github.com:c0sogi/LLMChat.git
5 | # echo "app" >> .git/info/sparse-checkout
--------------------------------------------------------------------------------
/llama_models/ggml/llama_cpp_models_here.txt:
--------------------------------------------------------------------------------
1 | The LLama.cpp GGML model must be put here as a file.
2 |
3 | For example, if you downloaded a q4_0 quantized model from "https://huggingface.co/TheBloke/robin-7B-v2-GGML",
4 | The path of the model has to be "robin-7b.ggmlv3.q4_0.bin".
--------------------------------------------------------------------------------
/llama_models/gptq/exllama_models_here.txt:
--------------------------------------------------------------------------------
1 | The Exllama GPTQ model must be put here as a folder.
2 |
3 | For example, if you downloaded 3 files from "https://huggingface.co/TheBloke/orca_mini_7B-GPTQ/tree/main":
4 |
5 | - orca-mini-7b-GPTQ-4bit-128g.no-act.order.safetensors
6 | - tokenizer.model
7 | - config.json
8 |
9 | Then you need to put them in a folder.
10 | The path of the model has to be the folder name. Let's say, "orca_mini_7b", which contains the 3 files.
11 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from os import environ
2 | from maintools import initialize_before_launch
3 |
4 |
5 | if __name__ == "__mp_main__":
6 | """Option 1: Skip section for multiprocess spawning
7 | This section will be skipped when running in multiprocessing mode"""
8 | pass
9 |
10 | elif __name__ == "__main__":
11 | #
12 | """Option 2: Debug mode
13 | Running this file directly to debug the app
14 | Run this section if you don't want to run app in docker"""
15 |
16 | if environ.get("API_ENV") != "test":
17 | environ["API_ENV"] = "local"
18 | environ["DOCKER_MODE"] = "False"
19 |
20 | import uvicorn
21 |
22 | from app.common.app_settings import create_app
23 | from app.common.config import config
24 |
25 | initialize_before_launch()
26 | app = create_app(config=config)
27 | uvicorn.run(
28 | app,
29 | host="0.0.0.0",
30 | port=config.port,
31 | )
32 | else:
33 | """Option 3: Non-debug mode
34 | Docker will run this section as the main entrypoint
35 | This section will mostly be used.
36 | Maybe LLaMa won't work in Docker."""
37 | from app.common.app_settings import create_app
38 | from app.common.config import config
39 |
40 | initialize_before_launch()
41 | app = create_app(config=config)
42 |
--------------------------------------------------------------------------------
/maintools.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 |
4 | def ensure_packages_installed():
5 | import subprocess
6 |
7 | subprocess.call(
8 | [
9 | "pip",
10 | "install",
11 | "--trusted-host",
12 | "pypi.python.org",
13 | "-r",
14 | "requirements.txt",
15 | ]
16 | )
17 |
18 |
19 | def set_priority(pid: Optional[int] = None, priority: str = "high"):
20 | import platform
21 | from os import getpid
22 | import psutil
23 |
24 | """Set The Priority of a Process. Priority is a string which can be 'low', 'below_normal',
25 | 'normal', 'above_normal', 'high', 'realtime'. 'normal' is the default priority."""
26 |
27 | if platform.system() == "Windows":
28 | priorities = {
29 | "low": psutil.IDLE_PRIORITY_CLASS,
30 | "below_normal": psutil.BELOW_NORMAL_PRIORITY_CLASS,
31 | "normal": psutil.NORMAL_PRIORITY_CLASS,
32 | "above_normal": psutil.ABOVE_NORMAL_PRIORITY_CLASS,
33 | "high": psutil.HIGH_PRIORITY_CLASS,
34 | "realtime": psutil.REALTIME_PRIORITY_CLASS,
35 | }
36 | else: # Linux and other Unix systems
37 | priorities = {
38 | "low": 19,
39 | "below_normal": 10,
40 | "normal": 0,
41 | "above_normal": -5,
42 | "high": -11,
43 | "realtime": -20,
44 | }
45 |
46 | if pid is None:
47 | pid = getpid()
48 | p = psutil.Process(pid)
49 | p.nice(priorities[priority])
50 |
51 |
52 | def initialize_before_launch(install_packages: bool = False):
53 | """Initialize the app"""
54 | import platform
55 |
56 | if platform.system() == "Windows":
57 | set_priority(priority="high")
58 | else:
59 | set_priority(priority="normal")
60 | if install_packages:
61 | ensure_packages_installed()
62 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | select = ["RUF100", "E501"]
3 | line-length = 120
4 |
--------------------------------------------------------------------------------
/qdrant/.lock:
--------------------------------------------------------------------------------
1 | tmp lock file
--------------------------------------------------------------------------------
/qdrant/meta.json:
--------------------------------------------------------------------------------
1 | {"collections": {}, "aliases": {}}
--------------------------------------------------------------------------------
/requirements-exllama.txt:
--------------------------------------------------------------------------------
1 | safetensors==0.3.1
2 | sentencepiece>=0.1.97
3 | ninja==1.11.1
4 | --find-links https://download.pytorch.org/whl/torch_stable.html
5 | torch==2.0.1+cu118
--------------------------------------------------------------------------------
/requirements-llama-cpp.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24
2 | sentencepiece==0.1.98
--------------------------------------------------------------------------------
/requirements-sentence-encoder.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-learn
3 | tensorflow>=2.0.0
4 | tensorflow-hub
--------------------------------------------------------------------------------
/requirements-transformer-embedding-cpu.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 | torch==2.0.1
3 |
--------------------------------------------------------------------------------
/requirements-transformer-embedding-gpu.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 | torch==2.0.1+cu118
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiofiles
2 | aiomysql
3 | attrs
4 | bcrypt
5 | boto3
6 | botocore
7 | cachetools
8 | certifi
9 | cffi
10 | chardet
11 | click
12 | cssselect
13 | cssutils
14 | dnspython
15 | email-validator
16 | fastapi
17 | h11
18 | httptools
19 | httpx
20 | idna
21 | itsdangerous
22 | iniconfig
23 | jinja2
24 | jmespath
25 | langchain
26 | llama-cpp-python[server]
27 | lxml
28 | orjson
29 | openai
30 | packaging
31 | pdf2image
32 | pexpect
33 | pluggy
34 | premailer
35 | psutil
36 | py
37 | pycparser
38 | pydantic
39 | PyJWT
40 | PyMySQL
41 | pyparsing
42 | pytesseract
43 | pytest
44 | pytest-asyncio
45 | python-dateutil
46 | python-dotenv
47 | qdrant_client
48 | redis
49 | requests
50 | requests_html
51 | s3transfer
52 | sentencepiece
53 | six
54 | sqlalchemy
55 | sqlalchemy-utils
56 | starlette
57 | starlette_admin
58 | starlette-session
59 | tabulate
60 | tiktoken
61 | toml
62 | transformers
63 | unstructured
64 | urllib3
65 | uvicorn[standard]
66 | websockets
67 | yagmail
68 | cryptography
--------------------------------------------------------------------------------
/run_local.bat:
--------------------------------------------------------------------------------
1 | @REM Virtual environment activation. You must have virtualenv installed.
2 | @REM If you don't have it, delete the next line and just run the last one.
3 | @REM Of course, you must have requirements.txt installed.
4 | @REM use "pip install -r requirements.txt" to install it.
5 | call .venv/Scripts/activate.bat
6 | python -m main
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/tests/__init__.py
--------------------------------------------------------------------------------
/tests/__pycache__/test_completion_api.cpython-311-pytest-7.4.0.pyc.8996:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c0sogi/LLMChat/01f53de6630e7302fc58f103b3fad382d8d270bc/tests/__pycache__/test_completion_api.cpython-311-pytest-7.4.0.pyc.8996
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from datetime import datetime
3 | from typing import AsyncGenerator, Generator
4 | from fastapi import FastAPI
5 | from fastapi.testclient import TestClient
6 | import pytest
7 | import httpx
8 | import pytest_asyncio
9 | from app.utils.auth.token import create_access_token
10 | from app.utils.tests.random import random_user_generator
11 | from app.database.schemas.auth import Users
12 | from app.common.app_settings import create_app
13 | from app.common.config import Config, LoggingConfig
14 | from app.utils.logger import CustomLogger
15 | from app.models.base_models import UserToken
16 | from app.utils.chat.managers.cache import CacheManager, cache
17 |
18 |
19 | @pytest.fixture(scope="session")
20 | def config():
21 | return Config.get(option="test")
22 |
23 |
24 | @pytest.fixture(scope="session")
25 | def cache_manager():
26 | cache.start(config=Config.get(option="test"))
27 | return CacheManager
28 |
29 |
30 | @pytest.fixture(scope="session")
31 | def test_logger():
32 | return CustomLogger(
33 | name="PyTest", logging_config=LoggingConfig(file_log_name="./logs/test.log")
34 | )
35 |
36 |
37 | @pytest.fixture(scope="session", autouse=True)
38 | def event_loop():
39 | loop = asyncio.get_event_loop()
40 | yield loop
41 | loop.close()
42 |
43 |
44 | @pytest.fixture(scope="session")
45 | def app(config) -> FastAPI:
46 | return create_app(config)
47 |
48 |
49 | @pytest.fixture(scope="session")
50 | def base_http_url() -> str:
51 | return "http://localhost"
52 |
53 |
54 | @pytest.fixture(scope="session")
55 | def base_websocket_url() -> str:
56 | return "ws://localhost"
57 |
58 |
59 | @pytest_asyncio.fixture(scope="session")
60 | async def async_client(
61 | app: FastAPI, base_http_url: str
62 | ) -> AsyncGenerator[httpx.AsyncClient, None]:
63 | async with httpx.AsyncClient(app=app, base_url=base_http_url) as ac:
64 | yield ac
65 |
66 |
67 | @pytest.fixture(scope="session")
68 | def client(app: FastAPI) -> Generator[TestClient, None, None]:
69 | with TestClient(app=app) as tc:
70 | yield tc
71 |
72 |
73 | @pytest_asyncio.fixture(scope="session")
74 | async def login_header(random_user: dict[str, str]) -> dict[str, str]:
75 | """
76 | 테스트 전 사용자 미리 등록
77 | """
78 | new_user = await Users.add_one(autocommit=True, refresh=True, **random_user) # type: ignore
79 | access_token = create_access_token(
80 | data=UserToken.from_orm(new_user).dict(exclude={"password", "marketing_agree"}),
81 | expires_delta=24,
82 | )
83 | return {"Authorization": f"Bearer {access_token}"}
84 |
85 |
86 | @pytest_asyncio.fixture(scope="session")
87 | async def api_key_dict(async_client: httpx.AsyncClient, login_header: dict[str, str]):
88 | api_key_memo: str = f"TESTING : {str(datetime.now())}"
89 | response: httpx.Response = await async_client.post(
90 | "/api/user/apikeys",
91 | json={"user_memo": api_key_memo},
92 | headers=login_header,
93 | )
94 | response_body = response.json()
95 | assert response.status_code == 201
96 | assert "access_key" in response_body
97 | assert "secret_key" in response_body
98 | apikey = {
99 | "access_key": response_body["access_key"],
100 | "secret_key": response_body["secret_key"],
101 | }
102 |
103 | response = await async_client.get("/api/user/apikeys", headers=login_header)
104 | response_body = response.json()
105 | assert response.status_code == 200
106 | assert api_key_memo in response_body[0]["user_memo"]
107 | return apikey
108 |
109 |
110 | @pytest.fixture(scope="session")
111 | def random_user():
112 | return random_user_generator()
113 |
114 |
115 | def pytest_collection_modifyitems(items):
116 | app_tests = []
117 | redis_tests = []
118 | other_tests = []
119 | for item in items:
120 | if "client" in item.fixturenames:
121 | app_tests.append(item)
122 | elif item.get_closest_marker("redistest") is not None:
123 | redis_tests.append(item)
124 | else:
125 | other_tests.append(item)
126 | items[:] = redis_tests + app_tests + other_tests
127 |
--------------------------------------------------------------------------------
/tests/test_api_service.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal
2 | from httpx import AsyncClient, Response
3 | import pytest
4 | from app.utils.date_utils import UTC
5 | from app.utils.params_utils import hash_params, parse_params
6 |
7 |
8 | async def request_service(
9 | async_client: AsyncClient,
10 | access_key: str,
11 | secret_key: str,
12 | request_method: Literal["get", "post", "put", "delete", "options"],
13 | allowed_status_codes: tuple = (200, 201, 307),
14 | service_name: str = "",
15 | required_query_params: dict[str, Any] = {},
16 | required_headers: dict[str, Any] = {},
17 | stream: bool = False,
18 | logger=None,
19 | ) -> Any:
20 | all_query_params: str = parse_params(
21 | params={
22 | "key": access_key,
23 | "timestamp": UTC.timestamp(),
24 | }
25 | | required_query_params
26 | )
27 | method_options: dict = {
28 | "headers": {
29 | "secret": hash_params(
30 | query_params=all_query_params,
31 | secret_key=secret_key,
32 | )
33 | }
34 | | required_headers
35 | }
36 | url: str = f"/api/services/{service_name}?{all_query_params}"
37 |
38 | if stream:
39 | response_body = ""
40 | async with async_client.stream(method=request_method.upper(), url=url, **method_options) as response:
41 | assert response.status_code in allowed_status_codes
42 | async for chunk in response.aiter_text():
43 | response_body += chunk
44 | logger.info(f"Streamed data: {chunk}") if logger is not None else ...
45 | else:
46 | response: Response = await getattr(async_client, request_method.lower())(url=url, **method_options)
47 | response_body: Any = response.json()
48 | logger.info(f"response_body: {response_body}") if logger is not None else ...
49 | assert response.status_code in allowed_status_codes
50 | return response_body
51 |
52 |
53 | @pytest.mark.asyncio
54 | async def test_request_api(async_client: AsyncClient, api_key_dict: dict, test_logger):
55 | access_key, secret_key = api_key_dict["access_key"], api_key_dict["secret_key"]
56 | service_name: str = ""
57 | request_method: str = "get"
58 | required_query_params: dict[str, Any] = {}
59 | required_headers: dict[str, Any] = {}
60 | allowed_status_codes: tuple = (200, 201, 307)
61 | await request_service(
62 | async_client=async_client,
63 | access_key=access_key,
64 | secret_key=secret_key,
65 | request_method=request_method,
66 | allowed_status_codes=allowed_status_codes,
67 | service_name=service_name,
68 | required_query_params=required_query_params,
69 | required_headers=required_headers,
70 | stream=False,
71 | logger=test_logger,
72 | )
73 |
74 |
75 | # @pytest.mark.asyncio
76 | # async def test_weather_api(async_client: AsyncClient, api_key_dict: dict):
77 | # access_key, secret_key = api_key_dict["access_key"], api_key_dict["secret_key"]
78 | # service_name: str = "weather"
79 | # request_method: str = "get"
80 | # required_query_params: dict[str, any] = {
81 | # "latitude": 37.0,
82 | # "longitude": 120.0,
83 | # }
84 | # required_headers: dict[str, any] = {}
85 | # allowed_status_codes: tuple[int] = (200, 307)
86 | # await request_service(
87 | # async_client=async_client,
88 | # access_key=access_key,
89 | # secret_key=secret_key,
90 | # request_method=request_method,
91 | # allowed_status_codes=allowed_status_codes,
92 | # service_name=service_name,
93 | # required_query_params=required_query_params,
94 | # required_headers=required_headers,
95 | # stream=False,
96 | # )
97 |
98 |
99 | # @pytest.mark.asyncio
100 | # async def test_stream_api(async_client: AsyncClient, api_key_dict: dict):
101 | # access_key, secret_key = api_key_dict["access_key"], api_key_dict["secret_key"]
102 | # service_name: str = "stream"
103 | # request_method: str = "get"
104 | # required_query_params: dict[str, any] = {}
105 | # required_headers: dict[str, any] = {}
106 | # allowed_status_codes: tuple[int] = (200, 307)
107 | # await request_service(
108 | # async_client=async_client,
109 | # access_key=access_key,
110 | # secret_key=secret_key,
111 | # request_method=request_method,
112 | # allowed_status_codes=allowed_status_codes,
113 | # service_name=service_name,
114 | # required_query_params=required_query_params,
115 | # required_headers=required_headers,
116 | # stream=True,
117 | # )
118 |
--------------------------------------------------------------------------------
/tests/test_auth.py:
--------------------------------------------------------------------------------
1 | from httpx import AsyncClient
2 | import pytest
3 | from fastapi import status
4 | from app.errors.api_exceptions import Responses_400
5 | from app.utils.tests.random import random_user_generator
6 |
7 |
8 | @pytest.mark.asyncio
9 | async def test_register(async_client):
10 | valid_emails = ("ankitrai326@gmail.com", "my.ownsite@our-earth.org", "aa@a.a")
11 | invalid_emails = ("ankitrai326gmail.com", "ankitrai326.com")
12 |
13 | valid_users = [random_user_generator(email=email) for email in valid_emails]
14 | invalid_users = [random_user_generator(email=email) for email in invalid_emails]
15 |
16 | # Valid email
17 | for valid_user in valid_users:
18 | response = await async_client.post("api/auth/register/email", json=valid_user)
19 | response_body = response.json()
20 | assert response.status_code == status.HTTP_201_CREATED
21 | assert "Authorization" in response_body.keys(), response.content.decode("utf-8")
22 |
23 | # Invalid email
24 | for invalid_user in invalid_users:
25 | response = await async_client.post("api/auth/register/email", json=invalid_user)
26 | response_body = response.json()
27 | assert response.status_code != status.HTTP_201_CREATED
28 | assert "Authorization" not in response_body.keys(), response.content.decode("utf-8")
29 |
30 |
31 | @pytest.mark.asyncio
32 | async def test_auth(async_client: AsyncClient):
33 | new_user = random_user_generator()
34 |
35 | # Register
36 | response = await async_client.post("api/auth/register/email", json=new_user)
37 | assert response.status_code == status.HTTP_201_CREATED, response.content.decode("utf-8")
38 |
39 | # Login
40 | response = await async_client.post("api/auth/login/email", json=new_user)
41 | response_body = response.json()
42 | assert response.status_code == status.HTTP_200_OK, "User login faiture"
43 | assert "Authorization" in response_body.keys(), response.content.decode("utf-8")
44 |
45 | # Duplicate email check
46 | response = await async_client.post("api/auth/register/email", json=new_user)
47 | response_body = response.json()
48 | assert response.status_code == Responses_400.email_already_exists.status_code, "Email duplication check failure"
49 | assert response_body["detail"] == Responses_400.email_already_exists.detail, "Email duplication check failure"
50 |
--------------------------------------------------------------------------------
/tests/test_database.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from sqlalchemy import select
3 | from uuid import uuid4
4 | from app.common.config import Config
5 | from app.models.base_models import AddApiKey, UserToken
6 | from app.database.connection import db
7 | from app.database.schemas.auth import Users, ApiKeys
8 | from app.database.crud.api_keys import create_api_key, get_api_key_and_owner
9 | from app.middlewares.token_validator import Validator
10 | from app.utils.date_utils import UTC
11 | from app.utils.params_utils import hash_params, parse_params
12 |
13 | db.start(config=Config.get(option="test"))
14 |
15 |
16 | @pytest.mark.asyncio
17 | async def test_apikey_idenfitication(random_user: dict[str, str]):
18 | user: Users = await Users.add_one(autocommit=True, refresh=True, **random_user) # type: ignore
19 | additional_key_info: AddApiKey = AddApiKey(user_memo="[Testing] test_apikey_query")
20 | new_api_key: ApiKeys = await create_api_key(
21 | user_id=user.id, additional_key_info=additional_key_info
22 | )
23 | matched_api_key, matched_user = await get_api_key_and_owner(
24 | access_key=new_api_key.access_key
25 | )
26 | assert matched_api_key.access_key == new_api_key.access_key
27 | assert random_user["email"] == matched_user.email
28 |
29 |
30 | @pytest.mark.asyncio
31 | async def test_api_key_validation(random_user: dict[str, str]):
32 | user: Users = await Users.add_one(autocommit=True, refresh=True, **random_user) # type: ignore
33 | additional_key_info: AddApiKey = AddApiKey(user_memo="[Testing] test_apikey_query")
34 | new_api_key: ApiKeys = await create_api_key(
35 | user_id=user.id, additional_key_info=additional_key_info
36 | )
37 | timestamp: str = str(UTC.timestamp())
38 | parsed_qs: str = parse_params(
39 | params={"key": new_api_key.access_key, "timestamp": timestamp}
40 | )
41 | user_token: UserToken = await Validator.api_key(
42 | api_access_key=new_api_key.access_key,
43 | hashed_secret=hash_params(
44 | query_params=parsed_qs, secret_key=new_api_key.secret_key
45 | ),
46 | query_params=parsed_qs,
47 | timestamp=timestamp,
48 | )
49 | assert user_token.id == user.id
50 |
51 |
52 | @pytest.mark.asyncio
53 | async def test_crud():
54 | total_users: int = 4
55 | users: list[dict[str, str]] = [
56 | {"email": str(uuid4())[:18], "password": str(uuid4())[:18]}
57 | for _ in range(total_users)
58 | ]
59 | user_1, user_2, user_3, user_4 = users
60 |
61 | # C/U
62 | await Users.add_all(user_1, user_2, autocommit=True, refresh=True)
63 | await Users.add_one(autocommit=True, refresh=True, **user_3)
64 | await Users.update_filtered(
65 | Users.email == user_1["email"],
66 | Users.password == user_1["password"],
67 | updated={"email": "UPDATED", "password": "updated"},
68 | autocommit=True,
69 | )
70 |
71 | # R
72 | result_1_stmt = select(Users).filter(
73 | Users.email.in_([user_1["email"], "UPDATED", user_3["email"]])
74 | )
75 | result_1 = await db.scalars__fetchall(result_1_stmt)
76 | assert len(result_1) == 2
77 | assert result_1[0].email == "UPDATED" # type: ignore
78 | assert result_1[1].email == user_3["email"] # type: ignore
79 | result_2 = await Users.one_filtered_by(**user_2)
80 | assert result_2.email == user_2["email"] # type: ignore
81 | result_3 = await Users.fetchall_filtered_by(**user_4)
82 | assert len(result_3) == 0
83 |
84 | # D
85 | await db.delete(result_2, autocommit=True)
86 | result_4_stmt = select(Users).filter_by(**user_2)
87 | result_4 = (await db.scalars(stmt=result_4_stmt)).first()
88 | assert result_4 is None
89 |
--------------------------------------------------------------------------------
/tests/test_dynamic_llm_load.py:
--------------------------------------------------------------------------------
1 | from app.models.llms import LLMModels
2 |
3 |
4 | def test_dynamic_llm_load():
5 | print("- member_map keys -")
6 | print(LLMModels.member_map.keys())
7 | print("- static member map keys -")
8 | print(LLMModels.static_member_map.keys())
9 | print("- dynamic member map keys -")
10 | print(LLMModels.dynamic_member_map.keys())
11 |
12 | LLMModels.add_member("foo", {"bar": "baz"})
13 |
14 | print("- member_map keys -")
15 | print(LLMModels.member_map.keys())
16 | print("- static member map keys -")
17 | print(LLMModels.static_member_map.keys())
18 | print("- dynamic member map keys -")
19 | print(LLMModels.dynamic_member_map.keys())
20 |
21 | print("- member_map class -")
22 | print(LLMModels.member_map)
23 |
24 | print("- is dynamic member instance of LLMModels ? -")
25 | print(isinstance(LLMModels.dynamic_member_map["foo"], LLMModels))
26 | print(f"- {LLMModels.dynamic_member_map['foo']} -")
27 |
--------------------------------------------------------------------------------
/tests/test_memory_deallocation.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import queue
3 | from collections import deque
4 |
5 | from app.utils.system import free_memory_of_first_item_from_container
6 |
7 | DEL_COUNT: int = 0
8 |
9 |
10 | class DummyObject:
11 | """Dummy object for testing."""
12 |
13 | foo: bool
14 | bar: bool
15 |
16 | def __init__(self) -> None:
17 | """Initialize."""
18 | self.foo = True
19 | self.bar = False
20 |
21 | def __del__(self) -> None:
22 | """Clean up resources."""
23 | global DEL_COUNT
24 | DEL_COUNT += 1
25 |
26 |
27 | def test_deallocate_item_from_memory_by_reference():
28 | """Tests if __del__ is called when item is removed from container."""
29 | global DEL_COUNT
30 |
31 | # Test with a deque
32 | _deque = deque([DummyObject() for _ in range(10)])
33 | _list = [DummyObject() for _ in range(10)]
34 | _dict = {i: DummyObject() for i in range(10)}
35 | _queue = queue.Queue()
36 | _asyncio_queue = asyncio.Queue()
37 | for _ in range(10):
38 | _queue.put(DummyObject())
39 | _asyncio_queue.put_nowait(DummyObject())
40 |
41 | # Test begin
42 | for container in [_deque, _list, _dict, _queue, _asyncio_queue]:
43 | print(f"- Testing {container.__class__.__name__}")
44 | DEL_COUNT = 0
45 | for _ in range(10):
46 | free_memory_of_first_item_from_container(container)
47 | print(f"- Finished testing {container.__class__.__name__}")
48 | assert (
49 | DEL_COUNT >= 10
50 | ), f"At least {container} should have 10 items deleted"
51 |
52 |
53 | def test_dereference():
54 | obj = DummyObject()
55 |
56 | def delete_obj(obj) -> None:
57 | del obj
58 |
59 | def delete_foo(obj) -> None:
60 | del obj.foo
61 |
62 | delete_obj(obj)
63 | assert obj is not None # Check if obj is still in memory
64 | delete_foo(obj)
65 | assert not hasattr(obj, "foo") # Check if obj.foo is deleted
66 |
--------------------------------------------------------------------------------
/tests/test_translate.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from app.common.config import (
3 | GOOGLE_TRANSLATE_API_KEY,
4 | PAPAGO_CLIENT_ID,
5 | PAPAGO_CLIENT_SECRET,
6 | RAPID_API_KEY,
7 | CUSTOM_TRANSLATE_URL,
8 | )
9 | from app.utils.api.translate import Translator
10 |
11 |
12 | @pytest.mark.asyncio
13 | async def test_custom_translate(test_logger):
14 | translated = await Translator.custom_translate_api(
15 | text="hello",
16 | src_lang="en",
17 | trg_lang="ko",
18 | api_url=str(CUSTOM_TRANSLATE_URL),
19 | )
20 | test_logger.info(__name__ + ":" + translated)
21 |
22 |
23 | @pytest.mark.asyncio
24 | async def test_google_translate(test_logger):
25 | translated = await Translator.google(
26 | text="hello",
27 | src_lang="en",
28 | trg_lang="ko",
29 | api_key=str(GOOGLE_TRANSLATE_API_KEY),
30 | )
31 | test_logger.info(__name__ + ":" + translated)
32 |
33 |
34 | @pytest.mark.asyncio
35 | async def test_papago(test_logger):
36 | translated = await Translator.papago(
37 | text="hello",
38 | src_lang="en",
39 | trg_lang="ko",
40 | client_id=str(PAPAGO_CLIENT_ID),
41 | client_secret=str(PAPAGO_CLIENT_SECRET),
42 | )
43 | test_logger.info(__name__ + ":" + translated)
44 |
45 |
46 | @pytest.mark.asyncio
47 | async def test_deepl(test_logger):
48 | translated = await Translator.deepl_via_rapid_api(
49 | text="hello",
50 | src_lang="en",
51 | trg_lang="ko",
52 | api_key=str(RAPID_API_KEY),
53 | )
54 | test_logger.info(__name__ + ":" + translated)
55 |
56 |
57 | @pytest.mark.asyncio
58 | async def test_auto_translate(test_logger):
59 | translated = await Translator.translate(
60 | text="hello",
61 | src_lang="en",
62 | trg_lang="ko",
63 | )
64 | test_logger.info("FIRST:" + translated + str(Translator.cached_function) + str(Translator.cached_args))
65 | translated = await Translator.translate(
66 | text="what is your name?",
67 | src_lang="en",
68 | trg_lang="ko",
69 | )
70 | test_logger.info("SECOND:" + translated + str(Translator.cached_function) + str(Translator.cached_args))
71 |
--------------------------------------------------------------------------------
/update-git.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | git fetch
3 | git reset --hard origin/master
4 | git pull origin master
--------------------------------------------------------------------------------