├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yaml ├── locust ├── data │ ├── chat_post.json │ ├── completion_parameters_post.json │ ├── message_post.json │ ├── prompt_post.json │ └── prompt_template_post.json └── locust_test.py ├── poetry.lock ├── pyproject.toml ├── src └── restllm │ ├── __init__.py │ ├── cli.py │ ├── cryptography │ ├── __init__.py │ ├── authentication.py │ ├── keys.py │ └── secure_url.py │ ├── dependencies.py │ ├── endpoints │ ├── __init__.py │ ├── authentication.py │ ├── completion.py │ └── crud.py │ ├── exceptions.py │ ├── main.py │ ├── middleware.py │ ├── models │ ├── __init__.py │ ├── authentication.py │ ├── base.py │ ├── chat.py │ ├── completion.py │ ├── events.py │ ├── functions.py │ ├── prompts.py │ ├── share.py │ └── validators.py │ ├── redis │ ├── __init__.py │ ├── commands.py │ ├── events.py │ ├── index.py │ ├── keys.py │ ├── queries.py │ ├── ratelimit.py │ └── search.py │ ├── routers │ ├── __init__.py │ ├── authentication.py │ ├── chats.py │ ├── completion.py │ ├── completion_parameters.py │ ├── events.py │ ├── functions.py │ ├── messages.py │ ├── prompts.py │ ├── share.py │ └── users.py │ ├── settings.py │ ├── tasks │ └── email.py │ └── types │ ├── __init__.py │ ├── paths.py │ └── queries.py └── tests └── unittests ├── test_cryptography.py ├── test_model_authentication.py └── test_model_prompts.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Redis data 163 | redis-data/ 164 | .vscode/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-buster 2 | 3 | WORKDIR /app 4 | 5 | COPY ./pyproject.toml ./poetry.lock /app/ 6 | COPY ./src /app/src 7 | 8 | RUN pip install poetry && \ 9 | poetry install 10 | 11 | CMD ["poetry", "run", "uvicorn", "restllm.main:app", "--reload", "--host", "0.0.0.0", "--port", "8000"] 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REST LLM 2 | This repository is work in progress! 3 | 4 | The repository contains a functioning REST API based on LiteLLM and the interface from OpenAI to instruct and chat models using their familiar completion API. The REST API is build with FastAPI wraps the LLM functionalities together with a Redis backend for CRUD operations. 5 | 6 | ## Authentication 7 | 8 | The API uses Oauth2 implementation in FastAPI for authentication. Currently there is login endpoint `/v1/authentication/token` and a `/v1/authentication/signup`. The password hashing algorithm defaults to `argon2`. 9 | 10 | ## Setup development environment 11 | To setup for development use poetry to install the package. 12 | ```bash 13 | poetry install 14 | ``` 15 | 16 | You can also run the full application with backends using docker compose. Run `docker compose build` to build the local image. Then run `docker compose up` to run the application. 17 | 18 | Check endpoint http://localhost:8000/docs for Swagger UI. 19 | 20 | ## Roadmap 21 | Below are some of the progressions in the repo 22 | 23 | | Feature | Status | 24 | |-------------------------------------------------|-------------| 25 | | CRUD endpoints for Chat | ✅ Done | 26 | | CRUD endpoints for Prompts and PromptTemplates | ✅ Done | 27 | | CRUD endpoints for custom functions | ✅ Done | 28 | | Endpoints for premade function calls | ✅ Done | 29 | | Events endpoint for listening to events | 🟥 In progress | 30 | | User authentication | 🟥 In progress | 31 | | Reset password | 🟥 In progress | 32 | | Email verification | ✅ Done | 33 | | Setup GitHub actions pipeline | 🟥 Coming | 34 | | Acceptance tests and unittests | 🟥 Coming | 35 | | Helm Chart for deployment on Kubernetes | 🟥 Coming | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | redis: 5 | image: redis/redis-stack:7.2.0-v6 6 | ports: 7 | - "6379:6379" 8 | volumes: 9 | - ./redis-data:/data 10 | environment: 11 | REDIS_ARGS: "--appendonly yes" 12 | 13 | restllm: 14 | build: 15 | context: . 16 | dockerfile: Dockerfile 17 | command: ["poetry", "run", "uvicorn", "restllm.main:app", "--reload", "--host", "0.0.0.0", "--port", "8000", "--workers", "2", "--log-level", "info"] 18 | volumes: 19 | - ./src:/app/src 20 | - ./poetry.lock:/app/poetry.lock 21 | - ./pyproject.toml:/app/pyproject.toml 22 | ports: 23 | - "8000:8000" 24 | - "11434:11434" 25 | depends_on: 26 | - redis 27 | environment: 28 | REDIS_DSN: redis://redis:6379/0 29 | OLLAMA_BASE_URL: http://localhost:11434 30 | OPENAI_API_KEY: ${OPENAI_API_KEY} 31 | COHERE_API_KEY: ${COHERE_API_KEY} 32 | EMAIL_USERNAME: ${EMAIL_USERNAME} 33 | EMAIL_PASSWORD: ${EMAIL_PASSWORD} 34 | EMAIL_HOSTNAME: ${EMAIL_HOSTNAME} 35 | 36 | schematest: 37 | image: schemathesis/schemathesis:stable 38 | profiles: 39 | - excluded 40 | command: > 41 | run http://restllm:8000/openapi.json 42 | depends_on: 43 | - restllm -------------------------------------------------------------------------------- /locust/data/chat_post.json: -------------------------------------------------------------------------------- 1 | { 2 | "completion_parameters": { 3 | "model": "gpt-3.5-turbo", 4 | "functions": null, 5 | "temperature": 0.2, 6 | "top_p": 0, 7 | "n": 1, 8 | "stop": "string", 9 | "max_tokens": null, 10 | "presence_penalty": null, 11 | "frequency_penalty": null, 12 | "logit_bias": null, 13 | "user": null 14 | }, 15 | "messages": [ 16 | { 17 | "role": "user", 18 | "content": "Can you write a function in Python that adds two numbers together?", 19 | "name": "Alice", 20 | "function_call": { 21 | "name": "get_weather_status", 22 | "args": {} 23 | } 24 | } 25 | ] 26 | } -------------------------------------------------------------------------------- /locust/data/completion_parameters_post.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "gpt-3.5-turbo", 3 | "functions": null, 4 | "temperature": 0.2, 5 | "top_p": 0, 6 | "n": 1, 7 | "stop": "string", 8 | "max_tokens": null, 9 | "presence_penalty": null, 10 | "frequency_penalty": null, 11 | "logit_bias": null, 12 | "user": null 13 | } -------------------------------------------------------------------------------- /locust/data/message_post.json: -------------------------------------------------------------------------------- 1 | { 2 | "role": "user", 3 | "content": "Can you write a function in Python that adds two numbers together?", 4 | "name": "Alice", 5 | "function_call": { 6 | "name": "get_weather_status", 7 | "args": {} 8 | } 9 | } -------------------------------------------------------------------------------- /locust/data/prompt_post.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EditPythonCodePrompt", 3 | "description": "Prompt to edit python code according to Clean Code principles.", 4 | "language": { 5 | "iso639_3": "eng" 6 | }, 7 | "tags": [ 8 | "Zero-shot Prompting" 9 | ], 10 | "messages": [ 11 | { 12 | "role": "system", 13 | "content": "You are an expert Python programmer that values Clean Code and simplicity." 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /locust/data/prompt_template_post.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EditPythonCodePrompt", 3 | "description": "Prompt to edit python code according to Clean Code principles.", 4 | "language": { 5 | "iso639_3": "eng" 6 | }, 7 | "tags": [ 8 | "Zero-shot Prompting" 9 | ], 10 | "arguments": [ 11 | { 12 | "name": "python_code", 13 | "type": "str" 14 | } 15 | ], 16 | "messages": [ 17 | { 18 | "role": "user", 19 | "content": "Please edit this python code to follow Clean Code best pratices: \"{{ python_code }}\"" 20 | } 21 | ] 22 | } -------------------------------------------------------------------------------- /locust/locust_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import cache 3 | from pathlib import Path 4 | 5 | from locust import HttpUser, between, task 6 | 7 | path_mapping = { 8 | "chat": Path("./locust/data/chat_post.json"), 9 | "completion_parameter": Path("./locust/data/completion_parameters_post.json"), 10 | "message": Path("./locust/data/message_post.json"), 11 | "prompt": Path("./locust/data/prompt_post.json"), 12 | "prompt_template": Path("./locust/data/prompt_template_post.json"), 13 | } 14 | 15 | 16 | @cache 17 | def get_payload(path: Path) -> str: 18 | with open(path) as file: 19 | return file.read() 20 | 21 | 22 | class ApiUser(HttpUser): 23 | id_set = set() 24 | index_set = {0} 25 | wait_time = between(1, 5) 26 | sorting_fields = ["updated_at", "created_at", "id"] 27 | 28 | @task 29 | def get_chat(self): 30 | chat_id = self.get_id() 31 | if not chat_id: 32 | return 33 | self.client.get( 34 | f"/v1/chat/{chat_id}", 35 | name="/v1/chat/{id}", 36 | ) 37 | 38 | @task 39 | def post_chat(self): 40 | response = self.client.post( 41 | "/v1/chat", 42 | data=get_payload( 43 | path_mapping.get("chat"), 44 | ), 45 | name="/v1/chat", 46 | ) 47 | if response.status_code == 201: 48 | self.id_set.add(response.json().get("id")) 49 | 50 | @task 51 | def patch_chat_message(self): 52 | chat_id = self.get_id() 53 | if not chat_id: 54 | return 55 | self.client.patch( 56 | f"/v1/chat/{chat_id}/messages/{self.get_index()}", 57 | data=get_payload(path_mapping.get("message")), 58 | name="/v1/chat/{id}/messages/{index}", 59 | ) 60 | 61 | @task 62 | def delete_chat(self): 63 | chat_id = self.get_id() 64 | if not chat_id: 65 | return 66 | response = self.client.delete( 67 | f"/v1/chat/{chat_id}", 68 | name="/v1/chat/{id}", 69 | ) 70 | if response.status_code == 204: 71 | self.id_set.remove(chat_id) 72 | 73 | @task 74 | def put_chat(self): 75 | chat_id = self.get_id() 76 | if not chat_id: 77 | return 78 | self.client.put( 79 | f"/v1/chat/{chat_id}", 80 | data=get_payload( 81 | path_mapping.get("chat"), 82 | ), 83 | name="/v1/chat/{id}", 84 | ) 85 | 86 | @task 87 | def search_chat(self): 88 | self.client.get( 89 | "/v1/chat", 90 | params={ 91 | "offset": random.randint(1, 20), 92 | "limit": random.randint(1, 20), 93 | "sorting_field": random.sample(self.sorting_fields, 1), 94 | "ascending": random.sample([False, True], 1), 95 | }, 96 | name="/v1/chat", 97 | ) 98 | 99 | def get_id(self) -> int: 100 | try: 101 | return random.sample(list(self.id_set), k=1)[0] 102 | except ValueError: 103 | return None 104 | 105 | def get_index(self) -> int: 106 | return random.sample(list(self.index_set), k=1)[0] 107 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "restllm" 3 | version = "0.1.0" 4 | description = "This repository is a prototype for a simple REST API to interact and build Chat interfaces or other prompt based interfaces for llms." 5 | authors = ["Jonas Høgh Kyhse-Andersen "] 6 | packages = [{ include = "restllm", from = "src" }] 7 | 8 | [tool.poetry.group.dev.dependencies] 9 | pytest = "^7.4.2" 10 | locust = "^2.17.0" 11 | 12 | [build-system] 13 | requires = ["poetry-core>=1.0.0"] 14 | build-backend = "poetry.core.masonry.api" 15 | 16 | [tool.poetry.dependencies] 17 | python = "^3.11" 18 | redis = {extras = ["hiredis"], version = "^5.0.1"} 19 | fastapi = "^0.103.1" 20 | pydantic-settings = "^2.0.3" 21 | pydantic = { extras = ["email"], version = "^2.4.1" } 22 | uvicorn = "^0.23.2" 23 | cryptography = "^41.0.4" 24 | jinja2 = "^3.1.2" 25 | litellm = "^1.7.12" 26 | iso639-lang = "^2.1.0" 27 | passlib = {extras = ["bcrypt"], version = "^1.7.4"} 28 | python-jose = {extras = ["cryptography"], version = "^3.3.0"} 29 | python-multipart = "^0.0.6" 30 | argon2-cffi = "^23.1.0" 31 | aiosmtplib = "^3.0.1" 32 | 33 | [tool.poetry.scripts] 34 | restllm = 'restllm.cli:cli' 35 | -------------------------------------------------------------------------------- /src/restllm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIMunchII/restllm/42d0385281f1d944874e8be9930a3bf8f071a976/src/restllm/__init__.py -------------------------------------------------------------------------------- /src/restllm/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import click 5 | import redis 6 | import redis.exceptions 7 | 8 | from .redis.index import ( 9 | create_index_on_meta_model, 10 | get_class_from_class_name, 11 | get_meta_model_schema, 12 | string_to_class_mapping, 13 | ) 14 | from .settings import settings 15 | 16 | 17 | class IndexAlreadyExists(redis.exceptions.ResponseError): 18 | pass 19 | 20 | 21 | @click.group() 22 | def cli(): 23 | pass 24 | 25 | 26 | @cli.command(name="create_index") 27 | @click.option( 28 | "--redis-url", 29 | default=str(settings.redis_dsn), 30 | help="Redis URL", 31 | ) 32 | @click.option( 33 | "--class-name", 34 | required=True, 35 | help="Class name for which index needs to be created", 36 | ) 37 | def create_index(redis_url: str, class_name: str): 38 | redis_client: redis.Redis = redis.from_url(redis_url) 39 | meta_model_schema = get_meta_model_schema() 40 | 41 | _class = get_class_from_class_name(class_name) 42 | 43 | create_index_on_meta_model(redis_client, meta_model_schema, _class) 44 | 45 | 46 | @cli.command(name="migrate_all_index") 47 | @click.option( 48 | "--redis-url", 49 | default=str(settings.redis_dsn), 50 | help="Redis URL", 51 | ) 52 | def migrate_all_index(redis_url: str): 53 | redis_client: redis.Redis = redis.from_url(redis_url) 54 | meta_model_schema = get_meta_model_schema() 55 | 56 | for _class in string_to_class_mapping.values(): 57 | try: 58 | create_index_on_meta_model(redis_client, meta_model_schema, _class) 59 | except redis.exceptions.ResponseError as error: 60 | if str(error) == "Index already exists": 61 | click.echo( 62 | f"Index for model class '{_class.__name__}' already exists. Skipping index creation!", 63 | color=True, 64 | ) 65 | else: 66 | raise error 67 | 68 | 69 | @cli.command(name="delete_data") 70 | @click.option( 71 | "--redis-url", 72 | default=str(settings.redis_dsn), 73 | help="Redis URL", 74 | ) 75 | def delete_data(redis_url: str): 76 | response = input( 77 | f"Are you sure you want to DELETE ALL data in redis database '{redis_url}'?: Choices: (yes/y, no/n)\nAnswer: " 78 | ) 79 | if response in {"Yes", "yes", "y", "Y"}: 80 | redis_client: redis.Redis = redis.from_url(redis_url) 81 | redis_client.flushall() 82 | click.echo(f"Deleted all data") 83 | else: 84 | click.echo(f"Aborted deletion") 85 | 86 | 87 | @cli.command(name="connection_count") 88 | @click.option( 89 | "--redis-url", 90 | default=str(settings.redis_dsn), 91 | help="Redis URL", 92 | ) 93 | def connection_count(redis_url: str): 94 | redis_client: redis.Redis = redis.from_url(redis_url) 95 | while True: 96 | connection_count = len(redis_client.client_list()) 97 | sys.stdout.write(f"\rCurrent connection count: {connection_count - 1}") 98 | sys.stdout.flush() 99 | time.sleep(1) 100 | 101 | 102 | if __name__ == "__main__": 103 | cli() 104 | -------------------------------------------------------------------------------- /src/restllm/cryptography/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIMunchII/restllm/42d0385281f1d944874e8be9930a3bf8f071a976/src/restllm/cryptography/__init__.py -------------------------------------------------------------------------------- /src/restllm/cryptography/authentication.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta, timezone 2 | 3 | from fastapi.security import OAuth2PasswordBearer 4 | from jose import jwt 5 | from passlib.context import CryptContext 6 | from ..settings import settings 7 | 8 | pwd_context = CryptContext( 9 | schemes=[settings.password_hash_algorithm], 10 | deprecated="auto", 11 | ) 12 | 13 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="v1/authentication/token") 14 | 15 | 16 | def verify_password(plain_password: str, hashed_password: str): 17 | return pwd_context.verify(plain_password, hashed_password) 18 | 19 | 20 | def get_password_hash(password: str): 21 | return pwd_context.hash(password) 22 | 23 | 24 | def create_token(data: dict, expires_delta: timedelta): 25 | to_encode = data.copy() 26 | to_encode.update( 27 | { 28 | "exp": datetime.now(tz=timezone.utc) + expires_delta, 29 | } 30 | ) 31 | encoded_jwt = jwt.encode( 32 | to_encode, 33 | settings.secret_key, 34 | algorithm=settings.jwt_algorithm, 35 | ) 36 | return encoded_jwt 37 | 38 | 39 | async def create_tokens(data: dict) -> tuple[str, str]: 40 | access_token = create_token( 41 | data=data, 42 | expires_delta=timedelta(minutes=settings.access_token_expire_minutes), 43 | ) 44 | refresh_token = create_token( 45 | data=data, 46 | expires_delta=timedelta(minutes=settings.refresh_token_expire_minutes), 47 | ) 48 | 49 | return access_token, refresh_token 50 | -------------------------------------------------------------------------------- /src/restllm/cryptography/keys.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from cryptography.fernet import Fernet 3 | 4 | 5 | async def get_fernet( 6 | redis_client: redis.Redis, 7 | key_name: str = "fernet_crypto_key", 8 | expiration_time: int = 3600, 9 | ) -> Fernet: 10 | fernet_key = await redis_client.get(key_name) 11 | 12 | if fernet_key is None: 13 | fernet_key = Fernet.generate_key().decode() 14 | await redis_client.setex(key_name, expiration_time, fernet_key) 15 | else: 16 | fernet_key = fernet_key.decode() 17 | await redis_client.expire(key_name, expiration_time) 18 | 19 | return Fernet(fernet_key) 20 | -------------------------------------------------------------------------------- /src/restllm/cryptography/secure_url.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import hmac 3 | import json 4 | 5 | from cryptography.fernet import Fernet 6 | 7 | from ..settings import settings 8 | 9 | 10 | def encrypt_payload(fernet: Fernet, payload: dict) -> bytes: 11 | payload_json = json.dumps(payload) 12 | return fernet.encrypt(payload_json.encode()) 13 | 14 | 15 | def decrypt_payload(fernet: Fernet, encrypted_payload: bytes | str) -> dict: 16 | decrypted_payload = fernet.decrypt(encrypted_payload) 17 | return json.loads(decrypted_payload) 18 | 19 | 20 | def sign_data(data: str) -> str: 21 | return hmac.new(settings.secret_key.encode(), data, hashlib.sha256).hexdigest() 22 | 23 | 24 | def signature_is_valid(encrypted_payload: str, received_signature: str) -> bool: 25 | return ( 26 | hmac.new( 27 | settings.secret_key.encode(), encrypted_payload, hashlib.sha256 28 | ).hexdigest() 29 | == received_signature 30 | ) 31 | 32 | 33 | def generate_secure_url(fernet: Fernet, payload: dict) -> dict[str, str]: 34 | encrypted_payload = encrypt_payload(fernet, payload) 35 | signature = sign_data(encrypted_payload) 36 | return {"payload": encrypted_payload.decode(), "signature": signature} 37 | 38 | 39 | def payload_is_valid(encrypted_payload: str, received_signature: str): 40 | return signature_is_valid(encrypted_payload.encode(), received_signature) 41 | -------------------------------------------------------------------------------- /src/restllm/dependencies.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | 3 | from fastapi import Depends, Form 4 | from jose import JWTError, ExpiredSignatureError, jwt 5 | from typing import Type 6 | from pydantic import SecretStr, EmailStr 7 | 8 | from .models.base import User 9 | from .settings import settings 10 | from .types import paths 11 | from .models.share import ShareableClass 12 | from .redis.keys import get_class_name 13 | from .redis.commands import get_instance 14 | from .cryptography.authentication import verify_password 15 | from .models.authentication import ( 16 | UserWithPasswordHash, 17 | UserSignUp, 18 | ChangePassword, 19 | get_user_email_key, 20 | ) 21 | from .exceptions import ( 22 | InvalidCredentialsException, 23 | TokenExpiredException, 24 | ) 25 | from .cryptography.authentication import oauth2_scheme 26 | from .settings import settings 27 | 28 | connection_pool: redis.ConnectionPool = None 29 | 30 | 31 | async def startup(): 32 | global connection_pool 33 | connection_pool = redis.ConnectionPool.from_url(str(settings.redis_dsn)) 34 | 35 | 36 | async def shutdown(): 37 | await connection_pool.aclose() 38 | 39 | 40 | async def get_redis_client(): 41 | redis_client = redis.Redis.from_pool(connection_pool) 42 | yield redis_client 43 | 44 | 45 | async def create_instance_id(redis_client: redis.Redis, class_name: str): 46 | return await redis_client.incr(f"sequence:{class_name}") 47 | 48 | 49 | async def get_user(token: str = Depends(oauth2_scheme)) -> User: 50 | payload = await decode_user_token(token) 51 | return User(**payload) 52 | 53 | 54 | async def decode_user_token(token: str) -> dict: 55 | try: 56 | payload = jwt.decode( 57 | token, 58 | settings.secret_key, 59 | algorithms=[settings.jwt_algorithm], 60 | ) 61 | payload.pop("exp") 62 | if not payload: 63 | raise InvalidCredentialsException 64 | return payload 65 | except ExpiredSignatureError: 66 | raise TokenExpiredException 67 | except JWTError: 68 | raise InvalidCredentialsException 69 | 70 | 71 | async def get_user_with_password_hash( 72 | email: EmailStr, 73 | redis_client: redis.Redis = Depends(get_redis_client), 74 | ) -> dict | None: 75 | user_id: bytes = await redis_client.get(get_user_email_key(email)) 76 | if not user_id: 77 | raise InvalidCredentialsException 78 | user_instance = await get_instance( 79 | redis_client, 80 | key=f"{get_class_name(User)}:{user_id.decode()}", 81 | ) 82 | return user_instance 83 | 84 | 85 | async def get_shareable_key( 86 | object: ShareableClass, 87 | id: int = paths.id_path, 88 | user: User = Depends(get_user), 89 | ): 90 | return f"{object.value}:{user.id}:{id}" 91 | 92 | 93 | async def authenticate_user( 94 | email: str, 95 | password: str, 96 | redis_client: redis.Redis, 97 | ) -> UserWithPasswordHash | bool: 98 | user = await get_user_with_password_hash( 99 | email, 100 | redis_client, 101 | ) 102 | if not user or not verify_password(password, user.get("hashed_password")): 103 | raise InvalidCredentialsException 104 | return UserWithPasswordHash(**user) 105 | 106 | 107 | async def signup_form( 108 | first_name: str = Form(...), 109 | last_name: str = Form(...), 110 | email: EmailStr = Form(...), 111 | password: SecretStr = Form(...), 112 | confirm_password: SecretStr = Form(...), 113 | ) -> UserSignUp: 114 | return UserSignUp( 115 | first_name=first_name, 116 | last_name=last_name, 117 | email=email, 118 | password=password, 119 | confirm_password=confirm_password, 120 | ) 121 | 122 | 123 | async def change_password_form( 124 | old_password: SecretStr = Form(...), 125 | new_password: SecretStr = Form(...), 126 | confirm_new_password: SecretStr = Form(...), 127 | ) -> ChangePassword: 128 | return ChangePassword( 129 | old_password=old_password, 130 | new_password=new_password, 131 | confirm_new_password=confirm_new_password, 132 | ) 133 | 134 | 135 | def get_key_with_id(class_name: str, owner: User, instance_id: int): 136 | return f"{class_name}:{owner.id}:{instance_id}" 137 | 138 | 139 | def get_single_key(class_name: str, owner: User): 140 | return f"{class_name}:{owner.id}" 141 | 142 | 143 | def build_get_new_instance_key(cls: Type): 144 | async def get_new_instance_key( 145 | user: User = Depends(get_user), 146 | redis_client: redis.Redis = Depends(get_redis_client), 147 | ) -> tuple[str, int]: 148 | class_name = get_class_name(cls) 149 | instance_id = await create_instance_id(redis_client, class_name) 150 | return get_key_with_id(class_name, user, instance_id), instance_id 151 | 152 | return get_new_instance_key 153 | 154 | 155 | def build_get_instance_key(cls: Type): 156 | async def get_instance_key( 157 | id: int = paths.id_path, 158 | user: User = Depends(get_user), 159 | ): 160 | class_name = get_class_name(cls) 161 | return get_key_with_id(class_name, user, id) 162 | 163 | return get_instance_key 164 | 165 | 166 | def build_get_new_class_user_key(cls: Type): 167 | async def get_class_user_key( 168 | user: User = Depends(get_user), 169 | redis_client: redis.Redis = Depends(get_redis_client), 170 | ) -> tuple[str, int]: 171 | class_name = get_class_name(cls) 172 | instance_id = await create_instance_id(redis_client, class_name) 173 | return get_single_key(class_name, user), instance_id 174 | 175 | return get_class_user_key 176 | 177 | 178 | def build_get_class_user_key(cls: Type): 179 | async def get_class_user_key( 180 | user: User = Depends(get_user), 181 | ): 182 | class_name = get_class_name(cls) 183 | return get_single_key(class_name, user) 184 | 185 | return get_class_user_key 186 | -------------------------------------------------------------------------------- /src/restllm/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | from .crud import ( 2 | build_get_instance_endpoint, 3 | build_create_instance_endpoint, 4 | build_update_instance_endpoint, 5 | build_delete_instance_endpoint, 6 | add_crud_route, 7 | ) 8 | -------------------------------------------------------------------------------- /src/restllm/endpoints/authentication.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from ..models.authentication import UserWithPasswordHash 3 | from redis.commands.json.path import Path 4 | 5 | 6 | async def create_user_instance( 7 | redis_client: redis.Redis, 8 | instance: UserWithPasswordHash, 9 | email_key: str, 10 | ) -> tuple[bool, dict]: 11 | key = f"User:{instance.id}" 12 | 13 | async with redis_client.pipeline() as pipeline: 14 | pipeline.multi() 15 | ( 16 | pipeline.json().set( 17 | key, 18 | Path.root_path(), 19 | instance.model_dump(mode="json"), 20 | nx=True, 21 | ), 22 | pipeline.set( 23 | email_key, 24 | instance.id, 25 | nx=True, 26 | ), 27 | ) 28 | result = await pipeline.execute() 29 | return all(result), instance 30 | -------------------------------------------------------------------------------- /src/restllm/endpoints/completion.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | 3 | from litellm import acompletion 4 | 5 | from ..models import ChatMessage, ChatWithMeta 6 | from ..redis.commands import append_chat_message 7 | 8 | 9 | async def chat_acompletion_call( 10 | chat_with_meta: ChatWithMeta, 11 | redis_client: redis.Redis, 12 | key: str, 13 | ): 14 | kwargs = chat_with_meta.object.dump_json_for_completion() 15 | response = acompletion( 16 | **kwargs, 17 | stream=True, 18 | ) 19 | chat_message = ChatMessage(role="assistant", content="") 20 | async for chunk in await response: 21 | next_token = chunk["choices"][0]["delta"].get("content") 22 | if not next_token: 23 | continue 24 | chat_message.content += next_token 25 | yield next_token 26 | await append_chat_message( 27 | redis_client=redis_client, 28 | instance=chat_message, 29 | key=key, 30 | ) 31 | -------------------------------------------------------------------------------- /src/restllm/endpoints/crud.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | import redis.asyncio as redis 4 | import redis.exceptions 5 | from fastapi import APIRouter, Depends, HTTPException, Response, status 6 | from pydantic import BaseModel 7 | 8 | from ..dependencies import ( 9 | get_redis_client, 10 | get_user, 11 | build_get_new_instance_key, 12 | build_get_instance_key, 13 | ) 14 | from ..exceptions import IndexNotImplemented, ObjectNotFoundException 15 | from ..models import User 16 | from ..redis.commands import ( 17 | create_instance, 18 | delete_instance, 19 | get_instance, 20 | update_instance, 21 | ) 22 | from ..redis.keys import get_class_name 23 | from ..redis.search import SortingField, list_instances 24 | from ..types import queries 25 | 26 | 27 | def build_get_instance_endpoint(cls: Type): 28 | async def get_instance_endpoint( 29 | redis_client: redis.Redis = Depends(get_redis_client), 30 | key: str = Depends(build_get_instance_key(cls)), 31 | ): 32 | instance = await get_instance( 33 | redis_client=redis_client, 34 | key=key, 35 | ) 36 | if not instance: 37 | raise ObjectNotFoundException(cls) 38 | return instance 39 | 40 | return get_instance_endpoint 41 | 42 | 43 | def build_create_instance_endpoint(cls: Type): 44 | async def create_instance_endpoint( 45 | instance: cls, 46 | redis_client: redis.Redis = Depends(get_redis_client), 47 | new_key: tuple[str, int] = Depends(build_get_new_instance_key(cls)), 48 | user: User = Depends(get_user), 49 | ): 50 | created, instance = await create_instance( 51 | redis_client=redis_client, 52 | owner=user, 53 | instance=instance, 54 | key=new_key[0], 55 | instance_id=new_key[1], 56 | ) 57 | if not created: 58 | raise HTTPException( 59 | status_code=422, 60 | detail=f"Ressource of type '{get_class_name(cls)}' could not be created.", 61 | ) 62 | return instance 63 | 64 | return create_instance_endpoint 65 | 66 | 67 | def build_update_instance_endpoint(cls: Type): 68 | async def update_instance_endpoint( 69 | instance: cls, 70 | redis_client: redis.Redis = Depends(get_redis_client), 71 | key: str = Depends(build_get_instance_key(cls)), 72 | ): 73 | updated, updated_at, instance = await update_instance( 74 | redis_client=redis_client, 75 | instance=instance, 76 | key=key, 77 | ) 78 | if not updated and not updated_at: 79 | raise ObjectNotFoundException(cls) 80 | return instance 81 | 82 | return update_instance_endpoint 83 | 84 | 85 | def build_delete_instance_endpoint(cls: Type): 86 | async def delete_instance_endpoint( 87 | redis_client: redis.Redis = Depends(get_redis_client), 88 | key: str = Depends(build_get_instance_key(cls)), 89 | ): 90 | deleted = await delete_instance( 91 | redis_client=redis_client, 92 | key=key, 93 | ) 94 | if deleted: 95 | return Response(status_code=status.HTTP_204_NO_CONTENT) 96 | else: 97 | raise ObjectNotFoundException(cls) 98 | 99 | return delete_instance_endpoint 100 | 101 | 102 | def build_list_instances_endpoint(cls: Type): 103 | async def list_instances_endpoint( 104 | offset: Optional[int] = queries.offset_query, 105 | limit: Optional[int] = queries.limit_query, 106 | sorting_field: Optional[SortingField] = queries.sorting_field_query, 107 | ascending: Optional[bool] = queries.ascending_query, 108 | redis_client: redis.Redis = Depends(get_redis_client), 109 | user: User = Depends(get_user), 110 | ): 111 | try: 112 | return await list_instances( 113 | redis_client, 114 | class_name=get_class_name(cls), 115 | owner=user, 116 | offset=offset, 117 | limit=limit, 118 | sorting_field=sorting_field, 119 | ascending=ascending, 120 | ) 121 | except redis.exceptions.ResponseError as exec: 122 | if str(exec).endswith("no such index"): 123 | raise IndexNotImplemented(cls) from exec 124 | else: 125 | raise exec 126 | 127 | return list_instances_endpoint 128 | 129 | 130 | def add_crud_route( 131 | router: APIRouter, 132 | instance_model: BaseModel, 133 | response_model: BaseModel, 134 | prefix: str = "", 135 | ): 136 | router.add_api_route( 137 | prefix + "/{id}", 138 | build_get_instance_endpoint(instance_model), 139 | response_model=response_model, 140 | methods=["GET"], 141 | ) 142 | router.add_api_route( 143 | prefix, 144 | build_create_instance_endpoint(instance_model), 145 | response_model=response_model, 146 | methods=["POST"], 147 | status_code=status.HTTP_201_CREATED, 148 | ) 149 | router.add_api_route( 150 | prefix + "/{id}", 151 | build_update_instance_endpoint(instance_model), 152 | response_model=response_model, 153 | methods=["PUT"], 154 | ) 155 | router.add_api_route( 156 | prefix + "/{id}", 157 | build_delete_instance_endpoint(instance_model), 158 | methods=["DELETE"], 159 | status_code=status.HTTP_204_NO_CONTENT, 160 | ) 161 | router.add_api_route( 162 | prefix, 163 | build_list_instances_endpoint(instance_model), 164 | response_model=list[response_model], 165 | methods=["GET"], 166 | ) 167 | -------------------------------------------------------------------------------- /src/restllm/exceptions.py: -------------------------------------------------------------------------------- 1 | import litellm 2 | 3 | from typing import Type 4 | 5 | from fastapi import status, HTTPException, Request 6 | from fastapi.responses import JSONResponse 7 | from pydantic import ValidationError 8 | from .redis.keys import get_class_name 9 | 10 | 11 | class ObjectNotFoundException(HTTPException): 12 | def __init__(self, cls: Type): 13 | super().__init__( 14 | status_code=status.HTTP_404_NOT_FOUND, 15 | detail=f"Ressource of type '{get_class_name(cls)}' not found", 16 | ) 17 | 18 | 19 | class IndexNotImplemented(HTTPException): 20 | def __init__(self, cls: Type): 21 | super().__init__( 22 | status_code=status.HTTP_501_NOT_IMPLEMENTED, 23 | detail=f"Search index not implemented for type '{get_class_name(cls)}'", 24 | ) 25 | 26 | 27 | class ObjectAlreadyExistsException(HTTPException): 28 | def __init__(self, cls: Type): 29 | super().__init__( 30 | status_code=status.HTTP_409_CONFLICT, 31 | detail=f"Ressource of type '{get_class_name(cls)}' allready exists", 32 | ) 33 | 34 | 35 | class InvalidCredentialsException(HTTPException): 36 | def __init__(self): 37 | super().__init__( 38 | status_code=status.HTTP_401_UNAUTHORIZED, 39 | detail="Could not validate credentials", 40 | headers={"WWW-Authenticate": "Bearer"}, 41 | ) 42 | 43 | 44 | class TokenExpiredException(HTTPException): 45 | def __init__(self): 46 | super().__init__( 47 | status_code=status.HTTP_400_BAD_REQUEST, 48 | detail="Token expired", 49 | headers={"WWW-Authenticate": "Bearer"}, 50 | ) 51 | 52 | 53 | async def validation_exception_handler(request: Request, exc: ValidationError): 54 | errors = exc.errors() 55 | formatted_errors = [ 56 | { 57 | "loc": error["loc"], 58 | "msg": error["msg"], 59 | "type": error["type"], 60 | } 61 | for error in errors 62 | ] 63 | return JSONResponse( 64 | status_code=422, 65 | content={"detail": formatted_errors}, 66 | ) 67 | 68 | async def litellm_badrequest_handler( 69 | request: Request, exception: litellm.exceptions.BadRequestError 70 | ): 71 | return JSONResponse( 72 | status_code=exception.status_code, 73 | content={ 74 | "detail": { 75 | "message": exception.message, 76 | "error_type": type(exception).__name__, 77 | "model": exception.model, 78 | "llm_provider": exception.llm_provider, 79 | } 80 | }, 81 | ) -------------------------------------------------------------------------------- /src/restllm/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from .dependencies import shutdown, startup 3 | from .middleware import AccessLogMiddleware 4 | from .routers import ( 5 | chats, 6 | events, 7 | prompts, 8 | share, 9 | completion_parameters, 10 | users, 11 | functions, 12 | authentication, 13 | ) 14 | from .exceptions import validation_exception_handler, litellm_badrequest_handler 15 | from pydantic import ValidationError 16 | import litellm 17 | 18 | app = FastAPI( 19 | title="REST LLM", 20 | description="REST API for interacting with Large Language Models. Runs on RedisStack (https://redis.io/docs/about/about-stack/) and LiteLLM (https://litellm.ai). The REST API is a work in progress", 21 | ) 22 | 23 | # Event handlers 24 | app.add_event_handler("startup", startup) 25 | app.add_event_handler("shutdown", shutdown) 26 | 27 | # Exception handlers 28 | app.add_exception_handler(ValidationError, validation_exception_handler) 29 | app.add_exception_handler(litellm.exceptions.BadRequestError, litellm_badrequest_handler) 30 | 31 | # V1 of API 32 | app.include_router(chats.router, prefix="/v1") 33 | app.include_router(events.router, prefix="/v1") 34 | app.include_router(share.router, prefix="/v1") 35 | app.include_router(prompts.router, prefix="/v1") 36 | app.include_router(completion_parameters.router, prefix="/v1") 37 | app.include_router(users.router, prefix="/v1") 38 | app.include_router(functions.router, prefix="/v1") 39 | app.include_router(authentication.router, prefix="/v1") 40 | 41 | app.add_middleware(AccessLogMiddleware) -------------------------------------------------------------------------------- /src/restllm/middleware.py: -------------------------------------------------------------------------------- 1 | from fastapi import Request 2 | from starlette.middleware.base import BaseHTTPMiddleware 3 | 4 | from .dependencies import get_redis_client 5 | import datetime 6 | 7 | 8 | class AccessLogMiddleware(BaseHTTPMiddleware): 9 | async def dispatch( 10 | self, 11 | request: Request, 12 | call_next, 13 | ): 14 | if not request.url.path.startswith(("/docs", "/openapi.json", "/redoc")): 15 | async_gen = get_redis_client() 16 | redis_client = await async_gen.__anext__() 17 | await redis_client.xadd( 18 | "access_log_stream", 19 | { 20 | "path": request.url.path, 21 | "timestamp": datetime.datetime.now( 22 | tz=datetime.timezone.utc 23 | ).timestamp(), 24 | }, 25 | ) 26 | return await call_next(request) 27 | return await call_next(request) 28 | -------------------------------------------------------------------------------- /src/restllm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ..models.base import * 2 | from ..models.chat import * 3 | from ..models.events import * 4 | from ..models.prompts import * 5 | from ..models.completion import CompletionParameters, CompletionParametersWithMeta 6 | -------------------------------------------------------------------------------- /src/restllm/models/authentication.py: -------------------------------------------------------------------------------- 1 | from pydantic import ( 2 | BaseModel, 3 | EmailStr, 4 | SecretStr, 5 | computed_field, 6 | model_validator, 7 | ValidationError, 8 | ) 9 | from ..models.base import User 10 | 11 | from ..cryptography.authentication import get_password_hash 12 | 13 | 14 | def get_user_email_key(email: str) -> str: 15 | return f"UserEmail:{email}" 16 | 17 | 18 | class UserWithPasswordHash(User): 19 | hashed_password: str 20 | 21 | def get_user_data(self) -> dict: 22 | return self.model_dump(exclude=["hashed_password"]) 23 | 24 | 25 | class Token(BaseModel): 26 | access_token: str 27 | token_type: str 28 | refresh_token: str 29 | 30 | 31 | class UserSignUp(BaseModel): 32 | first_name: str 33 | last_name: str 34 | email: EmailStr 35 | password: SecretStr 36 | confirm_password: SecretStr 37 | 38 | def create_user(self, id: str) -> UserWithPasswordHash: 39 | new_user = { 40 | "id": id, 41 | "first_name": self.first_name, 42 | "last_name": self.last_name, 43 | "email": self.email, 44 | "hashed_password": get_password_hash(self.password.get_secret_value()), 45 | } 46 | 47 | return UserWithPasswordHash(**new_user) 48 | 49 | @model_validator(mode="after") 50 | def check_passwords_match(self) -> "ChangePassword": 51 | pw1 = self.password 52 | pw2 = self.confirm_password 53 | if pw1 is not None and pw2 is not None and pw1 != pw2: 54 | raise ValueError("Both fields for new password must match") 55 | return self 56 | 57 | @computed_field(return_type=str) 58 | @property 59 | def email_key(self) -> str: 60 | return get_user_email_key(self.email) 61 | 62 | 63 | class ChangePassword(BaseModel): 64 | old_password: SecretStr 65 | new_password: SecretStr 66 | confirm_new_password: SecretStr 67 | 68 | @model_validator(mode="after") 69 | def check_passwords_match(self) -> "ChangePassword": 70 | pw1 = self.new_password 71 | pw2 = self.confirm_new_password 72 | if pw1 is not None and pw2 is not None and pw1 != pw2: 73 | raise ValueError("Both fields for new password must match") 74 | return self 75 | 76 | def get_new_password_hash(self) -> str: 77 | return get_password_hash(self.new_password.get_secret_value()) 78 | -------------------------------------------------------------------------------- /src/restllm/models/base.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from typing import Any, Optional 4 | 5 | from pydantic import BaseModel, Field, EmailStr, computed_field 6 | 7 | 8 | class CustomInstructions(BaseModel): 9 | response_instruction: Optional[str] = Field( 10 | default="", description="Custom instructions how the LLM should respond" 11 | ) 12 | preference_instruction: Optional[str] = Field( 13 | default="", description="Custom instructions for user preferences" 14 | ) 15 | enabled: bool = Field( 16 | default=False, description="Whether custom instructions are enabled for chats" 17 | ) 18 | 19 | 20 | class UserProfile(BaseModel): 21 | custom_instructions: CustomInstructions = Field( 22 | default_factory=CustomInstructions, 23 | description="Custom instructions to use in chat.", 24 | ) 25 | 26 | 27 | class User(BaseModel): 28 | id: int = Field(gt=0, examples=[1, 2, 3], frozen=True) 29 | first_name: str = Field( 30 | description="First name of the user", examples=["Alice"], frozen=True 31 | ) 32 | last_name: str = Field( 33 | description="Last name of the user", examples=["Wonderer"], frozen=True 34 | ) 35 | email: EmailStr = Field(description="Valid email for the user", frozen=True) 36 | verified: bool = Field( 37 | description="Wether a users email is verified or not", 38 | frozen=True, 39 | default=False, 40 | ) 41 | 42 | @computed_field(return_type=str) 43 | @property 44 | def full_name(self) -> str: 45 | return f"{self.first_name} {self.last_name}" 46 | 47 | def get_key(self) -> str: 48 | return f"{self.__class__.__name__}:{self.id}" 49 | 50 | 51 | class Datetime(BaseModel): 52 | datetime_iso: datetime.datetime = Field( 53 | default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc), 54 | description="Datetime string in iso format", 55 | ) 56 | timezone: str = Field( 57 | default_factory=lambda: str(datetime.timezone.utc), 58 | frozen=True, 59 | examples=["UTC"], 60 | ) 61 | 62 | @computed_field(return_type=float) 63 | @property 64 | def timestamp(self) -> float: 65 | return self.datetime_iso.timestamp() 66 | 67 | 68 | class MetaModel(BaseModel): 69 | id: int = Field(gt=0, examples=[1, 2, 3]) 70 | class_name: str 71 | owner: int 72 | object: Any 73 | created_at: Datetime = Field(default_factory=Datetime) 74 | updated_at: Datetime = Field(default_factory=Datetime) 75 | 76 | 77 | class UserProfileWithMeta(MetaModel): 78 | object: UserProfile 79 | -------------------------------------------------------------------------------- /src/restllm/models/chat.py: -------------------------------------------------------------------------------- 1 | from enum import auto, UNIQUE, verify, StrEnum 2 | from typing import Optional 3 | from pydantic import BaseModel, Field 4 | 5 | from .base import MetaModel 6 | from .functions import FunctionCall 7 | from .completion import CompletionParameters 8 | from ..models.functions import get_function_schemas 9 | 10 | @verify(UNIQUE) 11 | class RoleTypes(StrEnum): 12 | USER = auto() 13 | SYSTEM = auto() 14 | ASSISTANT = auto() 15 | FUNCTION = auto() 16 | 17 | 18 | @verify(UNIQUE) 19 | class ModelTypes(StrEnum): 20 | GPT3_TURBO = "gpt-3.5-turbo" 21 | GPT3_TURBO_16K = "gpt-3.5-turbo-16k" 22 | GPT4 = "gpt-4" 23 | GPT4_32K = "gpt-4-32k" 24 | 25 | 26 | class ChatMessage(BaseModel): 27 | role: RoleTypes = Field( 28 | description="The role of the message's author. Roles can be: system, user, assistant, or function.", 29 | examples=[RoleTypes.USER, RoleTypes.SYSTEM], 30 | ) 31 | content: str = Field( 32 | description="The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.", 33 | examples=["Can you write a function in Python that adds two numbers together?"], 34 | ) 35 | name: Optional[str] = Field( 36 | default=None, 37 | max_length=64, 38 | pattern="^[a-zA-Z0-9_]*$", 39 | description="The name of the author of the message. It is required if the role is 'function'. The name should match the name of the function represented in the content. It can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters.", 40 | examples=["Alice", "AI Assistant"], 41 | ) 42 | function_call: Optional[FunctionCall] = Field( 43 | default=None, 44 | description="The name and arguments of a function that should be called, as generated by the model.", 45 | examples=[None] 46 | ) 47 | 48 | 49 | class Chat(BaseModel): 50 | completion_parameters: CompletionParameters = Field( 51 | description="Set of parameters for chat completion. Check litellm docs for more detail: https://docs.litellm.ai/docs/completion/input" 52 | ) 53 | messages: list[ChatMessage] = Field( 54 | description="A list of messages comprising the conversation so far." 55 | ) 56 | 57 | def last_message_is_user(self) -> bool: 58 | return self.messages[-1].role == RoleTypes.USER 59 | 60 | def dump_json_for_completion(self) -> dict: 61 | completion_kwargs = self.model_dump(mode="json", exclude_none=True) 62 | completion_parameters: dict = completion_kwargs.pop("completion_parameters") 63 | completion_kwargs.update(completion_parameters) 64 | 65 | function_names = completion_parameters.get("functions") 66 | if function_names: 67 | function_schemas = get_function_schemas(function_names) 68 | completion_kwargs.update({"functions":function_schemas}) 69 | return completion_kwargs 70 | 71 | class ChatWithMeta(MetaModel): 72 | object: Chat 73 | 74 | 75 | class ChatMessageWithMeta(MetaModel): 76 | object: ChatMessage 77 | -------------------------------------------------------------------------------- /src/restllm/models/completion.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from enum import UNIQUE, verify, StrEnum 3 | 4 | from typing import Optional, Union, Any 5 | from .functions import FunctionName 6 | from .base import MetaModel 7 | 8 | 9 | @verify(UNIQUE) 10 | class ModelTypes(StrEnum): 11 | GPT3_TURBO = "gpt-3.5-turbo" 12 | GPT3_TURBO_16K = "gpt-3.5-turbo-16k" 13 | GPT4 = "gpt-4" 14 | GPT4_32K = "gpt-4-32k" 15 | 16 | 17 | class CompletionParameters(BaseModel): 18 | model: ModelTypes 19 | functions: Optional[list[FunctionName]] = Field( 20 | default=None, 21 | description="A list of functions that the model may use to generate JSON inputs. Each function should have the following properties", 22 | examples=[None,["SearchArticlesFunction"]], 23 | ) 24 | temperature: Optional[Union[float, None]] = Field( 25 | default=0.2, 26 | ge=0, 27 | le=2, 28 | description="The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic.", 29 | ) 30 | top_p: Optional[Union[float, None]] = Field( 31 | default=None, 32 | description="An alternative to sampling with temperature. It instructs the model to consider the results of the tokens with top_p probability. For example, 0.1 means only the tokens comprising the top 10% probability mass are considered.", 33 | examples=[None, 0.1] 34 | ) 35 | n: Optional[int] = Field( 36 | default=None, 37 | description="The number of chat completion choices to generate for each input message.", 38 | examples=[None, 2], 39 | ) 40 | stop: Optional[Union[str, list[str], None]] = Field( 41 | default=None, 42 | description="Up to 4 sequences where the API will stop generating further tokens.", 43 | examples=[None] 44 | ) 45 | max_tokens: Optional[int] = Field( 46 | default=None, 47 | description="The maximum number of tokens to generate in the chat completion.", 48 | examples=[None], 49 | ) 50 | presence_penalty: Optional[Union[float, None]] = Field( 51 | default=None, 52 | description="It is used to penalize new tokens based on their existence in the text so far.", 53 | examples=[None], 54 | ) 55 | frequency_penalty: Optional[Union[float, None]] = Field( 56 | default=None, 57 | description="It is used to penalize new tokens based on their frequency in the text so far.", 58 | examples=[None], 59 | ) 60 | logit_bias: Optional[dict[str, float]] = Field( 61 | default={}, 62 | description="Used to modify the probability of specific tokens appearing in the completion.", 63 | examples=[None], 64 | ) 65 | user: Optional[str] = Field( 66 | default=None, 67 | description="A unique identifier representing your end-user. This can help OpenAI to monitor and detect abuse.", 68 | examples=[None, "Alice"], 69 | ) 70 | function_call: Optional[Union[str, dict[str, Any]]] = Field( 71 | default=None, 72 | description="Controls how the model responds to function calls.", 73 | examples=[None, "auto"], 74 | ) 75 | 76 | 77 | class CompletionParametersWithMeta(MetaModel): 78 | object: CompletionParameters 79 | -------------------------------------------------------------------------------- /src/restllm/models/events.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime 3 | from enum import Enum 4 | from typing import Any 5 | 6 | import redis.asyncio as redis 7 | from pydantic import UUID4, BaseModel, Field 8 | 9 | 10 | class EventType(str, Enum): 11 | TASK = "task" 12 | OBJECT = "object" 13 | 14 | 15 | class CRUDAction(Enum): 16 | CREATE = "create" 17 | READ = "read" 18 | UPDATE = "update" 19 | DELETE = "delete" 20 | 21 | 22 | class TaskAction(Enum): 23 | START = "start" 24 | PAUSE = "pause" 25 | RESUME = "resume" 26 | COMPLETE = "complete" 27 | 28 | 29 | class EventStatus(Enum): 30 | PENDING = "pending" 31 | IN_PROGRESS = "in_progress" 32 | COMPLETED = "completed" 33 | FAILED = "failed" 34 | PAUSED = "paused" 35 | 36 | 37 | class Event(BaseModel): 38 | action: CRUDAction | TaskAction 39 | status: EventStatus 40 | object: Any 41 | 42 | 43 | class EventWithMeta(BaseModel): 44 | uuid: str | UUID4 = Field(default_factory=uuid.uuid4) 45 | owner: int | str 46 | type: EventType 47 | event: Event 48 | created_at: datetime = Field(default_factory=datetime.now) 49 | 50 | async def publish(self, redis_client: redis.Redis) -> int: 51 | return await redis_client.publish( 52 | self.get_channel(), 53 | self.model_dump_json(), 54 | ) 55 | 56 | def get_channel(self) -> str: 57 | return f"{self.type.value}:{self.owner}" 58 | -------------------------------------------------------------------------------- /src/restllm/models/functions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | from enum import StrEnum 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class NotAFunctionCall(ValueError): 8 | pass 9 | 10 | 11 | class FunctionNameMismatch(ValueError): 12 | pass 13 | 14 | 15 | class MissingFunctionDescription(ValueError): 16 | pass 17 | 18 | 19 | class FunctionCall(BaseModel): 20 | name: str = Field( 21 | description="Name of the function to call" 22 | ) 23 | args: dict[str, Any] 24 | 25 | 26 | class Function(BaseModel): 27 | name: str = Field( 28 | max_length=64, 29 | pattern="^[a-zA-Z0-9_-]*$", 30 | description="The name of the function to be called. It should contain a-z, A-Z, 0-9, underscores and dashes, with a maximum length of 64 characters.", 31 | examples=["get_weather_status", None], 32 | ) 33 | description: Optional[str] = Field( 34 | description="A description explaining what the function does. It helps the model to decide when and how to call the function.", 35 | examples=["Function to get current weather status"], 36 | ) 37 | parameters: dict[str, Any] = Field( 38 | description="The parameters that the function accepts, described as a JSON Schema object.", 39 | examples=[ 40 | { 41 | "type": "object", 42 | "properties": { 43 | "name": "location", 44 | "description": "Location for the weather status", 45 | }, 46 | }, 47 | None, 48 | ], 49 | ) 50 | 51 | 52 | # MIT License 53 | # 54 | # Copyright (c) 2023 Jason Liu 55 | # 56 | # Permission is hereby granted, free of charge, to any person obtaining a copy 57 | # of this software and associated documentation files (the "Software"), to deal 58 | # in the Software without restriction, including without limitation the rights 59 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 60 | # copies of the Software, and to permit persons to whom the Software is 61 | # furnished to do so, subject to the following conditions: 62 | # 63 | # The above copyright notice and this permission notice shall be included in all 64 | # copies or substantial portions of the Software. 65 | # 66 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 67 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 68 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 69 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 70 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 71 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 72 | # SOFTWARE. 73 | 74 | 75 | def _remove_keys(d: dict, remove_keys: list[str]) -> None: 76 | """Remove a key from a dictionary recursively""" 77 | if isinstance(d, dict): 78 | for key in list(d.keys()): 79 | if key in remove_keys and "type" in d.keys(): 80 | del d[key] 81 | else: 82 | _remove_keys(d[key], remove_keys) 83 | 84 | 85 | class SchemaCheckMeta(type(BaseModel)): 86 | def __init__( 87 | cls: BaseModel, 88 | name: str, 89 | bases: tuple, 90 | attrs: dict[str, Any], 91 | ): 92 | super().__init__(name, bases, attrs) 93 | 94 | if name == "FunctionSchemaBase": 95 | return 96 | 97 | schema = cls.model_json_schema() 98 | 99 | if "description" not in schema: 100 | raise MissingFunctionDescription( 101 | "Function 'description' is missing. Use doc strings on the class definition to provide description'" 102 | ) 103 | 104 | 105 | class FunctionSchemaBase(BaseModel, metaclass=SchemaCheckMeta): 106 | @classmethod 107 | @property 108 | def function_schema(cls): 109 | """ 110 | Return the schema in the format of OpenAI's schema as jsonschema 111 | """ 112 | schema = cls.model_json_schema() 113 | parameters = { 114 | key: value 115 | for key, value in schema.items() 116 | if key not in ("title", "description") 117 | } 118 | 119 | parameters["required"] = sorted( 120 | key 121 | for key, value in parameters["properties"].items() 122 | if not "default" in value 123 | ) 124 | 125 | if "description" not in schema: 126 | raise MissingFunctionDescription( 127 | "Function 'description' is missing. Use doc strings on the class definition to provide description'" 128 | ) 129 | 130 | _remove_keys(parameters, ["title", "additionalProperties"]) 131 | return { 132 | "name": schema["title"], 133 | "description": schema["description"], 134 | "parameters": parameters, 135 | } 136 | 137 | @classmethod 138 | def from_response( 139 | cls, 140 | completion: dict[str, Any], 141 | context: dict[str, Any] | None = None, 142 | strict: bool | None = None, 143 | ): 144 | """Execute the function from the response of an openai chat completion""" 145 | message = completion["choices"][0]["message"] 146 | 147 | if not "function_call" in message: 148 | raise NotAFunctionCall("No function call detected in message") 149 | 150 | if not message["function_call"]["name"] == cls.function_schema["name"]: 151 | raise FunctionNameMismatch("Function name does not match") 152 | 153 | return cls.model_validate_json( 154 | message["function_call"]["arguments"], 155 | context=context, 156 | strict=strict, 157 | ) 158 | 159 | 160 | class SearchArticlesFunction(FunctionSchemaBase): 161 | """ 162 | Function to search for articles in an article database. 163 | Usefull for when you want to answer questions related to an archive 164 | """ 165 | 166 | query: str = Field( 167 | description="The query string to search with. Can be anything related to articls", 168 | examples=["Who won the Danish election in 2022?"], 169 | ) 170 | publish_years: list[int] = Field( 171 | description="The publishing years to filter the search by. Should be a list of one or more integers.", 172 | examples=[[2010], [2022, 2023]], 173 | ) 174 | 175 | 176 | def get_all_function_schemas() -> list[Function]: 177 | return [_class.function_schema for _class in get_function_classes().values()] 178 | 179 | 180 | def get_function_schemas(function_names: list[str]) -> list[dict[str, Any]]: 181 | function_classes = get_function_classes() 182 | return [ 183 | function_classes.get(function_name).function_schema 184 | for function_name in function_names 185 | ] 186 | 187 | 188 | def get_function_classes() -> dict[str, FunctionSchemaBase]: 189 | """ 190 | List all functions to be callable in the API. 191 | """ 192 | return {SearchArticlesFunction.__name__: SearchArticlesFunction} 193 | 194 | 195 | def pascal_to_upper(class_name: str) -> str: 196 | result = [get_character(char) for char in class_name] 197 | return "".join(result)[1:] 198 | 199 | 200 | def get_character(character: str) -> str: 201 | return f"_{character.upper()}" if character.isupper() else character.upper() 202 | 203 | 204 | FunctionName = StrEnum( 205 | "FunctionName", 206 | [ 207 | (pascal_to_upper(cls.__name__), cls.__name__) 208 | for cls in get_function_classes().values() 209 | ], 210 | ) 211 | -------------------------------------------------------------------------------- /src/restllm/models/prompts.py: -------------------------------------------------------------------------------- 1 | import re 2 | from enum import Enum, auto, UNIQUE, verify, StrEnum 3 | 4 | from jinja2 import Template 5 | from pydantic import ( 6 | BaseModel, 7 | Field, 8 | computed_field, 9 | create_model, 10 | model_validator, 11 | field_validator, 12 | ) 13 | 14 | from .base import MetaModel 15 | from .validators import is_valid_jinja2_template, names_and_variables_match 16 | 17 | import iso639 18 | import iso639.exceptions 19 | 20 | 21 | class LanguageProperties(BaseModel): 22 | name: str = Field(description="Langauge name", examples=["English"]) 23 | pt1: str = Field(description="ISO 639-1 language code", examples=["en"]) 24 | pt2b: str = Field(description="ISO 639-2/B language code", examples=["eng"]) 25 | pt2t: str = Field(description="ISO 639-2/B language code", examples=["eng"]) 26 | pt3: str = Field(description="ISO 639-3 language code", examples=["eng"]) 27 | pt5: str = Field(description="ISO 639-5 language code", examples=["cpe"]) 28 | 29 | 30 | class Language(BaseModel): 31 | iso639_3: str = Field( 32 | max_length=3, 33 | min_length=3, 34 | description="iso639-3 language code.", 35 | examples=["eng"], 36 | ) 37 | 38 | @field_validator("iso639_3") 39 | def validate_language_code(cls, value): 40 | try: 41 | iso639.Lang(value) 42 | except iso639.exceptions.InvalidLanguageValue as exec: 43 | raise ValueError(f"Invalid ISO 639-3 language code: {value}") from exec 44 | return value 45 | 46 | @computed_field(return_type=LanguageProperties) 47 | @property 48 | def properties(self) -> LanguageProperties: 49 | return LanguageProperties(**iso639.Lang(self.iso639_3).asdict()) 50 | 51 | 52 | def get_name_pattern() -> re.Pattern: 53 | return r"^[a-zA-Z_][a-zA-Z0-9_]*$" 54 | 55 | 56 | @verify(UNIQUE) 57 | class PromptRole(StrEnum): 58 | USER = auto() 59 | SYSTEM = auto() 60 | 61 | 62 | class VariableType(Enum): 63 | STRING = "str" 64 | INTEGER = "int" 65 | FLOAT = "float" 66 | BOOLEAN = "bool" 67 | LIST = "list" 68 | DICT = "dict" 69 | 70 | @property 71 | def type(self): 72 | return eval(self._value_) 73 | 74 | 75 | class PromptTagName(StrEnum): 76 | ZEROSHOT = "Zero-shot Prompting" 77 | FEWSHOT = "Few-shot Prompting" 78 | MANYSHOT = "Many-shot Prompting" 79 | CURRICULUMLEARNING = "Curriculum Learning Prompting" 80 | META = "Meta-Prompting" 81 | CONTINUOUS = "Continuous Prompting" 82 | ADAPTIVE = "Adaptive Prompting" 83 | COMPARATIVE = "Comparative Prompting" 84 | CHAIN = "Chain Prompting" 85 | HIERARCHICAL = "Hierarchical Prompting" 86 | 87 | 88 | class PromptTagDescriptionMapping: 89 | _mapping = { 90 | PromptTagName.ZEROSHOT: "The model is provided with a prompt and is expected to generate a relevant response without any prior examples.", 91 | PromptTagName.FEWSHOT: "Providing a few examples along with the prompt to guide the model towards the desired output.", 92 | PromptTagName.MANYSHOT: "Providing a larger number of examples along with the prompt to further guide the model.", 93 | PromptTagName.CURRICULUMLEARNING: "Arranging prompts in an order of increasing complexity, training the model progressively.", 94 | PromptTagName.META: "Designing prompts that instruct the model to consider certain variables or conditions while generating a response.", 95 | PromptTagName.CONTINUOUS: "Employing a sequence of prompts in a continuous manner, where the model’s response to one prompt serves as a part of the prompt for the next task.", 96 | PromptTagName.ADAPTIVE: "Dynamically adjusting the prompt based on the model’s previous responses to better guide it towards the desired output.", 97 | PromptTagName.COMPARATIVE: "Providing comparisons within the prompt to guide the model towards generating more accurate or nuanced responses.", 98 | PromptTagName.CHAIN: "Creating a chain of interlinked prompts where the output of one task serves as the prompt for the subsequent task.", 99 | PromptTagName.HIERARCHICAL: "Structuring prompts in a hierarchical manner, where higher-level prompts guide the overall narrative and lower-level prompts guide the details.", 100 | } 101 | 102 | @classmethod 103 | def get_description(cls, prompt_tag: PromptTagName): 104 | return cls._mapping.get(prompt_tag, "Technique not found") 105 | 106 | 107 | class PromptTag(BaseModel): 108 | name: PromptTagName 109 | 110 | @property 111 | def description(self) -> str: 112 | return PromptTagDescriptionMapping.get_description(self.name) 113 | 114 | 115 | class BasePrompt(BaseModel): 116 | name: str = Field( 117 | description="Name of the prompt", 118 | pattern=get_name_pattern(), 119 | examples=[ 120 | "EditPythonCodePrompt", 121 | "SummariseArticlePrompt", 122 | ], 123 | ) 124 | description: str = Field( 125 | description="Description of the prompt and what it does.", 126 | examples=["Prompt to edit python code according to Clean Code principles."], 127 | ) 128 | language: Language = Field(description="Language of the text in the prompt") 129 | tags: list[PromptTagName] | None = Field( 130 | description="List of prompt tags descripting the type of prompt" 131 | ) 132 | 133 | 134 | class PromptMessage(BaseModel): 135 | role: PromptRole = Field( 136 | description="User or System role for prompt", examples=[PromptRole.SYSTEM] 137 | ) 138 | content: str = Field( 139 | description="Text based prompt for user or system role.'", 140 | examples=[ 141 | "You are an expert Python programmer that values Clean Code and simplicity." 142 | ], 143 | ) 144 | 145 | 146 | class Prompt(BasePrompt): 147 | messages: list[PromptMessage] = Field( 148 | description="List of prompt messages. Role System must preceed user", 149 | max_length=2, 150 | min_length=1, 151 | ) 152 | 153 | @field_validator("messages", mode="before") 154 | def validate_messages(cls, value): 155 | if len(value) == 2: 156 | if value[0].role == PromptRole.USER: 157 | raise ValueError("First role must be system when two messages is used") 158 | if value[0].role == value[1].role: 159 | raise ValueError("Consecutive roles cannot be the same") 160 | return value 161 | 162 | 163 | class PromptTemplateArgument(BaseModel): 164 | name: str = Field( 165 | pattern=get_name_pattern(), 166 | examples=["python_code", "article_body"], 167 | ) 168 | type: VariableType 169 | 170 | 171 | class TemplateMessage(BaseModel): 172 | role: PromptRole 173 | content: str = Field( 174 | description="Valid Jinja2 template for the prompt", 175 | examples=[ 176 | 'Please edit this python code to follow Clean Code best pratices: "{{ python_code }}"' 177 | ], 178 | ) 179 | 180 | 181 | class PromptTemplate(BasePrompt): 182 | arguments: list[PromptTemplateArgument] = Field( 183 | description="Parameter name and type for the Jinja2 template. Keys must match the template" 184 | ) 185 | messages: list[TemplateMessage] = Field( 186 | description="List of template messages containing valid Jinja2 template strings." 187 | ) 188 | 189 | @model_validator(mode="after") 190 | def check_valid_template(self) -> "PromptTemplate": 191 | template = self._get_template_text() 192 | if not is_valid_jinja2_template(template): 193 | raise ValueError(f"String is invalid Jinja2 template: {template}") 194 | if not names_and_variables_match(template, self._get_variable_names()): 195 | raise ValueError(f"Parameter keys and template variables must match.") 196 | return self 197 | 198 | def _get_template_text(self): 199 | return "\n".join([message.content for message in self.messages]) 200 | 201 | def _get_variable_names(self) -> list[str]: 202 | return [item.name for item in self.arguments] 203 | 204 | def _get_pydantic_types(self) -> dict[str, tuple[type, ...]]: 205 | return {item.name: (item.type.type, ...) for item in self.arguments} 206 | 207 | def create_model(self) -> BaseModel: 208 | return create_model(self.name, **self._get_pydantic_types()) 209 | 210 | def render(self, parameters: dict) -> dict: 211 | template_model = self.create_model() 212 | parameter_instance = template_model.model_validate(parameters, strict=True) 213 | messages = [ 214 | { 215 | "role": message.role, 216 | "content": Template(message.content).render( 217 | parameter_instance.model_dump() 218 | ), 219 | } 220 | for message in self.messages 221 | ] 222 | prompt_dict = self.model_dump() 223 | prompt_dict.update({"messages": messages}) 224 | return prompt_dict 225 | 226 | 227 | class PromptTemplateWithMeta(MetaModel): 228 | object: PromptTemplate 229 | 230 | 231 | class PromptWithMeta(MetaModel): 232 | object: Prompt 233 | -------------------------------------------------------------------------------- /src/restllm/models/share.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import Enum 3 | from urllib.parse import urljoin 4 | from pydantic import BaseModel, Field, computed_field 5 | 6 | from ..models import Chat, PromptTemplate, Prompt 7 | from ..redis.keys import get_class_name 8 | from ..settings import settings 9 | 10 | def join_path_segments(path_segments: list[str]): 11 | return '/'.join([segment.strip('/') for segment in path_segments]) 12 | 13 | class ShareableClass(str, Enum): 14 | CHAT = get_class_name(Chat) 15 | PROMPT_TEMPLAET = get_class_name(PromptTemplate) 16 | PROMPT = get_class_name(Prompt) 17 | 18 | 19 | class ShareableObject(BaseModel): 20 | signature: str = Field(description="Verification signature") 21 | payload: str = Field(description="Encrypted payload") 22 | expire_time: int = Field(description="Time in seconds before shared object expires") 23 | created_at: datetime = Field( 24 | default_factory=datetime.now, 25 | description="Datetime for when the shared object was created", 26 | ) 27 | 28 | @computed_field 29 | @property 30 | def uri(self) -> str: 31 | path_segments = [settings.share_prefix, self.payload, self.signature] 32 | return urljoin(settings.base_url, join_path_segments(path_segments)) 33 | -------------------------------------------------------------------------------- /src/restllm/models/validators.py: -------------------------------------------------------------------------------- 1 | from jinja2 import Environment, TemplateSyntaxError, meta 2 | 3 | 4 | def is_valid_jinja2_template(template: str) -> bool: 5 | env = Environment() 6 | try: 7 | env.parse(template) 8 | return True 9 | except TemplateSyntaxError: 10 | return False 11 | 12 | 13 | def names_and_variables_match(template: str, parameters: list[str]) -> bool: 14 | env = Environment() 15 | parsed_content = env.parse(template) 16 | required_keys = meta.find_undeclared_variables(parsed_content) 17 | return set(required_keys) == set(parameters) 18 | -------------------------------------------------------------------------------- /src/restllm/redis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIMunchII/restllm/42d0385281f1d944874e8be9930a3bf8f071a976/src/restllm/redis/__init__.py -------------------------------------------------------------------------------- /src/restllm/redis/commands.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | 3 | import redis.asyncio as redis 4 | from pydantic import BaseModel 5 | from redis.commands.json.path import Path 6 | 7 | from ..models import ChatMessage, Datetime, MetaModel, User 8 | 9 | 10 | async def get_multiple_instances( 11 | redis_client: redis.Redis, 12 | keys: list[str], 13 | ) -> dict: 14 | return await redis_client.json().mget(keys) 15 | 16 | 17 | async def get_instance( 18 | redis_client: redis.Redis, 19 | key: str, 20 | ) -> dict: 21 | return await redis_client.json().get(key) 22 | 23 | 24 | async def copy_instance( 25 | redis_client: redis.Redis, 26 | source_key: str, 27 | expire_time: int, 28 | ) -> tuple[bool, bool, str]: 29 | token = secrets.token_urlsafe(32) 30 | async with redis_client.pipeline() as pipeline: 31 | pipeline.multi() 32 | ( 33 | pipeline.copy( 34 | source=source_key, 35 | destination=token, 36 | ), 37 | pipeline.expire(token, time=expire_time, nx=True), 38 | ) 39 | copied, expired = await pipeline.execute() 40 | return copied, expired, token 41 | 42 | 43 | async def create_instance( 44 | redis_client: redis.Redis, 45 | owner: User, 46 | instance: BaseModel, 47 | instance_id: int, 48 | key: str, 49 | ) -> tuple[bool, MetaModel]: 50 | datetime_instance = Datetime() 51 | 52 | meta_instance = MetaModel( 53 | id=instance_id, 54 | class_name=instance.__class__.__name__, 55 | owner=owner.id, 56 | object=instance, 57 | created_at=datetime_instance, 58 | updated_at=datetime_instance, 59 | ) 60 | 61 | created = await redis_client.json().set( 62 | key, 63 | Path.root_path(), 64 | meta_instance.model_dump(mode="json"), 65 | nx=True, 66 | ) 67 | return created, meta_instance 68 | 69 | 70 | async def update_instance( 71 | redis_client: redis.Redis, 72 | instance: BaseModel, 73 | key: str, 74 | ) -> list[bool, bool, dict]: 75 | async with redis_client.pipeline() as pipeline: 76 | pipeline.multi() 77 | ( 78 | pipeline.json().set( 79 | key, 80 | "$.object", 81 | instance.model_dump(mode="json"), 82 | xx=True, 83 | ), 84 | pipeline.json().set( 85 | key, 86 | "$.updated_at", 87 | Datetime().model_dump(mode="json"), 88 | xx=True, 89 | ), 90 | pipeline.json().get(key), 91 | ) 92 | return await pipeline.execute() 93 | 94 | 95 | async def delete_instance( 96 | redis_client: redis.Redis, 97 | key: str, 98 | ) -> bool: 99 | return await redis_client.delete(key) 100 | 101 | 102 | async def edit_chat_message( 103 | redis_client: redis.Redis, 104 | instance: ChatMessage, 105 | index: int, 106 | key: str, 107 | ) -> list[bool, bool, dict]: 108 | async with redis_client.pipeline() as pipeline: 109 | pipeline.multi() 110 | ( 111 | pipeline.json().set( 112 | key, 113 | f"$.object.messages[{index}]", 114 | instance.model_dump(mode="json"), 115 | xx=True, 116 | ), 117 | pipeline.json().set( 118 | key, 119 | "$.updated_at", 120 | Datetime().model_dump(mode="json"), 121 | xx=True, 122 | ), 123 | pipeline.json().get(key), 124 | ) 125 | return await pipeline.execute() 126 | 127 | 128 | async def append_chat_message( 129 | redis_client: redis.Redis, 130 | instance: ChatMessage, 131 | key: str, 132 | ) -> list[bool, bool, dict]: 133 | async with redis_client.pipeline() as pipeline: 134 | pipeline.multi() 135 | ( 136 | pipeline.json().arrappend( 137 | key, 138 | f"$.object.messages", 139 | instance.model_dump(mode="json"), 140 | ), 141 | pipeline.json().set( 142 | key, 143 | "$.updated_at", 144 | Datetime().model_dump(mode="json"), 145 | xx=True, 146 | ), 147 | pipeline.json().get(key), 148 | ) 149 | return await pipeline.execute() 150 | -------------------------------------------------------------------------------- /src/restllm/redis/events.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | 3 | from ..models import EventType, EventWithMeta, User, Event 4 | 5 | STOPWORD = "STOP" 6 | 7 | 8 | def create_events_list(owner: User, event_types: list[EventType]) -> list[str]: 9 | return [f"{event_type.value}:{owner.id}" for event_type in event_types] 10 | 11 | 12 | async def subscribe_event( 13 | events_list: list[str], 14 | redis_client: redis.Redis, 15 | ) -> str: 16 | async with redis_client.pubsub() as pubsub: 17 | await pubsub.subscribe(*events_list) 18 | while True: 19 | message = await pubsub.get_message( 20 | ignore_subscribe_messages=True, timeout=None 21 | ) 22 | if message is None: 23 | continue 24 | data = message["data"].decode() 25 | if data == STOPWORD: 26 | break 27 | yield data + "\n" 28 | 29 | 30 | async def publish_event( 31 | event: Event, 32 | owner: User, 33 | redis_client: redis.Redis, 34 | event_type: EventType = EventType.OBJECT, 35 | ): 36 | event_with_meta = EventWithMeta( 37 | owner=owner.id, 38 | type=event_type, 39 | event=event, 40 | ) 41 | 42 | return await event_with_meta.publish(redis_client) 43 | -------------------------------------------------------------------------------- /src/restllm/redis/index.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type 2 | 3 | import redis 4 | import redis.exceptions 5 | from redis.commands.search.field import NumericField, TextField 6 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType 7 | 8 | from ..models import Chat, Prompt, PromptTemplate, CompletionParameters 9 | from ..redis.keys import get_class_name 10 | 11 | string_to_class_mapping = { 12 | "PromptTemplate": PromptTemplate, 13 | "Chat": Chat, 14 | "Prompt": Prompt, 15 | "CompletionParameters": CompletionParameters, 16 | } 17 | 18 | 19 | class IndexClassDoesNotExist(KeyError): 20 | pass 21 | 22 | 23 | class IndexNotFound(redis.exceptions.ResponseError): 24 | pass 25 | 26 | 27 | def get_class_from_class_name(class_name: str) -> dict[str, Type]: 28 | try: 29 | return string_to_class_mapping[class_name] 30 | except KeyError as exec: 31 | raise IndexClassDoesNotExist( 32 | f"The class '{class_name}' is not registeret for indexing" 33 | ) from exec 34 | 35 | 36 | def get_meta_model_schema(): 37 | return ( 38 | NumericField("$.owner", as_name="owner"), 39 | NumericField("$.id", as_name="id"), 40 | NumericField("$.created_at.timestamp", as_name="created_at"), 41 | NumericField("$.updated_at.timestamp", as_name="updated_at"), 42 | ) 43 | 44 | 45 | def get_prompt_schema(): 46 | return ( 47 | TextField("$.object.description"), 48 | TextField("$.object.name"), 49 | ) 50 | 51 | 52 | def get_index_key_from_class(_class: Type): 53 | return get_index_key(get_class_name(_class)) 54 | 55 | 56 | def get_index_key(class_name: str): 57 | return f"meta_model_index:{class_name}" 58 | 59 | 60 | def get_index_prefix(_class: Type): 61 | return f"{get_class_name(_class)}:" 62 | 63 | 64 | def create_index_on_meta_model( 65 | redis_client: redis.Redis, 66 | meta_model_schema: tuple, 67 | _class: Type, 68 | ) -> Any: 69 | index = get_index(_class, redis_client) 70 | return index.create_index( 71 | meta_model_schema, 72 | definition=IndexDefinition( 73 | prefix=[get_index_prefix(_class)], index_type=IndexType.JSON 74 | ), 75 | ) 76 | 77 | 78 | def get_index(_class: type, redis_client: redis.Redis): 79 | return redis_client.ft(get_index_key_from_class(_class)) 80 | -------------------------------------------------------------------------------- /src/restllm/redis/keys.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from .. import models 3 | from ..models.authentication import UserWithPasswordHash 4 | 5 | # IMPORTANT: This mapping is like a table definition for data in Redis. Do not alter the value of the key 6 | # It is however possible to alter the classname without issues, however, the same values for those classes will apply going forward. 7 | # Altering the key names requires a migration of data. 8 | model_to_key_mapping = { 9 | models.ChatMessage: "ChatMessage", 10 | models.ChatMessageWithMeta: "ChatMessage", 11 | models.Chat: "Chat", 12 | models.ChatWithMeta: "Chat", 13 | models.PromptTemplate: "PromptTemplate", 14 | models.PromptTemplateWithMeta: "PromptTemplate", 15 | models.Prompt: "Prompt", 16 | models.PromptWithMeta: "Prompt", 17 | models.CompletionParameters: "CompletionParameters", 18 | models.CompletionParametersWithMeta: "CompletionParameters", 19 | models.UserProfile: "UserProfile", 20 | models.UserProfileWithMeta: "UserProfile", 21 | models.User: "User", 22 | UserWithPasswordHash: "User", 23 | } 24 | 25 | 26 | def get_class_name(_class: BaseModel) -> str: 27 | return model_to_key_mapping[_class] 28 | -------------------------------------------------------------------------------- /src/restllm/redis/queries.py: -------------------------------------------------------------------------------- 1 | from redis.commands.search.query import Query 2 | from enum import StrEnum, auto 3 | from ..models import User 4 | 5 | 6 | class SortingField(StrEnum): 7 | CREATED_AT = auto() 8 | UPDATED_AT = auto() 9 | ID = auto() 10 | 11 | 12 | def create_privat_query( 13 | owner: User, 14 | ) -> Query: 15 | return Query(f"@owner:[{owner.id} {owner.id}]").dialect(3) 16 | 17 | 18 | def add_pagination_to_query( 19 | query: Query, 20 | offset: int, 21 | limit: int, 22 | ) -> Query: 23 | return query.paging(offset, limit) 24 | 25 | 26 | def add_sorting_to_query( 27 | query: Query, 28 | sorting_field: SortingField, 29 | ascending: bool, 30 | ) -> Query: 31 | return query.sort_by(str(sorting_field), asc=ascending) 32 | -------------------------------------------------------------------------------- /src/restllm/redis/ratelimit.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | import time 3 | 4 | from ..models import User 5 | 6 | CALLS_PER_MINUTE_LIMIT = 25 7 | 8 | MILLISECONDS_IN_DAY = 86_400_000 9 | MILLISECONDS_IN_MINUTE = 60_000 10 | MILLISECONDS_IN_HOUR = 3_600_000 11 | 12 | 13 | async def check_rate_limit( 14 | redis_client: redis.Redis, 15 | user: User, 16 | api_path: str, 17 | route_weight_score: int = 1, 18 | ) -> str: 19 | epoch_ms = get_current_epoch_milliseconds() 20 | rate_limit_key = f"ratelimit:{user.id}:{api_path}" 21 | async with redis_client.pipeline() as pipeline: 22 | pipeline.multi() 23 | 24 | remove_old_entries(pipeline, rate_limit_key, epoch_ms - MILLISECONDS_IN_MINUTE) 25 | add_new_entry(pipeline, rate_limit_key, epoch_ms, route_weight_score) 26 | get_all_entries(pipeline, rate_limit_key) 27 | set_expiration(pipeline, rate_limit_key, MILLISECONDS_IN_MINUTE) 28 | 29 | result = await pipeline.execute() 30 | return is_rate_limited(calculate_rate_limit_score(result)) 31 | 32 | 33 | def get_current_epoch_milliseconds(): 34 | return int(time.time() * 1000) 35 | 36 | 37 | def remove_old_entries( 38 | pipeline: redis.Redis, 39 | key: str, 40 | old_score_threshold: int, 41 | ): 42 | pipeline.zremrangebyscore(key, 0, old_score_threshold) 43 | 44 | 45 | def add_new_entry( 46 | pipeline: redis.Redis, 47 | key: str, 48 | epoch_ms: int, 49 | route_weight_score: int, 50 | ): 51 | value = f"{epoch_ms}:{route_weight_score}" 52 | pipeline.zadd(key, {value: epoch_ms}) 53 | 54 | 55 | def get_all_entries( 56 | pipeline: redis.Redis, 57 | key: str, 58 | ): 59 | pipeline.zrange(key, 0, -1) 60 | 61 | 62 | def set_expiration( 63 | pipeline: redis.Redis, 64 | key: str, 65 | expiration_ms: int, 66 | ): 67 | pipeline.expire(key, expiration_ms) 68 | 69 | 70 | def calculate_rate_limit_score(entries: list[bytes]): 71 | return sum(int(entry.decode("utf-8").split(":")[-1]) for entry in entries) 72 | 73 | 74 | def is_rate_limited(score: int): 75 | return score > CALLS_PER_MINUTE_LIMIT 76 | -------------------------------------------------------------------------------- /src/restllm/redis/search.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import redis.asyncio as redis 4 | from redis.commands.search.query import Query 5 | from ..models import User 6 | 7 | from .index import get_index_key 8 | from .queries import ( 9 | SortingField, 10 | add_pagination_to_query, 11 | add_sorting_to_query, 12 | create_privat_query, 13 | ) 14 | 15 | 16 | async def search_index( 17 | redis_client: redis.Redis, query: Query, class_name: str 18 | ) -> list[dict]: 19 | index = redis_client.ft(get_index_key(class_name)) 20 | result = await index.search(query) 21 | return [json.loads(item["json"])[0] for item in result.docs] 22 | 23 | 24 | async def list_instances( 25 | redis_client: redis.Redis, 26 | class_name: str, 27 | owner: User, 28 | offset: int | None = None, 29 | limit: int | None = None, 30 | sorting_field: SortingField | None = None, 31 | ascending: bool = True, 32 | ) -> list[dict]: 33 | query = create_privat_query(owner) 34 | if offset >= 0 and limit > 0: 35 | query = add_pagination_to_query(query, offset, limit) 36 | if sorting_field: 37 | query = add_sorting_to_query(query, sorting_field, ascending) 38 | return await search_index(redis_client, query, class_name) 39 | -------------------------------------------------------------------------------- /src/restllm/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIMunchII/restllm/42d0385281f1d944874e8be9930a3bf8f071a976/src/restllm/routers/__init__.py -------------------------------------------------------------------------------- /src/restllm/routers/authentication.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from fastapi import APIRouter, BackgroundTasks 3 | 4 | from fastapi import Depends, HTTPException, status, Response 5 | from fastapi.security import OAuth2PasswordRequestForm 6 | 7 | from ..tasks.email import send_email_verification_url 8 | from ..cryptography.authentication import create_tokens 9 | from ..cryptography.keys import get_fernet 10 | from ..cryptography.secure_url import ( 11 | decrypt_payload, 12 | generate_secure_url, 13 | payload_is_valid, 14 | ) 15 | from ..models.authentication import Token, UserSignUp, ChangePassword 16 | from ..endpoints.authentication import create_user_instance 17 | from ..models import User 18 | from ..dependencies import ( 19 | get_redis_client, 20 | authenticate_user, 21 | signup_form, 22 | change_password_form, 23 | create_instance_id, 24 | decode_user_token, 25 | get_user, 26 | ) 27 | from ..redis.keys import get_class_name 28 | from ..types import paths 29 | 30 | router = APIRouter( 31 | prefix="/authentication", 32 | tags=["Authentication"], 33 | ) 34 | 35 | 36 | @router.post("/token", response_model=Token) 37 | async def login_for_access_token( 38 | form_data: OAuth2PasswordRequestForm = Depends(), 39 | redis_client: redis.Redis = Depends(get_redis_client), 40 | ): 41 | user = await authenticate_user(form_data.username, form_data.password, redis_client) 42 | if not user: 43 | raise HTTPException( 44 | status_code=status.HTTP_401_UNAUTHORIZED, 45 | detail="Incorrect username or password", 46 | headers={"WWW-Authenticate": "Bearer"}, 47 | ) 48 | if not user.verified: 49 | raise HTTPException( 50 | status_code=status.HTTP_401_UNAUTHORIZED, 51 | detail="Account is not yet verified", 52 | headers={"WWW-Authenticate": "Bearer"}, 53 | ) 54 | access_token, refresh_token = await create_tokens(user.get_user_data()) 55 | 56 | return { 57 | "access_token": access_token, 58 | "token_type": "bearer", 59 | "refresh_token": refresh_token, 60 | } 61 | 62 | 63 | @router.post("/refresh", response_model=Token) 64 | async def refresh_access_token(token_data: dict = Depends(decode_user_token)): 65 | access_token, refresh_token = await create_tokens(token_data) 66 | 67 | return { 68 | "access_token": access_token, 69 | "token_type": "bearer", 70 | "refresh_token": refresh_token, 71 | } 72 | 73 | 74 | @router.post("/signup", response_model=User) 75 | async def user_signup( 76 | background_tasks: BackgroundTasks, 77 | signup_data: UserSignUp = Depends(signup_form), 78 | redis_client: redis.Redis = Depends(get_redis_client), 79 | ): 80 | user_exists = await redis_client.exists(signup_data.email_key) 81 | if user_exists: 82 | raise HTTPException( 83 | status_code=status.HTTP_400_BAD_REQUEST, 84 | detail="User already registered", 85 | ) 86 | 87 | instance_id = await create_instance_id(redis_client, get_class_name(User)) 88 | 89 | created, instance = await create_user_instance( 90 | redis_client, 91 | signup_data.create_user(instance_id), 92 | signup_data.email_key, 93 | ) 94 | if not created: 95 | raise HTTPException( 96 | status_code=status.HTTP_400_BAD_REQUEST, 97 | detail="Registration failed", 98 | ) 99 | 100 | fernet = await get_fernet(redis_client) 101 | signed_data = generate_secure_url(fernet, {"user_id": instance_id}) 102 | verification_url = f"http://localhost:8000/v1/authentication/verify-email/{signed_data.get('payload')}/{signed_data.get('signature')}" 103 | background_tasks.add_task( 104 | send_email_verification_url, 105 | signup_data.email, 106 | verification_url, 107 | ) 108 | 109 | return instance 110 | 111 | 112 | @router.get("/verify-email/{payload}/{signature}") 113 | async def verify_email( 114 | payload: str = paths.payload_path, 115 | signature: str = paths.signature_path, 116 | redis_client: redis.Redis = Depends(get_redis_client), 117 | ): 118 | fernet = await get_fernet(redis_client) 119 | if not payload_is_valid(payload, signature): 120 | raise HTTPException( 121 | status_code=status.HTTP_400_BAD_REQUEST, 122 | detail=f"Invalid request", 123 | ) 124 | user_id: int = decrypt_payload(fernet, payload).get("user_id") 125 | user_key = f"{get_class_name(User)}:{user_id}" 126 | updated = await redis_client.json().set(user_key, "$.verified", True, xx=True) 127 | if not updated: 128 | raise HTTPException( 129 | status_code=status.HTTP_400_BAD_REQUEST, 130 | detail="Invalid request", 131 | ) 132 | return Response(status_code=status.HTTP_204_NO_CONTENT) 133 | 134 | 135 | @router.post("/password-change") 136 | async def change_password( 137 | password_change: ChangePassword = Depends(change_password_form), 138 | user: User = Depends(get_user), 139 | redis_client: redis.Redis = Depends(get_redis_client), 140 | ): 141 | await authenticate_user( 142 | user.email, 143 | password_change.old_password.get_secret_value(), 144 | redis_client, 145 | ) 146 | 147 | updated = redis_client.json().set( 148 | user.get_key(), 149 | "$.hashed_password", 150 | password_change.get_new_password_hash(), 151 | xx=True, 152 | ) 153 | if not updated: 154 | raise HTTPException( 155 | status_code=status.HTTP_400_BAD_REQUEST, 156 | detail="Password update failed", 157 | ) 158 | return Response(status_code=status.HTTP_204_NO_CONTENT) 159 | -------------------------------------------------------------------------------- /src/restllm/routers/chats.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from ..endpoints import add_crud_route 4 | from ..models import Chat, ChatWithMeta 5 | from . import completion, messages 6 | 7 | router = APIRouter( 8 | prefix="/chat", 9 | tags=["chat"], 10 | ) 11 | router.include_router( 12 | messages.router, 13 | prefix="/{id}", 14 | ) 15 | router.include_router( 16 | completion.router, 17 | prefix="/{id}", 18 | ) 19 | 20 | 21 | add_crud_route(router, Chat, ChatWithMeta) 22 | -------------------------------------------------------------------------------- /src/restllm/routers/completion.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | import redis.exceptions 3 | 4 | from fastapi import APIRouter, Depends, HTTPException 5 | from fastapi.responses import StreamingResponse 6 | 7 | from ..dependencies import ( 8 | get_redis_client, 9 | build_get_instance_key, 10 | ) 11 | from ..endpoints.completion import chat_acompletion_call 12 | from ..exceptions import ObjectNotFoundException 13 | from ..models import Chat, ChatMessage, ChatWithMeta, RoleTypes 14 | from ..redis.commands import append_chat_message, get_instance 15 | from ..redis.keys import get_class_name 16 | 17 | router = APIRouter() 18 | 19 | 20 | @router.get( 21 | "/completion", 22 | description="Complete a chat from an existing state on server", 23 | ) 24 | async def get_completion( 25 | redis_client: redis.Redis = Depends(get_redis_client), 26 | key: str = Depends(build_get_instance_key(Chat)), 27 | ) -> str: 28 | instance = await get_instance( 29 | redis_client=redis_client, 30 | key=key, 31 | ) 32 | if not instance: 33 | raise ObjectNotFoundException(Chat) 34 | chat_with_meta = ChatWithMeta.model_validate(instance) 35 | if not chat_with_meta.object.last_message_is_user(): 36 | raise HTTPException( 37 | status_code=404, 38 | detail=f"{get_class_name(ChatMessage)} with role {RoleTypes.USER} not found", 39 | ) 40 | return StreamingResponse( 41 | chat_acompletion_call(chat_with_meta, redis_client, key), 42 | media_type="text/event-stream", 43 | ) 44 | 45 | 46 | @router.post( 47 | "/completion", 48 | description="Complete a chat from providet message as string input.", 49 | ) 50 | async def post_completion( 51 | chat_message: ChatMessage, 52 | redis_client: redis.Redis = Depends(get_redis_client), 53 | key: str = Depends(build_get_instance_key(Chat)), 54 | ) -> str: 55 | try: 56 | updated, updated_at, instance = await append_chat_message( 57 | redis_client=redis_client, 58 | instance=chat_message, 59 | key=key, 60 | ) 61 | chat_with_meta = ChatWithMeta.model_validate(instance) 62 | return StreamingResponse( 63 | chat_acompletion_call(chat_with_meta, redis_client, key), 64 | media_type="text/event-stream", 65 | ) 66 | except redis.exceptions.ResponseError as exec: 67 | raise ObjectNotFoundException(Chat) from exec 68 | -------------------------------------------------------------------------------- /src/restllm/routers/completion_parameters.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from ..endpoints import add_crud_route 4 | from ..models import CompletionParameters, CompletionParametersWithMeta 5 | 6 | router = APIRouter( 7 | prefix="/completion", 8 | tags=["parameters"], 9 | ) 10 | 11 | add_crud_route( 12 | router, CompletionParameters, CompletionParametersWithMeta, prefix="/parameters" 13 | ) 14 | -------------------------------------------------------------------------------- /src/restllm/routers/events.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from fastapi import APIRouter, Depends, Request 3 | from fastapi.responses import StreamingResponse 4 | 5 | from ..dependencies import get_redis_client, get_user 6 | from ..models import ChatMessage, CRUDAction, Event, EventStatus, EventType 7 | from ..redis.events import create_events_list, publish_event, subscribe_event 8 | 9 | router = APIRouter( 10 | prefix="/events", 11 | tags=["events"], 12 | ) 13 | 14 | 15 | @router.get("/{event_type}") 16 | async def events( 17 | request: Request, 18 | event_type: EventType, 19 | redis_client: redis.Redis = Depends(get_redis_client), 20 | ): 21 | events_list = create_events_list(get_user(request), [event_type]) 22 | 23 | return StreamingResponse( 24 | subscribe_event( 25 | events_list, 26 | redis_client, 27 | ), 28 | media_type="text/event-stream", 29 | ) 30 | 31 | 32 | @router.post("/create") 33 | async def create_event( 34 | request: Request, 35 | redis_client: redis.Redis = Depends(get_redis_client), 36 | ): 37 | event = Event( 38 | action=CRUDAction.CREATE, 39 | status=EventStatus.COMPLETED, 40 | object=ChatMessage(role="user", content="Hallo world"), 41 | ) 42 | return await publish_event( 43 | event=event, 44 | owner=get_user(request), 45 | redis_client=redis_client, 46 | ) 47 | -------------------------------------------------------------------------------- /src/restllm/routers/functions.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from ..models import functions 4 | 5 | router = APIRouter( 6 | prefix="/functions", 7 | tags=["functions"], 8 | ) 9 | 10 | 11 | @router.get("", response_model=list[functions.Function]) 12 | def get_functions(): 13 | return functions.get_all_function_schemas() 14 | 15 | 16 | @router.get("/{name}", response_model=functions.Function) 17 | def get_functions(name: functions.FunctionName): 18 | return functions.get_function_classes()[name].function_schema 19 | -------------------------------------------------------------------------------- /src/restllm/routers/messages.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | import redis.exceptions 3 | from fastapi import APIRouter, Depends 4 | 5 | from ..dependencies import get_redis_client, build_get_instance_key 6 | from ..exceptions import ObjectNotFoundException 7 | from ..models import Chat, ChatMessage, ChatWithMeta 8 | from ..redis.commands import append_chat_message, edit_chat_message 9 | from ..types import paths 10 | 11 | router = APIRouter() 12 | 13 | 14 | @router.patch( 15 | "/messages/{index}", 16 | response_model=ChatWithMeta, 17 | description="Edit chat message at index", 18 | ) 19 | async def update_chat_message( 20 | chat_message: ChatMessage, 21 | index: int = paths.index_path, 22 | redis_client: redis.Redis = Depends(get_redis_client), 23 | key: str = Depends(build_get_instance_key(Chat)), 24 | ): 25 | updated, updated_at, instance = await edit_chat_message( 26 | redis_client=redis_client, 27 | instance=chat_message, 28 | index=index, 29 | key=key, 30 | ) 31 | if not updated and not updated_at: 32 | raise ObjectNotFoundException(ChatMessage) 33 | return instance 34 | 35 | 36 | @router.post( 37 | "/messages", 38 | response_model=ChatWithMeta, 39 | description="Append user message to chat. Operation does not run completetion", 40 | ) 41 | async def add_chat_message( 42 | chat_message: ChatMessage, 43 | redis_client: redis.Redis = Depends(get_redis_client), 44 | key: str = Depends(build_get_instance_key(Chat)), 45 | ): 46 | try: 47 | updated, updated_at, instance = await append_chat_message( 48 | redis_client=redis_client, 49 | instance=chat_message, 50 | key=key, 51 | ) 52 | return instance 53 | except redis.exceptions.ResponseError as exec: 54 | raise ObjectNotFoundException(Chat) from exec 55 | -------------------------------------------------------------------------------- /src/restllm/routers/prompts.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from pydantic import ValidationError 3 | 4 | from fastapi import APIRouter, Depends 5 | from fastapi.responses import JSONResponse 6 | 7 | from ..dependencies import get_redis_client, build_get_instance_key 8 | from ..endpoints import add_crud_route 9 | from ..models import ( 10 | Prompt, 11 | PromptTemplate, 12 | PromptTemplateWithMeta, 13 | PromptWithMeta, 14 | ) 15 | from ..redis.commands import get_instance 16 | from ..exceptions import ObjectNotFoundException 17 | 18 | router = APIRouter( 19 | prefix="/prompts", 20 | tags=["prompts"], 21 | ) 22 | 23 | add_crud_route(router, PromptTemplate, PromptTemplateWithMeta, prefix="/template") 24 | add_crud_route(router, Prompt, PromptWithMeta, prefix="/prompt") 25 | 26 | 27 | @router.post("/template/{id}/render", response_model=Prompt) 28 | async def render_prompt_template( 29 | parameters: dict[str, str | int | list | dict], 30 | redis_client: redis.Redis = Depends(get_redis_client), 31 | key: str = Depends(build_get_instance_key(PromptTemplate)), 32 | ) -> Prompt: 33 | template_data = await get_instance( 34 | redis_client=redis_client, 35 | key=key, 36 | ) 37 | if not template_data: 38 | raise ObjectNotFoundException(PromptTemplate) 39 | prompt_template = PromptTemplate.model_validate(template_data.get("object")) 40 | 41 | try: 42 | return prompt_template.render(parameters) 43 | except ValidationError as exec: 44 | return JSONResponse(content={"detail": exec.errors()}, status_code=422) 45 | -------------------------------------------------------------------------------- /src/restllm/routers/share.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from fastapi import APIRouter, Depends, HTTPException, status 3 | 4 | from ..cryptography.keys import get_fernet 5 | from ..cryptography.secure_url import ( 6 | decrypt_payload, 7 | generate_secure_url, 8 | payload_is_valid, 9 | ) 10 | from ..dependencies import get_redis_client, get_shareable_key 11 | from ..models import MetaModel 12 | from ..models.share import ShareableObject 13 | from ..redis.commands import copy_instance 14 | from ..settings import settings 15 | from ..types import paths 16 | 17 | router = APIRouter( 18 | prefix=settings.share_prefix, 19 | tags=["shares"], 20 | ) 21 | 22 | 23 | @router.get("/{object}/{id}/generate") 24 | async def generate_shared_object( 25 | redis_client: redis.Redis = Depends(get_redis_client), 26 | shareable_key: str = Depends(get_shareable_key), 27 | ) -> ShareableObject: 28 | token = await copy_instance( 29 | redis_client=redis_client, 30 | source_key=shareable_key, 31 | expire_time=settings.shared_object_expire, 32 | ) 33 | fernet = await get_fernet(redis_client) 34 | signed_url = generate_secure_url(fernet, {"token": token}) 35 | signed_url.update(expire_time=settings.shared_object_expire) 36 | return signed_url 37 | 38 | 39 | @router.get("/{payload}/{signature}") 40 | async def get_shared_object( 41 | payload: str = paths.payload_path, 42 | signature: str = paths.signature_path, 43 | redis_client: redis.Redis = Depends(get_redis_client), 44 | ) -> MetaModel: 45 | fernet = await get_fernet(redis_client) 46 | if not payload_is_valid(payload, signature): 47 | raise HTTPException( 48 | status_code=status.HTTP_400_BAD_REQUEST, 49 | detail=f"Invalid request", 50 | ) 51 | token: str = decrypt_payload(fernet, payload).get("token") 52 | instance = await redis_client.json().get(token) 53 | if not instance: 54 | raise HTTPException( 55 | status_code=status.HTTP_404_NOT_FOUND, 56 | detail=f"Shared object has expired", 57 | ) 58 | return instance 59 | -------------------------------------------------------------------------------- /src/restllm/routers/users.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | import redis.exceptions 3 | from fastapi import APIRouter, Depends, Response, status 4 | 5 | from ..dependencies import ( 6 | get_redis_client, 7 | build_get_new_class_user_key, 8 | build_get_class_user_key, 9 | get_user, 10 | ) 11 | from ..exceptions import ObjectNotFoundException, ObjectAlreadyExistsException 12 | from ..models import UserProfile, UserProfileWithMeta, User 13 | from ..redis.commands import create_instance, update_instance 14 | 15 | router = APIRouter( 16 | prefix="/user", 17 | tags=["users"], 18 | ) 19 | 20 | 21 | @router.get("", response_model=User) 22 | async def get_token_data(current_user: User = Depends(get_user)): 23 | return current_user 24 | 25 | 26 | @router.get("/profile", response_model=UserProfileWithMeta) 27 | async def get_user_profile( 28 | redis_client: redis.Redis = Depends(get_redis_client), 29 | key: str = Depends(build_get_class_user_key(UserProfile)), 30 | ): 31 | instance = await redis_client.json().get(key) 32 | if not instance: 33 | raise ObjectNotFoundException(UserProfile) 34 | return instance 35 | 36 | 37 | @router.post( 38 | "/profile", 39 | response_model=UserProfileWithMeta, 40 | status_code=status.HTTP_201_CREATED, 41 | ) 42 | async def create_user_profile( 43 | user_profile: UserProfile, 44 | user: User = Depends(get_user), 45 | redis_client: redis.Redis = Depends(get_redis_client), 46 | new_key: str = Depends(build_get_new_class_user_key(UserProfile)), 47 | ): 48 | new_key, instance_id = new_key 49 | profile_exists = await redis_client.exists(new_key) 50 | if profile_exists: 51 | raise ObjectAlreadyExistsException(UserProfile) 52 | created, instance = await create_instance( 53 | redis_client=redis_client, 54 | owner=user, 55 | instance=user_profile, 56 | instance_id=instance_id, 57 | key=new_key, 58 | ) 59 | return instance 60 | 61 | 62 | @router.put("/profile", response_model=UserProfileWithMeta) 63 | async def update_user_profile( 64 | user_profile: UserProfile, 65 | redis_client: redis.Redis = Depends(get_redis_client), 66 | key: str = Depends(build_get_class_user_key(UserProfile)), 67 | ): 68 | updated, updated_at, instance = await update_instance( 69 | redis_client=redis_client, 70 | instance=user_profile, 71 | key=key, 72 | ) 73 | if not updated and not updated_at: 74 | raise ObjectNotFoundException(UserProfile) 75 | return instance 76 | 77 | 78 | @router.delete("/profile", status_code=204) 79 | async def delete_user_profile( 80 | redis_client: redis.Redis = Depends(get_redis_client), 81 | key: str = Depends(build_get_class_user_key(UserProfile)), 82 | ): 83 | deleted = await redis_client.delete(key) 84 | if deleted: 85 | return Response(status_code=204) 86 | else: 87 | raise ObjectNotFoundException(UserProfile) 88 | -------------------------------------------------------------------------------- /src/restllm/settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | from pydantic import Field, RedisDsn, HttpUrl, SecretStr 3 | 4 | 5 | class Settings(BaseSettings): 6 | secret_key: str = "insecure-09u23lkansld920394u23,njsldk" 7 | jwt_algorithm: str = "HS256" 8 | password_hash_algorithm: str = "argon2" 9 | access_token_expire_minutes: int = 120 10 | refresh_token_expire_minutes: int = 240 11 | email_verification_expire_minutes: int = Field( 12 | default=240, 13 | description="Time in minutes until email verification url expires.", 14 | ) 15 | shared_object_expire: int = Field( 16 | default=3600, 17 | description="Time in seconds before shared objects expire", 18 | ) 19 | share_prefix: str = Field( 20 | default="/share", 21 | description="APIRouter prefix for the 'share' route", 22 | ) 23 | base_url: HttpUrl = "http://localhost:8000" 24 | redis_dsn: RedisDsn = "redis://localhost:6379/0" 25 | ollama_base_url: HttpUrl = "http://localhost:11434" 26 | email_username: str = "" 27 | email_password: SecretStr = "" 28 | email_hostname: str = "" 29 | email_port: int = 587 30 | 31 | 32 | settings = Settings() 33 | -------------------------------------------------------------------------------- /src/restllm/tasks/email.py: -------------------------------------------------------------------------------- 1 | import aiosmtplib 2 | from email.message import EmailMessage 3 | from ..settings import settings 4 | 5 | 6 | async def send_email(subject: str, recipient: str, body: str): 7 | message = EmailMessage() 8 | message["From"] = settings.email_username 9 | message["To"] = recipient 10 | message["Subject"] = subject 11 | message.set_content(body) 12 | 13 | await aiosmtplib.send( 14 | message, 15 | hostname=settings.email_hostname, 16 | port=settings.email_port, 17 | username=settings.email_username, 18 | password=settings.email_password.get_secret_value(), 19 | start_tls=True, 20 | ) 21 | 22 | 23 | async def send_email_verification_url( 24 | recipient: str, 25 | verification_url: str, 26 | ) -> None: 27 | subject = "Email verification" 28 | body = f"Hi there \n\n Thanks for signing up! \n\n Please verify your email by using this link: {verification_url}" 29 | await send_email(subject, recipient, body) 30 | -------------------------------------------------------------------------------- /src/restllm/types/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIMunchII/restllm/42d0385281f1d944874e8be9930a3bf8f071a976/src/restllm/types/__init__.py -------------------------------------------------------------------------------- /src/restllm/types/paths.py: -------------------------------------------------------------------------------- 1 | from fastapi import Path 2 | 3 | id_path = Path( 4 | ..., 5 | gt=0, 6 | description="Whole positiv number corresponding to the ID of the ressource.", 7 | examples=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 8 | openapi_examples={str(i): {"value": i} for i in range(1, 11)}, 9 | ) 10 | index_path = Path( 11 | ..., 12 | ge=0, 13 | description="Whole positiv number or 0, corresponding to the Index of the chat message.", 14 | examples=[0, 1, 2], 15 | openapi_examples={str(i): {"value": i} for i in range(4)}, 16 | ) 17 | payload_path = Path(..., description="Encrypted payload created from share endpoint.") 18 | signature_path = Path(..., description="Payload signature created from share endpoint.") 19 | -------------------------------------------------------------------------------- /src/restllm/types/queries.py: -------------------------------------------------------------------------------- 1 | from fastapi import Query 2 | 3 | from ..redis.queries import SortingField 4 | 5 | offset_query = Query( 6 | default=0, 7 | ge=0, 8 | description="Offset to use for resuling list. Can be used for pagination by starting with offset 0 and limit 10, then offset 10 and limit 20", 9 | examples=[0, 10], 10 | ) 11 | limit_query = Query( 12 | default=10, 13 | gt=0, 14 | description="Limit to how many results should be returned as a response", 15 | examples=[5, 10], 16 | ) 17 | sorting_field_query = Query( 18 | default=SortingField.CREATED_AT, 19 | description="Field to use for sorting response", 20 | examples=[SortingField.CREATED_AT, SortingField.UPDATED_AT], 21 | ) 22 | ascending_query = Query( 23 | default=True, 24 | description="Should response be sorting in ascending order?. Default is True", 25 | examples=[True, False], 26 | ) 27 | -------------------------------------------------------------------------------- /tests/unittests/test_cryptography.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from cryptography.fernet import Fernet 3 | 4 | from restllm.cryptography.secure_url import ( 5 | decrypt_payload, 6 | encrypt_payload, 7 | generate_secure_url, 8 | payload_is_valid, 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def test_fernet_instance() -> Fernet: 14 | return Fernet(Fernet.generate_key()) 15 | 16 | 17 | @pytest.fixture 18 | def test_payload() -> dict: 19 | return { 20 | "user_id": "user_123", 21 | "resource_id": "resource_456", 22 | } 23 | 24 | 25 | def test_generate_secure_url(test_fernet_instance, test_payload): 26 | parsed_url_data = generate_secure_url(test_fernet_instance, test_payload) 27 | 28 | is_valid = payload_is_valid( 29 | parsed_url_data["payload"], 30 | parsed_url_data["signature"], 31 | ) 32 | assert is_valid, "Failed to verify encrypted payload against signature" 33 | 34 | 35 | def test_decrypt_payload(test_fernet_instance, test_payload): 36 | encrypted_payload = encrypt_payload(test_fernet_instance, test_payload) 37 | decrypted_payload = decrypt_payload(test_fernet_instance, encrypted_payload) 38 | assert ( 39 | test_payload == decrypted_payload 40 | ), f"{decrypt_payload} failed to decrypt object into expected dictionary: {test_payload}" 41 | -------------------------------------------------------------------------------- /tests/unittests/test_model_authentication.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from restllm.models.authentication import ChangePassword 3 | 4 | 5 | def test_passwords_match(): 6 | data = { 7 | "old_password": "oldpassword123", 8 | "new_password": "newpassword123", 9 | "confirm_new_password": "newpassword123", 10 | } 11 | model = ChangePassword(**data) 12 | assert model.new_password.get_secret_value() == data["new_password"] 13 | assert model.confirm_new_password.get_secret_value() == data["confirm_new_password"] 14 | 15 | 16 | def test_passwords_do_not_match(): 17 | data = { 18 | "old_password": "oldpassword123", 19 | "new_password": "newpassword123", 20 | "confirm_new_password": "differentpassword123", 21 | } 22 | with pytest.raises(ValueError): 23 | ChangePassword(**data) 24 | 25 | 26 | def test_new_passwords_are_empty(): 27 | data = { 28 | "old_password": "oldpassword123", 29 | "new_password": None, 30 | "confirm_new_password": None, 31 | } 32 | with pytest.raises(ValueError): 33 | ChangePassword(**data) 34 | -------------------------------------------------------------------------------- /tests/unittests/test_model_prompts.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from restllm.models.prompts import ( 5 | PromptTemplate, 6 | PromptTagName, 7 | is_valid_jinja2_template, 8 | ) 9 | 10 | 11 | @pytest.fixture 12 | def valid_test_messages() -> dict: 13 | return [{"role": "user", "content": "Hello, {{ test }}"}] 14 | 15 | 16 | @pytest.fixture 17 | def invalid_test_messages() -> dict: 18 | return [{"role": "user", "content": "Hello, {{ test."}] 19 | 20 | 21 | @pytest.fixture 22 | def invalid_test_variables() -> list[dict]: 23 | return [{"role": "user", "content": "Hello, {{ test }}, {{ invalid }}"}] 24 | 25 | 26 | @pytest.fixture 27 | def invalid_test_variables() -> list[dict]: 28 | return [{"role": "user", "content": "Hello, {{ test }}, {{ invalid }}"}] 29 | 30 | 31 | @pytest.fixture 32 | def valid_test_variables() -> list[dict]: 33 | return [{"role": "user", "content": "Hello, {{ test }}, {{ test2 }}"}] 34 | 35 | 36 | @pytest.fixture 37 | def valid_multiple_messages() -> list[dict]: 38 | return [ 39 | {"role": "system", "content": "Hello, {{ test_system }}"}, 40 | {"role": "user", "content": "Hello, {{ test_user }}"}, 41 | ] 42 | 43 | 44 | def test_is_valid_jinja2_template(): 45 | assert is_valid_jinja2_template("Hello, {{ test }}") == True 46 | assert is_valid_jinja2_template("Hello, {{ test.") == False 47 | 48 | 49 | def test_valid_prompt_template(valid_test_messages): 50 | instance = PromptTemplate( 51 | name="test", 52 | description="Test description", 53 | messages=valid_test_messages, 54 | arguments=[{"name": "test", "type": "str"}], 55 | language={"iso639_3": "eng"}, 56 | tags=[PromptTagName.ZEROSHOT], 57 | ) 58 | assert valid_test_messages == [ 59 | message.model_dump(mode="json") for message in instance.messages 60 | ] 61 | 62 | 63 | def test_valid_prompt_template_multiple(valid_multiple_messages): 64 | instance = PromptTemplate( 65 | name="test", 66 | description="Test description", 67 | messages=valid_multiple_messages, 68 | arguments=[ 69 | {"name": "test_system", "type": "str"}, 70 | {"name": "test_user", "type": "str"}, 71 | ], 72 | language={"iso639_3": "eng"}, 73 | tags=[PromptTagName.ZEROSHOT], 74 | ) 75 | assert valid_multiple_messages == [ 76 | message.model_dump(mode="json") for message in instance.messages 77 | ] 78 | 79 | 80 | def test_invalid_prompt_template(invalid_test_messages): 81 | with pytest.raises(ValidationError): 82 | instance = PromptTemplate( 83 | name="test", 84 | description="Test description", 85 | messages=invalid_test_messages, 86 | arguments=[{"name": "test", "type": "str"}], 87 | language={"iso639_3": "eng"}, 88 | tags=[PromptTagName.ZEROSHOT], 89 | ) 90 | 91 | 92 | def test_invalid_template_variables(invalid_test_variables): 93 | with pytest.raises(ValidationError): 94 | PromptTemplate( 95 | name="test", 96 | description="Test description", 97 | messages=invalid_test_variables, 98 | arguments=[{"name": "test", "type": "str"}], 99 | language={"iso639_3": "eng"}, 100 | tags=[PromptTagName.ZEROSHOT], 101 | ) 102 | 103 | 104 | def test_invalid_parameter_keys(valid_test_messages): 105 | with pytest.raises(ValidationError): 106 | PromptTemplate( 107 | name="test", 108 | description="Test description", 109 | messages=valid_test_messages, 110 | arguments=[ 111 | {"name": "test", "type": "str"}, 112 | {"invalid": "invalid", "type": "str"}, 113 | ], 114 | language={"iso639_3": "eng"}, 115 | tags=[PromptTagName.ZEROSHOT], 116 | ) 117 | 118 | 119 | def test_create_model_from_template(valid_test_variables): 120 | instance = PromptTemplate( 121 | name="test", 122 | description="Test description", 123 | messages=valid_test_variables, 124 | arguments=[{"name": "test", "type": "str"}, {"name": "test2", "type": "str"}], 125 | language={"iso639_3": "eng"}, 126 | tags=[PromptTagName.ZEROSHOT], 127 | ) 128 | model_class = instance.create_model() 129 | assert "test" in model_class.model_fields.keys() 130 | assert "test2" in model_class.model_fields.keys() 131 | 132 | 133 | def test_model_render_from_template(valid_test_variables): 134 | instance = PromptTemplate( 135 | name="test", 136 | description="Test description", 137 | messages=valid_test_variables, 138 | arguments=[{"name": "test", "type": "str"}, {"name": "test2", "type": "str"}], 139 | language={"iso639_3": "eng"}, 140 | tags=[PromptTagName.ZEROSHOT], 141 | ) 142 | rendered_string = instance.render({"test": "name1", "test2": "name2"}) 143 | assert rendered_string.get("messages")[0].get("content") == "Hello, name1, name2" 144 | 145 | 146 | def test_model_render_from_template_multiple(valid_multiple_messages): 147 | instance = PromptTemplate( 148 | name="test", 149 | description="Test description", 150 | messages=valid_multiple_messages, 151 | arguments=[ 152 | {"name": "test_system", "type": "str"}, 153 | {"name": "test_user", "type": "str"}, 154 | ], 155 | language={"iso639_3": "eng"}, 156 | tags=[PromptTagName.ZEROSHOT], 157 | ) 158 | rendered_string = instance.render({"test_system": "name1", "test_user": "name2"}) 159 | assert rendered_string.get("messages")[0].get("content") == "Hello, name1" 160 | assert rendered_string.get("messages")[1].get("content") == "Hello, name2" 161 | --------------------------------------------------------------------------------