├── .dockerignore ├── .env.example ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── app ├── __init__.py ├── chats │ ├── __init__.py │ ├── api.py │ ├── constants.py │ ├── exceptions.py │ ├── ingest.py │ ├── models.py │ ├── retrieval.py │ ├── services.py │ └── streaming.py ├── core │ ├── __init__.py │ ├── api.py │ ├── constants.py │ ├── logs.py │ ├── middlewares.py │ └── models.py ├── data │ ├── __init__.py │ └── models.py ├── db.py ├── ingest │ ├── __init__.py │ ├── __main__.py │ ├── ingest.py │ └── publisher.py ├── main.py └── process │ ├── __init__.py │ ├── __main__.py │ └── subscriber.py ├── data └── .gitkeep ├── db ├── migrations │ └── 20230925090948_init.sql ├── queries │ ├── assistants │ │ ├── get_first_assistant.sql │ │ └── insert_chat.sql │ ├── chats │ │ ├── get_chat.sql │ │ └── insert_chat.sql │ └── messages │ │ ├── .gitkeep │ │ ├── insert.sql │ │ └── select_all.sql └── schema.sql ├── docker-compose.yml ├── pyproject.toml ├── requirements ├── requirements-dev.txt └── requirements.txt ├── settings ├── __init__.py ├── base.py └── gunicorn.conf.py └── tests ├── conftest.py └── test_app.py /.dockerignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | docs/static 3 | .venv 4 | *.pyc 5 | __pycache__ 6 | terraform 7 | volumes 8 | data 9 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | DATABASE_URL="postgesql connection string" 2 | OPENAI_API_KEY=sk-xxxxxx 3 | NATS_URI="nats://nats:4222/" # nats in docker localhost locally -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | .idea 9 | .env 10 | data/* 11 | 12 | # Distribution / packaging 13 | .Python 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .nox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | *.py,cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | cover/ 50 | pytest.xml 51 | pytest-coverage.txt 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | profile_default/ 64 | ipython_config.py 65 | .python-version 66 | 67 | 68 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 69 | __pypackages__/ 70 | 71 | # Environments 72 | .env 73 | .venv 74 | env/ 75 | venv/ 76 | ENV/ 77 | env.bak/ 78 | venv.bak/ 79 | 80 | # Spyder project settings 81 | .spyderproject 82 | .spyproject 83 | 84 | # Rope project settings 85 | .ropeproject 86 | 87 | # mkdocs documentation 88 | /site 89 | # Pyre type checker 90 | .pyre/ 91 | # pytype static type analyzer 92 | .pytype/ 93 | # Cython debug symbols 94 | cython_debug/ 95 | # Ruff 96 | .ruff_cache/ 97 | 98 | /local_cache/ 99 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/charliermarsh/ruff-pre-commit 3 | rev: 'v0.1.4' 4 | hooks: 5 | - id: ruff 6 | args: ["--fix"] 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.6 as server 2 | RUN curl -fsSL -o /usr/local/bin/dbmate https://github.com/amacneil/dbmate/releases/latest/download/dbmate-linux-amd64 && \ 3 | chmod +x /usr/local/bin/dbmate 4 | RUN pip install --no-cache-dir --upgrade pip 5 | WORKDIR /project 6 | COPY /requirements/requirements.txt /project/requirements.txt 7 | RUN pip install --no-cache-dir --upgrade -r /project/requirements.txt 8 | COPY . /project 9 | CMD ["bash", "-c", "dbmate up & gunicorn -c settings/gunicorn.conf.py app.main:app"] 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Olaf Górski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PATH := $(PATH) 2 | SHELL := /bin/bash 3 | 4 | black: 5 | black . 6 | 7 | ruff: 8 | ruff check . --fix 9 | 10 | lint: 11 | make black 12 | make ruff 13 | 14 | pre-commit: 15 | pre-commit run --all-files 16 | 17 | test-without-clean: 18 | set -o pipefail; \ 19 | pytest \ 20 | --color=yes \ 21 | --junitxml=pytest.xml \ 22 | --cov-report=term-missing:skip-covered \ 23 | --cov=app \ 24 | tests \ 25 | | tee pytest-coverage.txt 26 | 27 | test: 28 | make test-without-clean 29 | make clean 30 | 31 | upgrade: 32 | make pip-compile-upgrade 33 | make pip-compile-dev-upgrade 34 | 35 | 36 | compile: 37 | make pip-compile 38 | make pip-compile-dev 39 | 40 | sync: 41 | pip-sync requirements/requirements.txt 42 | sync-dev: 43 | pip-sync requirements/requirements.txt requirements/requirements-dev.txt 44 | 45 | pip-compile: 46 | pip-compile -o requirements/requirements.txt pyproject.toml 47 | pip-compile-dev: 48 | pip-compile --extra dev -o requirements/requirements-dev.txt pyproject.toml 49 | 50 | pip-compile-upgrade: 51 | pip-compile --strip-extras --upgrade -o requirements/requirements.txt pyproject.toml 52 | pip-compile-dev-upgrade: 53 | pip-compile --extra dev --upgrade -o requirements/requirements-dev.txt pyproject.toml 54 | 55 | run: 56 | python -m gunicorn -c settings/gunicorn.conf.py app.main:app 57 | 58 | run-dev: 59 | python -m uvicorn --reload app.main:app 60 | 61 | clean: 62 | rm -f pytest.xml 63 | rm -f pytest-coverage.txt 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bRAG [WIP] 2 | bRAG is an example repo how you can do Basic Retrieval Augmented Generation the alternative way. 3 | No langchain. No ORMs. No alembic/migration generation tools. 4 | Instead we go with PDF. What in the world? pugsql, dbmate & fastapi. That's it. 5 | 6 | To better understand this project, you should read the following article: 7 | [Build your own LLM RAG API without langchain using OpenAI, qdrant and PDF stack](https://grski.pl/pdf-brag) 8 | 9 | More coming. 10 | 11 | Or to be more precise, we will use: 12 | 13 | 1. [plain python openai](https://github.com/openai/openai-python) client to interact with our LLM 14 | 2. [qdrant](https://qdrant.tech/) as our vector database of choice and their [client](https://github.com/qdrant/qdrant-client). 15 | 3. For the embeddings we will use ~~openai~~ Qdrant's [fastembed](https://github.com/qdrant/fastembed) lib. 16 | 4. the API will be written using [fastapi](https://github.com/tiangolo/fastapi) 17 | 5. no ORM for us, instead we will roll with [pugsql](https://github.com/mcfunley/pugsql) 18 | 6. migrations will be done using plain sql and [dbMate](https://github.com/mcfunley/pugsql). 19 | 7. instead of poetry, we will go for [pyenv](https://github.com/pyenv/pyenv) + [pip-tools](https://github.com/jazzband/pip-tools) 20 | 8. on the db front - [postgres](https://www.postgresql.org/) 21 | 9. the app will be containerised - [docker](https://www.docker.com/) 22 | 10. simple ingestion with [faststream](https://github.com/airtai/faststream) and [nats](https://nats.io/) 23 | 24 | Does that sound weird to you? Because it is! Even more fun. 25 | 26 | ## Basics 27 | 28 | ### One click deployment 29 | ```bash 30 | docker-compose up 31 | ``` 32 | 33 | ### local development for the api + db & vdb in docker 34 | ```bash 35 | docker compose up -d database qdrant nats api worker 36 | make run-dev 37 | ``` 38 | that's it. Then head over to http://localhost:8000/docs to see the swagger docs. 39 | 40 | To run the worker locally with faststream: 41 | o run the worker processing the reviews: 42 | ```bash 43 | faststream run --reload app.process.subscriber:app 44 | ``` 45 | lastly to ingest: 46 | ```bash 47 | python -m app.process.ingest 48 | ``` 49 | 50 | 51 | To get run-dev working proceed through the following steps: 52 | 53 | To start, let's make sure you have pyenv installed. What is pyenv? You can read a bit about it on my blog. 54 | [Pyenv, poetry and other rascals - modern Python dependency and version management.](https://grski.pl/pyenv-en) 55 | 56 | Long story short it's like a virtualenv but for python versions that keeps everything isolated, without polluting your system interpreter or interpreters of other projects, which would be bad. You don't need to know much more than that. 57 | 58 | If you are on mac you can just 59 | ```bash 60 | brew install pyenv 61 | ``` 62 | 63 | or if you do not like homebrew/are linux based: 64 | 65 | ```bash 66 | curl https://pyenv.run | bash 67 | ``` 68 | 69 | Remember to [set up your shell for pyenv.](https://github.com/pyenv/pyenv#set-up-your-shell-environment-for-pyenv).. 70 | In short you have to add some stuff to either your `.bashrc`, `.zshrc` or `.profile`. Why? So the proper 'commands' are available in your terminal and so that stuff works. How? 71 | 72 | ```bash 73 | # For bash: 74 | echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc 75 | echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc 76 | echo 'eval "$(pyenv init -)"' >> ~/.bashrc 77 | 78 | # For Zsh: 79 | echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.zshrc 80 | echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.zshrc 81 | echo 'eval "$(pyenv init -)"' >> ~/.zshrc 82 | ``` 83 | Done? Now you just need to get pyenv to get python 3.11 downloaded & installed locally (no worries it won't change your system interpreter): 84 | 85 | ```bash 86 | pyenv install 3.11 87 | ``` 88 | When you are done installing pyenv, we can get going. 89 | 90 | We will name our project **bRAG** - the name coming from **b**asic **R**etrieval **A**ugmented **G**eneration API. 91 | First of all, let's create a directory of the project: 92 | 93 | ```bash 94 | mkdir bRAG 95 | cd bRAG 96 | ``` 97 | 98 | Now we can set up a python version we want to use in this particular directory. How? With pyenv and pyenv-virtualenv, which is nowadays installed by default with the basic pyenv installer. If you need to, check the article I referred before to understand what's happening. 99 | 100 | ```bash 101 | pyenv virtualenv 3.11 bRAG-3-11 # this creates a 'virtualenv' based on python 3.11 named bRAG-3-11 102 | pyenv local bRAG-3-11 103 | ``` 104 | 105 | after that just install stuff from requirements.txt. 106 | 107 | 108 | ## Project structure 109 | 110 | The file structure is: 111 | ``` 112 | |-- app/ # Main codebase directory 113 | |---- chat/ 114 | |------ __init__.py 115 | |------ api.py 116 | |------ constants.py 117 | |------ message.py 118 | |------ models.py 119 | |------ streams.py 120 | 121 | |---- core/ 122 | |------ __init__.py 123 | |------ api.py 124 | |------ logs.py 125 | |------ middlewares.py 126 | |------ models.py 127 | 128 | |---- __init__.py 129 | |---- db.py 130 | |---- main.py 131 | 132 | 133 | |-- db/ # Database/migration/pugsql related code 134 | |---- migrations/ 135 | |---- queries/ 136 | |---- schema.sql 137 | 138 | |-- settings/ # Settings files directory 139 | |---- base.py 140 | |---- gunicorn.conf.py 141 | |-- tests/ # Main tests directory 142 | |-- requirements/ 143 | 144 | |-- .dockerignore 145 | |-- .env.example 146 | |-- .gitignore 147 | |-- pre-commit-config.yaml # for linting in during development 148 | |-- docker-compose.yml 149 | |-- Dockerfile 150 | |-- Makefile # useful shortcuts 151 | |-- README.md 152 | ``` 153 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | import tomllib 2 | from pathlib import Path 3 | 4 | with Path.open("pyproject.toml", "rb") as f: 5 | data = tomllib.load(f) 6 | version = data["project"]["version"] 7 | -------------------------------------------------------------------------------- /app/chats/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/app/chats/__init__.py -------------------------------------------------------------------------------- /app/chats/api.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from uuid import UUID 4 | 5 | import openai 6 | from starlette.responses import StreamingResponse 7 | 8 | from app.chats.exceptions import OpenAIException 9 | from app.chats.models import BaseMessage, Message 10 | from app.chats.services import BetaOpenAIService, OpenAIService 11 | from app.db import messages_queries 12 | 13 | router = APIRouter(tags=["Chat Endpoints"]) 14 | 15 | 16 | @router.post("/v1/chat/{uuid}") 17 | async def chat_create(uuid: UUID, input_message: BaseMessage): 18 | try: 19 | return await BetaOpenAIService(chat_uuid=uuid).send_message_to_a_thread( 20 | input_message=input_message 21 | ) 22 | except openai.OpenAIError as e: 23 | print(e) 24 | raise OpenAIException from e 25 | 26 | @router.get("/v1/chat/{uuid}") 27 | async def chat_get(uuid: UUID): 28 | try: 29 | return await BetaOpenAIService(chat_uuid=uuid).get_chat() 30 | except openai.OpenAIError as e: 31 | print(e) 32 | raise OpenAIException from e 33 | 34 | 35 | @router.get("/v1/messages") 36 | async def get_messages() -> list[Message]: 37 | # a bit messy as we might want to move this to a service 38 | return [Message(**message) for message in messages_queries.select_all()] 39 | 40 | 41 | @router.post("/v1/completion") 42 | async def completion_create(input_message: BaseMessage) -> Message: 43 | try: 44 | return await OpenAIService.chat_completion_without_streaming( 45 | input_message=input_message 46 | ) 47 | except openai.OpenAIError as e: 48 | raise OpenAIException from e 49 | 50 | 51 | @router.post("/v1/completion-stream") 52 | async def completion_stream(input_message: BaseMessage) -> StreamingResponse: 53 | """Streaming response won't return json but rather a properly formatted string for SSE.""" 54 | try: 55 | return await OpenAIService.chat_completion_with_streaming( 56 | input_message=input_message 57 | ) 58 | except openai.OpenAIError as e: 59 | raise OpenAIException from e 60 | 61 | 62 | @router.post("/v1/qa-create") 63 | async def qa_create(input_message: BaseMessage) -> Message: 64 | try: 65 | return await OpenAIService.qa_without_stream(input_message=input_message) 66 | except openai.OpenAIError as e: 67 | raise OpenAIException from e 68 | 69 | 70 | @router.post("/v1/qa-stream") 71 | async def qa_stream(input_message: BaseMessage) -> StreamingResponse: 72 | """Streaming response won't return json but rather a properly formatted string for SSE.""" 73 | try: 74 | return await OpenAIService.qa_with_stream(input_message=input_message) 75 | except openai.OpenAIError as e: 76 | raise OpenAIException from e 77 | -------------------------------------------------------------------------------- /app/chats/constants.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | 3 | 4 | class FailureReasonsEnum(StrEnum): 5 | OPENAI_ERROR = "OpenAI call failed" 6 | STREAM_TIMEOUT = "Stream timed out" 7 | FAILED_PROCESSING = "Post processing failed" 8 | 9 | 10 | class ChatRolesEnum(StrEnum): 11 | USER = "user" 12 | SYSTEM = "system" 13 | ASSISTANT = "assistant" 14 | 15 | 16 | class ModelsEnum(StrEnum): 17 | GPT4 = "gpt-4-0613" 18 | 19 | 20 | NO_DOCUMENTS_FOUND: str = ( 21 | "No documents found in context. Please try again with a different query." 22 | ) 23 | -------------------------------------------------------------------------------- /app/chats/exceptions.py: -------------------------------------------------------------------------------- 1 | from fastapi import HTTPException 2 | 3 | from starlette import status 4 | 5 | from app.chats.constants import NO_DOCUMENTS_FOUND, FailureReasonsEnum 6 | 7 | 8 | class OpenAIException(HTTPException): 9 | def __init__(self): 10 | super().__init__( 11 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 12 | detail=FailureReasonsEnum.OPENAI_ERROR.value, 13 | ) 14 | 15 | 16 | class OpenAIFailedProcessingException(HTTPException): 17 | def __init__(self): 18 | super().__init__( 19 | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, 20 | detail=FailureReasonsEnum.FAILED_PROCESSING.value, 21 | ) 22 | 23 | 24 | class OpenAIStreamTimeoutException(HTTPException): 25 | def __init__(self): 26 | super().__init__( 27 | status_code=status.HTTP_504_GATEWAY_TIMEOUT, 28 | detail=FailureReasonsEnum.STREAM_TIMEOUT.value, 29 | ) 30 | 31 | 32 | class RetrievalNoDocumentsFoundException(HTTPException): 33 | def __init__(self): 34 | super().__init__( 35 | status_code=status.HTTP_404_NOT_FOUND, detail=NO_DOCUMENTS_FOUND 36 | ) 37 | -------------------------------------------------------------------------------- /app/chats/ingest.py: -------------------------------------------------------------------------------- 1 | from qdrant_client import QdrantClient 2 | 3 | from settings import settings 4 | 5 | 6 | def test_populate_vector_db(client: QdrantClient): 7 | # Prepare your documents, metadata, and IDs 8 | docs = ( 9 | "LoremIpsum sit dolorem", 10 | "Completely random phrase", 11 | "Another random phrase", 12 | "pdf-BRAG is awesome.", 13 | ) 14 | metadata = ( 15 | {"source": "Olaf"}, 16 | {"source": "grski"}, 17 | {"source": "Olaf"}, 18 | {"source": "grski"}, 19 | ) 20 | ids = [42, 2, 3, 5] 21 | 22 | # Use the new add method 23 | client.add( 24 | collection_name=settings.QDRANT_COLLECTION_NAME, 25 | documents=docs, 26 | metadata=metadata, 27 | ids=ids, 28 | ) 29 | -------------------------------------------------------------------------------- /app/chats/models.py: -------------------------------------------------------------------------------- 1 | from uuid import UUID, uuid4 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from app.chats.constants import ChatRolesEnum, ModelsEnum 6 | from app.chats.exceptions import OpenAIFailedProcessingException 7 | from app.core.models import TimestampAbstractModel 8 | 9 | 10 | class BaseMessage(BaseModel): 11 | """Base pydantic model that we use to interact with the API.""" 12 | 13 | model: ModelsEnum = Field(default=ModelsEnum.GPT4.value) 14 | message: str 15 | 16 | 17 | class Message(TimestampAbstractModel, BaseMessage): 18 | id: int = Field(default=None) 19 | uuid: UUID = Field(default_factory=uuid4) 20 | role: ChatRolesEnum 21 | 22 | 23 | class Chunk(BaseModel): 24 | id: str 25 | created: int = Field(default=0) 26 | model: ModelsEnum = Field(default="gpt-4-0613") 27 | content: str 28 | finish_reason: str | None = None 29 | 30 | @classmethod 31 | def from_chunk(cls, chunk): 32 | delta_content: str = cls.get_chunk_delta_content(chunk=chunk) 33 | return cls( 34 | id=chunk["id"], 35 | created=chunk["created"], 36 | model=chunk["model"], 37 | content=delta_content, 38 | finish_reason=chunk["choices"][0].get("finish_reason", None), 39 | ) 40 | 41 | @staticmethod 42 | def get_chunk_delta_content(chunk: dict | str) -> str: 43 | try: 44 | match chunk: 45 | case str(chunk): 46 | return chunk 47 | case dict(chunk): 48 | return chunk["choices"][0]["delta"].get("content", "") 49 | except Exception as e: 50 | raise OpenAIFailedProcessingException from e 51 | 52 | 53 | class Chat(BaseModel): 54 | """Simple model to tie together our local chat with the OpenAI thread. Later you could add user here.""" 55 | 56 | uuid: UUID = Field(default_factory=uuid4) 57 | openai_id: str 58 | 59 | 60 | class Assistant(BaseModel): 61 | uuid: UUID = Field(default_factory=uuid4) 62 | openai_id: str 63 | -------------------------------------------------------------------------------- /app/chats/retrieval.py: -------------------------------------------------------------------------------- 1 | from qdrant_client import QdrantClient 2 | 3 | from app.chats.exceptions import RetrievalNoDocumentsFoundException 4 | from app.chats.ingest import test_populate_vector_db 5 | from app.chats.models import BaseMessage 6 | from app.core.logs import logger 7 | from settings import settings 8 | 9 | client = QdrantClient(settings.QDRANT_HOST, port=settings.QDRANT_PORT) 10 | 11 | # ONLY FOR TEST PURPOSE 12 | test_populate_vector_db(client=client) 13 | 14 | 15 | def process_retrieval(message: BaseMessage) -> BaseMessage: 16 | """Search for a given query using vector similarity search. If no documents are found we raise an exception. 17 | If we do find documents we take the top 3 and put them into the context. 18 | """ 19 | search_result = search(query=message.message) 20 | resulting_query: str = ( 21 | f"Answer based only on the context, nothing else. \n" 22 | f"QUERY:\n{message.message}\n" 23 | f"CONTEXT:\n{search_result}" 24 | ) 25 | logger.info(f"Resulting Query: {resulting_query}") 26 | return BaseMessage(message=resulting_query, model=message.model) 27 | 28 | 29 | def search(query: str) -> str: 30 | """Function takes our query string, transforms the string into vectors and then searches for the most similar 31 | documents, returning the top 3. If no documents are found we raise an exception - as the user is asking 32 | about something not in the context of our documents. 33 | """ 34 | search_result = client.query( 35 | collection_name=settings.QDRANT_COLLECTION_NAME, limit=3, query_text=query 36 | ) 37 | print(search_result) 38 | if not search_result: 39 | raise RetrievalNoDocumentsFoundException 40 | return "\n".join(result.document for result in search_result) 41 | -------------------------------------------------------------------------------- /app/chats/services.py: -------------------------------------------------------------------------------- 1 | from uuid import UUID 2 | 3 | import openai 4 | from openai import AsyncOpenAI, OpenAI 5 | from openai.types.chat import ChatCompletion 6 | from starlette.responses import StreamingResponse 7 | 8 | from app.chats.constants import NO_DOCUMENTS_FOUND, ChatRolesEnum 9 | from app.chats.exceptions import RetrievalNoDocumentsFoundException 10 | from app.chats.models import Assistant, BaseMessage, Chat, Message 11 | from app.chats.retrieval import process_retrieval 12 | from app.chats.streaming import format_to_event_stream, stream_generator 13 | from app.core.logs import logger 14 | from app.db import assistants_queries, chats_queries, messages_queries 15 | from settings import settings 16 | 17 | openai_client = OpenAI(api_key=settings.OPENAI_API_KEY) 18 | async_openai_client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) 19 | 20 | 21 | def create_openai_thread(chat_uuid): 22 | openai_id = openai_client.beta.threads.create().id 23 | chat = Chat(openai_id=openai_id, uuid=chat_uuid) 24 | chats_queries.insert_chat(uuid=chat.uuid, openai_thread_id=chat.openai_id) 25 | return chat 26 | 27 | 28 | def create_openai_assistant(): 29 | openai_id = openai_client.beta.assistants.create( 30 | instructions="You are a helpul instruction following assistant.", 31 | name="Helpful Assistant", 32 | model="gpt-4", 33 | ).id 34 | assistant = Assistant(openai_id=openai_id) 35 | assistants_queries.insert_assistant(uuid=assistant.uuid, openai_id=assistant.openai_id) 36 | return assistant 37 | 38 | 39 | class BetaOpenAIService: 40 | def __init__(self, chat_uuid: UUID): 41 | self.assistant: Assistant = self.get_or_init_assistant() 42 | self.chat: Chat = self.get_or_init_chat(chat_uuid=chat_uuid) 43 | 44 | @staticmethod 45 | def get_or_init_chat(chat_uuid: UUID) -> Chat: 46 | if chat := chats_queries.get_chat(uuid=chat_uuid): 47 | return Chat(uuid=chat["uuid"], openai_id=chat["openai_id"]) 48 | return create_openai_thread(chat_uuid=chat_uuid) 49 | 50 | @staticmethod 51 | def get_or_init_assistant(): 52 | if assistant := assistants_queries.get_first_assistant(): 53 | return Assistant(uuid=assistant["uuid"], openai_id=assistant["openai_id"]) 54 | return create_openai_assistant() 55 | 56 | async def get_chat(self): 57 | chat = chats_queries.get_chat(uuid=self.chat.uuid) 58 | return await async_openai_client.beta.threads.messages.list( 59 | thread_id=chat["openai_id"] 60 | ) 61 | 62 | async def send_message_to_a_thread(self, input_message: BaseMessage): 63 | await async_openai_client.beta.threads.messages.create( 64 | thread_id=self.chat.openai_id, 65 | role=ChatRolesEnum.USER.value, 66 | content=input_message.message, 67 | ) 68 | await async_openai_client.beta.threads.runs.create( 69 | thread_id=self.chat.openai_id, 70 | assistant_id=self.assistant.openai_id, 71 | ) 72 | return self.chat 73 | 74 | 75 | class OpenAIService: 76 | @classmethod 77 | async def chat_completion_without_streaming( 78 | cls, input_message: BaseMessage 79 | ) -> Message: 80 | completion: openai.ChatCompletion = await openai.ChatCompletion.acreate( 81 | model=input_message.model, 82 | api_key=settings.OPENAI_API_KEY, 83 | messages=[ 84 | {"role": ChatRolesEnum.USER.value, "content": input_message.message} 85 | ], 86 | ) 87 | logger.info(f"Got the following response: {completion}") 88 | message = Message( 89 | model=input_message.model, 90 | message=cls.extract_response_from_completion(completion), 91 | role=ChatRolesEnum.ASSISTANT.value, 92 | ) 93 | messages_queries.insert( 94 | model=message.model, message=message.message, role=message.role 95 | ) 96 | return message 97 | 98 | @staticmethod 99 | async def chat_completion_with_streaming( 100 | input_message: BaseMessage, 101 | ) -> StreamingResponse: 102 | subscription: openai.ChatCompletion = await openai.ChatCompletion.acreate( 103 | model=input_message.model, 104 | api_key=settings.OPENAI_API_KEY, 105 | messages=[ 106 | {"role": ChatRolesEnum.USER.value, "content": input_message.message} 107 | ], 108 | stream=True, 109 | ) 110 | return StreamingResponse( 111 | stream_generator(subscription), media_type="text/event-stream" 112 | ) 113 | 114 | @staticmethod 115 | def extract_response_from_completion(chat_completion: ChatCompletion) -> str: 116 | return chat_completion.choices[0].message.content 117 | 118 | @classmethod 119 | async def qa_without_stream(cls, input_message: BaseMessage) -> Message: 120 | try: 121 | augmented_message: BaseMessage = process_retrieval(message=input_message) 122 | return await cls.chat_completion_without_streaming( 123 | input_message=augmented_message 124 | ) 125 | except RetrievalNoDocumentsFoundException: 126 | return Message( 127 | model=input_message.model, 128 | message=NO_DOCUMENTS_FOUND, 129 | role=ChatRolesEnum.ASSISTANT.value, 130 | ) 131 | 132 | @classmethod 133 | async def qa_with_stream(cls, input_message: BaseMessage) -> StreamingResponse: 134 | try: 135 | augmented_message: BaseMessage = process_retrieval(message=input_message) 136 | return await cls.chat_completion_with_streaming( 137 | input_message=augmented_message 138 | ) 139 | except RetrievalNoDocumentsFoundException: 140 | return StreamingResponse( 141 | (format_to_event_stream(y) for y in "Not found"), 142 | media_type="text/event-stream", 143 | ) 144 | -------------------------------------------------------------------------------- /app/chats/streaming.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | 4 | import async_timeout 5 | 6 | from app.chats.constants import ChatRolesEnum 7 | from app.chats.exceptions import ( 8 | OpenAIFailedProcessingException, 9 | OpenAIStreamTimeoutException, 10 | ) 11 | from app.chats.models import Chunk, Message 12 | from app.core.logs import logger 13 | from app.db import messages_queries 14 | from settings import settings 15 | 16 | 17 | async def stream_generator(subscription): 18 | """In the future if we are paranoid about performance we can pre-allocate a buffer and use it instead of f-string 19 | should save a cycle or two in the processor or be a fun exercise but totally not worth the effort or complexity 20 | currently (in the future too). 21 | """ 22 | async with async_timeout.timeout(settings.GENERATION_TIMEOUT_SEC): 23 | try: 24 | complete_response: str = "" 25 | async for chunk in subscription: 26 | # f-string is faster than concatenation so we use it here 27 | complete_response = ( 28 | f"{complete_response}{Chunk.get_chunk_delta_content(chunk=chunk)}" 29 | ) 30 | yield format_to_event_stream(post_processing(chunk)) 31 | message: Message = Message( 32 | model=chunk.model, 33 | message=complete_response, 34 | role=ChatRolesEnum.ASSISTANT.value, 35 | ) 36 | messages_queries.insert( 37 | model=message.model, message=message.message, role=message.role 38 | ) 39 | logger.info(f"Complete Streamed Message: {message}") 40 | except asyncio.TimeoutError as e: 41 | raise OpenAIStreamTimeoutException from e 42 | 43 | 44 | def format_to_event_stream(data: str) -> str: 45 | """We format the event to a format more in line with the standard SSE format.""" 46 | return f"event: message\ndata: {data}\n\n" 47 | 48 | 49 | def post_processing(chunk) -> str: 50 | try: 51 | logger.info(f"Chunk: {chunk}") 52 | formatted_chunk = Chunk.from_chunk(chunk=chunk) 53 | logger.info(f"Formatted Chunk: {formatted_chunk}") 54 | return json.dumps(formatted_chunk.model_dump()) 55 | except Exception as e: 56 | raise OpenAIFailedProcessingException from e 57 | -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/app/core/__init__.py -------------------------------------------------------------------------------- /app/core/api.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from starlette import status 4 | from starlette.requests import Request 5 | 6 | router = APIRouter(tags=["Core Endpoints"]) 7 | 8 | 9 | @router.get("/healthcheck", status_code=status.HTTP_200_OK) 10 | async def healthcheck(request: Request) -> dict: 11 | return {"version": request.app.version} 12 | -------------------------------------------------------------------------------- /app/core/constants.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | 3 | DEFAULT_NUMBER_OF_WORKERS_ON_LOCAL = 1 4 | 5 | 6 | class Environments(StrEnum): 7 | LOCAL = "local" 8 | DEV = "dev" 9 | PROD = "prod" 10 | -------------------------------------------------------------------------------- /app/core/logs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO) 4 | 5 | 6 | logger = logging.getLogger(__name__) 7 | ConsoleOutputHandler = logging.StreamHandler() 8 | logger.addHandler(ConsoleOutputHandler) 9 | -------------------------------------------------------------------------------- /app/core/middlewares.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | import sentry_sdk 4 | from sentry_sdk.integrations.asgi import SentryAsgiMiddleware 5 | from starlette.middleware.cors import CORSMiddleware 6 | 7 | from settings import settings 8 | 9 | 10 | def add_cors_middleware(app: FastAPI) -> FastAPI: 11 | app.add_middleware( 12 | CORSMiddleware, 13 | allow_origins=["*"], 14 | allow_credentials=True, 15 | allow_methods=["GET", "POST", "PUT", "PATCH", "HEAD", "OPTIONS", "DELETE"], 16 | allow_headers=[ 17 | "Access-Control-Allow-Headers", 18 | "Content-Type", 19 | "Authorization", 20 | "Access-Control-Allow-Origin", 21 | "Set-Cookie", 22 | ], 23 | ) 24 | return app 25 | 26 | 27 | def add_sentry_middleware(app: FastAPI) -> FastAPI: 28 | if settings.SENTRY_DSN: # pragma: no cover 29 | sentry_sdk.init( 30 | dsn=settings.SENTRY_DSN, 31 | environment=settings.ENVIRONMENT, 32 | send_default_pii=False, 33 | traces_sample_rate=0.3, 34 | profiles_sample_rate=0.3, 35 | ) 36 | app.add_middleware(SentryAsgiMiddleware) 37 | return app 38 | 39 | 40 | MIDDLEWARES = ( 41 | add_cors_middleware, 42 | add_sentry_middleware, 43 | ) 44 | 45 | 46 | def apply_middlewares(app: FastAPI) -> FastAPI: 47 | for middleware in MIDDLEWARES: 48 | app = middleware(app) 49 | return app 50 | -------------------------------------------------------------------------------- /app/core/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class TimestampAbstractModel(BaseModel): 7 | created_at: datetime = Field(default_factory=datetime.utcnow) 8 | -------------------------------------------------------------------------------- /app/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/app/data/__init__.py -------------------------------------------------------------------------------- /app/data/models.py: -------------------------------------------------------------------------------- 1 | from uuid import UUID 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class TextIngestion(BaseModel): 7 | text: str 8 | run_uuid: UUID 9 | 10 | 11 | class CSVIngestion(BaseModel): 12 | row: dict 13 | -------------------------------------------------------------------------------- /app/db.py: -------------------------------------------------------------------------------- 1 | import pugsql 2 | 3 | from settings import settings 4 | 5 | messages_queries = pugsql.module("./db/queries/messages") 6 | messages_queries.connect(settings.DATABASE_URL) 7 | 8 | chats_queries = pugsql.module("./db/queries/chats") 9 | chats_queries.connect(settings.DATABASE_URL) 10 | 11 | assistants_queries = pugsql.module("./db/queries/assistants") 12 | assistants_queries.connect(settings.DATABASE_URL) 13 | -------------------------------------------------------------------------------- /app/ingest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/app/ingest/__init__.py -------------------------------------------------------------------------------- /app/ingest/__main__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from app.ingest.ingest import ingest_markdown 4 | 5 | asyncio.run(ingest_markdown()) 6 | -------------------------------------------------------------------------------- /app/ingest/ingest.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import uuid 4 | from pathlib import Path 5 | 6 | from faststream.nats import NatsBroker 7 | 8 | from app.core.logs import logger 9 | from app.data.models import TextIngestion 10 | from app.ingest.publisher import publish_review 11 | from settings import settings 12 | 13 | 14 | def find_md_files(directory): 15 | # List to store found md files 16 | md_files = [] 17 | 18 | # Walk through directory 19 | for root, _dirs, files in os.walk(directory): 20 | for filename in fnmatch.filter(files, "*.md"): 21 | # append full file path to our files list 22 | md_files.append(Path(root) / Path(filename)) 23 | 24 | return md_files 25 | 26 | 27 | async def ingest_markdown(): 28 | run_uuid = uuid.uuid4() 29 | async with NatsBroker(settings.NATS_URI) as broker: 30 | md_files = find_md_files("data") 31 | for md_file in md_files: 32 | with Path.open(md_file) as f: 33 | ingested_file = TextIngestion(text=f.read(), run_uuid=run_uuid) 34 | logger.info(f"Publishing review: {ingested_file.text[:50]}") 35 | await publish_review(broker=broker, ingested_text=ingested_file) 36 | return run_uuid 37 | -------------------------------------------------------------------------------- /app/ingest/publisher.py: -------------------------------------------------------------------------------- 1 | from faststream.nats import NatsBroker 2 | 3 | from app.data.models import TextIngestion 4 | from settings import settings 5 | 6 | 7 | async def publish_review(broker: NatsBroker, ingested_text: TextIngestion): 8 | await broker.publish( 9 | ingested_text, 10 | subject=settings.NATS_SUBJECT, 11 | ) 12 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | from app import version 4 | from app.chats.api import router as chat_router 5 | from app.core.api import router as core_router 6 | from app.core.logs import logger 7 | 8 | # from app.core.middlewares import apply_middlewares 9 | 10 | app = FastAPI(version=version) 11 | # app = apply_middlewares(app) 12 | 13 | app.include_router(core_router) 14 | app.include_router(chat_router) 15 | 16 | logger.info("App is ready!") 17 | -------------------------------------------------------------------------------- /app/process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/app/process/__init__.py -------------------------------------------------------------------------------- /app/process/__main__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/app/process/__main__.py -------------------------------------------------------------------------------- /app/process/subscriber.py: -------------------------------------------------------------------------------- 1 | from faststream import FastStream 2 | from faststream.nats import NatsBroker 3 | from qdrant_client import QdrantClient 4 | 5 | from app.data.models import TextIngestion 6 | from settings import settings 7 | 8 | broker = NatsBroker(settings.NATS_URI) 9 | app = FastStream(broker) 10 | 11 | qdrant_client = QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT) 12 | 13 | 14 | @broker.subscriber(settings.NATS_SUBJECT) 15 | async def process_ingestion(text_ingestion: TextIngestion) -> bool: 16 | qdrant_client.add(collection_name=settings.QDRANT_COLLECTION_NAME, documents=[text_ingestion.text]) 17 | return True 18 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/data/.gitkeep -------------------------------------------------------------------------------- /db/migrations/20230925090948_init.sql: -------------------------------------------------------------------------------- 1 | -- migrate:up 2 | CREATE TABLE chat ( 3 | id SERIAL PRIMARY KEY, 4 | uuid UUID NOT NULL UNIQUE, 5 | openai_id VARCHAR(255) NOT NULL UNIQUE, 6 | created_at TIMESTAMPTZ DEFAULT NOW() 7 | ); 8 | 9 | CREATE TABLE message ( 10 | id SERIAL PRIMARY KEY, 11 | chat_uuid UUID NOT NULL UNIQUE REFERENCES chat(uuid) ON DELETE CASCADE, 12 | model VARCHAR(255), 13 | role VARCHAR(255), 14 | message TEXT, 15 | created_at TIMESTAMPTZ DEFAULT NOW() 16 | ); 17 | 18 | CREATE TABLE assistant ( 19 | id SERIAL PRIMARY KEY, 20 | uuid UUID NOT NULL UNIQUE, 21 | openai_id VARCHAR(255) NOT NULL UNIQUE, 22 | created_at TIMESTAMPTZ DEFAULT NOW() 23 | ); 24 | 25 | 26 | -- migrate:down 27 | 28 | DROP TABLE IF EXISTS message; 29 | -------------------------------------------------------------------------------- /db/queries/assistants/get_first_assistant.sql: -------------------------------------------------------------------------------- 1 | -- :name get_first_assistant :one 2 | SELECT * FROM assistant LIMIT 1 -------------------------------------------------------------------------------- /db/queries/assistants/insert_chat.sql: -------------------------------------------------------------------------------- 1 | -- :name insert_assistant :insert 2 | INSERT INTO assistant (uuid, openai_id) 3 | VALUES (:uuid, :openai_id); 4 | -------------------------------------------------------------------------------- /db/queries/chats/get_chat.sql: -------------------------------------------------------------------------------- 1 | -- :name get_chat :one 2 | SELECT * FROM chat WHERE uuid = :uuid; -------------------------------------------------------------------------------- /db/queries/chats/insert_chat.sql: -------------------------------------------------------------------------------- 1 | -- :name insert_chat :insert 2 | INSERT INTO chat (uuid, openai_id) 3 | VALUES (:uuid, :openai_thread_id); 4 | -------------------------------------------------------------------------------- /db/queries/messages/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grski/bRAG/ec4c9b50db60a521f4b28815998a651fb0c6e55d/db/queries/messages/.gitkeep -------------------------------------------------------------------------------- /db/queries/messages/insert.sql: -------------------------------------------------------------------------------- 1 | -- :name insert :insert 2 | INSERT INTO message (chat_uuid, model, role, message) 3 | VALUES (:chat_uuid, :model, :role, :message); -------------------------------------------------------------------------------- /db/queries/messages/select_all.sql: -------------------------------------------------------------------------------- 1 | -- :name select_all :many 2 | SELECT * FROM message ORDER BY created_at DESC -------------------------------------------------------------------------------- /db/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS "schema_migrations" (version varchar(128) primary key); 2 | CREATE TABLE Message ( 3 | id SERIAL PRIMARY KEY, 4 | user_uuid VARCHAR(255), 5 | output_message TEXT 6 | 7 | ); 8 | CREATE TABLE Chat ( 9 | id SERIAL PRIMARY KEY, 10 | uuid VARCHAR(255), 11 | user_uuid VARCHAR(255), 12 | is_active BOOLEAN 13 | 14 | ); 15 | -- Dbmate schema migrations 16 | INSERT INTO "schema_migrations" (version) VALUES 17 | ('20231008214110'); 18 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | services: 3 | 4 | api: 5 | build: 6 | context: . 7 | depends_on: 8 | database: 9 | condition: service_healthy 10 | restart: always 11 | deploy: 12 | replicas: 1 13 | ports: 14 | - "8000:8000" 15 | volumes: 16 | - .:/code/ 17 | command: ["bash", "-c","dbmate up && gunicorn -c settings/gunicorn.conf.py app.main:app"] 18 | env_file: 19 | - .env 20 | ingest: 21 | build: 22 | context: . 23 | depends_on: 24 | - nats 25 | restart: always 26 | volumes: 27 | - .:/project/ 28 | command: [ "bash", "-c", "python -m app.ingest" ] 29 | env_file: 30 | - .env 31 | 32 | worker: 33 | build: 34 | context: . 35 | depends_on: 36 | database: 37 | condition: service_healthy 38 | restart: always 39 | volumes: 40 | - .:/project/ 41 | command: [ "bash", "-c", "python -m faststream run --reload app.process.subscriber:app" ] 42 | env_file: 43 | - .env 44 | 45 | qdrant: 46 | restart: always 47 | image: qdrant/qdrant:latest 48 | ports: 49 | - "6333:6333" 50 | nats: 51 | image: nats 52 | restart: always 53 | ports: 54 | - "4222:4222" 55 | - "8222:8222" 56 | - "6222:6222" 57 | 58 | database: 59 | restart: always 60 | image: postgres:15-alpine 61 | volumes: 62 | - pg-data:/var/lib/postgresql/data 63 | healthcheck: 64 | test: [ "CMD-SHELL", "pg_isready -U application" ] 65 | interval: 5s 66 | timeout: 5s 67 | retries: 5 68 | environment: 69 | POSTGRES_USER: application 70 | POSTGRES_PASSWORD: secret_pass 71 | POSTGRES_DB: application 72 | ports: 73 | - "5432:5432" 74 | 75 | volumes: 76 | pg-data: 77 | driver: local 78 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "brag" 7 | requires-python = ">=3.11" 8 | version = "0.0.1" 9 | dependencies = ["fastapi", "openai", "pydantic-settings", "uvicorn", "gunicorn", "qdrant-client[fastembed]", "sentry-sdk", "pugsql", "psycopg2-binary", "faststream[nats]", "async-timeout"] 10 | 11 | [project.optional-dependencies] 12 | dev = ["pre-commit", "ruff", "pytest"] 13 | 14 | 15 | 16 | 17 | [tool.ruff] 18 | target-version = 'py311' 19 | select = [ 20 | "B", # flake8-bugbear 21 | "C4", # flake8-comprehensions 22 | "D", # pydocstyle 23 | "E", # Error 24 | "F", # pyflakes 25 | "I", # isort 26 | "ISC", # flake8-implicit-str-concat 27 | "N", # pep8-naming 28 | "PGH", # pygrep-hooks 29 | "PTH", # flake8-use-pathlib 30 | "Q", # flake8-quotes 31 | "S", # bandit 32 | "SIM", # flake8-simplify 33 | "TRY", # tryceratops 34 | "UP", # pyupgrade 35 | "W", # Warning 36 | "YTT", # flake8-2020 37 | ] 38 | 39 | exclude = [ 40 | "migrations", 41 | "__pycache__", 42 | "manage.py", 43 | "settings.py", 44 | "env", 45 | ".env", 46 | "venv", 47 | ".venv", 48 | ] 49 | 50 | ignore = [ 51 | "D100", 52 | "D101", 53 | "D102", 54 | "D103", 55 | "D104", 56 | "D105", 57 | "D106", 58 | "D107", 59 | "D200", 60 | "D205", 61 | "D401", 62 | "E402", 63 | "E501", 64 | "F401", 65 | ] 66 | line-length = 120 # Must agree with Black 67 | 68 | [tool.ruff.flake8-bugbear] 69 | extend-immutable-calls = [ 70 | "chr", 71 | "typer.Argument", 72 | "typer.Option", 73 | ] 74 | 75 | [tool.ruff.pydocstyle] 76 | convention = "numpy" 77 | 78 | [tool.ruff.per-file-ignores] 79 | "tests/*.py" = [ 80 | "D100", 81 | "D101", 82 | "D102", 83 | "D103", 84 | "D104", 85 | "D105", 86 | "D106", 87 | "D107", 88 | "S101", # use of "assert" 89 | "S102", # use of "exec" 90 | "S106", # possible hardcoded password. 91 | "PGH001", # use of "eval" 92 | ] 93 | "settings/*.py" = [ 94 | "N999" 95 | ] 96 | 97 | [tool.ruff.pep8-naming] 98 | staticmethod-decorators = [ 99 | "pydantic.validator", 100 | "pydantic.root_validator", 101 | ] 102 | 103 | 104 | [tool.ruff.isort] 105 | section-order = ["fastapi", "future", "standard-library", "third-party", "first-party", "local-folder"] 106 | 107 | 108 | [tool.ruff.isort.sections] 109 | fastapi = ["fastapi"] -------------------------------------------------------------------------------- /requirements/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.11 3 | # by the following command: 4 | # 5 | # pip-compile --extra=dev --output-file=requirements/requirements-dev.txt pyproject.toml 6 | # 7 | annotated-types==0.6.0 8 | # via pydantic 9 | anyio==3.7.1 10 | # via 11 | # fast-depends 12 | # fastapi 13 | # httpx 14 | # openai 15 | # starlette 16 | # watchfiles 17 | certifi==2023.7.22 18 | # via 19 | # httpcore 20 | # httpx 21 | # requests 22 | # sentry-sdk 23 | cfgv==3.4.0 24 | # via pre-commit 25 | charset-normalizer==3.3.2 26 | # via requests 27 | click==8.1.7 28 | # via 29 | # typer 30 | # uvicorn 31 | coloredlogs==15.0.1 32 | # via onnxruntime 33 | distlib==0.3.7 34 | # via virtualenv 35 | distro==1.8.0 36 | # via openai 37 | fast-depends==2.2.1 38 | # via faststream 39 | fastapi==0.104.1 40 | # via brag (pyproject.toml) 41 | fastembed==0.1.1 42 | # via qdrant-client 43 | faststream[nats]==0.2.11 44 | # via brag (pyproject.toml) 45 | filelock==3.13.1 46 | # via virtualenv 47 | flatbuffers==23.5.26 48 | # via onnxruntime 49 | grpcio==1.59.2 50 | # via 51 | # grpcio-tools 52 | # qdrant-client 53 | grpcio-tools==1.59.2 54 | # via qdrant-client 55 | gunicorn==21.2.0 56 | # via brag (pyproject.toml) 57 | h11==0.14.0 58 | # via 59 | # httpcore 60 | # uvicorn 61 | h2==4.1.0 62 | # via httpx 63 | hpack==4.0.0 64 | # via h2 65 | httpcore==1.0.1 66 | # via httpx 67 | httpx[http2]==0.25.1 68 | # via 69 | # openai 70 | # qdrant-client 71 | humanfriendly==10.0 72 | # via coloredlogs 73 | hyperframe==6.0.1 74 | # via h2 75 | identify==2.5.31 76 | # via pre-commit 77 | idna==3.4 78 | # via 79 | # anyio 80 | # httpx 81 | # requests 82 | iniconfig==2.0.0 83 | # via pytest 84 | mpmath==1.3.0 85 | # via sympy 86 | nats-py==2.6.0 87 | # via faststream 88 | nodeenv==1.8.0 89 | # via pre-commit 90 | numpy==1.26.1 91 | # via 92 | # onnx 93 | # onnxruntime 94 | # qdrant-client 95 | onnx==1.15.0 96 | # via fastembed 97 | onnxruntime==1.16.1 98 | # via fastembed 99 | openai==1.1.1 100 | # via brag (pyproject.toml) 101 | packaging==23.2 102 | # via 103 | # gunicorn 104 | # onnxruntime 105 | # pytest 106 | platformdirs==3.11.0 107 | # via virtualenv 108 | pluggy==1.3.0 109 | # via pytest 110 | portalocker==2.8.2 111 | # via qdrant-client 112 | pre-commit==3.5.0 113 | # via brag (pyproject.toml) 114 | protobuf==4.25.0 115 | # via 116 | # grpcio-tools 117 | # onnx 118 | # onnxruntime 119 | psycopg2-binary==2.9.9 120 | # via brag (pyproject.toml) 121 | pugsql==0.2.4 122 | # via brag (pyproject.toml) 123 | pydantic==2.4.2 124 | # via 125 | # fast-depends 126 | # fastapi 127 | # openai 128 | # pydantic-settings 129 | # qdrant-client 130 | pydantic-core==2.10.1 131 | # via pydantic 132 | pydantic-settings==2.0.3 133 | # via brag (pyproject.toml) 134 | pytest==7.4.3 135 | # via brag (pyproject.toml) 136 | python-dotenv==1.0.0 137 | # via pydantic-settings 138 | pyyaml==6.0.1 139 | # via pre-commit 140 | qdrant-client[fastembed]==1.6.4 141 | # via brag (pyproject.toml) 142 | requests==2.31.0 143 | # via fastembed 144 | ruff==0.1.4 145 | # via brag (pyproject.toml) 146 | sentry-sdk==1.34.0 147 | # via brag (pyproject.toml) 148 | sniffio==1.3.0 149 | # via 150 | # anyio 151 | # httpx 152 | sqlalchemy==1.4.50 153 | # via pugsql 154 | starlette==0.27.0 155 | # via fastapi 156 | sympy==1.12 157 | # via onnxruntime 158 | tokenizers==0.13.3 159 | # via fastembed 160 | tqdm==4.66.1 161 | # via 162 | # fastembed 163 | # openai 164 | typer==0.9.0 165 | # via faststream 166 | typing-extensions==4.8.0 167 | # via 168 | # fastapi 169 | # openai 170 | # pydantic 171 | # pydantic-core 172 | # typer 173 | urllib3==1.26.18 174 | # via 175 | # qdrant-client 176 | # requests 177 | # sentry-sdk 178 | uvicorn==0.24.0.post1 179 | # via brag (pyproject.toml) 180 | uvloop==0.19.0 181 | # via faststream 182 | virtualenv==20.24.6 183 | # via pre-commit 184 | watchfiles==0.21.0 185 | # via faststream 186 | 187 | # The following packages are considered to be unsafe in a requirements file: 188 | # setuptools 189 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.11 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements/requirements.txt --strip-extras pyproject.toml 6 | # 7 | annotated-types==0.6.0 8 | # via pydantic 9 | anyio==3.7.1 10 | # via 11 | # fast-depends 12 | # fastapi 13 | # httpx 14 | # openai 15 | # starlette 16 | # watchfiles 17 | certifi==2023.7.22 18 | # via 19 | # httpcore 20 | # httpx 21 | # requests 22 | # sentry-sdk 23 | charset-normalizer==3.3.2 24 | # via requests 25 | click==8.1.7 26 | # via 27 | # typer 28 | # uvicorn 29 | coloredlogs==15.0.1 30 | # via onnxruntime 31 | distro==1.8.0 32 | # via openai 33 | fast-depends==2.2.1 34 | # via faststream 35 | fastapi==0.104.1 36 | # via brag (pyproject.toml) 37 | fastembed==0.1.1 38 | # via qdrant-client 39 | faststream==0.2.11 40 | # via brag (pyproject.toml) 41 | flatbuffers==23.5.26 42 | # via onnxruntime 43 | grpcio==1.59.2 44 | # via 45 | # grpcio-tools 46 | # qdrant-client 47 | grpcio-tools==1.59.2 48 | # via qdrant-client 49 | gunicorn==21.2.0 50 | # via brag (pyproject.toml) 51 | h11==0.14.0 52 | # via 53 | # httpcore 54 | # uvicorn 55 | h2==4.1.0 56 | # via httpx 57 | hpack==4.0.0 58 | # via h2 59 | httpcore==1.0.1 60 | # via httpx 61 | httpx==0.25.1 62 | # via 63 | # openai 64 | # qdrant-client 65 | humanfriendly==10.0 66 | # via coloredlogs 67 | hyperframe==6.0.1 68 | # via h2 69 | idna==3.4 70 | # via 71 | # anyio 72 | # httpx 73 | # requests 74 | mpmath==1.3.0 75 | # via sympy 76 | nats-py==2.6.0 77 | # via faststream 78 | numpy==1.26.1 79 | # via 80 | # onnx 81 | # onnxruntime 82 | # qdrant-client 83 | onnx==1.15.0 84 | # via fastembed 85 | onnxruntime==1.16.1 86 | # via fastembed 87 | openai==1.1.1 88 | # via brag (pyproject.toml) 89 | packaging==23.2 90 | # via 91 | # gunicorn 92 | # onnxruntime 93 | portalocker==2.8.2 94 | # via qdrant-client 95 | protobuf==4.25.0 96 | # via 97 | # grpcio-tools 98 | # onnx 99 | # onnxruntime 100 | psycopg2-binary==2.9.9 101 | # via brag (pyproject.toml) 102 | pugsql==0.2.4 103 | # via brag (pyproject.toml) 104 | pydantic==2.4.2 105 | # via 106 | # fast-depends 107 | # fastapi 108 | # openai 109 | # pydantic-settings 110 | # qdrant-client 111 | pydantic-core==2.10.1 112 | # via pydantic 113 | pydantic-settings==2.0.3 114 | # via brag (pyproject.toml) 115 | python-dotenv==1.0.0 116 | # via pydantic-settings 117 | qdrant-client==1.6.4 118 | # via brag (pyproject.toml) 119 | requests==2.31.0 120 | # via fastembed 121 | sentry-sdk==1.34.0 122 | # via brag (pyproject.toml) 123 | sniffio==1.3.0 124 | # via 125 | # anyio 126 | # httpx 127 | sqlalchemy==1.4.50 128 | # via pugsql 129 | starlette==0.27.0 130 | # via fastapi 131 | sympy==1.12 132 | # via onnxruntime 133 | tokenizers==0.13.3 134 | # via fastembed 135 | tqdm==4.66.1 136 | # via 137 | # fastembed 138 | # openai 139 | typer==0.9.0 140 | # via faststream 141 | typing-extensions==4.8.0 142 | # via 143 | # fastapi 144 | # openai 145 | # pydantic 146 | # pydantic-core 147 | # typer 148 | urllib3==1.26.18 149 | # via 150 | # qdrant-client 151 | # requests 152 | # sentry-sdk 153 | uvicorn==0.24.0.post1 154 | # via brag (pyproject.toml) 155 | uvloop==0.19.0 156 | # via faststream 157 | watchfiles==0.21.0 158 | # via faststream 159 | 160 | # The following packages are considered to be unsafe in a requirements file: 161 | # setuptools 162 | -------------------------------------------------------------------------------- /settings/__init__.py: -------------------------------------------------------------------------------- 1 | from settings.base import Settings 2 | 3 | settings = Settings() 4 | -------------------------------------------------------------------------------- /settings/base.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | from pydantic_settings import BaseSettings 3 | 4 | from app.core.constants import Environments 5 | 6 | 7 | class Settings(BaseSettings): 8 | APP_NAME: str = "bRAG" 9 | ENVIRONMENT: str = Field(env="ENVIRONMENT", default=Environments.LOCAL.value) 10 | 11 | DATABASE_URL: str = Field(env="DATABASE_URL") 12 | OPENAI_API_KEY: str = Field(env="OPENAI_API_KEY", default="None") 13 | GENERATION_TIMEOUT_SEC: int = Field(env="GENERATION_TIMEOUT_SEC", default=120) 14 | 15 | QDRANT_HOST: str = Field( 16 | env="QDRANT_HOST", default="qdrant" 17 | ) # qdrant inside docker, localhost locally 18 | QDRANT_PORT: int = Field(env="QDRANT_PORT", default=6333) 19 | QDRANT_COLLECTION_NAME: str = Field(env="QDRANT_COLLECTION_NAME", default="demo") 20 | 21 | NATS_URI: str = Field(env="NATS_URI", default="nats://localhost:4222/") 22 | NATS_SUBJECT: str = Field(env="NATS_SUBJECT", default="ingestion") 23 | 24 | @property 25 | def is_local(self): 26 | return Environments.LOCAL.value == self.ENVIRONMENT 27 | 28 | class Config: 29 | env_file = ".env" 30 | -------------------------------------------------------------------------------- /settings/gunicorn.conf.py: -------------------------------------------------------------------------------- 1 | from app.core.constants import DEFAULT_NUMBER_OF_WORKERS_ON_LOCAL 2 | from settings import settings 3 | 4 | THREADS_PER_CORE = 2 5 | 6 | 7 | def calculate_workers(): 8 | if settings.is_local: 9 | return DEFAULT_NUMBER_OF_WORKERS_ON_LOCAL 10 | 11 | import multiprocessing 12 | 13 | # for cpus with two threads per core this is the recommended formula 14 | # we might want to adjust this for our use case and infrastructure in the future 15 | # as in general one vcpu in aws is not the same as one cpu core in a physical machine 16 | # usually one vcpu is one thread of a cpu core, for starters it's good enough though 17 | 18 | return multiprocessing.cpu_count() * THREADS_PER_CORE + 1 19 | 20 | 21 | bind = "0.0.0.0:8000" 22 | workers = settings.WORKERS 23 | worker_class = "uvicorn.workers.UvicornWorker" 24 | timeout = settings.TIMEOUT 25 | reload = settings.RELOAD 26 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest_asyncio 2 | from httpx import AsyncClient 3 | 4 | from app.main import app 5 | 6 | 7 | @pytest_asyncio.fixture 8 | async def fastapi_client() -> AsyncClient: 9 | async with AsyncClient(app=app, base_url="http://test") as client: 10 | yield client 11 | -------------------------------------------------------------------------------- /tests/test_app.py: -------------------------------------------------------------------------------- 1 | from fastapi import status 2 | 3 | import pytest 4 | from httpx import AsyncClient 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_read_heartbeat(fastapi_client: AsyncClient) -> None: 9 | response = await fastapi_client.get("/healthcheck") 10 | response.json() 11 | # assert data["version"] == version 12 | assert response.status_code == status.HTTP_200_OK 13 | --------------------------------------------------------------------------------