├── .github └── workflows │ └── black.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md ├── convert.py └── winemag-data-130k-v2.jsonl.gz └── dbs ├── elasticsearch ├── .env.example ├── Dockerfile ├── README.md ├── api │ ├── __init__.py │ ├── config.py │ ├── main.py │ ├── routers │ │ └── rest.py │ └── schemas │ │ ├── __init__.py │ │ └── rest.py ├── docker-compose.yml ├── requirements.txt └── scripts │ ├── __init__.py │ ├── bulk_index.py │ ├── mapping │ └── mapping.json │ └── schemas │ ├── __init__.py │ └── wine.py ├── lancedb ├── .env.example ├── Dockerfile ├── README.md ├── api │ ├── __init__.py │ ├── config.py │ ├── main.py │ ├── routers │ │ ├── __init__.py │ │ └── rest.py │ └── schemas │ │ ├── __init__.py │ │ └── rest.py ├── docker-compose.yml ├── requirements.txt ├── schemas │ ├── __init__.py │ └── wine.py └── scripts │ └── bulk_index_sbert.py ├── meilisearch ├── .env.example ├── Dockerfile ├── README.md ├── api │ ├── __init__.py │ ├── config.py │ ├── main.py │ ├── routers │ │ └── rest.py │ └── schemas │ │ ├── __init__.py │ │ └── rest.py ├── docker-compose.yml ├── requirements.txt ├── schemas │ ├── __init__.py │ └── wine.py └── scripts │ ├── __init__.py │ ├── bulk_index_async.py │ ├── bulk_index_sync.py │ └── settings │ └── settings.json ├── neo4j ├── .dockerignore ├── .env.example ├── Dockerfile ├── README.md ├── api │ ├── __init__.py │ ├── config.py │ ├── main.py │ ├── routers │ │ └── rest.py │ └── schemas │ │ ├── __init__.py │ │ └── rest.py ├── assets │ └── data_model.png ├── docker-compose.yml ├── requirements.txt └── scripts │ ├── __init__.py │ ├── build_graph.py │ └── schemas │ ├── __init__.py │ └── wine.py ├── qdrant ├── .env.example ├── Dockerfile ├── README.md ├── api │ ├── __init__.py │ ├── config.py │ ├── main.py │ ├── routers │ │ └── rest.py │ └── schemas │ │ ├── __init__.py │ │ └── rest.py ├── docker-compose.yml ├── requirements.txt ├── schemas │ ├── __init__.py │ └── wine.py └── scripts │ ├── __init__.py │ └── bulk_index_sbert.py └── weaviate ├── .env.example ├── Dockerfile ├── README.md ├── api ├── __init__.py ├── config.py ├── main.py ├── routers │ └── rest.py └── schemas │ ├── __init__.py │ └── rest.py ├── docker-compose.yml ├── requirements.txt ├── schemas ├── __init__.py └── wine.py └── scripts ├── __init__.py ├── bulk_index_onnx.py ├── bulk_index_sbert.py └── settings └── schema.json /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | # Trigger the workflow on push or pull request, 5 | # but only for the main branch 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | run-linters: 12 | name: Run linters 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Check out Git repository 17 | uses: actions/checkout@v3 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: 3.11 23 | 24 | - name: Install Python dependencies 25 | run: pip install black 26 | 27 | - name: Run linters 28 | uses: wearerequired/lint-action@v2 29 | with: 30 | github_token: ${{ secrets.GITHUB_TOKEN }} 31 | black: true 32 | black_args: -l 100 33 | auto_fix: true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mac 132 | .DS_Store 133 | 134 | # data 135 | data/*.json 136 | data/*.jsonl 137 | dbs/meilisearch/meili_data 138 | dbs/lancedb/winemag -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Prashanth Rao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DBHub 2 | 3 | ## Boilerplate for async ingestion and querying of DBs 4 | 5 | This repo aims to provide working code and reproducible setups for bulk data ingestion and querying from numerous databases via their Python clients. Wherever possible, async database client APIs are utilized for data ingestion. The query interface to the data is exposed via async FastAPI endpoints. To enable reproducibility across environments, Dockerfiles are provided as well. 6 | 7 | The `docker-compose.yml` does the following: 8 | 1. Set up a local DB server in a container 9 | 2. Set up local volume mounts to persist the data 10 | 3. Set up a FastAPI server in another container 11 | 4. Set up a network bridge such that the DB server can be accessed from the FastAPI server 12 | 5. Tear down all the containers once development and testing is complete 13 | 14 | ### Currently implemented 15 | * Neo4j 16 | * Elasticsearch 17 | * Meilisearch 18 | * Qdrant 19 | * Weaviate 20 | * LanceDB 21 | 22 | ## Goals 23 | 24 | The main goals of this repo are explained as follows. 25 | 26 | 1. **Ease of setup**: There are tons of databases and client APIs out there, so it's useful to have a clean, efficient and reproducible workflow to experiment with a range of datasets, as well as databases for the problem at hand. 27 | 28 | 2. **Ease of distribution**: We may want to expose (potentially sensitive) data to downstream client applications, so building an API on top of the database can be a very useful tool to share the data in a controlled manner 29 | 30 | 3. **Ease of testing advanced use cases**: Search databases (either full-text keyword search or vector DBs) can be important "sources of truth" for contextual querying via LLMs like ChatGPT, allowing us to ground our model's results with factual data. 31 | 32 | 33 | ## Pre-requisites 34 | 35 | * Python 3.10+ 36 | * Docker 37 | * A passion to learn more about and experiment with databases! 38 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Dataset: 130k wine reviews 2 | 3 | We use this [wine reviews dataset from Kaggle](https://www.kaggle.com/datasets/zynicide/wine-reviews) as input. The data consists of 130k wine reviews with the variety, location, winery, price, description, and some other metadata provided for each wine. Refer to the Kaggle source for more information on the data. 4 | 5 | For quick reference, a sample wine item in JSON format is shown below. 6 | 7 | ```json 8 | { 9 | "points": "90", 10 | "title": "Castello San Donato in Perano 2009 Riserva (Chianti Classico)", 11 | "description": "Made from a blend of 85% Sangiovese and 15% Merlot, this ripe wine delivers soft plum, black currants, clove and cracked pepper sensations accented with coffee and espresso notes. A backbone of firm tannins give structure. Drink now through 2019.", 12 | "taster_name": "Kerin O'Keefe", 13 | "taster_twitter_handle": "@kerinokeefe", 14 | "price": 30, 15 | "designation": "Riserva", 16 | "variety": "Red Blend", 17 | "region_1": "Chianti Classico", 18 | "region_2": null, 19 | "province": "Tuscany", 20 | "country": "Italy", 21 | "winery": "Castello San Donato in Perano", 22 | "id": 40825 23 | } 24 | 25 | ``` 26 | 27 | The data is converted to a ZIP achive, and the code for this as well as the ZIP data is provided here for reference. There is no need to rerun the code to reproduce the results in the rest of the code base in this repo. 28 | -------------------------------------------------------------------------------- /data/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run `pip install srsly` to use this script 3 | 4 | This script converts the JSON data file from https://www.kaggle.com/datasets/zynicide/wine-reviews 5 | to a .gzip line-delimited (.jsonl) file for use downstream with the databases in question. 6 | 7 | Full credit to the original author, @zynicide, on Kaggle, for the data. 8 | """ 9 | from pathlib import Path 10 | from typing import Any 11 | 12 | import srsly 13 | 14 | JsonBlob = dict[str, Any] 15 | 16 | 17 | def convert_to_jsonl(filename: str) -> None: 18 | data = srsly.read_json(filename) 19 | # Add an `id` field to the start of each dict item so we have a primary key for indexing 20 | new_data = [{"id": idx, **item} for idx, item in enumerate(data, 1)] 21 | srsly.write_gzip_jsonl(f"{Path(filename).stem}.jsonl.gz", new_data) 22 | 23 | 24 | if __name__ == "__main__": 25 | # Download the JSON data file from https://www.kaggle.com/datasets/zynicide/wine-reviews' 26 | convert_to_jsonl("winemag-data-130k-v2.json") 27 | -------------------------------------------------------------------------------- /data/winemag-data-130k-v2.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/data/winemag-data-130k-v2.jsonl.gz -------------------------------------------------------------------------------- /dbs/elasticsearch/.env.example: -------------------------------------------------------------------------------- 1 | 2 | ELASTIC_USER = "elastic" 3 | ELASTIC_PASSWORD = "" 4 | STACK_VERSION = "8.10.2" 5 | ELASTIC_INDEX_ALIAS = "wines" 6 | ELASTIC_PORT = 9200 7 | KIBANA_PORT = 5601 8 | ELASTIC_URL = "localhost" 9 | ELASTIC_SERVICE = "elasticsearch" 10 | API_PORT = 8002 11 | 12 | # Container image tag 13 | TAG = "0.2.0" 14 | 15 | # Docker project namespace (defaults to the current folder name if not set) 16 | COMPOSE_PROJECT_NAME = elastic_wine -------------------------------------------------------------------------------- /dbs/elasticsearch/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | WORKDIR /wine 4 | 5 | COPY ./requirements.txt /wine/requirements.txt 6 | 7 | RUN pip install --no-cache-dir --upgrade -r /wine/requirements.txt 8 | 9 | COPY ./api /wine/api 10 | 11 | EXPOSE 8000 -------------------------------------------------------------------------------- /dbs/elasticsearch/README.md: -------------------------------------------------------------------------------- 1 | # Elasticsearch 2 | 3 | [Elasticsearch](https://www.elastic.co/what-is/elasticsearch) is a distributed search and analytics engine for various kinds of structured and unstructured data. The primary use case for Elasticsearch is answer business questions that involve searching and retrieving information on full text, such as descriptions and titles. 4 | 5 | Code is provided for ingesting the wine reviews dataset into Elasticsearch in an async fashion. In addition, a query API written in FastAPI is also provided that allows a user to query available endpoints. As always in FastAPI, documentation is available via OpenAPI (http://localhost:8000/docs). 6 | 7 | * All code (wherever possible) is async 8 | * [Pydantic](https://docs.pydantic.dev) is used for schema validation, both prior to data ingestion and during API request handling 9 | * The same schema is used for data ingestion and for the API, so there is only one source of truth regarding how the data is handled 10 | * For ease of reproducibility, the whole setup is orchestrated and deployed via docker 11 | 12 | ## Setup 13 | 14 | Note that this code base has been tested in Python 3.11, and requires a minimum of Python 3.10 to work. Install dependencies via `requirements.txt`. 15 | 16 | ```sh 17 | # Setup the environment for the first time 18 | python -m venv elastic_venv # python -> python 3.10+ 19 | 20 | # Activate the environment (for subsequent runs) 21 | source elastic_venv/bin/activate 22 | 23 | python -m pip install -r requirements.txt 24 | ``` 25 | 26 | --- 27 | 28 | ## Step 1: Set up containers 29 | 30 | Use the provided `docker-compose.yml` to initiate separate containers, one that run Elasticsearch, and another one that serves as an API on top of the database. 31 | 32 | ``` 33 | docker compose up -d 34 | ``` 35 | 36 | This compose file starts a persistent-volume Elasticsearch database with credentials specified in `.env`. The `elasticsearch` service variable in the environment file indicates that we are opening up the database service to a FastAPI server (running as a separate service, in a separate container) downstream. Both containers can communicate with one another with the common network that they share, on the exact port numbers specified. 37 | 38 | The services can be stopped at any time for maintenance and updates. 39 | 40 | ``` 41 | docker compose down 42 | ``` 43 | 44 | **Note:** The setup shown here would not be ideal in production, as there are other details related to security and scalability that are not addressed via simple docker, but, this is a good starting point to begin experimenting! 45 | 46 | 47 | ## Step 2: Ingest the data 48 | 49 | The first step is to ingest the wine reviews dataset into Elasticsearch. Data is asynchronously ingested into the Elasticsearch database through the scripts in the `scripts` directory. 50 | 51 | ```sh 52 | cd scripts 53 | python bulk_index.py 54 | ``` 55 | 56 | * This script first checks the database for a mapping (that tells Elasticsearch what fields to analyze and how to index them). Each index is attached to an alias, "wines", which is used to reference all the operations downstream 57 | * If no existing index or alias is found, new ones are created 58 | * The script then validates the input JSON data via [Pydantic](https://docs.pydantic.dev) and asynchronously indexes them into the database using the [`AsyncElasticsearch` client](https://elasticsearch-py.readthedocs.io/en/v8.7.0/async.html) for fastest performance 59 | 60 | 61 | ## Step 3: Test API 62 | 63 | Once the data has been successfully loaded into Elasticsearch and the containers are up and running, we can test out a search query via an HTTP request as follows. 64 | 65 | ```sh 66 | curl -X 'GET' \ 67 | 'http://localhost:8000/wine/search?terms=tuscany%20red' 68 | ``` 69 | 70 | This cURL request passes the search terms "**tuscany red**" to the `/wine/search/` endpoint, which is then parsed into a working Elasticsearch JSON query by the FastAPI backend. The query runs and retrieves results from the database (that looks for these keywords in the wine's title, description and variety fields), and, if the setup was done correctly, we should see the following response: 71 | 72 | ```json 73 | [ 74 | { 75 | "id": 109409, 76 | "country": "Italy", 77 | "title": "Castello Banfi 2007 Excelsus Red (Toscana)", 78 | "description": "The 2007 Excelsus is a gorgeous super Tuscan expression (with Cabernet Sauvignon and Merlot) that shows quality and superior fruit on all levels. Castello Banfi has really hit a home run with this vintage. You'll encounter persuasive aromas of cassis, blackberry, chocolate, tobacco, curry leaf and deep renderings of exotic spice. The wine's texture is exceedingly smooth, rich and long lasting.", 79 | "points": 97, 80 | "price": 81.0, 81 | "variety": "Red Blend", 82 | "winery": "Castello Banfi" 83 | }, 84 | { 85 | "id": 21079, 86 | "country": "Italy", 87 | "title": "Marchesi Antinori 2010 Solaia Red (Toscana)", 88 | "description": "Already one of Italy's most iconic bottlings, this gorgeous 2010 is already a classic. Its complex and intense bouquet unfolds with ripe blackberries, violets, leather, thyme and balsamic herbs. The palate shows structure, poise and complexity, delivering rich black currants, black cherry, licorice, mint and menthol notes alongside assertive but polished tannins and vibrant energy. This wine will age and develop for decades. Drink 2018–2040.", 89 | "points": 97, 90 | "price": 325.0, 91 | "variety": "Red Blend", 92 | "winery": "Marchesi Antinori" 93 | }, 94 | { 95 | "id": 35520, 96 | "country": "Italy", 97 | "title": "Marchesi Antinori 2012 Solaia Red (Toscana)", 98 | "description": "This stunning expression of Solaia opens with ample aromas of exotic spices, tilled soil, mature black-skinned fruit and an underlying whiff of fragrant blue flowers. The vibrant, elegantly structured palate doles out high-toned black cherry, ripe blackberry, white pepper, cinnamon, clove and Mediterranean herbs alongside a backbone of firm, polished tannins and bright acidity. Drink 2017–2032.", 99 | "points": 97, 100 | "price": 325.0, 101 | "variety": "Red Blend", 102 | "winery": "Marchesi Antinori" 103 | } 104 | ] 105 | ``` 106 | 107 | Not bad! This example correctly returns some highly rated Tuscan red wines along with their price and country of origin (obviously, Italy in this case). 108 | 109 | ### Step 4: Extend the API 110 | 111 | The API can be easily extended with the provided structure. 112 | 113 | - The `schemas` directory houses the Pydantic schemas, both for the data input as well as for the endpoint outputs 114 | - As the data model gets more complex, we can add more files and separate the ingestion logic from the API logic here 115 | - The `api/routers` directory contains the endpoint routes so that we can provide additional endpoint that answer more business questions 116 | - For e.g.: "What are the top rated wines from Argentina?" 117 | - In general, it makes sense to organize specific business use cases into their own router files 118 | - The `api/main.py` file collects all the routes and schemas to run the API 119 | 120 | 121 | #### Existing endpoints 122 | 123 | So far, the following endpoints that help answer interesting questions have been implemented. 124 | 125 | ``` 126 | GET 127 | /wine/search 128 | Search By Keywords 129 | ``` 130 | 131 | ``` 132 | GET 133 | /wine/top_by_country 134 | Top By Country 135 | ``` 136 | 137 | ``` 138 | GET 139 | /wine/top_by_province 140 | Top By Province 141 | ``` 142 | 143 | ``` 144 | GET 145 | /wine/count_by_country 146 | Get counts of wines by country 147 | ``` 148 | 149 | ``` 150 | GET 151 | /wine/count_by_filters 152 | Get counts of wines by country, price and points (review ratings) 153 | ``` 154 | 155 | -------------------------------------------------------------------------------- /dbs/elasticsearch/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/elasticsearch/api/__init__.py -------------------------------------------------------------------------------- /dbs/elasticsearch/api/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | ) 9 | 10 | elastic_service: str 11 | elastic_user: str 12 | elastic_password: str 13 | elastic_url: str 14 | elastic_port: int 15 | elastic_index_alias: str 16 | tag: str 17 | -------------------------------------------------------------------------------- /dbs/elasticsearch/api/main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections.abc import AsyncGenerator 3 | from contextlib import asynccontextmanager 4 | from functools import lru_cache 5 | 6 | from elasticsearch import AsyncElasticsearch 7 | from fastapi import FastAPI 8 | 9 | from api.config import Settings 10 | from api.routers import rest 11 | 12 | 13 | @lru_cache() 14 | def get_settings(): 15 | # Use lru_cache to avoid loading .env file for every request 16 | return Settings() 17 | 18 | 19 | @asynccontextmanager 20 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 21 | """Async context manager for Elasticsearch connection.""" 22 | settings = get_settings() 23 | username = settings.elastic_user 24 | password = settings.elastic_password 25 | port = settings.elastic_port 26 | service = settings.elastic_service 27 | with warnings.catch_warnings(): 28 | warnings.simplefilter("ignore") 29 | elastic_client = AsyncElasticsearch( 30 | f"http://{service}:{port}", 31 | basic_auth=(username, password), 32 | request_timeout=60, 33 | max_retries=3, 34 | retry_on_timeout=True, 35 | verify_certs=False, 36 | ) 37 | """Async context manager for Elasticsearch connection.""" 38 | app.client = elastic_client 39 | print("Successfully connected to Elasticsearch") 40 | yield 41 | await elastic_client.close() 42 | print("Successfully closed Elasticsearch connection") 43 | 44 | 45 | app = FastAPI( 46 | title="REST API for wine reviews on Elasticsearch", 47 | description=( 48 | "Query from an Elasticsearch database of 130k wine reviews from the Wine Enthusiast magazine" 49 | ), 50 | version=get_settings().tag, 51 | lifespan=lifespan, 52 | ) 53 | 54 | 55 | @app.get("/", include_in_schema=False) 56 | async def root(): 57 | return { 58 | "message": "REST API for querying Elasticsearch database of 130k wine reviews from the Wine Enthusiast magazine" 59 | } 60 | 61 | 62 | # Attach routes 63 | app.include_router(rest.router, prefix="/wine", tags=["wine"]) 64 | -------------------------------------------------------------------------------- /dbs/elasticsearch/api/routers/rest.py: -------------------------------------------------------------------------------- 1 | from api.schemas.rest import ( 2 | CountByCountry, 3 | FullTextSearch, 4 | TopWinesByCountry, 5 | TopWinesByProvince, 6 | ) 7 | from elasticsearch import AsyncElasticsearch 8 | from fastapi import APIRouter, HTTPException, Query, Request 9 | 10 | router = APIRouter() 11 | 12 | 13 | # --- Routes --- 14 | 15 | 16 | @router.get( 17 | "/search", 18 | response_model=list[FullTextSearch], 19 | response_description="Search wines by title, description and variety", 20 | ) 21 | async def search_by_keywords( 22 | request: Request, 23 | terms: str = Query(description="Search wine by keywords in title, description and variety"), 24 | max_price: int = Query( 25 | default=100.0, description="Specify the maximum price for the wine (e.g., 30)" 26 | ), 27 | ) -> list[FullTextSearch] | None: 28 | result = await _search_by_keywords(request.app.client, terms, max_price) 29 | if not result: 30 | raise HTTPException( 31 | status_code=404, 32 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 33 | ) 34 | return result 35 | 36 | 37 | @router.get( 38 | "/top_by_country", 39 | response_model=list[TopWinesByCountry], 40 | response_description="Get top-rated wines by country", 41 | ) 42 | async def top_by_country( 43 | request: Request, 44 | country: str = Query( 45 | description="Get top-rated wines by country name specified (must be exact name)" 46 | ), 47 | ) -> list[TopWinesByCountry] | None: 48 | result = await _top_by_country(request.app.client, country) 49 | if not result: 50 | raise HTTPException( 51 | status_code=404, 52 | detail=f"No wine from the provided country '{country}' found in database - please enter exact country name", 53 | ) 54 | return result 55 | 56 | 57 | @router.get( 58 | "/top_by_province", 59 | response_model=list[TopWinesByProvince], 60 | response_description="Get top-rated wines by province", 61 | ) 62 | async def top_by_province( 63 | request: Request, 64 | province: str = Query( 65 | description="Get top-rated wines by province name specified (must be exact name)" 66 | ), 67 | ) -> list[TopWinesByProvince] | None: 68 | result = await _top_by_province(request.app.client, province) 69 | if not result: 70 | raise HTTPException( 71 | status_code=404, 72 | detail=f"No wine from the provided province '{province}' found in database - please enter exact province name", 73 | ) 74 | return result 75 | 76 | 77 | @router.get( 78 | "/count_by_country", 79 | response_model=CountByCountry, 80 | response_description="Get counts of wine for a particular country", 81 | ) 82 | async def count_by_country( 83 | request: Request, 84 | country: str = Query(description="Country name to get counts for"), 85 | ) -> CountByCountry | None: 86 | result = await _count_by_country(request.app.client, country) 87 | if not result: 88 | raise HTTPException( 89 | status_code=404, 90 | detail=f"No wines from the provided province '{country}' found in database - please enter exact province name", 91 | ) 92 | return result 93 | 94 | 95 | @router.get( 96 | "/count_by_filters", 97 | response_model=CountByCountry, 98 | response_description="Get counts of wine for a particular country, filtered by points and price", 99 | ) 100 | async def count_by_filters( 101 | request: Request, 102 | country: str = Query(description="Country name to get counts for"), 103 | points: int = Query(default=85, description="Minimum number of points for a wine"), 104 | price: float = Query(default=100.0, description="Maximum price for a wine"), 105 | ) -> CountByCountry | None: 106 | result = await _count_by_filters(request.app.client, country, points, price) 107 | if not result: 108 | raise HTTPException( 109 | status_code=404, 110 | detail=f"No wines from the provided province '{country}' found in database - please enter exact province name", 111 | ) 112 | return result 113 | 114 | 115 | # --- Elasticsearch query funcs --- 116 | 117 | 118 | async def _search_by_keywords( 119 | client: AsyncElasticsearch, terms: str, max_price: int 120 | ) -> list[FullTextSearch] | None: 121 | response = await client.search( 122 | index="wines", 123 | size=5, 124 | query={ 125 | "bool": { 126 | "must": [ 127 | { 128 | "multi_match": { 129 | "query": terms, 130 | "fields": ["title", "description", "variety"], 131 | "minimum_should_match": 2, 132 | "fuzziness": "AUTO", 133 | } 134 | } 135 | ], 136 | "filter": {"range": {"price": {"lte": max_price}}}, 137 | } 138 | }, 139 | sort={"points": {"order": "desc"}}, 140 | ) 141 | result = response["hits"].get("hits") 142 | if result: 143 | data = [] 144 | for item in result: 145 | data_dict = item["_source"] 146 | data.append(data_dict) 147 | return data 148 | return None 149 | 150 | 151 | async def _top_by_country( 152 | client: AsyncElasticsearch, country: str 153 | ) -> list[TopWinesByCountry] | None: 154 | response = await client.search( 155 | index="wines", 156 | size=5, 157 | query={ 158 | "bool": { 159 | "must": [ 160 | { 161 | "match_phrase": { 162 | "country": country, 163 | } 164 | } 165 | ] 166 | } 167 | }, 168 | sort={"points": {"order": "desc"}}, 169 | ) 170 | result = response["hits"].get("hits") 171 | if result: 172 | data = [] 173 | for item in result: 174 | data_dict = item["_source"] 175 | data.append(data_dict) 176 | return data 177 | return None 178 | 179 | 180 | async def _top_by_province( 181 | client: AsyncElasticsearch, province: str 182 | ) -> list[TopWinesByProvince] | None: 183 | response = await client.search( 184 | index="wines", 185 | size=5, 186 | query={ 187 | "bool": { 188 | "must": [ 189 | { 190 | "match_phrase": { 191 | "province": province, 192 | } 193 | } 194 | ] 195 | } 196 | }, 197 | sort={"points": {"order": "desc"}}, 198 | ) 199 | result = response["hits"].get("hits") 200 | if result: 201 | data = [] 202 | for item in result: 203 | data_dict = item["_source"] 204 | data.append(data_dict) 205 | return data 206 | return None 207 | 208 | 209 | async def _count_by_country(client: AsyncElasticsearch, country: str) -> CountByCountry | None: 210 | response = await client.count( 211 | index="wines", query={"bool": {"must": [{"match": {"country": country}}]}} 212 | ) 213 | result = {"count": response.get("count", 0)} 214 | return result 215 | 216 | 217 | async def _count_by_filters( 218 | client: AsyncElasticsearch, country: str, points: float, price: int 219 | ) -> CountByCountry | None: 220 | response = await client.count( 221 | index="wines", 222 | query={ 223 | "bool": { 224 | "must": [ 225 | {"match": {"country": country}}, 226 | {"range": {"points": {"gte": points}}}, 227 | {"range": {"price": {"lte": price}}}, 228 | ] 229 | } 230 | }, 231 | ) 232 | result = {"count": response.get("count", 0)} 233 | return result 234 | -------------------------------------------------------------------------------- /dbs/elasticsearch/api/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/elasticsearch/api/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/elasticsearch/api/schemas/rest.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | 3 | 4 | class FullTextSearch(BaseModel): 5 | model_config = ConfigDict( 6 | json_schema_extra={ 7 | "example": { 8 | "id": 3845, 9 | "country": "Italy", 10 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 11 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 12 | "points": 93, 13 | "price": 16, 14 | "variety": "Red Blend", 15 | "winery": "Castellinuzza e Piuca", 16 | } 17 | } 18 | ) 19 | 20 | id: int 21 | country: str 22 | title: str 23 | description: str | None 24 | points: int 25 | price: float | str | None 26 | variety: str | None 27 | winery: str | None 28 | 29 | 30 | class TopWinesByCountry(BaseModel): 31 | id: int 32 | country: str 33 | title: str 34 | description: str | None 35 | points: int 36 | price: float | str | None = "Not available" 37 | variety: str | None 38 | winery: str | None 39 | 40 | 41 | class TopWinesByProvince(BaseModel): 42 | id: int 43 | country: str 44 | province: str 45 | title: str 46 | description: str | None 47 | points: int 48 | price: float | str | None = "Not available" 49 | variety: str | None 50 | winery: str | None 51 | 52 | 53 | class MostWinesByVariety(BaseModel): 54 | country: str 55 | wineCount: int 56 | 57 | 58 | class CountByCountry(BaseModel): 59 | count: int 60 | -------------------------------------------------------------------------------- /dbs/elasticsearch/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | elasticsearch: 5 | image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION} 6 | environment: 7 | - discovery.type=single-node 8 | - ES_JAVA_OPTS=-Xms1g -Xmx1g 9 | - xpack.security.enabled=false 10 | volumes: 11 | - esdata:/usr/share/elasticsearch/data 12 | - eslogs:/usr/share/elasticsearch/logs 13 | ports: 14 | - ${ELASTIC_PORT}:9200 15 | networks: 16 | - wine 17 | 18 | kibana: 19 | image: docker.elastic.co/kibana/kibana:${STACK_VERSION} 20 | volumes: 21 | - kibanadata:/usr/share/kibana/data 22 | ports: 23 | - ${KIBANA_PORT}:5601 24 | networks: 25 | - wine 26 | depends_on: 27 | - elasticsearch 28 | 29 | fastapi: 30 | image: elastic_wine_fastapi:${TAG} 31 | build: . 32 | restart: unless-stopped 33 | env_file: 34 | - .env 35 | ports: 36 | - ${API_PORT}:8000 37 | depends_on: 38 | - elasticsearch 39 | volumes: 40 | - ./:/wine 41 | networks: 42 | - wine 43 | command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload 44 | 45 | volumes: 46 | esdata: 47 | eslogs: 48 | kibanadata: 49 | 50 | networks: 51 | wine: 52 | driver: bridge -------------------------------------------------------------------------------- /dbs/elasticsearch/requirements.txt: -------------------------------------------------------------------------------- 1 | elasticsearch~=8.10.0 2 | pydantic~=2.4.0 3 | pydantic-settings~=2.0.0 4 | python-dotenv>=1.0.0 5 | fastapi~=0.100.0 6 | httpx>=0.24.0 7 | aiohttp>=3.8.4 8 | uvicorn>=0.21.0, <1.0.0 9 | srsly>=2.4.6 -------------------------------------------------------------------------------- /dbs/elasticsearch/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/elasticsearch/scripts/__init__.py -------------------------------------------------------------------------------- /dbs/elasticsearch/scripts/bulk_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import os 4 | import sys 5 | import warnings 6 | from functools import lru_cache, partial 7 | from pathlib import Path 8 | from typing import Any, Iterator 9 | 10 | import srsly 11 | from dotenv import load_dotenv 12 | from elasticsearch import AsyncElasticsearch, helpers 13 | from schemas.wine import Wine 14 | 15 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 16 | from api.config import Settings 17 | 18 | load_dotenv() 19 | # Custom types 20 | JsonBlob = dict[str, Any] 21 | 22 | 23 | class FileNotFoundError(Exception): 24 | pass 25 | 26 | 27 | # --- Blocking functions --- 28 | 29 | 30 | @lru_cache() 31 | def get_settings(): 32 | # Use lru_cache to avoid loading .env file for every request 33 | return Settings() 34 | 35 | 36 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: 37 | """ 38 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 39 | """ 40 | for i in range(0, len(item_list), chunksize): 41 | yield tuple(item_list[i : i + chunksize]) 42 | 43 | 44 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 45 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 46 | file_path = data_dir / filename 47 | if not file_path.is_file(): 48 | # File may not have been uncompressed yet so try to do that first 49 | data = srsly.read_gzip_jsonl(file_path) 50 | # This time if it isn't there it really doesn't exist 51 | if not file_path.is_file(): 52 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 53 | else: 54 | data = srsly.read_gzip_jsonl(file_path) 55 | return data 56 | 57 | 58 | def validate( 59 | data: tuple[JsonBlob], 60 | exclude_none: bool = False, 61 | ) -> list[JsonBlob]: 62 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 63 | return validated_data 64 | 65 | 66 | def process_chunks(data: list[JsonBlob]) -> tuple[list[JsonBlob], str]: 67 | validated_data = validate(data, exclude_none=True) 68 | return validated_data 69 | 70 | 71 | # --- Async functions --- 72 | 73 | 74 | async def get_elastic_client(settings) -> AsyncElasticsearch: 75 | # Get environment variables 76 | USERNAME = settings.elastic_user 77 | PASSWORD = settings.elastic_password 78 | PORT = settings.elastic_port 79 | ELASTIC_URL = settings.elastic_url 80 | # Connect to ElasticSearch 81 | elastic_client = AsyncElasticsearch( 82 | f"http://{ELASTIC_URL}:{PORT}", 83 | basic_auth=(USERNAME, PASSWORD), 84 | request_timeout=300, 85 | max_retries=3, 86 | retry_on_timeout=True, 87 | verify_certs=False, 88 | ) 89 | return elastic_client 90 | 91 | 92 | async def create_index(client: AsyncElasticsearch, index: str, mappings_path: Path) -> None: 93 | """Create an index associated with an alias in ElasticSearch""" 94 | elastic_config = dict(srsly.read_json(mappings_path)) 95 | assert elastic_config is not None 96 | 97 | exists_alias = await client.indices.exists_alias(name=index) 98 | if not exists_alias: 99 | print(f"Did not find index {index} in db, creating index...\n") 100 | with warnings.catch_warnings(): 101 | warnings.simplefilter("ignore") 102 | # Get settings and mappings from the mappings.json file 103 | mappings = elastic_config.get("mappings") 104 | settings = elastic_config.get("settings") 105 | index_name = f"{index}-1" 106 | try: 107 | await client.indices.create(index=index_name, mappings=mappings, settings=settings) 108 | await client.indices.put_alias(index=index_name, name=INDEX_ALIAS) 109 | # Verify that the new index has been created 110 | assert await client.indices.exists(index=index_name) 111 | index_and_alias = await client.indices.get_alias(index=index_name) 112 | print(index_and_alias) 113 | except Exception as e: 114 | print(f"Warning: Did not create index {index_name} due to exception {e}\n") 115 | else: 116 | print(f"Found index {index} in db, skipping index creation...\n") 117 | 118 | 119 | async def update_documents_to_index( 120 | client: AsyncElasticsearch, index: str, data: list[Wine] 121 | ) -> None: 122 | await helpers.async_bulk( 123 | client, 124 | data, 125 | index=index, 126 | chunk_size=CHUNKSIZE, 127 | ) 128 | ids = [item["id"] for item in data] 129 | print(f"Processed ids in range {min(ids)}-{max(ids)}") 130 | 131 | 132 | async def main(data: list[JsonBlob]) -> None: 133 | settings = get_settings() 134 | with warnings.catch_warnings(): 135 | elastic_client = await get_elastic_client(settings) 136 | assert await elastic_client.ping() 137 | await create_index(elastic_client, INDEX_ALIAS, Path("mapping/mapping.json")) 138 | # Validate data and chunk it for ingesting in batches 139 | validated_data = validate(data, exclude_none=False) 140 | chunked_data = chunk_iterable(validated_data, chunksize=CHUNKSIZE) 141 | for chunk in chunked_data: 142 | try: 143 | ids = [item["id"] for item in chunk] 144 | print(f"Finished indexing ID range {min(ids)}-{max(ids)}") 145 | await helpers.async_bulk(elastic_client, chunk, index=INDEX_ALIAS) 146 | except Exception as e: 147 | print(f"{e}: Error while indexing ID range {min(ids)}-{max(ids)}") 148 | # Close AsyncElasticsearch client 149 | await elastic_client.close() 150 | 151 | 152 | if __name__ == "__main__": 153 | # fmt: off 154 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 155 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 156 | parser.add_argument("--chunksize", type=int, default=10_000, help="Size of each chunk to break the dataset into before processing") 157 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 158 | args = vars(parser.parse_args()) 159 | # fmt: on 160 | 161 | LIMIT = args["limit"] 162 | DATA_DIR = Path(__file__).parents[3] / "data" 163 | FILENAME = args["filename"] 164 | CHUNKSIZE = args["chunksize"] 165 | 166 | # Specify an alias to index the data under 167 | INDEX_ALIAS = get_settings().elastic_index_alias 168 | assert INDEX_ALIAS 169 | 170 | data = list(get_json_data(DATA_DIR, FILENAME)) 171 | if LIMIT > 0: 172 | data = data[:LIMIT] 173 | 174 | # Run main async event loop 175 | if data: 176 | asyncio.run(main(data)) 177 | -------------------------------------------------------------------------------- /dbs/elasticsearch/scripts/mapping/mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "settings": { 3 | "analysis": { 4 | "analyzer": { 5 | "custom_analyzer": { 6 | "type": "custom", 7 | "tokenizer": "standard", 8 | "filter": [ 9 | "lowercase" 10 | ] 11 | } 12 | } 13 | } 14 | }, 15 | "mappings": { 16 | "properties": { 17 | "id": { 18 | "type": "keyword" 19 | }, 20 | "points": { 21 | "type": "unsigned_long" 22 | }, 23 | "title": { 24 | "type": "text", 25 | "analyzer": "custom_analyzer" 26 | }, 27 | "description": { 28 | "type": "text", 29 | "analyzer": "custom_analyzer" 30 | }, 31 | "price": { 32 | "type": "half_float" 33 | }, 34 | "variety": { 35 | "type": "text", 36 | "analyzer": "custom_analyzer", 37 | "fields": { 38 | "raw": { 39 | "type": "keyword" 40 | } 41 | } 42 | }, 43 | "winery": { 44 | "type": "text", 45 | "analyzer": "custom_analyzer", 46 | "fields": { 47 | "raw": { 48 | "type": "keyword" 49 | } 50 | } 51 | }, 52 | "vineyard": { 53 | "type": "text", 54 | "analyzer": "custom_analyzer" 55 | }, 56 | "country": { 57 | "type": "text", 58 | "analyzer": "custom_analyzer", 59 | "fields": { 60 | "raw": { 61 | "type": "keyword" 62 | } 63 | } 64 | }, 65 | "province": { 66 | "type": "text", 67 | "analyzer": "custom_analyzer", 68 | "fields": { 69 | "raw": { 70 | "type": "keyword" 71 | } 72 | } 73 | }, 74 | "region_1": { 75 | "type": "text", 76 | "analyzer": "custom_analyzer", 77 | "fields": { 78 | "raw": { 79 | "type": "keyword" 80 | } 81 | } 82 | }, 83 | "region_2": { 84 | "type": "text", 85 | "analyzer": "custom_analyzer", 86 | "fields": { 87 | "raw": { 88 | "type": "keyword" 89 | } 90 | } 91 | }, 92 | "taster_name": { 93 | "type": "text", 94 | "analyzer": "custom_analyzer", 95 | "fields": { 96 | "raw": { 97 | "type": "keyword" 98 | } 99 | } 100 | }, 101 | "taster_twitter_handle": { 102 | "type": "keyword" 103 | } 104 | } 105 | } 106 | } -------------------------------------------------------------------------------- /dbs/elasticsearch/scripts/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/elasticsearch/scripts/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/elasticsearch/scripts/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field, model_validator 2 | 3 | 4 | class Wine(BaseModel): 5 | model_config = ConfigDict( 6 | populate_by_name=True, 7 | validate_assignment=True, 8 | extra="allow", 9 | str_strip_whitespace=True, 10 | json_schema_extra={ 11 | "example": { 12 | "id": 45100, 13 | "points": 85, 14 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 15 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 16 | "price": 10.0, 17 | "variety": "Merlot", 18 | "winery": "Balduzzi", 19 | "vineyard": "Reserva", 20 | "country": "Chile", 21 | "province": "Maule Valley", 22 | "region_1": "null", 23 | "region_2": "null", 24 | "taster_name": "Michael Schachner", 25 | "taster_twitter_handle": "@wineschach", 26 | } 27 | }, 28 | ) 29 | 30 | id: int 31 | points: int 32 | title: str 33 | description: str | None 34 | price: float | None 35 | variety: str | None 36 | winery: str | None 37 | vineyard: str | None = Field(..., alias="designation") 38 | country: str | None 39 | province: str | None 40 | region_1: str | None 41 | region_2: str | None 42 | taster_name: str | None 43 | taster_twitter_handle: str | None 44 | 45 | @model_validator(mode="before") 46 | def _fill_country_unknowns(cls, values): 47 | "Fill in missing country values with 'Unknown', as we always want this field to be queryable" 48 | country = values.get("country") 49 | if country is None or country == "null": 50 | values["country"] = "Unknown" 51 | return values 52 | 53 | @model_validator(mode="before") 54 | def _create_id(cls, values): 55 | "Create an _id field because Elastic needs this to store as primary key" 56 | values["_id"] = values["id"] 57 | return values 58 | 59 | 60 | if __name__ == "__main__": 61 | data = { 62 | "id": 45100, 63 | "points": 85, 64 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 65 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 66 | "price": 10, # Test if field is cast to float 67 | "variety": "Merlot", 68 | "winery": "Balduzzi", 69 | "designation": "Reserva", # Test if field is renamed 70 | "country": "null", # Test unknown country 71 | "province": " Maule Valley ", # Test if field is stripped 72 | "region_1": "null", 73 | "region_2": "null", 74 | "taster_name": "Michael Schachner", 75 | "taster_twitter_handle": "@wineschach", 76 | } 77 | from pprint import pprint 78 | 79 | wine = Wine(**data) 80 | pprint(wine.model_dump(), sort_dicts=False) 81 | -------------------------------------------------------------------------------- /dbs/lancedb/.env.example: -------------------------------------------------------------------------------- 1 | LANCEDB_DIR = "winemag" 2 | API_PORT = 8006 3 | EMBEDDING_MODEL_CHECKPOINT = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" 4 | 5 | # Container image tag 6 | TAG = "0.1.0" 7 | 8 | # Docker project namespace (defaults to the current folder name if not set) 9 | COMPOSE_PROJECT_NAME = lancedb_wine -------------------------------------------------------------------------------- /dbs/lancedb/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | WORKDIR /wine 4 | 5 | COPY ./requirements.txt /wine/requirements.txt 6 | 7 | RUN pip install --no-cache-dir -U pip wheel setuptools 8 | RUN pip install --no-cache-dir -r /wine/requirements.txt 9 | 10 | COPY ./winemag /wine/winemag 11 | COPY ./api /wine/api 12 | COPY ./schemas /wine/schemas 13 | 14 | EXPOSE 8000 -------------------------------------------------------------------------------- /dbs/lancedb/README.md: -------------------------------------------------------------------------------- 1 | # LanceDB 2 | 3 | [LanceDB](https://github.com/lancedb/lancedb) is an embedded vector database written in Rust. The primary advantage of LanceDB serverless architecture is to place the database right next to the application, so as to retrieve results that are most semantically similar to the input natural language query. The semantic similarity is obtained by comparing the sentence embeddings (which are n-dimensional vectors) between the input query and the data stored in the database. 4 | 5 | Code is provided for ingesting the wine reviews dataset into LanceDB. In addition, a query API written in FastAPI is also provided that allows a user to query available endpoints. As always in FastAPI, documentation is available via OpenAPI (http://localhost:8000/docs). 6 | 7 | * Unlike "normal" databases, in a vector DB, the vectorization process is the biggest bottleneck 8 | * [Pydantic](https://docs.pydantic.dev) is used for schema validation, both prior to data ingestion and during API request handling 9 | * For ease of reproducibility during development, the whole setup is orchestrated and deployed via docker 10 | 11 | ## Setup 12 | 13 | Note that this code base has been tested in Python 3.10, and requires a minimum of Python 3.10 to work. Install dependencies via `requirements.txt`. 14 | 15 | ```sh 16 | # Setup the environment for the first time 17 | python -m venv .venv # python -> python 3.10 18 | 19 | # Activate the environment (for subsequent runs) 20 | source .venv/bin/activate 21 | 22 | python -m pip install -r requirements.txt 23 | ``` 24 | 25 | --- 26 | 27 | ## Step 1: Set up containers 28 | 29 | A `docker-compose.yml` file is provided, which starts a FastAPI container with the information supplied in `.env`. Because LanceDB is serverless, the database doesn't run in a separate process -- it is simply part of the Python code that is imported into the FastAPI backend. The API is then served via `uvicorn`, which is a production-ready ASGI server that is used by FastAPI. 30 | 31 | The FastAPI service can be restarted at any time for maintenance and updates by simply running the `docker restart ` command. 32 | 33 | **💡 Note:** The setup shown here would not be ideal in production, as there are other details related to security and scalability that are not addressed via simple docker, but, this is a good starting point to begin experimenting! 34 | 35 | ### Use `sbert` model 36 | 37 | If using the `sbert` model [from the sentence-transformers repo](https://www.sbert.net/) directly, use the provided `docker-compose.yml` to initiate separate containers, one that runs LanceDB, and another one that serves as an API on top of the database. 38 | 39 | **⚠️ Note**: This approach will attempt to run `sbert` on a GPU if available, and if not, on CPU (while utilizing all CPU cores). 40 | 41 | ``` 42 | docker compose -f docker-compose.yml up -d 43 | ``` 44 | Tear down the services using the following command. 45 | 46 | ``` 47 | docker compose -f docker-compose.yml down 48 | ``` 49 | 50 | ## Step 2: Ingest the data 51 | 52 | We ingest both the JSON data for filtering, as well as the sentence embedding vectors (for similarity search) into LanceDB. For this dataset, it's reasonable to expect that a simple concatenation of fields like `title`, `variety` and `description` would result in a useful sentence embedding that can be compared against a search query which is also converted to a vector during query time. 53 | 54 | As an example, consider the following data snippet form the `data/` directory in this repo: 55 | 56 | ```json 57 | "title": "Castello San Donato in Perano 2009 Riserva (Chianti Classico)", 58 | "description": "Made from a blend of 85% Sangiovese and 15% Merlot, this ripe wine delivers soft plum, black currants, clove and cracked pepper sensations accented with coffee and espresso notes. A backbone of firm tannins give structure. Drink now through 2019.", 59 | "variety": "Red Blend" 60 | ``` 61 | 62 | The three fields are concatenated for vectorization as follows: 63 | 64 | ```py 65 | to_vectorize = data["variety"] + data["title"] + data["description"] 66 | ``` 67 | 68 | ### Choice of embedding model 69 | 70 | [SentenceTransformers](https://www.sbert.net/) is a Python framework for a range of sentence and text embeddings. It results from extensive work on fine-tuning BERT to work well on semantic similarity tasks using Siamese BERT networks, where the model is trained to predict the similarity between sentence pairs. The original work is [described here](https://arxiv.org/abs/1908.10084). 71 | 72 | #### Why use sentence transformers? 73 | 74 | Although larger and more powerful text embedding models exist (such as [OpenAI embeddings](https://platform.openai.com/docs/guides/embeddings)), they can become really expensive as they are not free, and charge per token of text. SentenceTransformers are free and open-source, and have been optimized for years for performance, both to utilize all CPU cores and for reduced size while maintaining performance. A full list of sentence transformer models [is in the project page](https://www.sbert.net/docs/pretrained_models.html). 75 | 76 | For this work, it makes sense to use among the fastest models in this list, which is the `multi-qa-MiniLM-L6-cos-v1` **uncased** model. As per the docs, it was tuned for semantic search and question answering, and generates sentence embeddings for single sentences or paragraphs up to a maximum sequence length of 512. It was trained on 215M question answer pairs from various sources. Compared to the more general-purpose `all-MiniLM-L6-v2` model, it shows slightly improved performance on semantic search tasks while offering a similar level of performance. [See the sbert docs](https://www.sbert.net/docs/pretrained_models.html) for more details on performance comparisons between the various pretrained models. 77 | 78 | ### Run data loader 79 | 80 | Data is ingested into the LanceDB database through the scripts in the `scripts` directly. The scripts validate the input JSON data via [Pydantic](https://docs.pydantic.dev), and then index both the JSON data and the vectors to LanceDB using the [LanceDB Python client](https://lancedb.github.io/lancedb/). 81 | 82 | Prior to indexing and vectorizing, we simply concatenate the key fields that contain useful information about each wine and vectorize this instead. 83 | 84 | If running on a Macbook or other development machine, it's possible to generate sentence embeddings using the original `sbert` model as per the `EMBEDDING_MODEL_CHECKPOINT` variable in the `.env` file. 85 | 86 | ```sh 87 | cd scripts 88 | python bulk_index_sbert.py 89 | ``` 90 | 91 | Depending on the CPU on your machine, this may take a while. On a 2022 M2 Macbook Pro, vectorizing and bulk-indexing ~130k records took about 25 minutes. When tested on an AWS EC2 T2 medium instance, the same process took just over an hour. 92 | 93 | ## Step 3: Test API 94 | 95 | Once the data has been successfully loaded into LanceDB and the containers are up and running, we can test out a search query via an HTTP request as follows. 96 | 97 | ```sh 98 | curl -X 'GET' \ 99 | 'http://0.0.0.0:8000/wine/search?terms=tuscany%20red&max_price=100&country=Italy' 100 | ``` 101 | 102 | This cURL request passes the search terms "**tuscany red**", along with the country "Italy" and a maximum price of "100" to the `/wine/search/` endpoint, which is then parsed into a working filter query to LanceDB by the FastAPI backend. The query runs and retrieves results that are semantically similar to the input query for red Tuscan wines, and, if the setup was done correctly, we should see the following response: 103 | 104 | ```json 105 | [ 106 | { 107 | "id": 8456, 108 | "country": "Italy", 109 | "province": "Tuscany", 110 | "title": "Petra 2008 Petra Red (Toscana)", 111 | "description": "From one of Italy's most important showcase designer wineries, this blend of Cabernet Sauvignon and Merlot lives up to its super Tuscan celebrity. It is gently redolent of dark chocolate, ripe fruit, leather, tobacco and crushed black pepper—the bouquet's elegant moderation is one of its strongest points. The mouthfeel is rich, creamy and long. Drink after 2018.", 112 | "points": 92, 113 | "price": 80.0, 114 | "variety": "Red Blend", 115 | "winery": "Petra" 116 | }, 117 | { 118 | "id": 896, 119 | "country": "Italy", 120 | "province": "Tuscany", 121 | "title": "Le Buche 2006 Giuseppe Olivi Memento Red (Toscana)", 122 | "description": "Le Buche is an interesting winery to watch, and its various Tuscan blends show great promise. Memento is equal parts Sangiovese and Syrah with a soft, velvety texture and a bright berry finish.", 123 | "points": 90, 124 | "price": 45.0, 125 | "variety": "Red Blend", 126 | "winery": "Le Buche" 127 | }, 128 | { 129 | "id": 9343, 130 | "country": "Italy", 131 | "province": "Tuscany", 132 | "title": "Poggio Mandorlo 2008 Red (Toscana)", 133 | "description": "Made from Merlot and Cabernet Franc, this structured red offers aromas of black currant, toast, graphite and a whiff of cedar. The firm palate offers coconut, coffee, grilled sage and red berry alongside bracing tannins. Drink sooner rather than later to capture the fruit richness.", 134 | "points": 89, 135 | "price": 60.0, 136 | "variety": "Red Blend", 137 | "winery": "Poggio Mandorlo" 138 | } 139 | ] 140 | ``` 141 | 142 | Not bad! This example correctly returns some highly rated Tuscan red wines form Italy along with their price. More specific search queries, such as low/high acidity, or flavour profiles of wines can also be entered to get more relevant results by country. 143 | 144 | ## Step 4: Extend the API 145 | 146 | The API can be easily extended with the provided structure. 147 | 148 | - The `schemas` directory houses the Pydantic schemas, both for the data input as well as for the endpoint outputs 149 | - As the data model gets more complex, we can add more files and separate the ingestion logic from the API logic here 150 | - The `api/routers` directory contains the endpoint routes so that we can provide additional endpoint that answer more business questions 151 | - For e.g.: "What are the top rated wines from Argentina?" 152 | - In general, it makes sense to organize specific business use cases into their own router files 153 | - The `api/main.py` file collects all the routes and schemas to run the API 154 | 155 | 156 | #### Existing endpoints 157 | 158 | As an example, a search endpoint is implemented and can be accessed via the API at the following URL. 159 | 160 | ``` 161 | GET 162 | /wine/search 163 | Search By Similarity 164 | 165 | 166 | GET 167 | /wine/search_by_country 168 | Search By Similarity And Country 169 | 170 | 171 | GET 172 | /wine/search_by_filters 173 | Search By Similarity And Filters 174 | 175 | 176 | GET 177 | /wine/count_by_country 178 | Count By Country 179 | 180 | 181 | GET 182 | /wine/count_by_filters 183 | Count By Filters 184 | ``` -------------------------------------------------------------------------------- /dbs/lancedb/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/lancedb/api/__init__.py -------------------------------------------------------------------------------- /dbs/lancedb/api/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | ) 9 | 10 | lancedb_dir: str 11 | api_port: str 12 | embedding_model_checkpoint: str 13 | tag: str 14 | -------------------------------------------------------------------------------- /dbs/lancedb/api/main.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from contextlib import asynccontextmanager 3 | from functools import lru_cache 4 | 5 | import lancedb 6 | from fastapi import FastAPI 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from sentence_transformers import SentenceTransformer 9 | 10 | from api.config import Settings 11 | from api.routers.rest import router 12 | 13 | 14 | @lru_cache() 15 | def get_settings(): 16 | # Use lru_cache to avoid loading .env file for every request 17 | return Settings() 18 | 19 | 20 | @asynccontextmanager 21 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 22 | """Async context manager for lancedb connection.""" 23 | settings = get_settings() 24 | model_checkpoint = settings.embedding_model_checkpoint 25 | app.model = SentenceTransformer(model_checkpoint) 26 | # Define LanceDB client 27 | db = lancedb.connect("./winemag") 28 | app.table = db.open_table("wines") 29 | print("Successfully connected to LanceDB") 30 | yield 31 | print("Successfully closed LanceDB connection and released resources") 32 | 33 | 34 | app = FastAPI( 35 | title="REST API for wine reviews on LanceDB", 36 | description=( 37 | "Query from a LanceDB database of 130k wine reviews from the Wine Enthusiast magazine" 38 | ), 39 | version=get_settings().tag, 40 | lifespan=lifespan, 41 | ) 42 | 43 | 44 | @app.get("/", include_in_schema=False) 45 | async def root(): 46 | return { 47 | "message": "REST API for querying LanceDB database of 130k wine reviews from the Wine Enthusiast magazine" 48 | } 49 | 50 | 51 | # Attach routes 52 | app.include_router(router, prefix="/wine", tags=["wine"]) 53 | 54 | # Add CORS middleware 55 | app.add_middleware( 56 | CORSMiddleware, 57 | allow_origins=["http://localhost:8000"], 58 | allow_methods=["GET"], 59 | allow_headers=["*"], 60 | ) 61 | -------------------------------------------------------------------------------- /dbs/lancedb/api/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/lancedb/api/routers/__init__.py -------------------------------------------------------------------------------- /dbs/lancedb/api/routers/rest.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Query, Request 2 | 3 | from api.schemas.rest import CountByCountry, SimilaritySearch 4 | 5 | router = APIRouter() 6 | 7 | NUM_PROBES = 20 8 | 9 | # --- Routes --- 10 | 11 | 12 | @router.get( 13 | "/search", 14 | response_model=list[SimilaritySearch], 15 | response_description="Search for wines via semantically similar terms", 16 | ) 17 | def search_by_similarity( 18 | request: Request, 19 | terms: str = Query( 20 | description="Specify terms to search for in the variety, title and description" 21 | ), 22 | ) -> list[SimilaritySearch] | None: 23 | result = _search_by_similarity(request, terms) 24 | if not result: 25 | raise HTTPException( 26 | status_code=404, 27 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 28 | ) 29 | return result 30 | 31 | 32 | @router.get( 33 | "/search_by_country", 34 | response_model=list[SimilaritySearch], 35 | response_description="Search for wines via semantically similar terms from a particular country", 36 | ) 37 | def search_by_similarity_and_country( 38 | request: Request, 39 | terms: str = Query( 40 | description="Specify terms to search for in the variety, title and description" 41 | ), 42 | country: str = Query(description="Country name to search for wines from"), 43 | ) -> list[SimilaritySearch] | None: 44 | result = _search_by_similarity_and_country(request, terms, country) 45 | if not result: 46 | raise HTTPException( 47 | status_code=404, 48 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 49 | ) 50 | return result 51 | 52 | 53 | @router.get( 54 | "/search_by_filters", 55 | response_model=list[SimilaritySearch], 56 | response_description="Search for wines via semantically similar terms with added filters", 57 | ) 58 | def search_by_similarity_and_filters( 59 | request: Request, 60 | terms: str = Query( 61 | description="Specify terms to search for in the variety, title and description" 62 | ), 63 | country: str = Query(description="Country name to search for wines from"), 64 | points: int = Query(default=85, description="Minimum number of points for a wine"), 65 | price: float = Query(default=100.0, description="Maximum price for a wine"), 66 | ) -> list[SimilaritySearch] | None: 67 | result = _search_by_similarity_and_filters(request, terms, country, points, price) 68 | if not result: 69 | raise HTTPException( 70 | status_code=404, 71 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 72 | ) 73 | return result 74 | 75 | 76 | @router.get( 77 | "/count_by_country", 78 | response_model=CountByCountry, 79 | response_description="Get counts of wine for a particular country", 80 | ) 81 | def count_by_country( 82 | request: Request, 83 | country: str = Query(description="Country name to get counts for"), 84 | ) -> CountByCountry: 85 | result = _count_by_country(request, country) 86 | if not result: 87 | raise HTTPException( 88 | status_code=404, 89 | detail=f"No wine with the provided country '{country}' found in database - please try again", 90 | ) 91 | return result 92 | 93 | 94 | @router.get( 95 | "/count_by_filters", 96 | response_model=CountByCountry, 97 | response_description="Get counts of wine for a particular country, filtered by points and price", 98 | ) 99 | def count_by_filters( 100 | request: Request, 101 | country: str = Query(description="Country name to get counts for"), 102 | points: int = Query(default=85, description="Minimum number of points for a wine"), 103 | price: float = Query(default=100.0, description="Maximum price for a wine"), 104 | ) -> CountByCountry: 105 | result = _count_by_filters(request, country, points, price) 106 | if not result: 107 | raise HTTPException( 108 | status_code=404, 109 | detail=f"No wine with the provided country '{country}' found in database - please try again", 110 | ) 111 | return result 112 | 113 | 114 | # --- Helper functions --- 115 | 116 | 117 | def _search_by_similarity( 118 | request: Request, 119 | terms: str, 120 | ) -> list[SimilaritySearch] | None: 121 | query_vector = request.app.model.encode(terms.lower()) 122 | search_result = ( 123 | request.app.table.search(query_vector).metric("cosine").nprobes(NUM_PROBES).limit(5).to_df() 124 | ).to_dict(orient="records") 125 | if not search_result: 126 | return None 127 | return search_result 128 | 129 | 130 | def _search_by_similarity_and_country( 131 | request: Request, terms: str, country: str 132 | ) -> list[SimilaritySearch] | None: 133 | query_vector = request.app.model.encode(terms.lower()) 134 | search_result = ( 135 | request.app.table.search(query_vector) 136 | .metric("cosine") 137 | .nprobes(NUM_PROBES) 138 | .where( 139 | f""" 140 | country = '{country}' 141 | """ 142 | ) 143 | .limit(5) 144 | .to_df() 145 | ).to_dict(orient="records") 146 | if not search_result: 147 | return None 148 | return search_result 149 | 150 | 151 | def _search_by_similarity_and_filters( 152 | request: Request, 153 | terms: str, 154 | country: str, 155 | points: int, 156 | price: float, 157 | ) -> list[SimilaritySearch] | None: 158 | query_vector = request.app.model.encode(terms.lower()) 159 | price = float(price) 160 | search_result = ( 161 | request.app.table.search(query_vector) 162 | .metric("cosine") 163 | .nprobes(NUM_PROBES) 164 | .where( 165 | f""" 166 | country = '{country}' 167 | and points >= {points} 168 | and price <= {price} 169 | """ 170 | ) 171 | .limit(5) 172 | .to_df() 173 | ).to_dict(orient="records") 174 | if not search_result: 175 | return None 176 | return search_result 177 | 178 | 179 | def _count_by_country( 180 | request: Request, 181 | country: str, 182 | ) -> CountByCountry: 183 | search_result = ( 184 | request.app.table.search() 185 | .where( 186 | f""" 187 | country = '{country}' 188 | """ 189 | ) 190 | .to_df() 191 | ).shape[0] 192 | final_result = CountByCountry(count=search_result) 193 | return final_result 194 | 195 | 196 | def _count_by_filters( 197 | request: Request, 198 | country: str, 199 | points: int, 200 | price: float, 201 | ) -> CountByCountry: 202 | price = float(price) 203 | search_result = ( 204 | request.app.table.search() 205 | .where( 206 | f""" 207 | country = '{country}' 208 | and points >= {points} 209 | and price <= {price} 210 | """ 211 | ) 212 | .to_df() 213 | ).shape[0] 214 | final_result = CountByCountry(count=search_result) 215 | return final_result 216 | -------------------------------------------------------------------------------- /dbs/lancedb/api/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/lancedb/api/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/lancedb/api/schemas/rest.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, ConfigDict 4 | 5 | 6 | class SimilaritySearch(BaseModel): 7 | model_config = ConfigDict( 8 | extra="ignore", 9 | json_schema_extra={ 10 | "example": { 11 | "wineID": 3845, 12 | "country": "Italy", 13 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 14 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 15 | "points": 93, 16 | "price": 16, 17 | "variety": "Red Blend", 18 | "winery": "Castellinuzza e Piuca", 19 | } 20 | }, 21 | ) 22 | 23 | id: int 24 | country: str 25 | province: Optional[str] 26 | title: str 27 | description: Optional[str] 28 | points: int 29 | price: Optional[float] 30 | variety: Optional[str] 31 | winery: Optional[str] 32 | 33 | 34 | class CountByCountry(BaseModel): 35 | count: int 36 | -------------------------------------------------------------------------------- /dbs/lancedb/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | fastapi: 5 | image: lancedb_wine_fastapi:${TAG} 6 | build: 7 | context: . 8 | dockerfile: Dockerfile 9 | restart: unless-stopped 10 | env_file: 11 | - .env 12 | ports: 13 | - ${API_PORT}:8000 14 | environment: 15 | - LANCEDB_CONFIG_DIR=/wine 16 | command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload 17 | -------------------------------------------------------------------------------- /dbs/lancedb/requirements.txt: -------------------------------------------------------------------------------- 1 | lancedb~=0.3.0 2 | transformers~=4.28.0 3 | sentence-transformers~=2.2.0 4 | pydantic~=2.3.0 5 | pydantic-settings~=2.0.0 6 | python-dotenv>=1.0.0 7 | fastapi~=0.104.0 8 | httpx>=0.24.0 9 | aiohttp>=3.8.4 10 | uvicorn>=0.21.0, <1.0.0 11 | srsly>=2.4.6 12 | pandas~=2.1.0 13 | codetiming~=1.4.0 -------------------------------------------------------------------------------- /dbs/lancedb/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/lancedb/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/lancedb/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from lancedb.pydantic import Vector 4 | from pydantic import BaseModel, ConfigDict, Field, model_validator 5 | 6 | 7 | class Wine(BaseModel): 8 | model_config = ConfigDict( 9 | populate_by_name=True, 10 | validate_assignment=True, 11 | extra="allow", 12 | str_strip_whitespace=True, 13 | json_schema_extra={ 14 | "example": { 15 | "id": 45100, 16 | "points": 85, 17 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 18 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 19 | "price": 10.0, 20 | "variety": "Merlot", 21 | "winery": "Balduzzi", 22 | "vineyard": "Reserva", 23 | "country": "Chile", 24 | "province": "Maule Valley", 25 | "region_1": "null", 26 | "region_2": "null", 27 | "taster_name": "Michael Schachner", 28 | "taster_twitter_handle": "@wineschach", 29 | } 30 | }, 31 | ) 32 | 33 | id: int 34 | points: int 35 | title: str 36 | description: Optional[str] 37 | price: Optional[float] 38 | variety: Optional[str] 39 | winery: Optional[str] 40 | vineyard: Optional[str] = Field(..., alias="designation") 41 | country: Optional[str] 42 | province: Optional[str] 43 | region_1: Optional[str] 44 | region_2: Optional[str] 45 | taster_name: Optional[str] 46 | taster_twitter_handle: Optional[str] 47 | 48 | @model_validator(mode="before") 49 | def _fill_country_unknowns(cls, values): 50 | "Fill in missing country values with 'Unknown', as we always want this field to be queryable" 51 | country = values.get("country") 52 | if not country: 53 | values["country"] = "Unknown" 54 | return values 55 | 56 | @model_validator(mode="before") 57 | def _add_to_vectorize_fields(cls, values): 58 | "Add a field to_vectorize that will be used to create sentence embeddings" 59 | variety = values.get("variety", "") 60 | title = values.get("title", "") 61 | description = values.get("description", "") 62 | to_vectorize = list(filter(None, [variety, title, description])) 63 | values["to_vectorize"] = " ".join(to_vectorize).strip() 64 | return values 65 | 66 | 67 | class LanceModelWine(BaseModel): 68 | model_config = ConfigDict( 69 | populate_by_name=True, 70 | validate_assignment=True, 71 | extra="allow", 72 | str_strip_whitespace=True, 73 | json_schema_extra={ 74 | "example": { 75 | "id": 45100, 76 | "points": 85, 77 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 78 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 79 | "price": 10.0, 80 | "variety": "Merlot", 81 | "winery": "Balduzzi", 82 | "vineyard": "Reserva", 83 | "country": "Chile", 84 | "province": "Maule Valley", 85 | "region_1": "null", 86 | "region_2": "null", 87 | "taster_name": "Michael Schachner", 88 | "taster_twitter_handle": "@wineschach", 89 | } 90 | }, 91 | ) 92 | 93 | id: int 94 | points: int 95 | title: str 96 | description: Optional[str] 97 | price: Optional[float] 98 | variety: Optional[str] 99 | winery: Optional[str] 100 | vineyard: Optional[str] = Field(..., alias="designation") 101 | country: Optional[str] 102 | province: Optional[str] 103 | region_1: Optional[str] 104 | region_2: Optional[str] 105 | taster_name: Optional[str] 106 | taster_twitter_handle: Optional[str] 107 | to_vectorize: str 108 | vector: Vector(384) 109 | -------------------------------------------------------------------------------- /dbs/lancedb/scripts/bulk_index_sbert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from concurrent.futures import ProcessPoolExecutor, as_completed 5 | from functools import lru_cache 6 | from pathlib import Path 7 | from typing import Any, Iterator 8 | 9 | import lancedb 10 | import pandas as pd 11 | import srsly 12 | from codetiming import Timer 13 | from dotenv import load_dotenv 14 | from lancedb.pydantic import pydantic_to_schema 15 | from sentence_transformers import SentenceTransformer 16 | from tqdm import tqdm 17 | 18 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 19 | from api.config import Settings 20 | from schemas.wine import LanceModelWine, Wine 21 | 22 | 23 | load_dotenv() 24 | # Custom types 25 | JsonBlob = dict[str, Any] 26 | 27 | 28 | class FileNotFoundError(Exception): 29 | pass 30 | 31 | 32 | @lru_cache() 33 | def get_settings(): 34 | # Use lru_cache to avoid loading .env file for every request 35 | return Settings() 36 | 37 | 38 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[list[JsonBlob]]: 39 | """ 40 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 41 | """ 42 | for i in range(0, len(item_list), chunksize): 43 | yield item_list[i : i + chunksize] 44 | 45 | 46 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 47 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 48 | file_path = data_dir / filename 49 | if not file_path.is_file(): 50 | # File may not have been uncompressed yet so try to do that first 51 | data = srsly.read_gzip_jsonl(file_path) 52 | # This time if it isn't there it really doesn't exist 53 | if not file_path.is_file(): 54 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 55 | else: 56 | data = srsly.read_gzip_jsonl(file_path) 57 | return data 58 | 59 | 60 | def validate( 61 | data: list[JsonBlob], 62 | exclude_none: bool = False, 63 | ) -> list[JsonBlob]: 64 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 65 | return validated_data 66 | 67 | 68 | def embed_func(batch: list[str], model) -> list[list[float]]: 69 | return [model.encode(sentence.lower()) for sentence in batch] 70 | 71 | 72 | def vectorize_text(data: list[JsonBlob]) -> list[LanceModelWine] | None: 73 | # Load a sentence transformer model for semantic similarity from a specified checkpoint 74 | model_id = get_settings().embedding_model_checkpoint 75 | assert model_id, "Invalid embedding model checkpoint specified in .env file" 76 | MODEL = SentenceTransformer(model_id) 77 | 78 | ids = [item["id"] for item in data] 79 | to_vectorize = [text.get("to_vectorize") for text in data] 80 | vectors = embed_func(to_vectorize, MODEL) 81 | try: 82 | data_batch = [{**d, "vector": vector} for d, vector in zip(data, vectors)] 83 | except Exception as e: 84 | print(f"{e}: Failed to add ID range {min(ids)}-{max(ids)}") 85 | return None 86 | return data_batch 87 | 88 | 89 | def embed_batches(tbl: str, validated_data: list[JsonBlob]) -> pd.DataFrame: 90 | with ProcessPoolExecutor(max_workers=WORKERS) as executor: 91 | chunked_data = chunk_iterable(validated_data, CHUNKSIZE) 92 | embed_data = [] 93 | for chunk in tqdm(chunked_data, total=len(validated_data) // CHUNKSIZE): 94 | futures = [executor.submit(vectorize_text, chunk)] 95 | embed_data = [f.result() for f in as_completed(futures) if f.result()][0] 96 | df = pd.DataFrame.from_dict(embed_data) 97 | tbl.add(df, mode="overwrite") 98 | 99 | 100 | def main(data: list[JsonBlob]) -> None: 101 | DB_NAME = f"../{get_settings().lancedb_dir}" 102 | TABLE = "wines" 103 | db = lancedb.connect(DB_NAME) 104 | 105 | tbl = db.create_table(TABLE, schema=pydantic_to_schema(LanceModelWine), mode="overwrite") 106 | print(f"Created table `{TABLE}`, with length {len(tbl)}") 107 | 108 | with Timer(name="Bulk Index", text="Validated data using Pydantic in {:.4f} sec"): 109 | validated_data = validate(data, exclude_none=False) 110 | 111 | with Timer(name="Embed batches", text="Created sentence embeddings in {:.4f} sec"): 112 | embed_batches(tbl, validated_data) 113 | 114 | print(f"Finished inserting {len(tbl)} items into LanceDB table") 115 | 116 | with Timer(name="Create index", text="Created IVF-PQ index in {:.4f} sec"): 117 | # Creating index (choose num partitions as a power of 2 that's closest to len(dataset) // 5000) 118 | # In this case, we have 130k datapoints, so the nearest power of 2 is 130000//5000 ~ 32) 119 | tbl.create_index(metric="cosine", num_partitions=4, num_sub_vectors=32) 120 | 121 | 122 | if __name__ == "__main__": 123 | # fmt: off 124 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 125 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 126 | parser.add_argument("--chunksize", type=int, default=1000, help="Size of each chunk to break the dataset into before processing") 127 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 128 | parser.add_argument("--workers", type=int, default=4, help="Number of workers to use for vectorization") 129 | args = vars(parser.parse_args()) 130 | # fmt: on 131 | 132 | LIMIT = args["limit"] 133 | DATA_DIR = Path(__file__).parents[3] / "data" 134 | FILENAME = args["filename"] 135 | CHUNKSIZE = args["chunksize"] 136 | WORKERS = args["workers"] 137 | 138 | data = list(get_json_data(DATA_DIR, FILENAME)) 139 | assert data, "No data found in the specified file" 140 | data = data[:LIMIT] if LIMIT > 0 else data 141 | main(data) 142 | print("Finished execution!") 143 | -------------------------------------------------------------------------------- /dbs/meilisearch/.env.example: -------------------------------------------------------------------------------- 1 | # Master key must be at least 16 bytes, composed of valid UTF-8 characters 2 | MEILI_MASTER_KEY = "" 3 | MEILI_VERSION = "v1.2.0" 4 | MEILI_PORT = 7700 5 | MEILI_URL = "localhost" 6 | MEILI_SERVICE = "meilisearch" 7 | API_PORT = 8003 8 | 9 | # Container image tag 10 | TAG = "0.2.0" 11 | 12 | # Docker project namespace (defaults to the current folder name if not set) 13 | COMPOSE_PROJECT_NAME = meili_wine -------------------------------------------------------------------------------- /dbs/meilisearch/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | WORKDIR /wine 4 | 5 | COPY ./requirements.txt /wine/requirements.txt 6 | 7 | RUN pip install --no-cache-dir --upgrade -r /wine/requirements.txt 8 | 9 | COPY ./api /wine/api 10 | 11 | EXPOSE 8000 -------------------------------------------------------------------------------- /dbs/meilisearch/README.md: -------------------------------------------------------------------------------- 1 | # Meilisearch 2 | 3 | [Meilisearch](https://www.meilisearch.com/docs/learn/what_is_meilisearch/overview) is a fast, responsive RESTful search engine database [built in Rust](https://github.com/meilisearch/meilisearch). The primary use case for Meilisearch is to answer business questions that involve typo-tolerant and near-instant searching on keywords or phrases, all enabled by its efficient indexing and storage techniques. 4 | 5 | Code is provided for ingesting the wine reviews dataset into Meilisearch in an async fashion [via this excellent client](https://github.com/sanders41/meilisearch-python-async). In addition, a query API written in FastAPI is also provided that allows a user to query available endpoints. As always in FastAPI, documentation is available via OpenAPI. 6 | 7 | * All code (wherever possible) is async 8 | * [Pydantic](https://docs.pydantic.dev) is used for schema validation, both prior to data ingestion and during API request handling 9 | * The same schema is used for data ingestion and for the API, so there is only one source of truth regarding how the data is handled 10 | * For ease of reproducibility, the whole setup is orchestrated and deployed via docker 11 | 12 | ## Setup 13 | 14 | Note that this code base has been tested in Python 3.11, and requires a minimum of Python 3.10 to work. Install dependencies via `requirements.txt`. 15 | 16 | ```sh 17 | # Setup the environment for the first time 18 | python -m venv meili_venv # python -> python 3.10+ 19 | 20 | # Activate the environment (for subsequent runs) 21 | source meili_venv/bin/activate 22 | 23 | python -m pip install -r requirements.txt 24 | 25 | ``` 26 | 27 | ## Create a .env file 28 | 29 | In the .env file add the following parameters, updating the values to match your setup 30 | 31 | ``` 32 | meili_service=meilisearch 33 | meili_master_key=masterKey 34 | meili_port=7700 35 | meili_url=127.0.0.1 36 | tag=latest 37 | ``` 38 | 39 | ## Step 1: Set up containers 40 | 41 | Use the provided `docker-compose.yml` to initiate separate containers, one that runs Meilisearch, and another one that serves as an API on top of the database. 42 | 43 | ``` 44 | docker compose up -d 45 | ``` 46 | 47 | This compose file starts a persistent-volume Meilisearch database with credentials specified in `.env`. The `meilisearch` service variable in the environment file indicates that we are opening up the database service to a FastAPI server (running as a separate service, in a separate container) downstream. Both containers can communicate with one another with the common network that they share, on the exact port numbers specified. 48 | 49 | The services can be stopped at any time for maintenance and updates. 50 | 51 | ``` 52 | docker compose down 53 | ``` 54 | 55 | **Note:** The setup shown here would not be ideal in production, as there are other details related to security and scalability that are not addressed via simple docker, but, this is a good starting point to begin experimenting! 56 | 57 | 58 | ## Step 2: Ingest the data 59 | 60 | The first step is to ingest the wine reviews dataset into Meilisearch. Data is asynchronously ingested into the Meilisearch database through the scripts in the `scripts` directory. 61 | 62 | ```sh 63 | cd scripts 64 | python bulk_index_async.py 65 | ``` 66 | 67 | * This script first sets important items like which fields are searchable, filterable and sortable: 68 | * To speed up indexing, Meilisearch allows us to explicitly specify which fields are searchable, filterable and sortable 69 | * Choosing these fields carefully can really help speeding up indexing a large dataset, of the order of $10^5-10^6$ records 70 | * The script then validates the input JSON data via [Pydantic](https://docs.pydantic.dev) and asynchronously indexes them into the database using the [`meilisearch-python-async` client](https://github.com/sanders41/meilisearch-python-async) for fastest performance 71 | * The third-party async Python client is chosen over the [official client](https://github.com/meilisearch/meilisearch-python) (for now, sync) for Meilisearch, as the goal is to provide an async-compatible API via FastAPI 72 | 73 | 74 | ## Step 3: Test API 75 | 76 | Once the data has been successfully loaded into Meilisearch and the containers are up and running, we can test out a search query via an HTTP request as follows. 77 | 78 | ```sh 79 | curl -X 'GET' \ 80 | 'http://localhost:8003/wine/search?terms=tuscany%20red' 81 | ``` 82 | 83 | This cURL request passes the search terms "**tuscany red**" to the `/wine/search/` endpoint, which is then parsed into a working Meilisearch JSON query by the FastAPI backend. The query runs and retrieves results from the database (that looks for these keywords in the wine's title, description and variety fields), and, if the setup was done correctly, we should see the following response: 84 | 85 | ```json 86 | [ 87 | { 88 | "id": 22170, 89 | "country": "Italy", 90 | "title": "Kirkland Signature 2004 Tuscany Red (Toscana)", 91 | "description": "Here is a masculine and robust blend of Sangiovese, Cab Sauvignon and Merlot that exhibits thick concentration and aromas of exotic spices, cherry, prune, plum, vanilla and Amaretto. The nose is gorgeous but the mouthfeel is less convincing, with firm tannins.", 92 | "points": 87, 93 | "price": 20.0, 94 | "variety": "Red Blend", 95 | "winery": "Kirkland Signature" 96 | }, 97 | { 98 | "id": 55924, 99 | "country": "Italy", 100 | "title": "Col d'Orcia 2011 Spezieri Red (Toscana)", 101 | "description": "This easy going blended red from Tuscany opens with bright cherry and blackberry aromas against a backdrop of bitter almond and a touch of Indian spice. The fresh acidity makes this a perfect pasta wine.", 102 | "points": 87, 103 | "price": 17.0, 104 | "variety": "Red Blend", 105 | "winery": "Col d'Orcia" 106 | }, 107 | { 108 | "id": 40960, 109 | "country": "Italy", 110 | "title": "Fattoria di Grignano 2011 Pietramaggio Red (Toscana)", 111 | "description": "Here's a simple but well made red from Tuscany that has floral aromas of violet and rose with berry notes. The palate offers bright cherry, red currant and a touch of spice. Pair this with pasta dishes or grilled vegetables.", 112 | "points": 86, 113 | "price": 11.0, 114 | "variety": "Red Blend", 115 | "winery": "Fattoria di Grignano" 116 | } 117 | ] 118 | ``` 119 | 120 | Not bad! This example correctly returns some highly rated Tuscan red wines along with their price and country of origin (obviously, Italy in this case). 121 | 122 | ### Step 4: Extend the API 123 | 124 | The API can be easily extended with the provided structure. 125 | 126 | - The `schemas` directory houses the Pydantic schemas, both for the data input as well as for the endpoint outputs 127 | - As the data model gets more complex, we can add more files and separate the ingestion logic from the API logic here 128 | - The `api/routers` directory contains the endpoint routes so that we can provide additional endpoint that answer more business questions 129 | - For e.g.: "What are the top rated wines from Argentina?" 130 | - In general, it makes sense to organize specific business use cases into their own router files 131 | - The `api/main.py` file collects all the routes and schemas to run the API 132 | 133 | 134 | #### Existing endpoints 135 | 136 | So far, the following endpoints that help answer interesting questions have been implemented. 137 | 138 | ``` 139 | GET 140 | /wine/search 141 | Search By Keywords 142 | ``` 143 | 144 | ``` 145 | GET 146 | /wine/top_by_country 147 | Top By Country 148 | ``` 149 | 150 | ``` 151 | GET 152 | /wine/top_by_province 153 | Top By Province 154 | ``` 155 | 156 | Run the FastAPI app in a docker container to explore them! 157 | 158 | --- 159 | 160 | ### 💡 Limitations of Meilisearch 161 | 162 | Because Meilisearch was designed from the ground up to be a near-instant search data store, it does not have great support for aggregations or analytics, which are features we might be used to from other NoSQL databases like ElasticSearch and MongoDB. More info on this is provided in [this excellent blog post](https://blog.meilisearch.com/why-should-you-use-meilisearch-over-elasticsearch/) by the Meilisearch creators themselves. 163 | 164 | As stated in that blog post by Meilisearch : 165 | 166 | > Meilisearch is not made to search through billions of large text files or parse complex queries. This kind of searching power would require a higher degree of complexity and lead to slower search experiences, which runs against our instant search philosophy. For those purposes, look no further than Elasticsearch; it’s an excellent solution for companies with the necessary resources, whether that be the financial means to hire consultants or the time and money required to implement it themselves. 167 | 168 | **Bottom Line:** If your goal is to run analytics on your unstructured data, or more complex queries than string-based information retrieval, then, maybe Meilisearch isn't the best choice -- stick to more established alternatives like MongoDB or ElasticSearch that were designed for much more versatile use cases. 169 | -------------------------------------------------------------------------------- /dbs/meilisearch/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/meilisearch/api/__init__.py -------------------------------------------------------------------------------- /dbs/meilisearch/api/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | ) 9 | 10 | meili_service: str 11 | meili_master_key: str 12 | meili_port: int 13 | meili_url: str 14 | tag: str 15 | -------------------------------------------------------------------------------- /dbs/meilisearch/api/main.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from contextlib import asynccontextmanager 3 | from functools import lru_cache 4 | 5 | from fastapi import FastAPI 6 | from meilisearch_python_async import Client 7 | 8 | from api.config import Settings 9 | from api.routers import rest 10 | 11 | 12 | @lru_cache() 13 | def get_settings(): 14 | # Use lru_cache to avoid loading .env file for every request 15 | return Settings() 16 | 17 | 18 | async def get_search_api_key(settings) -> str: 19 | URI = f"http://{settings.meili_service}:{settings.meili_port}" 20 | MASTER_KEY = settings.meili_master_key 21 | async with Client(URI, MASTER_KEY) as client: 22 | response = await client.get_keys() 23 | # Search key is always the first result obtained (followed by admin key) 24 | search_key = response.results[0].key 25 | return search_key 26 | 27 | 28 | @asynccontextmanager 29 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 30 | # Search for wines by keyword phrase using Meilisearch 31 | settings = get_settings() 32 | print(settings) 33 | search_key = await get_search_api_key(settings) 34 | URI = f"http://{settings.meili_service}:{settings.meili_port}" 35 | async with Client(URI, search_key) as client: 36 | app.client = client 37 | print("Successfully connected to Meilisearch") 38 | yield 39 | print("Successfully closed Meilisearch connection") 40 | 41 | 42 | app = FastAPI( 43 | title="REST API for wine reviews on Meilisearch", 44 | description=( 45 | "Query from a Meilisearch database of 130k wine reviews from the Wine Enthusiast magazine" 46 | ), 47 | version=get_settings().tag, 48 | lifespan=lifespan, 49 | ) 50 | 51 | 52 | @app.get("/", include_in_schema=False) 53 | async def root(): 54 | return { 55 | "message": "REST API for querying Meilisearch database of 130k wine reviews from the Wine Enthusiast magazine" 56 | } 57 | 58 | 59 | # Attach routes 60 | app.include_router(rest.router, prefix="/wine", tags=["wine"]) 61 | -------------------------------------------------------------------------------- /dbs/meilisearch/api/routers/rest.py: -------------------------------------------------------------------------------- 1 | from api.schemas.rest import ( 2 | FullTextSearch, 3 | TopWinesByCountry, 4 | TopWinesByProvince, 5 | ) 6 | from fastapi import APIRouter, HTTPException, Query, Request 7 | from meilisearch_python_async import Client 8 | 9 | router = APIRouter() 10 | 11 | 12 | # --- Routes --- 13 | 14 | 15 | @router.get( 16 | "/search", 17 | response_model=list[FullTextSearch], 18 | response_description="Search wines by title, description and variety", 19 | ) 20 | async def search_by_keywords( 21 | request: Request, 22 | terms: str = Query(description="Search wine by keywords in title, description and variety"), 23 | max_price: int = Query( 24 | default=100.0, description="Specify the maximum price for the wine (e.g., 30)" 25 | ), 26 | ) -> list[FullTextSearch] | None: 27 | result = await _search_by_keywords(request.app.client, terms, max_price) 28 | if not result: 29 | raise HTTPException( 30 | status_code=404, 31 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 32 | ) 33 | return result 34 | 35 | 36 | @router.get( 37 | "/top_by_country", 38 | response_model=list[TopWinesByCountry], 39 | response_description="Get top-rated wines by country", 40 | ) 41 | async def top_by_country( 42 | request: Request, 43 | country: str = Query( 44 | description="Get top-rated wines by country name specified (must be exact name)" 45 | ), 46 | ) -> list[TopWinesByCountry] | None: 47 | result = await _top_by_country(request.app.client, country) 48 | if not result: 49 | raise HTTPException( 50 | status_code=404, 51 | detail=f"No wine from the provided country '{country}' found in database - please enter exact country name", 52 | ) 53 | return result 54 | 55 | 56 | @router.get( 57 | "/top_by_province", 58 | response_model=list[TopWinesByProvince], 59 | response_description="Get top-rated wines by province", 60 | ) 61 | async def top_by_province( 62 | request: Request, 63 | province: str = Query( 64 | description="Get top-rated wines by province name specified (must be exact name)" 65 | ), 66 | ) -> list[TopWinesByProvince] | None: 67 | result = await _top_by_province(request.app.client, province) 68 | if not result: 69 | raise HTTPException( 70 | status_code=404, 71 | detail=f"No wine from the provided province '{province}' found in database - please enter exact province name", 72 | ) 73 | return result 74 | 75 | 76 | # --- Meilisearch query funcs --- 77 | 78 | 79 | async def _search_by_keywords( 80 | client: Client, terms: str, max_price: int, index="wines" 81 | ) -> list[FullTextSearch] | None: 82 | index = client.index(index) 83 | response = await index.search( 84 | terms, 85 | limit=5, 86 | filter=f"price < {max_price}", 87 | sort=["points:desc", "price:asc"], 88 | ) 89 | if response: 90 | return response.hits 91 | return None 92 | 93 | 94 | async def _top_by_country( 95 | client: Client, country: str, index="wines" 96 | ) -> list[TopWinesByCountry] | None: 97 | index = client.index(index) 98 | response = await index.search( 99 | "", 100 | limit=5, 101 | filter=f'country = "{country}"', 102 | sort=["points:desc", "price:asc"], 103 | ) 104 | if response: 105 | print(response.hits) 106 | return response.hits 107 | return None 108 | 109 | 110 | async def _top_by_province( 111 | client: Client, province: str, index="wines" 112 | ) -> list[TopWinesByProvince] | None: 113 | index = client.index(index) 114 | response = await index.search( 115 | "terms", 116 | limit=5, 117 | filter=f'province = "{province}"', 118 | sort=["points:desc", "price:asc"], 119 | ) 120 | if response: 121 | return response.hits 122 | return None 123 | -------------------------------------------------------------------------------- /dbs/meilisearch/api/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/meilisearch/api/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/meilisearch/api/schemas/rest.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | 3 | 4 | class FullTextSearch(BaseModel): 5 | model_config = ConfigDict( 6 | json_schema_extra={ 7 | "example": { 8 | "id": 3845, 9 | "country": "Italy", 10 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 11 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 12 | "points": 93, 13 | "price": 16, 14 | "variety": "Red Blend", 15 | "winery": "Castellinuzza e Piuca", 16 | } 17 | } 18 | ) 19 | 20 | id: int 21 | country: str 22 | title: str 23 | description: str | None 24 | points: int 25 | price: float | str | None 26 | variety: str | None 27 | winery: str | None 28 | 29 | 30 | class TopWinesByCountry(BaseModel): 31 | model_config = ConfigDict( 32 | validate_assignment=True, 33 | ) 34 | 35 | id: int 36 | country: str 37 | title: str 38 | description: str | None 39 | points: int 40 | price: float | str | None = "Not available" 41 | variety: str | None 42 | winery: str | None 43 | 44 | 45 | class TopWinesByProvince(BaseModel): 46 | model_config = ConfigDict( 47 | validate_assignment=True, 48 | ) 49 | 50 | id: int 51 | country: str 52 | province: str 53 | title: str 54 | description: str | None 55 | points: int 56 | price: float | str | None = "Not available" 57 | variety: str | None 58 | winery: str | None 59 | -------------------------------------------------------------------------------- /dbs/meilisearch/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | meilisearch: 5 | image: getmeili/meilisearch:${MEILI_VERSION} 6 | restart: unless-stopped 7 | environment: 8 | - http_proxy 9 | - https_proxy 10 | - MEILI_MASTER_KEY=${MEILI_MASTER_KEY} 11 | - MEILI_NO_ANALYTICS=true 12 | - MEILI_ENV=development 13 | ports: 14 | - ${MEILI_PORT:-7700}:7700 15 | volumes: 16 | - meili_data:/meili_data 17 | networks: 18 | - wine 19 | 20 | fastapi: 21 | image: meili_wine_fastapi:${TAG} 22 | build: . 23 | restart: unless-stopped 24 | env_file: 25 | - .env 26 | ports: 27 | - ${API_PORT}:8000 28 | depends_on: 29 | - meilisearch 30 | volumes: 31 | - ./:/wine 32 | networks: 33 | - wine 34 | command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload 35 | 36 | volumes: 37 | meili_data: 38 | 39 | networks: 40 | wine: 41 | driver: bridge -------------------------------------------------------------------------------- /dbs/meilisearch/requirements.txt: -------------------------------------------------------------------------------- 1 | meilisearch-python-async~=1.4.0 2 | meilisearch~=0.28.0 3 | pydantic~=2.0.0 4 | pydantic-settings~=2.0.0 5 | python-dotenv>=1.0.0 6 | fastapi~=0.100.0 7 | httpx>=0.24.0 8 | aiohttp>=3.8.4 9 | uvicorn>=0.21.0, <1.0.0 10 | srsly>=2.4.6 11 | codetiming>=1.4.0 12 | tqdm>=4.65.0 13 | -------------------------------------------------------------------------------- /dbs/meilisearch/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/meilisearch/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/meilisearch/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field, model_validator 2 | 3 | 4 | class Wine(BaseModel): 5 | model_config = ConfigDict( 6 | populate_by_name=True, 7 | validate_assignment=True, 8 | extra="allow", 9 | str_strip_whitespace=True, 10 | json_schema_extra={ 11 | "example": { 12 | "id": 45100, 13 | "points": 85, 14 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 15 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 16 | "price": 10.0, 17 | "variety": "Merlot", 18 | "winery": "Balduzzi", 19 | "vineyard": "Reserva", 20 | "country": "Chile", 21 | "province": "Maule Valley", 22 | "region_1": "null", 23 | "region_2": "null", 24 | "taster_name": "Michael Schachner", 25 | "taster_twitter_handle": "@wineschach", 26 | } 27 | }, 28 | ) 29 | 30 | id: int 31 | points: int 32 | title: str 33 | description: str | None 34 | price: float | None 35 | variety: str | None 36 | winery: str | None 37 | vineyard: str | None = Field(..., alias="designation") 38 | country: str | None 39 | province: str | None 40 | region_1: str | None 41 | region_2: str | None 42 | taster_name: str | None 43 | taster_twitter_handle: str | None 44 | 45 | @model_validator(mode="before") 46 | def _fill_country_unknowns(cls, values): 47 | "Fill in missing country values with 'Unknown', as we always want this field to be queryable" 48 | country = values.get("country") 49 | if country is None or country == "null": 50 | values["country"] = "Unknown" 51 | return values 52 | 53 | 54 | if __name__ == "__main__": 55 | data = { 56 | "id": 45100, 57 | "points": 85, 58 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 59 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 60 | "price": 10, # Test if field is cast to float 61 | "variety": "Merlot", 62 | "winery": "Balduzzi", 63 | "designation": "Reserva", # Test if field is renamed 64 | "country": "null", # Test unknown country 65 | "province": " Maule Valley ", # Test if field is stripped 66 | "region_1": "null", 67 | "region_2": "null", 68 | "taster_name": "Michael Schachner", 69 | "taster_twitter_handle": "@wineschach", 70 | } 71 | from pprint import pprint 72 | 73 | wine = Wine(**data) 74 | pprint(wine.model_dump(), sort_dicts=False) 75 | -------------------------------------------------------------------------------- /dbs/meilisearch/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/meilisearch/scripts/__init__.py -------------------------------------------------------------------------------- /dbs/meilisearch/scripts/bulk_index_async.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import asyncio 5 | import os 6 | import sys 7 | from functools import lru_cache 8 | from pathlib import Path 9 | from typing import Any, Iterator 10 | 11 | import srsly 12 | from codetiming import Timer 13 | from dotenv import load_dotenv 14 | from meilisearch_python_async import Client 15 | from meilisearch_python_async.index import Index 16 | from meilisearch_python_async.models.settings import MeilisearchSettings 17 | from tqdm import tqdm 18 | from tqdm.asyncio import tqdm_asyncio 19 | 20 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 21 | from api.config import Settings 22 | from schemas.wine import Wine 23 | 24 | load_dotenv() 25 | # Custom types 26 | JsonBlob = dict[str, Any] 27 | 28 | 29 | class FileNotFoundError(Exception): 30 | pass 31 | 32 | 33 | # --- Blocking functions --- 34 | 35 | 36 | @lru_cache() 37 | def get_settings(): 38 | # Use lru_cache to avoid loading .env file for every request 39 | return Settings() 40 | 41 | 42 | def chunk_files(item_list: list[Any], file_chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: 43 | """ 44 | Break a large list of files into a list of lists of files, where each inner list is of size `file_chunksize` 45 | """ 46 | for i in range(0, len(item_list), file_chunksize): 47 | yield tuple(item_list[i : i + file_chunksize]) 48 | 49 | 50 | def get_json_data(file_path: Path) -> list[JsonBlob]: 51 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 52 | if not file_path.is_file(): 53 | # File may not have been uncompressed yet so try to do that first 54 | data = srsly.read_gzip_jsonl(file_path) 55 | # This time if it isn't there it really doesn't exist 56 | if not file_path.is_file(): 57 | raise FileNotFoundError( 58 | f"`{file_path}` doesn't contain a valid `.jsonl.gz` file - check and try again." 59 | ) 60 | else: 61 | data = srsly.read_gzip_jsonl(file_path) 62 | return data 63 | 64 | 65 | def validate( 66 | data: list[JsonBlob], 67 | exclude_none: bool = True, 68 | ) -> list[JsonBlob]: 69 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 70 | return validated_data 71 | 72 | 73 | def get_meili_settings(filename: str) -> MeilisearchSettings: 74 | settings = dict(srsly.read_json(filename)) 75 | # Convert to MeilisearchSettings pydantic model object 76 | settings = MeilisearchSettings(**settings) 77 | return settings 78 | 79 | 80 | # --- Async functions --- 81 | 82 | 83 | async def update_documents(filepath: Path, index: Index, primary_key: str, batch_size: int): 84 | data = list(get_json_data(filepath)) 85 | if LIMIT > 0: 86 | data = data[:LIMIT] 87 | validated_data = validate(data) 88 | await index.update_documents_in_batches( 89 | validated_data, 90 | batch_size=batch_size, 91 | primary_key=primary_key, 92 | ) 93 | 94 | 95 | async def main(data_files: list[Path]) -> None: 96 | meili_settings = get_meili_settings(filename="settings/settings.json") 97 | config = Settings() 98 | URI = f"http://{config.meili_url}:{config.meili_port}" 99 | MASTER_KEY = config.meili_master_key 100 | index_name = "wines" 101 | primary_key = "id" 102 | async with Client(URI, MASTER_KEY) as client: 103 | with Timer(name="Bulk Index", text="Bulk index took {:.4f} seconds"): 104 | # Create index 105 | index = client.index(index_name) 106 | # Update settings 107 | await client.index(index_name).update_settings(meili_settings) 108 | print("Finished updating database index settings") 109 | file_chunks = chunk_files(data_files, file_chunksize=FILE_CHUNKSIZE) 110 | for chunk in tqdm( 111 | file_chunks, desc="Handling file chunks", total=len(data_files) // FILE_CHUNKSIZE 112 | ): 113 | try: 114 | tasks = [ 115 | # Update index 116 | update_documents( 117 | filepath, 118 | index, 119 | primary_key=primary_key, 120 | batch_size=BATCHSIZE, 121 | ) 122 | # In a real case we'd be iterating through a list of files 123 | # For this example, it's just looping through the same file N times 124 | for filepath in chunk 125 | ] 126 | await tqdm_asyncio.gather(*tasks) 127 | except Exception as e: 128 | print(f"{e}: Error while indexing to db") 129 | print(f"Finished running benchmarks") 130 | 131 | 132 | if __name__ == "__main__": 133 | # fmt: off 134 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 135 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 136 | parser.add_argument("--batchsize", "-b", type=int, default=10_000, help="Size of each batch to break the dataset into before ingesting") 137 | parser.add_argument("--file_chunksize", "-c", type=int, default=5, help="Size of file chunk that will be concurrently processed and passed to the client in batches") 138 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 139 | parser.add_argument("--benchmark_num", "-n", type=int, default=1, help="Run a benchmark of the script N times") 140 | args = vars(parser.parse_args()) 141 | # fmt: on 142 | 143 | LIMIT = args["limit"] 144 | DATA_DIR = Path(__file__).parents[3] / "data" 145 | FILENAME = args["filename"] 146 | BATCHSIZE = args["batchsize"] 147 | BENCHMARK_NUM = args["benchmark_num"] 148 | FILE_CHUNKSIZE = args["file_chunksize"] 149 | 150 | # Get a list of all files in the data directory 151 | data_files = [f for f in DATA_DIR.glob("*.jsonl.gz") if f.is_file()] 152 | # For benchmarking, we want to run on the same data multiple times (in the real world this would be many different files) 153 | benchmark_data_files = data_files * BENCHMARK_NUM 154 | 155 | meili_settings = get_meili_settings(filename="settings/settings.json") 156 | 157 | # Run main async event loop 158 | asyncio.run(main(benchmark_data_files)) 159 | -------------------------------------------------------------------------------- /dbs/meilisearch/scripts/bulk_index_sync.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import os 5 | import sys 6 | from functools import lru_cache 7 | from pathlib import Path 8 | from typing import Any 9 | 10 | import srsly 11 | from codetiming import Timer 12 | from dotenv import load_dotenv 13 | from meilisearch import Client 14 | from meilisearch.index import Index 15 | from schemas.wine import Wine 16 | from tqdm import tqdm 17 | 18 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 19 | from api.config import Settings 20 | 21 | load_dotenv() 22 | # Custom types 23 | JsonBlob = dict[str, Any] 24 | 25 | 26 | class FileNotFoundError(Exception): 27 | pass 28 | 29 | 30 | # --- Blocking functions --- 31 | 32 | 33 | @lru_cache() 34 | def get_settings(): 35 | # Use lru_cache to avoid loading .env file for every request 36 | return Settings() 37 | 38 | 39 | def get_json_data(file_path: Path) -> list[JsonBlob]: 40 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 41 | if not file_path.is_file(): 42 | # File may not have been uncompressed yet so try to do that first 43 | data = srsly.read_gzip_jsonl(file_path) 44 | # This time if it isn't there it really doesn't exist 45 | if not file_path.is_file(): 46 | raise FileNotFoundError( 47 | f"`{file_path}` doesn't contain a valid `.jsonl.gz` file - check and try again." 48 | ) 49 | else: 50 | data = srsly.read_gzip_jsonl(file_path) 51 | return data 52 | 53 | 54 | def validate( 55 | data: list[JsonBlob], 56 | exclude_none: bool = True, 57 | ) -> list[JsonBlob]: 58 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 59 | return validated_data 60 | 61 | 62 | def get_meili_settings(filename: str) -> dict[str, Any]: 63 | settings = dict(srsly.read_json(filename)) 64 | return settings 65 | 66 | 67 | def update_documents(filepath: Path, index: Index, primary_key: str, batch_size: int): 68 | data = list(get_json_data(filepath)) 69 | if LIMIT > 0: 70 | data = data[:LIMIT] 71 | validated_data = validate(data) 72 | index.update_documents_in_batches( 73 | validated_data, 74 | batch_size=batch_size, 75 | primary_key=primary_key, 76 | ) 77 | 78 | 79 | def main(data_files: list[Path]) -> None: 80 | meili_settings = get_meili_settings(filename="settings/settings.json") 81 | config = Settings() 82 | URI = f"http://{config.meili_url}:{config.meili_port}" 83 | MASTER_KEY = config.meili_master_key 84 | index_name = "wines" 85 | primary_key = "id" 86 | 87 | client = Client(URI, MASTER_KEY) 88 | with Timer(name="Bulk Index", text="Bulk index took {:.4f} seconds"): 89 | # Create index 90 | index = client.index(index_name) 91 | # Update settings 92 | client.index(index_name).update_settings(meili_settings) 93 | print("Finished updating database index settings") 94 | try: 95 | # In a real case we'd be iterating through a list of files 96 | # For this example, it's just looping through the same file N times 97 | for filepath in tqdm(data_files): 98 | # Update index 99 | update_documents(filepath, index, primary_key=primary_key, batch_size=BATCHSIZE) 100 | except Exception as e: 101 | print(f"{e}: Error while indexing to db") 102 | 103 | 104 | if __name__ == "__main__": 105 | # fmt: off 106 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 107 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 108 | parser.add_argument("--batchsize", "-b", type=int, default=10_000, help="Size of each chunk to break the dataset into before processing") 109 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 110 | parser.add_argument("--benchmark_num", "-n", type=int, default=1, help="Run a benchmark of the script N times") 111 | args = vars(parser.parse_args()) 112 | # fmt: on 113 | 114 | LIMIT = args["limit"] 115 | DATA_DIR = Path(__file__).parents[3] / "data" 116 | FILENAME = args["filename"] 117 | BATCHSIZE = args["batchsize"] 118 | BENCHMARK_NUM = args["benchmark_num"] 119 | 120 | # Get a list of all files in the data directory 121 | data_files = [f for f in DATA_DIR.glob("*.jsonl.gz") if f.is_file()] 122 | # For benchmarking, we want to run on the same data multiple times (in the real world this would be many different files) 123 | benchmark_data_files = data_files * BENCHMARK_NUM 124 | 125 | meili_settings = get_meili_settings(filename="settings/settings.json") 126 | 127 | # Run main function 128 | main(benchmark_data_files) 129 | -------------------------------------------------------------------------------- /dbs/meilisearch/scripts/settings/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rankingRules": [ 3 | "words", 4 | "typo", 5 | "proximity", 6 | "attribute", 7 | "sort", 8 | "exactness", 9 | "points:desc" 10 | ], 11 | "searchableAttributes": [ 12 | "title", 13 | "description", 14 | "country", 15 | "province", 16 | "variety", 17 | "region_1", 18 | "region_2", 19 | "taster_name" 20 | ], 21 | "filterableAttributes": [ 22 | "price", 23 | "points", 24 | "country", 25 | "province", 26 | "variety" 27 | ], 28 | "displayedAttributes": [ 29 | "id", 30 | "points", 31 | "price", 32 | "title", 33 | "country", 34 | "province", 35 | "variety", 36 | "winery", 37 | "taster_name", 38 | "description" 39 | ], 40 | "sortableAttributes": [ 41 | "points", 42 | "price" 43 | ], 44 | "stopWords": [ 45 | "the", 46 | "a", 47 | "an", 48 | "of", 49 | "to", 50 | "in", 51 | "for", 52 | "on" 53 | ], 54 | "typoTolerance": { 55 | "minWordSizeForTypos": { 56 | "oneTypo": 4, 57 | "twoTypos": 9 58 | } 59 | }, 60 | "pagination": { 61 | "maxTotalHits": 500 62 | }, 63 | "faceting": { 64 | "maxValuesPerFacet": 100 65 | } 66 | } -------------------------------------------------------------------------------- /dbs/neo4j/.dockerignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | .Python 6 | env 7 | pip-log.txt 8 | pip-delete-this-directory.txt 9 | .tox 10 | .coverage 11 | .coverage.* 12 | .cache 13 | nosetests.xml 14 | coverage.xml 15 | *.cover 16 | *.log 17 | .git 18 | .mypy_cache 19 | .pytest_cache 20 | -------------------------------------------------------------------------------- /dbs/neo4j/.env.example: -------------------------------------------------------------------------------- 1 | # Neo4j 2 | NEO4J_PASSWORD = "" 3 | NEO4J_VERSION = 5.6.0 4 | DB_SERVICE = "db" 5 | API_PORT = 8001 6 | 7 | # Container image tag 8 | TAG = "0.2.0" 9 | 10 | # Docker project namespace (defaults to the current folder name if not set) 11 | COMPOSE_PROJECT_NAME = neo4j_wine -------------------------------------------------------------------------------- /dbs/neo4j/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | WORKDIR /wine 4 | 5 | COPY ./requirements.txt /wine/requirements.txt 6 | 7 | RUN pip install --no-cache-dir --upgrade -r /wine/requirements.txt 8 | 9 | COPY ./api /wine/api 10 | 11 | EXPOSE 8000 -------------------------------------------------------------------------------- /dbs/neo4j/README.md: -------------------------------------------------------------------------------- 1 | # Neo4j 2 | 3 | [Neo4j](https://neo4j.com/) is an ACID-compliant transactional database with native graph storage and processing, popular for use cases where we want to query connected data. The primary use case for a graph database is to answer business questions that involve connected data. 4 | 5 | * Which wines from Chile were tasted by at least two different tasters? 6 | * What are the top-rated wines from Italy that share their variety with my favourite ones from Portugal? 7 | 8 | Code is provided for ingesting the wine reviews dataset into Neo4j in an async fashion. In addition, a query API written in FastAPI is also provided that allows a user to query available endpoints. As always in FastAPI, documentation is available via OpenAPI (http://localhost:8000/docs). 9 | 10 | * All code (wherever possible) is async 11 | * [Pydantic](https://docs.pydantic.dev) is used for schema validation, both prior to data ingestion and during API request handling 12 | * The same schema is used for data ingestion and for the API, so there is only one source of truth regarding how the data is handled 13 | * For ease of reproducibility, the whole setup is orchestrated and deployed via docker 14 | 15 | ## Setup 16 | 17 | Note that this code base has been tested in Python 3.11, and requires a minimum of Python 3.10 to work. Install dependencies via `requirements.txt`. 18 | 19 | ```sh 20 | # Setup the environment for the first time 21 | python -m venv neo4j_venv # python -> python 3.10+ 22 | 23 | # Activate the environment (for subsequent runs) 24 | source neoj_venv/bin/activate 25 | 26 | python -m pip install -r requirements.txt 27 | ``` 28 | 29 | --- 30 | 31 | ## Step 1: Set up containers 32 | 33 | Use the provided `docker-compose.yml` to initiate separate containers, one that run Neo4j, and another one that serves as an API on top of the database. 34 | 35 | ``` 36 | docker compose up -d 37 | ``` 38 | 39 | This compose file starts a persistent-volume Neo4j database with credentials specified in `.env`. The `db` variable in the environment file indicates that we are opening up the database service to a FastAPI server (running as a separate service, in a separate container) downstream. Both containers can communicate with one another with the common network that they share, on the exact port numbers specified. 40 | 41 | The services can be stopped at any time for maintenance and updates. 42 | 43 | ``` 44 | docker compose down 45 | ``` 46 | 47 | **Note:** The setup shown here would not be ideal in production, as there are other details related to security and scalability that are not addressed via simple docker, but, this is a good starting point to begin experimenting! 48 | 49 | 50 | ## Step 2: Ingest the data 51 | 52 | The first step is to ingest the wine reviews dataset into Neo4j. To do this, we first conceptualize the following data model: 53 | 54 | ![](./assets/data_model.png) 55 | 56 | The idea behind this data model is as follows: 57 | 58 | * We want to be able to query for wines from a specific region, or country, or both 59 | * The taste of wine is influenced by both the province (i.e., specific region) of a country, as well as the general part of the world it is from 60 | * We also want to be able to relate wines to the person that tasted them to give them a particular number of points for a review 61 | 62 | The data model can be far more detailed than this example, and depends heavily on the use cases for which we want to query for. At present, this model will suffice. 63 | 64 | ### Run async data loader 65 | 66 | Data is asynchronously ingested into the Neo4j database through the scripts in the `scripts` directly. 67 | 68 | ```sh 69 | cd scripts 70 | python build_graph.py 71 | ``` 72 | 73 | This script validates the input JSON data via [Pydantic](https://docs.pydantic.dev), and then asynchronously ingests them into Neo4j using the [Neo4j `AsyncGraphDatabase` driver](https://neo4j.com/docs/api/python-driver/current/async_api.html), with appropriate constraints and indexes for best performance. 74 | 75 | ## Step 3: Test API 76 | 77 | Once the data has been successfully loaded into Neo4j and the containers are up and running, we can test out a search query via an HTTP request as follows. 78 | 79 | ```sh 80 | curl -X 'GET' \ 81 | 'http://localhost:8000/wine/search?terms=tuscany%20red&max_price=50' 82 | ``` 83 | 84 | This cURL request passes the search terms "**tuscany red**" to the `/wine/search/` endpoint, which is then parsed into a working Cypher query by the FastAPI backend. The query runs and retrieves results from a full text search index (that looks for these keywords in the wine's title and description), and, if the setup was done correctly, we should see the following response: 85 | 86 | ```json 87 | [ 88 | { 89 | "wineID": 66393, 90 | "country": "Italy", 91 | "title": "Capezzana 1999 Ghiaie Della Furba Red (Tuscany)", 92 | "description": "Very much a baby, this is one big, bold, burly Cab-Merlot-Syrah blend that's filled to the brim with extracted plum fruit, bitter chocolate and earth. It takes a long time in the glass for it to lose its youthful, funky aromatics, and on the palate things are still a bit scattered. But in due time things will settle and integrate", 93 | "points": 90, 94 | "price": 49, 95 | "variety": "Red Blend", 96 | "winery": "Capezzana" 97 | }, 98 | { 99 | "wineID": 40960, 100 | "country": "Italy", 101 | "title": "Fattoria di Grignano 2011 Pietramaggio Red (Toscana)", 102 | "description": "Here's a simple but well made red from Tuscany that has floral aromas of violet and rose with berry notes. The palate offers bright cherry, red currant and a touch of spice. Pair this with pasta dishes or grilled vegetables.", 103 | "points": 86, 104 | "price": 11, 105 | "variety": "Red Blend", 106 | "winery": "Fattoria di Grignano" 107 | }, 108 | { 109 | "wineID": 73595, 110 | "country": "Italy", 111 | "title": "I Giusti e Zanza 2011 Belcore Red (Toscana)", 112 | "description": "With aromas of violet, tilled soil and red berries, this blend of Sangiovese and Merlot recalls sunny Tuscany. It's loaded with wild cherry flavors accented by white pepper, cinnamon and vanilla. The palate is uplifted by vibrant acidity and fine tannins.", 113 | "points": 89, 114 | "price": 27, 115 | "variety": "Red Blend", 116 | "winery": "I Giusti e Zanza" 117 | } 118 | ] 119 | ``` 120 | 121 | Not bad! This example correctly returns some highly rated Tuscan red wines along with their price and country of origin (obviously, Italy in this case). 122 | 123 | ### Step 4: Extend the API 124 | 125 | The API can be easily extended with the provided structure. 126 | 127 | - The `schemas` directory houses the Pydantic schemas, both for the data input as well as for the endpoint outputs 128 | - As the data model gets more complex, we can add more files and separate the ingestion logic from the API logic here 129 | - The `api/routers` directory contains the endpoint routes so that we can provide additional endpoint that answer more business questions 130 | - For e.g.: "What are the top rated wines from Argentina?" 131 | - In general, it makes sense to organize specific business use cases into their own router files 132 | - The `api/main.py` file collects all the routes and schemas to run the API 133 | 134 | 135 | #### Existing endpoints 136 | 137 | So far, the following endpoints that help answer interesting questions have been implemented. 138 | 139 | ``` 140 | GET 141 | /wine/search 142 | Search By Keywords 143 | ``` 144 | 145 | ``` 146 | GET 147 | /wine/top_by_country 148 | Top By Country 149 | ``` 150 | 151 | ``` 152 | GET 153 | /wine/top_by_province 154 | Top By Province 155 | ``` 156 | 157 | ``` 158 | GET 159 | /wine/most_by_variety 160 | Most By Variety 161 | ``` 162 | 163 | Run the FastAPI app in a docker container to explore them! 164 | 165 | -------------------------------------------------------------------------------- /dbs/neo4j/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/neo4j/api/__init__.py -------------------------------------------------------------------------------- /dbs/neo4j/api/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | ) 9 | 10 | neo4j_service: str 11 | neo4j_url: str 12 | neo4j_user: str 13 | neo4j_password: str 14 | tag: str 15 | -------------------------------------------------------------------------------- /dbs/neo4j/api/main.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from contextlib import asynccontextmanager 3 | from functools import lru_cache 4 | 5 | from fastapi import FastAPI 6 | from neo4j import AsyncGraphDatabase 7 | 8 | from api.config import Settings 9 | from api.routers import rest 10 | 11 | 12 | @lru_cache() 13 | def get_settings(): 14 | # Use lru_cache to avoid loading .env file for every request 15 | return Settings() 16 | 17 | 18 | @asynccontextmanager 19 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 20 | """Async context manager for MongoDB connection.""" 21 | settings = get_settings() 22 | service = settings.neo4j_service 23 | URI = f"bolt://{service}:7687" 24 | AUTH = (settings.neo4j_user, settings.neo4j_password) 25 | async with AsyncGraphDatabase.driver(URI, auth=AUTH) as driver: 26 | async with driver.session(database="neo4j") as session: 27 | app.session = session 28 | print("Successfully connected to wine reviews Neo4j DB") 29 | yield 30 | print("Successfully closed wine reviews Neo4j connection") 31 | 32 | 33 | app = FastAPI( 34 | title="REST API for wine reviews on Neo4j", 35 | description=( 36 | "Query from a Neo4j database of 130k wine reviews from the Wine Enthusiast magazine" 37 | ), 38 | version=get_settings().tag, 39 | lifespan=lifespan, 40 | ) 41 | 42 | 43 | @app.get("/", include_in_schema=False) 44 | async def root(): 45 | return { 46 | "message": "REST API for querying Neo4j database of 130k wine reviews from the Wine Enthusiast magazine" 47 | } 48 | 49 | 50 | # Attach routes 51 | app.include_router(rest.router, prefix="/wine", tags=["wine"]) 52 | -------------------------------------------------------------------------------- /dbs/neo4j/api/routers/rest.py: -------------------------------------------------------------------------------- 1 | from api.schemas.rest import ( 2 | FullTextSearch, 3 | MostWinesByVariety, 4 | TopWinesByCountry, 5 | TopWinesByProvince, 6 | ) 7 | from fastapi import APIRouter, HTTPException, Query, Request 8 | from neo4j import AsyncManagedTransaction 9 | 10 | router = APIRouter() 11 | 12 | 13 | # --- Routes --- 14 | 15 | 16 | @router.get( 17 | "/search", 18 | response_model=list[FullTextSearch], 19 | response_description="Search wines by title and description", 20 | ) 21 | async def search_by_keywords( 22 | request: Request, 23 | terms: str = Query(description="Search wine by keywords in title or description"), 24 | max_price: float = Query( 25 | default=100.0, description="Specify the maximum price for the wine (e.g., 30)" 26 | ), 27 | ) -> list[FullTextSearch] | None: 28 | session = request.app.session 29 | result = await session.execute_read(_search_by_keywords, terms, max_price) 30 | if not result: 31 | raise HTTPException( 32 | status_code=404, 33 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 34 | ) 35 | return result 36 | 37 | 38 | @router.get( 39 | "/top_by_country", 40 | response_model=list[TopWinesByCountry], 41 | response_description="Get top-rated wines by country", 42 | ) 43 | async def top_by_country( 44 | request: Request, 45 | country: str = Query( 46 | description="Get top-rated wines by country name specified (must be exact name)" 47 | ), 48 | ) -> list[TopWinesByCountry] | None: 49 | session = request.app.session 50 | result = await session.execute_read(_top_by_country, country) 51 | if not result: 52 | raise HTTPException( 53 | status_code=404, 54 | detail=f"No wine from the provided country '{country}' found in database - please enter exact country name", 55 | ) 56 | return result 57 | 58 | 59 | @router.get( 60 | "/top_by_province", 61 | response_model=list[TopWinesByProvince], 62 | response_description="Get top-rated wines by province", 63 | ) 64 | async def top_by_province( 65 | request: Request, 66 | province: str = Query( 67 | description="Get top-rated wines by province name specified (must be exact name)" 68 | ), 69 | ) -> list[TopWinesByProvince] | None: 70 | session = request.app.session 71 | result = await session.execute_read(_top_by_province, province) 72 | if not result: 73 | raise HTTPException( 74 | status_code=404, 75 | detail=f"No wine from the provided province '{province}' found in database - please enter exact province name", 76 | ) 77 | return result 78 | 79 | 80 | @router.get( 81 | "/most_by_variety", 82 | response_model=list[MostWinesByVariety], 83 | response_description="Get the countries with the most wines above a points-rating of a specified variety (blended or otherwise)", 84 | ) 85 | async def most_by_variety( 86 | request: Request, 87 | variety: str = Query( 88 | description="Specify the variety of wine to search for (e.g., 'Pinot Noir' or 'Red Blend')" 89 | ), 90 | points: int = Query( 91 | default=85, 92 | description="Specify the minimum points-rating for the wine (e.g., 85)", 93 | ), 94 | ) -> list[MostWinesByVariety] | None: 95 | session = request.app.session 96 | result = await session.execute_read(_most_by_variety, variety, points) 97 | if not result: 98 | raise HTTPException( 99 | status_code=404, 100 | detail=f"No wine of the specified variety '{variety}' found in database - please try a different variety", 101 | ) 102 | return result 103 | 104 | 105 | # --- Neo4j query funcs --- 106 | 107 | 108 | async def _search_by_keywords( 109 | tx: AsyncManagedTransaction, 110 | terms: str, 111 | price: float, 112 | ) -> list[FullTextSearch] | None: 113 | query = """ 114 | CALL db.index.fulltext.queryNodes("searchText", $terms) YIELD node AS wine, score 115 | WITH DISTINCT wine, score 116 | MATCH (wine)-[:IS_FROM_COUNTRY]->(c:Country) 117 | WHERE wine.price <= $price 118 | RETURN 119 | c.countryName AS country, 120 | wine.wineID AS wineID, 121 | wine.points AS points, 122 | wine.title AS title, 123 | wine.description AS description, 124 | coalesce(wine.price, "Not available") AS price, 125 | wine.variety AS variety, 126 | wine.winery AS winery 127 | ORDER BY score DESC, points DESC LIMIT 5 128 | """ 129 | response = await tx.run(query, terms=terms, price=price) 130 | result = await response.data() 131 | if result: 132 | return [FullTextSearch(**r) for r in result] 133 | return None 134 | 135 | 136 | async def _top_by_country( 137 | tx: AsyncManagedTransaction, 138 | country: str, 139 | ) -> list[TopWinesByCountry] | None: 140 | query = """ 141 | MATCH (wine:Wine)-[:IS_FROM_COUNTRY]->(c:Country) 142 | WHERE tolower(c.countryName) = tolower($country) 143 | RETURN 144 | wine.wineID AS wineID, 145 | wine.points AS points, 146 | wine.title AS title, 147 | wine.description AS description, 148 | c.countryName AS country, 149 | coalesce(wine.price, "Not available") AS price, 150 | wine.variety AS variety, 151 | wine.winery AS winery 152 | ORDER BY points DESC LIMIT 5 153 | """ 154 | response = await tx.run(query, country=country) 155 | result = await response.data() 156 | if result: 157 | return [TopWinesByCountry(**r) for r in result] 158 | return None 159 | 160 | 161 | async def _top_by_province( 162 | tx: AsyncManagedTransaction, 163 | province: str, 164 | ) -> list[TopWinesByProvince] | None: 165 | query = """ 166 | MATCH (wine:Wine)-[:IS_FROM_PROVINCE]->(p:Province)-[:IS_LOCATED_IN]->(c:Country) 167 | WHERE tolower(p.provinceName) = tolower($province) 168 | RETURN 169 | wine.wineID AS wineID, 170 | wine.points AS points, 171 | wine.title AS title, 172 | wine.description AS description, 173 | c.countryName AS country, 174 | p.provinceName AS province, 175 | coalesce(wine.price, "Not available") AS price, 176 | wine.variety AS variety, 177 | wine.winery AS winery 178 | ORDER BY points DESC LIMIT 5 179 | """ 180 | response = await tx.run(query, province=province) 181 | result = await response.data() 182 | if result: 183 | return [TopWinesByProvince(**r) for r in result] 184 | return None 185 | 186 | 187 | async def _most_by_variety( 188 | tx: AsyncManagedTransaction, 189 | variety: str, 190 | points: int, 191 | ) -> list[MostWinesByVariety] | None: 192 | query = """ 193 | CALL db.index.fulltext.queryNodes("searchText", $variety) YIELD node AS wine, score 194 | WITH wine 195 | MATCH (wine)-[:IS_FROM_COUNTRY]->(c:Country) 196 | WHERE wine.points >= $points 197 | RETURN 198 | c.countryName AS country, 199 | count(wine) as wineCount 200 | ORDER BY wineCount DESC LIMIT 5 201 | """ 202 | response = await tx.run(query, variety=variety, points=points) 203 | result = await response.data() 204 | if result: 205 | return [MostWinesByVariety(**r) for r in result] 206 | return None 207 | -------------------------------------------------------------------------------- /dbs/neo4j/api/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/neo4j/api/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/neo4j/api/schemas/rest.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | 3 | 4 | class FullTextSearch(BaseModel): 5 | model_config = ConfigDict( 6 | json_schema_extra={ 7 | "example": { 8 | "wineID": 3845, 9 | "country": "Italy", 10 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 11 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 12 | "points": 93, 13 | "price": 16, 14 | "variety": "Red Blend", 15 | "winery": "Castellinuzza e Piuca", 16 | } 17 | } 18 | ) 19 | 20 | wineID: int 21 | country: str 22 | title: str 23 | description: str | None 24 | points: int 25 | price: float | str 26 | variety: str | None 27 | winery: str | None 28 | 29 | 30 | class TopWinesByCountry(BaseModel): 31 | wineID: int 32 | country: str 33 | title: str 34 | description: str | None 35 | points: int 36 | price: float | str 37 | variety: str | None 38 | winery: str | None 39 | 40 | 41 | class TopWinesByProvince(BaseModel): 42 | wineID: int 43 | country: str 44 | province: str 45 | title: str 46 | description: str | None 47 | points: int 48 | price: float | str 49 | variety: str | None 50 | winery: str | None 51 | 52 | 53 | class MostWinesByVariety(BaseModel): 54 | country: str 55 | wineCount: int 56 | -------------------------------------------------------------------------------- /dbs/neo4j/assets/data_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/neo4j/assets/data_model.png -------------------------------------------------------------------------------- /dbs/neo4j/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | 3 | services: 4 | neo4j: 5 | container_name: neo4j_wine 6 | image: neo4j:${NEO4J_VERSION} 7 | restart: unless-stopped 8 | environment: 9 | - NEO4J_AUTH=neo4j/${NEO4J_PASSWORD} 10 | - NEO4J_PLUGINS=["graph-data-science", "apoc"] 11 | # DB and server 12 | - NEO4J_server_memory_pagecache_size=1G 13 | - NEO4J_server_memory_heap_initial__size=1G 14 | - NEO4J_server_memory_heap_max__size=2G 15 | - NEO4J_dbms_security_procedures_unrestricted=gds.*,apoc.* 16 | ports: 17 | - 7687:7687 18 | volumes: 19 | - logs:/logs 20 | - data:/data 21 | - plugins:/plugins 22 | - import:/import 23 | networks: 24 | - wine 25 | 26 | fastapi: 27 | image: neo4j_wine_fastapi:${TAG} 28 | build: . 29 | restart: unless-stopped 30 | env_file: 31 | - .env 32 | ports: 33 | - ${API_PORT}:8000 34 | depends_on: 35 | - neo4j 36 | volumes: 37 | - ./:/wine 38 | networks: 39 | - wine 40 | command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload 41 | 42 | volumes: 43 | logs: 44 | data: 45 | plugins: 46 | import: 47 | 48 | networks: 49 | wine: 50 | driver: bridge -------------------------------------------------------------------------------- /dbs/neo4j/requirements.txt: -------------------------------------------------------------------------------- 1 | neo4j~=5.9.0 2 | pydantic~=2.0.0 3 | pydantic-settings~=2.0.0 4 | python-dotenv>=1.0.0 5 | fastapi~=0.100.0 6 | httpx>=0.24.0 7 | aiohttp>=3.8.4 8 | uvloop>=0.17.0 9 | uvicorn>=0.21.0, <1.0.0 10 | srsly>=2.4.6 -------------------------------------------------------------------------------- /dbs/neo4j/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/neo4j/scripts/__init__.py -------------------------------------------------------------------------------- /dbs/neo4j/scripts/build_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import os 4 | import sys 5 | import time 6 | from functools import lru_cache 7 | from pathlib import Path 8 | from typing import Any, Iterator 9 | 10 | import srsly 11 | from dotenv import load_dotenv 12 | from neo4j import AsyncGraphDatabase, AsyncManagedTransaction, AsyncSession 13 | 14 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 15 | from api.config import Settings 16 | from schemas.wine import Wine 17 | 18 | # Custom types 19 | JsonBlob = dict[str, Any] 20 | 21 | 22 | class FileNotFoundError(Exception): 23 | pass 24 | 25 | 26 | # --- Blocking functions --- 27 | 28 | 29 | @lru_cache() 30 | def get_settings(): 31 | load_dotenv() 32 | # Use lru_cache to avoid loading .env file for every request 33 | return Settings() 34 | 35 | 36 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: 37 | """ 38 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 39 | """ 40 | for i in range(0, len(item_list), chunksize): 41 | yield tuple(item_list[i : i + chunksize]) 42 | 43 | 44 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 45 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 46 | file_path = data_dir / filename 47 | if not file_path.is_file(): 48 | # File may not have been uncompressed yet so try to do that first 49 | data = srsly.read_gzip_jsonl(file_path) 50 | # This time if it isn't there it really doesn't exist 51 | if not file_path.is_file(): 52 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 53 | else: 54 | data = srsly.read_gzip_jsonl(file_path) 55 | return data 56 | 57 | 58 | def validate( 59 | data: list[JsonBlob], 60 | exclude_none: bool = False, 61 | ) -> list[JsonBlob]: 62 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 63 | return validated_data 64 | 65 | 66 | def process_chunks(data: list[JsonBlob]) -> list[JsonBlob]: 67 | validated_data = validate(data, exclude_none=True) 68 | return validated_data 69 | 70 | 71 | # --- Async functions --- 72 | 73 | 74 | async def create_indexes_and_constraints(session: AsyncSession) -> None: 75 | queries = [ 76 | # constraints 77 | "CREATE CONSTRAINT countryName IF NOT EXISTS FOR (c:Country) REQUIRE c.countryName IS UNIQUE ", 78 | "CREATE CONSTRAINT wineID IF NOT EXISTS FOR (w:Wine) REQUIRE w.wineID IS UNIQUE ", 79 | # indexes 80 | "CREATE INDEX provinceName IF NOT EXISTS FOR (p:Province) ON (p.provinceName) ", 81 | "CREATE INDEX tasterName IF NOT EXISTS FOR (p:Person) ON (p.tasterName) ", 82 | "CREATE FULLTEXT INDEX searchText IF NOT EXISTS FOR (w:Wine) ON EACH [w.title, w.description, w.variety] ", 83 | ] 84 | for query in queries: 85 | await session.run(query) 86 | 87 | 88 | async def build_query(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: 89 | query = """ 90 | UNWIND $data AS record 91 | MERGE (wine:Wine {wineID: record.id}) 92 | SET wine += { 93 | points: record.points, 94 | title: record.title, 95 | description: record.description, 96 | price: record.price, 97 | variety: record.variety, 98 | winery: record.winery, 99 | vineyard: record.vineyard, 100 | region_1: record.region_1, 101 | region_2: record.region_2 102 | } 103 | WITH record, wine 104 | WHERE record.taster_name IS NOT NULL 105 | MERGE (taster:Person {tasterName: record.taster_name}) 106 | SET taster += {tasterTwitterHandle: record.taster_twitter_handle} 107 | MERGE (wine)-[:TASTED_BY]->(taster) 108 | WITH record, wine 109 | MERGE (country:Country {countryName: record.country}) 110 | MERGE (wine)-[:IS_FROM_COUNTRY]->(country) 111 | WITH record, wine, country 112 | WHERE record.province IS NOT NULL 113 | MERGE (province:Province {provinceName: record.province}) 114 | MERGE (wine)-[:IS_FROM_PROVINCE]->(province) 115 | WITH record, wine, country, province 116 | WHERE record.province IS NOT NULL AND record.country IS NOT NULL 117 | MERGE (province)-[:IS_LOCATED_IN]->(country) 118 | """ 119 | await tx.run(query, data=data) 120 | 121 | 122 | async def main(data: list[JsonBlob]) -> None: 123 | async with AsyncGraphDatabase.driver(URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) as driver: 124 | async with driver.session(database="neo4j") as session: 125 | # Create indexes and constraints 126 | await create_indexes_and_constraints(session) 127 | # Validate 128 | validation_start_time = time.time() 129 | print("Validating data...") 130 | validated_data = validate(data, exclude_none=True) 131 | chunked_data = chunk_iterable(validated_data, CHUNKSIZE) 132 | print( 133 | f"Finished validating data in pydantic in {(time.time() - validation_start_time):.4f} sec" 134 | ) 135 | # Bulk ingest 136 | ingestion_time = time.time() 137 | # Ingest the data into Neo4j 138 | for chunk in chunked_data: 139 | ids = [item["id"] for item in chunk] 140 | try: 141 | await session.execute_write(build_query, chunk) 142 | print(f"Processed ids in range {min(ids)}-{max(ids)}") 143 | except Exception as e: 144 | print(f"{e}: Failed to ingest IDs in range {min(ids)}-{max(ids)}") 145 | print(f"Finished ingesting data in {(time.time() - ingestion_time):.4f} sec") 146 | 147 | 148 | if __name__ == "__main__": 149 | # fmt: off 150 | parser = argparse.ArgumentParser("Build a graph from the wine reviews JSONL data") 151 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 152 | parser.add_argument("--chunksize", type=int, default=10_000, help="Size of each chunk to break the dataset into before processing") 153 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 154 | args = vars(parser.parse_args()) 155 | # fmt: on 156 | 157 | LIMIT = args["limit"] 158 | DATA_DIR = Path(__file__).parents[3] / "data" 159 | FILENAME = args["filename"] 160 | CHUNKSIZE = args["chunksize"] 161 | 162 | # # Neo4j 163 | settings = get_settings() 164 | URI = f"bolt://{settings.neo4j_url}:7687" 165 | NEO4J_USER = settings.neo4j_user 166 | NEO4J_PASSWORD = settings.neo4j_password 167 | 168 | data = list(get_json_data(DATA_DIR, FILENAME)) 169 | if LIMIT > 0: 170 | data = data[:LIMIT] 171 | 172 | # Run main async event loop using uvloop for slightly better performance 173 | # Neo4j async python driver uses uvloop under the hood, which is why it makes sense 174 | # to attach the uvloop policy to the asyncio event loop for our main function 175 | import uvloop 176 | 177 | uvloop.install() 178 | asyncio.run(main(data)) 179 | -------------------------------------------------------------------------------- /dbs/neo4j/scripts/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/neo4j/scripts/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/neo4j/scripts/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field, model_validator 2 | 3 | 4 | class Wine(BaseModel): 5 | model_config = ConfigDict( 6 | populate_by_name=True, 7 | validate_assignment=True, 8 | extra="allow", 9 | str_strip_whitespace=True, 10 | json_schema_extra={ 11 | "example": { 12 | "id": 45100, 13 | "points": 85, 14 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 15 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 16 | "price": 10.0, 17 | "variety": "Merlot", 18 | "winery": "Balduzzi", 19 | "vineyard": "Reserva", 20 | "country": "Chile", 21 | "province": "Maule Valley", 22 | "region_1": "null", 23 | "region_2": "null", 24 | "taster_name": "Michael Schachner", 25 | "taster_twitter_handle": "@wineschach", 26 | } 27 | }, 28 | ) 29 | 30 | id: int 31 | points: int 32 | title: str 33 | description: str | None 34 | price: float | None 35 | variety: str | None 36 | winery: str | None 37 | vineyard: str | None = Field(..., alias="designation") 38 | country: str | None 39 | province: str | None 40 | region_1: str | None 41 | region_2: str | None 42 | taster_name: str | None 43 | taster_twitter_handle: str | None 44 | 45 | @model_validator(mode="before") 46 | def _fill_country_unknowns(cls, values): 47 | "Fill in missing country values with 'Unknown', as we always want this field to be queryable" 48 | country = values.get("country") 49 | if country is None or country == "null": 50 | values["country"] = "Unknown" 51 | return values 52 | 53 | 54 | if __name__ == "__main__": 55 | data = { 56 | "id": 45100, 57 | "points": 85, 58 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 59 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 60 | "price": 10, # Test if field is cast to float 61 | "variety": "Merlot", 62 | "winery": "Balduzzi", 63 | "designation": "Reserva", # Test if field is renamed 64 | "country": "null", # Test unknown country 65 | "province": " Maule Valley ", # Test if field is stripped 66 | "region_1": "null", 67 | "region_2": "null", 68 | "taster_name": "Michael Schachner", 69 | "taster_twitter_handle": "@wineschach", 70 | } 71 | from pprint import pprint 72 | 73 | wine = Wine(**data) 74 | pprint(wine.model_dump()) 75 | -------------------------------------------------------------------------------- /dbs/qdrant/.env.example: -------------------------------------------------------------------------------- 1 | QDRANT_VERSION = "v1.6.1" 2 | QDRANT_PORT = 6333 3 | QDRANT_HOST = "localhost" 4 | QDRANT_SERVICE = "qdrant" 5 | API_PORT = 8000 6 | EMBEDDING_MODEL_CHECKPOINT = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" 7 | ONNX_MODEL_FILENAME = "model_optimized_quantized.onnx" 8 | 9 | # Container image tag 10 | TAG = "0.1.0" 11 | 12 | # Docker project namespace (defaults to the current folder name if not set) 13 | COMPOSE_PROJECT_NAME = qdrant_wine -------------------------------------------------------------------------------- /dbs/qdrant/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | WORKDIR /wine 4 | 5 | COPY ./requirements.txt /wine/requirements.txt 6 | 7 | RUN pip install --no-cache-dir -U pip wheel setuptools 8 | RUN pip install --no-cache-dir -r /wine/requirements.txt 9 | 10 | COPY ./api /wine/api 11 | COPY ./schemas /wine/schemas 12 | 13 | EXPOSE 8000 -------------------------------------------------------------------------------- /dbs/qdrant/README.md: -------------------------------------------------------------------------------- 1 | # Qdrant 2 | 3 | [Qdrant](https://qdrant.tech/) is a vector database and vector similarity search engine written in Rust. The primary use case for a vector database is to retrieve results that are most semantically similar to the input natural language query. The semantic similarity is obtained by comparing the sentence embeddings (which are n-dimensional vectors) between the input query and the data stored in the database. Most vector DBs, including Qdrant, store both the metadata (as JSON) and the sentence embeddings of text on which we want to search (as vectors), allowing us to perform much more flexible searches than keyword-only search databases. 4 | 5 | Code is provided for ingesting the wine reviews dataset into Qdrant. In addition, a query API written in FastAPI is also provided that allows a user to query available endpoints. As always in FastAPI, documentation is available via OpenAPI (http://localhost:8005/docs). 6 | 7 | * Unlike "normal" databases, in a vector DB, the vectorization process is the biggest bottleneck, and because a lot of vector DBs are relatively new, they do not yet support async indexing (although they will soon). 8 | * [Pydantic](https://docs.pydantic.dev) is used for schema validation, both prior to data ingestion and during API request handling 9 | * For ease of reproducibility during development, the whole setup is orchestrated and deployed via docker 10 | 11 | ## Setup 12 | 13 | Note that this code base has been tested in Python 3.10, and requires a minimum of Python 3.10 to work. Install dependencies via `requirements.txt`. 14 | 15 | ```sh 16 | # Setup the environment for the first time 17 | python -m venv qdrant_venv # python -> python 3.10 18 | 19 | # Activate the environment (for subsequent runs) 20 | source qdrant_venv/bin/activate 21 | 22 | python -m pip install -r requirements.txt 23 | ``` 24 | 25 | --- 26 | 27 | ## Step 1: Set up containers 28 | 29 | Docker compose files are provided, which start a persistent-volume Qdrant database with credentials specified in `.env`. The `qdrant` variable in the environment file under the `fastapi` service indicates that we are opening up the database service to FastAPI (running as a separate service, in a separate container) downstream. Both containers can communicate with one another with the common network that they share, on the exact port numbers specified. 30 | 31 | The database and API services can be restarted at any time for maintenance and updates by simply running the `docker restart ` command. 32 | 33 | **💡 Note:** The setup shown here would not be ideal in production, as there are other details related to security and scalability that are not addressed via simple docker, but, this is a good starting point to begin experimenting! 34 | 35 | ### Use `sbert` model 36 | 37 | If using the `sbert` model [from the sentence-transformers repo](https://www.sbert.net/) directly, use the provided `docker-compose.yml` to initiate separate containers, one that runs Qdrant, and another one that serves as an API on top of the database. 38 | 39 | **⚠️ Note**: This approach will attempt to run `sbert` on a GPU if available, and if not, on CPU (while utilizing all CPU cores). 40 | 41 | ``` 42 | docker compose -f docker-compose.yml up -d 43 | ``` 44 | Tear down the services using the following command. 45 | 46 | ``` 47 | docker compose -f docker-compose.yml down 48 | ``` 49 | 50 | ## Step 2: Ingest the data 51 | 52 | We ingest both the JSON data for full-text search and filtering, as well as the sentence embedding vectors (for similarity search) into Qdrant. For this dataset, it's reasonable to expect that a simple concatenation of fields like `title`, `variety` and `description` would result in a useful sentence embedding that can be compared against a search query which is also converted to a vector during query time. 53 | 54 | As an example, consider the following data snippet form the `data/` directory in this repo: 55 | 56 | ```json 57 | "title": "Castello San Donato in Perano 2009 Riserva (Chianti Classico)", 58 | "description": "Made from a blend of 85% Sangiovese and 15% Merlot, this ripe wine delivers soft plum, black currants, clove and cracked pepper sensations accented with coffee and espresso notes. A backbone of firm tannins give structure. Drink now through 2019.", 59 | "variety": "Red Blend" 60 | ``` 61 | 62 | The three fields are concatenated for vectorization as follows: 63 | 64 | ```py 65 | to_vectorize = data["variety"] + data["title"] + data["description"] 66 | ``` 67 | 68 | ### Choice of embedding model 69 | 70 | [SentenceTransformers](https://www.sbert.net/) is a Python framework for a range of sentence and text embeddings. It results from extensive work on fine-tuning BERT to work well on semantic similarity tasks using Siamese BERT networks, where the model is trained to predict the similarity between sentence pairs. The original work is [described here](https://arxiv.org/abs/1908.10084). 71 | 72 | #### Why use sentence transformers? 73 | 74 | Although larger and more powerful text embedding models exist (such as [OpenAI embeddings](https://platform.openai.com/docs/guides/embeddings)), they can become really expensive as they are not free, and charge per token of text. SentenceTransformers are free and open-source, and have been optimized for years for performance, both to utilize all CPU cores and for reduced size while maintaining performance. A full list of sentence transformer models [is in the project page](https://www.sbert.net/docs/pretrained_models.html). 75 | 76 | For this work, it makes sense to use among the fastest models in this list, which is the `multi-qa-MiniLM-L6-cos-v1` **uncased** model. As per the docs, it was tuned for semantic search and question answering, and generates sentence embeddings for single sentences or paragraphs up to a maximum sequence length of 512. It was trained on 215M question answer pairs from various sources. Compared to the more general-purpose `all-MiniLM-L6-v2` model, it shows slightly improved performance on semantic search tasks while offering a similar level of performance. [See the sbert docs](https://www.sbert.net/docs/pretrained_models.html) for more details on performance comparisons between the various pretrained models. 77 | 78 | ### Run data loader 79 | 80 | Data is ingested into the Qdrant database through the scripts in the `scripts` directly. The scripts validate the input JSON data via [Pydantic](https://docs.pydantic.dev), and then index both the JSON data and the vectors to Qdrant using the [Qdrant Python client](https://github.com/qdrant/qdrant-client). 81 | 82 | Prior to indexing and vectorizing, we simply concatenate the key fields that contain useful information about each wine and vectorize this instead. 83 | 84 | If running on a Macbook or other development machine, it's possible to generate sentence embeddings using the original `sbert` model as per the `EMBEDDING_MODEL_CHECKPOINT` variable in the `.env` file. 85 | 86 | ```sh 87 | cd scripts 88 | python bulk_index_sbert.py 89 | ``` 90 | 91 | Depending on the CPU on your machine, this may take a while. On a 2022 M2 Macbook Pro, vectorizing and bulk-indexing ~130k records took about 25 minutes. When tested on an AWS EC2 T2 medium instance, the same process took just over an hour. 92 | 93 | ## Step 3: Test API 94 | 95 | Once the data has been successfully loaded into Qdrant and the containers are up and running, we can test out a search query via an HTTP request as follows. 96 | 97 | ```sh 98 | curl -X 'GET' \ 99 | 'http://0.0.0.0:8005/wine/search?terms=tuscany%20red&max_price=100&country=Italy' 100 | ``` 101 | 102 | This cURL request passes the search terms "**tuscany red**", along with the country "Italy" and a maximum price of "100" to the `/wine/search/` endpoint, which is then parsed into a working filter query to Qdrant by the FastAPI backend. The query runs and retrieves results that are semantically similar to the input query for red Tuscan wines, and, if the setup was done correctly, we should see the following response: 103 | 104 | ```json 105 | [ 106 | { 107 | "id": 8456, 108 | "country": "Italy", 109 | "province": "Tuscany", 110 | "title": "Petra 2008 Petra Red (Toscana)", 111 | "description": "From one of Italy's most important showcase designer wineries, this blend of Cabernet Sauvignon and Merlot lives up to its super Tuscan celebrity. It is gently redolent of dark chocolate, ripe fruit, leather, tobacco and crushed black pepper—the bouquet's elegant moderation is one of its strongest points. The mouthfeel is rich, creamy and long. Drink after 2018.", 112 | "points": 92, 113 | "price": 80.0, 114 | "variety": "Red Blend", 115 | "winery": "Petra" 116 | }, 117 | { 118 | "id": 896, 119 | "country": "Italy", 120 | "province": "Tuscany", 121 | "title": "Le Buche 2006 Giuseppe Olivi Memento Red (Toscana)", 122 | "description": "Le Buche is an interesting winery to watch, and its various Tuscan blends show great promise. Memento is equal parts Sangiovese and Syrah with a soft, velvety texture and a bright berry finish.", 123 | "points": 90, 124 | "price": 45.0, 125 | "variety": "Red Blend", 126 | "winery": "Le Buche" 127 | }, 128 | { 129 | "id": 9343, 130 | "country": "Italy", 131 | "province": "Tuscany", 132 | "title": "Poggio Mandorlo 2008 Red (Toscana)", 133 | "description": "Made from Merlot and Cabernet Franc, this structured red offers aromas of black currant, toast, graphite and a whiff of cedar. The firm palate offers coconut, coffee, grilled sage and red berry alongside bracing tannins. Drink sooner rather than later to capture the fruit richness.", 134 | "points": 89, 135 | "price": 60.0, 136 | "variety": "Red Blend", 137 | "winery": "Poggio Mandorlo" 138 | } 139 | ] 140 | ``` 141 | 142 | Not bad! This example correctly returns some highly rated Tuscan red wines form Italy along with their price. More specific search queries, such as low/high acidity, or flavour profiles of wines can also be entered to get more relevant results by country. 143 | 144 | ## Step 4: Extend the API 145 | 146 | The API can be easily extended with the provided structure. 147 | 148 | - The `schemas` directory houses the Pydantic schemas, both for the data input as well as for the endpoint outputs 149 | - As the data model gets more complex, we can add more files and separate the ingestion logic from the API logic here 150 | - The `api/routers` directory contains the endpoint routes so that we can provide additional endpoint that answer more business questions 151 | - For e.g.: "What are the top rated wines from Argentina?" 152 | - In general, it makes sense to organize specific business use cases into their own router files 153 | - The `api/main.py` file collects all the routes and schemas to run the API 154 | 155 | 156 | #### Existing endpoints 157 | 158 | As an example, a search endpoint is implemented and can be accessed via the API at the following URL. 159 | 160 | ``` 161 | GET 162 | /wine/search 163 | Semantic similarity search 164 | ``` -------------------------------------------------------------------------------- /dbs/qdrant/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/qdrant/api/__init__.py -------------------------------------------------------------------------------- /dbs/qdrant/api/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | ) 9 | 10 | qdrant_service: str 11 | qdrant_port: str 12 | qdrant_host: str 13 | qdrant_service: str 14 | api_port: str 15 | embedding_model_checkpoint: str 16 | onnx_model_filename: str 17 | tag: str 18 | -------------------------------------------------------------------------------- /dbs/qdrant/api/main.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from contextlib import asynccontextmanager 3 | from functools import lru_cache 4 | 5 | from fastapi import FastAPI 6 | from fastapi.middleware.cors import CORSMiddleware 7 | from qdrant_client import QdrantClient 8 | from sentence_transformers import SentenceTransformer 9 | 10 | from api.config import Settings 11 | from api.routers import rest 12 | 13 | model_type = "sbert" 14 | 15 | 16 | @lru_cache() 17 | def get_settings(): 18 | # Use lru_cache to avoid loading .env file for every request 19 | return Settings() 20 | 21 | 22 | @asynccontextmanager 23 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 24 | """Async context manager for Qdrant database connection.""" 25 | settings = get_settings() 26 | model_checkpoint = settings.embedding_model_checkpoint 27 | app.model = SentenceTransformer(model_checkpoint) 28 | app.model_type = "sbert" 29 | # Define Qdrant client 30 | app.client = QdrantClient(host=settings.qdrant_service, port=settings.qdrant_port, timeout=None) 31 | print("Successfully connected to Qdrant") 32 | yield 33 | print("Successfully closed Qdrant connection and released resources") 34 | 35 | 36 | app = FastAPI( 37 | title="REST API for wine reviews on Qdrant", 38 | description=( 39 | "Query from a Qdrant database of 130k wine reviews from the Wine Enthusiast magazine" 40 | ), 41 | version=get_settings().tag, 42 | lifespan=lifespan, 43 | ) 44 | 45 | 46 | @app.get("/", include_in_schema=False) 47 | async def root(): 48 | return { 49 | "message": "REST API for querying Qdrant database of 130k wine reviews from the Wine Enthusiast magazine" 50 | } 51 | 52 | 53 | app.add_middleware( 54 | CORSMiddleware, 55 | allow_origins=["*"], 56 | allow_credentials=True, 57 | allow_methods=["*"], 58 | allow_headers=["*"], 59 | expose_headers=["*"], 60 | ) 61 | 62 | # Attach routes 63 | app.include_router(rest.router, prefix="/wine", tags=["wine"]) 64 | -------------------------------------------------------------------------------- /dbs/qdrant/api/routers/rest.py: -------------------------------------------------------------------------------- 1 | from api.schemas.rest import CountByCountry, SimilaritySearch 2 | from fastapi import APIRouter, HTTPException, Query, Request 3 | from qdrant_client.http import models 4 | 5 | router = APIRouter() 6 | 7 | 8 | # --- Routes --- 9 | 10 | 11 | @router.get( 12 | "/search", 13 | response_model=list[SimilaritySearch], 14 | response_description="Search for wines via semantically similar terms", 15 | ) 16 | def search_by_similarity( 17 | request: Request, 18 | terms: str = Query( 19 | description="Specify terms to search for in the variety, title and description" 20 | ), 21 | ) -> list[SimilaritySearch] | None: 22 | COLLECTION = "wines" 23 | result = _search_by_similarity(request, COLLECTION, terms) 24 | if not result: 25 | raise HTTPException( 26 | status_code=404, 27 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 28 | ) 29 | return result 30 | 31 | 32 | @router.get( 33 | "/search_by_country", 34 | response_model=list[SimilaritySearch], 35 | response_description="Search for wines via semantically similar terms from a particular country", 36 | ) 37 | def search_by_similarity_and_country( 38 | request: Request, 39 | terms: str = Query( 40 | description="Specify terms to search for in the variety, title and description" 41 | ), 42 | country: str = Query(description="Country name to search for wines from"), 43 | ) -> list[SimilaritySearch] | None: 44 | COLLECTION = "wines" 45 | result = _search_by_similarity_and_country(request, COLLECTION, terms, country) 46 | if not result: 47 | raise HTTPException( 48 | status_code=404, 49 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 50 | ) 51 | return result 52 | 53 | 54 | @router.get( 55 | "/search_by_filters", 56 | response_model=list[SimilaritySearch], 57 | response_description="Search for wines via semantically similar terms with added filters", 58 | ) 59 | def search_by_similarity_and_filters( 60 | request: Request, 61 | terms: str = Query( 62 | description="Specify terms to search for in the variety, title and description" 63 | ), 64 | country: str = Query(description="Country name to search for wines from"), 65 | points: int = Query(default=85, description="Minimum number of points for a wine"), 66 | price: float = Query(default=100.0, description="Maximum price for a wine"), 67 | ) -> list[SimilaritySearch] | None: 68 | COLLECTION = "wines" 69 | result = _search_by_similarity_and_filters(request, COLLECTION, terms, country, points, price) 70 | if not result: 71 | raise HTTPException( 72 | status_code=404, 73 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 74 | ) 75 | return result 76 | 77 | 78 | @router.get( 79 | "/count_by_country", 80 | response_model=CountByCountry, 81 | response_description="Get counts of wine for a particular country", 82 | ) 83 | def count_by_country( 84 | request: Request, 85 | country: str = Query(description="Country name to get counts for"), 86 | ) -> CountByCountry | None: 87 | COLLECTION = "wines" 88 | result = _count_by_country(request, COLLECTION, country) 89 | if not result: 90 | raise HTTPException( 91 | status_code=404, 92 | detail=f"No wine with the provided country '{country}' found in database - please try again", 93 | ) 94 | return result 95 | 96 | 97 | @router.get( 98 | "/count_by_filters", 99 | response_model=CountByCountry, 100 | response_description="Get counts of wine for a particular country, filtered by points and price", 101 | ) 102 | def count_by_filters( 103 | request: Request, 104 | country: str = Query(description="Country name to get counts for"), 105 | points: int = Query(default=85, description="Minimum number of points for a wine"), 106 | price: float = Query(default=100.0, description="Maximum price for a wine"), 107 | ) -> CountByCountry | None: 108 | COLLECTION = "wines" 109 | result = _count_by_filters(request, COLLECTION, country, points, price) 110 | if not result: 111 | raise HTTPException( 112 | status_code=404, 113 | detail=f"No wine with the provided country '{country}' found in database - please try again", 114 | ) 115 | return result 116 | 117 | 118 | # --- Helper functions --- 119 | 120 | 121 | def _search_by_similarity( 122 | request: Request, 123 | collection: str, 124 | terms: str, 125 | ) -> list[SimilaritySearch] | None: 126 | vector = request.app.model.encode(terms, batch_size=64).tolist() 127 | # Use `vector` for similarity search on the closest vectors in the collection 128 | search_result = request.app.client.search( 129 | collection_name=collection, query_vector=vector, limit=5 130 | ) 131 | # `search_result` contains found vector ids with similarity scores along with the stored payload 132 | # For now we are interested in payload only 133 | payloads = [hit.payload for hit in search_result] 134 | if not payloads: 135 | return None 136 | return payloads 137 | 138 | 139 | def _search_by_similarity_and_country( 140 | request: Request, collection: str, terms: str, country: str 141 | ) -> list[SimilaritySearch] | None: 142 | vector = request.app.model.encode(terms, batch_size=64).tolist() 143 | filter = models.Filter( 144 | **{ 145 | "must": [ 146 | { 147 | "key": "country", 148 | "match": { 149 | "value": country, 150 | }, 151 | }, 152 | ] 153 | } 154 | ) 155 | search_result = request.app.client.search( 156 | collection_name=collection, query_vector=vector, query_filter=filter, limit=5 157 | ) 158 | payloads = [hit.payload for hit in search_result] 159 | if not payloads: 160 | return None 161 | return payloads 162 | 163 | 164 | def _search_by_similarity_and_filters( 165 | request: Request, 166 | collection: str, 167 | terms: str, 168 | country: str, 169 | points: int, 170 | price: float, 171 | ) -> list[SimilaritySearch] | None: 172 | vector = request.app.model.encode(terms, batch_size=64).tolist() 173 | filter = models.Filter( 174 | **{ 175 | "must": [ 176 | { 177 | "key": "country", 178 | "match": { 179 | "value": country, 180 | }, 181 | }, 182 | { 183 | "key": "price", 184 | "range": { 185 | "lte": price, 186 | }, 187 | }, 188 | { 189 | "key": "points", 190 | "range": { 191 | "gte": points, 192 | }, 193 | }, 194 | ] 195 | } 196 | ) 197 | search_result = request.app.client.search( 198 | collection_name=collection, query_vector=vector, query_filter=filter, limit=5 199 | ) 200 | payloads = [hit.payload for hit in search_result] 201 | if not payloads: 202 | return None 203 | return payloads 204 | 205 | 206 | def _count_by_country( 207 | request: Request, 208 | collection: str, 209 | country: str, 210 | ) -> CountByCountry | None: 211 | filter = models.Filter( 212 | **{ 213 | "must": [ 214 | { 215 | "key": "country", 216 | "match": { 217 | "value": country, 218 | }, 219 | }, 220 | ] 221 | } 222 | ) 223 | result = request.app.client.count(collection_name=collection, count_filter=filter) 224 | if not result: 225 | return None 226 | return result 227 | 228 | 229 | def _count_by_filters( 230 | request: Request, 231 | collection: str, 232 | country: str, 233 | points: int, 234 | price: float, 235 | ) -> CountByCountry | None: 236 | filter = models.Filter( 237 | **{ 238 | "must": [ 239 | { 240 | "key": "country", 241 | "match": { 242 | "value": country, 243 | }, 244 | }, 245 | { 246 | "key": "price", 247 | "range": { 248 | "lte": price, 249 | }, 250 | }, 251 | { 252 | "key": "points", 253 | "range": { 254 | "gte": points, 255 | }, 256 | }, 257 | ] 258 | } 259 | ) 260 | result = request.app.client.count(collection_name=collection, count_filter=filter) 261 | if not result: 262 | return None 263 | return result 264 | -------------------------------------------------------------------------------- /dbs/qdrant/api/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/qdrant/api/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/qdrant/api/schemas/rest.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | 3 | 4 | class SimilaritySearch(BaseModel): 5 | model_config = ConfigDict( 6 | extra="ignore", 7 | json_schema_extra={ 8 | "example": { 9 | "wineID": 3845, 10 | "country": "Italy", 11 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 12 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 13 | "points": 93, 14 | "price": 16, 15 | "variety": "Red Blend", 16 | "winery": "Castellinuzza e Piuca", 17 | } 18 | }, 19 | ) 20 | 21 | id: int 22 | country: str 23 | province: str | None 24 | title: str 25 | description: str | None 26 | points: int 27 | price: float | str | None 28 | variety: str | None 29 | winery: str | None 30 | 31 | 32 | class CountByCountry(BaseModel): 33 | count: int | None 34 | -------------------------------------------------------------------------------- /dbs/qdrant/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | qdrant: 5 | image: qdrant/qdrant:${QDRANT_VERSION} 6 | restart: unless-stopped 7 | environment: 8 | - QDRANT_HOST=${QDRANT_HOST} 9 | ports: 10 | - ${QDRANT_PORT}:6333 11 | volumes: 12 | - qdrant_storage:/qdrant/storage 13 | networks: 14 | - wine 15 | 16 | fastapi: 17 | image: qdrant_wine_fastapi:${TAG} 18 | build: 19 | context: . 20 | dockerfile: Dockerfile 21 | restart: unless-stopped 22 | env_file: 23 | - .env 24 | ports: 25 | - ${API_PORT}:8000 26 | depends_on: 27 | - qdrant 28 | volumes: 29 | - ./:/wine 30 | networks: 31 | - wine 32 | command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload 33 | 34 | volumes: 35 | qdrant_storage: 36 | 37 | networks: 38 | wine: 39 | driver: bridge -------------------------------------------------------------------------------- /dbs/qdrant/requirements.txt: -------------------------------------------------------------------------------- 1 | qdrant-client~=1.6.0 2 | transformers~=4.33.0 3 | sentence-transformers~=2.2.0 4 | pydantic~=2.4.0 5 | pydantic-settings>=2.0.0 6 | python-dotenv>=1.0.0 7 | fastapi~=0.104.0 8 | httpx>=0.24.0 9 | aiohttp>=3.8.4 10 | uvloop>=0.17.0 11 | uvicorn>=0.21.0, <1.0.0 12 | srsly>=2.4.6 13 | -------------------------------------------------------------------------------- /dbs/qdrant/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/qdrant/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/qdrant/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field, model_validator 2 | 3 | 4 | class Wine(BaseModel): 5 | model_config = ConfigDict( 6 | populate_by_name=True, 7 | validate_assignment=True, 8 | extra="allow", 9 | str_strip_whitespace=True, 10 | json_schema_extra={ 11 | "example": { 12 | "id": 45100, 13 | "points": 85, 14 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 15 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 16 | "price": 10.0, 17 | "variety": "Merlot", 18 | "winery": "Balduzzi", 19 | "vineyard": "Reserva", 20 | "country": "Chile", 21 | "province": "Maule Valley", 22 | "region_1": "null", 23 | "region_2": "null", 24 | "taster_name": "Michael Schachner", 25 | "taster_twitter_handle": "@wineschach", 26 | } 27 | }, 28 | ) 29 | 30 | id: int 31 | points: int 32 | title: str 33 | description: str | None 34 | price: float | None 35 | variety: str | None 36 | winery: str | None 37 | vineyard: str | None = Field(..., alias="designation") 38 | country: str | None 39 | province: str | None 40 | region_1: str | None 41 | region_2: str | None 42 | taster_name: str | None 43 | taster_twitter_handle: str | None 44 | 45 | @model_validator(mode="before") 46 | def _fill_country_unknowns(cls, values): 47 | "Fill in missing country values with 'Unknown', as we always want this field to be queryable" 48 | country = values.get("country") 49 | if not country: 50 | values["country"] = "Unknown" 51 | return values 52 | 53 | @model_validator(mode="before") 54 | def _add_to_vectorize_fields(cls, values): 55 | "Add a field to_vectorize that will be used to create sentence embeddings" 56 | variety = values.get("variety", "") 57 | title = values.get("title", "") 58 | description = values.get("description", "") 59 | to_vectorize = list(filter(None, [variety, title, description])) 60 | values["to_vectorize"] = " ".join(to_vectorize).strip() 61 | return values 62 | -------------------------------------------------------------------------------- /dbs/qdrant/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/qdrant/scripts/__init__.py -------------------------------------------------------------------------------- /dbs/qdrant/scripts/bulk_index_sbert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from concurrent.futures import ProcessPoolExecutor 5 | from functools import lru_cache 6 | from pathlib import Path 7 | from typing import Any, Iterator 8 | 9 | import srsly 10 | from dotenv import load_dotenv 11 | from qdrant_client import QdrantClient 12 | from qdrant_client.http import models 13 | 14 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 15 | from api.config import Settings 16 | from schemas.wine import Wine 17 | from sentence_transformers import SentenceTransformer 18 | 19 | load_dotenv() 20 | # Custom types 21 | JsonBlob = dict[str, Any] 22 | 23 | 24 | class FileNotFoundError(Exception): 25 | pass 26 | 27 | 28 | @lru_cache() 29 | def get_settings(): 30 | # Use lru_cache to avoid loading .env file for every request 31 | return Settings() 32 | 33 | 34 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: 35 | """ 36 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 37 | """ 38 | for i in range(0, len(item_list), chunksize): 39 | yield tuple(item_list[i : i + chunksize]) 40 | 41 | 42 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 43 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 44 | file_path = data_dir / filename 45 | if not file_path.is_file(): 46 | # File may not have been uncompressed yet so try to do that first 47 | data = srsly.read_gzip_jsonl(file_path) 48 | # This time if it isn't there it really doesn't exist 49 | if not file_path.is_file(): 50 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 51 | else: 52 | data = srsly.read_gzip_jsonl(file_path) 53 | return data 54 | 55 | 56 | def validate( 57 | data: list[JsonBlob], 58 | exclude_none: bool = False, 59 | ) -> list[JsonBlob]: 60 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 61 | return validated_data 62 | 63 | 64 | def create_index( 65 | client: QdrantClient, 66 | collection_name: str, 67 | ) -> None: 68 | # Field that will be vectorized requires special treatment for tokenization 69 | client.create_payload_index( 70 | collection_name=collection_name, 71 | field_name="description", 72 | field_schema=models.TextIndexParams( 73 | type="text", 74 | tokenizer=models.TokenizerType.WORD, 75 | min_token_len=3, 76 | max_token_len=15, 77 | lowercase=True, 78 | ), 79 | ) 80 | # Lowercase fields that will be filtered on 81 | for field_name in ["country", "province", "region_1", "region_2", "variety"]: 82 | client.create_payload_index( 83 | collection_name=collection_name, 84 | field_name=field_name, 85 | field_schema="keyword", 86 | ) 87 | 88 | 89 | def add_vectors_to_index(data_chunk: tuple[JsonBlob, ...]) -> None: 90 | settings = get_settings() 91 | collection = "wines" 92 | client = QdrantClient(host=settings.qdrant_host, port=settings.qdrant_port, timeout=None) 93 | data = validate(data_chunk, exclude_none=True) 94 | 95 | # Load a sentence transformer model for semantic similarity from a specified checkpoint 96 | model_id = get_settings().embedding_model_checkpoint 97 | MODEL = SentenceTransformer(model_id) 98 | 99 | ids = [item["id"] for item in data] 100 | to_vectorize = [text.pop("to_vectorize") for text in data] 101 | sentence_embeddings = [ 102 | MODEL.encode(text.lower(), batch_size=64).tolist() for text in to_vectorize 103 | ] 104 | print(f"Finished vectorizing data in the ID range {min(ids)}-{max(ids)}") 105 | try: 106 | # Upsert payload 107 | client.upsert( 108 | collection_name=collection, 109 | points=models.Batch( 110 | ids=ids, 111 | payloads=data, 112 | vectors=sentence_embeddings, 113 | ), 114 | ) 115 | print(f"Indexed ID range {min(ids)}-{max(ids)} to db") 116 | except Exception as e: 117 | print(f"{e}: Failed to index ID range {min(ids)}-{max(ids)} to db") 118 | return ids 119 | 120 | 121 | def main(data: list[JsonBlob]) -> None: 122 | settings = get_settings() 123 | COLLECTION = "wines" 124 | client = QdrantClient(host=settings.qdrant_host, port=settings.qdrant_port, timeout=None) 125 | client.recreate_collection( 126 | collection_name=COLLECTION, 127 | vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE), 128 | ) 129 | # Create payload with text field whose sentence embeddings will be added to the index 130 | create_index(client, COLLECTION) 131 | print("Created index") 132 | 133 | print("Processing chunks") 134 | with ProcessPoolExecutor(max_workers=WORKERS) as executor: 135 | chunked_data = chunk_iterable(data, CHUNKSIZE) 136 | for _ in executor.map(add_vectors_to_index, chunked_data): 137 | pass 138 | 139 | 140 | if __name__ == "__main__": 141 | # fmt: off 142 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 143 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 144 | parser.add_argument("--chunksize", type=int, default=512, help="Size of each chunk to break the dataset into before processing") 145 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 146 | parser.add_argument("--workers", type=int, default=4, help="Number of workers to use for vectorization") 147 | args = vars(parser.parse_args()) 148 | # fmt: on 149 | 150 | LIMIT = args["limit"] 151 | DATA_DIR = Path(__file__).parents[3] / "data" 152 | FILENAME = args["filename"] 153 | CHUNKSIZE = args["chunksize"] 154 | WORKERS = args["workers"] 155 | 156 | data = list(get_json_data(DATA_DIR, FILENAME)) 157 | 158 | if data: 159 | data = data[:LIMIT] if LIMIT > 0 else data 160 | main(data) 161 | print("Finished execution!") 162 | -------------------------------------------------------------------------------- /dbs/weaviate/.env.example: -------------------------------------------------------------------------------- 1 | WEAVIATE_VERSION = "1.20.2" 2 | WEAVIATE_PORT = 8080 3 | WEAVIATE_HOST = "localhost" 4 | WEAVIATE_SERVICE = "weaviate" 5 | API_PORT = 8004 6 | EMBEDDING_MODEL_CHECKPOINT = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" 7 | ONNX_MODEL_FILENAME = "model_optimized_quantized.onnx" 8 | 9 | # Container image tag 10 | TAG = "0.2.0" 11 | 12 | # Docker project namespace (defaults to the current folder name if not set) 13 | COMPOSE_PROJECT_NAME = weaviate_wine -------------------------------------------------------------------------------- /dbs/weaviate/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim-bullseye 2 | 3 | WORKDIR /wine 4 | 5 | COPY ./requirements.txt /wine/requirements.txt 6 | 7 | RUN pip install --no-cache-dir -U pip wheel setuptools 8 | RUN pip install --no-cache-dir -r /wine/requirements.txt 9 | 10 | COPY ./api /wine/api 11 | 12 | EXPOSE 8000 -------------------------------------------------------------------------------- /dbs/weaviate/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/weaviate/api/__init__.py -------------------------------------------------------------------------------- /dbs/weaviate/api/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | ) 9 | 10 | weaviate_service: str 11 | weaviate_port: str 12 | weaviate_host: str 13 | weaviate_service: str 14 | api_port: int 15 | embedding_model_checkpoint: str 16 | onnx_model_filename: str 17 | tag: str 18 | -------------------------------------------------------------------------------- /dbs/weaviate/api/main.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from contextlib import asynccontextmanager 3 | from functools import lru_cache 4 | 5 | import weaviate 6 | from fastapi import FastAPI 7 | from sentence_transformers import SentenceTransformer 8 | 9 | from api.config import Settings 10 | from api.routers import rest 11 | 12 | model_type = "sbert" 13 | 14 | 15 | @lru_cache() 16 | def get_settings(): 17 | # Use lru_cache to avoid loading .env file for every request 18 | return Settings() 19 | 20 | 21 | @asynccontextmanager 22 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 23 | """Async context manager for Weaviate database connection.""" 24 | settings = get_settings() 25 | model_checkpoint = settings.embedding_model_checkpoint 26 | app.model = SentenceTransformer(model_checkpoint) 27 | app.model_type = "sbert" 28 | # Create Weaviate client 29 | HOST = settings.weaviate_service 30 | PORT = settings.weaviate_port 31 | app.client = weaviate.Client(f"http://{HOST}:{PORT}") 32 | print("Successfully connected to Weaviate") 33 | yield 34 | print("Successfully closed Weaviate connection and released resources") 35 | 36 | 37 | app = FastAPI( 38 | title="REST API for wine reviews on Weaviate", 39 | description=( 40 | "Query from a Weaviate database of 130k wine reviews from the Wine Enthusiast magazine" 41 | ), 42 | version=get_settings().tag, 43 | lifespan=lifespan, 44 | ) 45 | 46 | 47 | @app.get("/", include_in_schema=False) 48 | async def root(): 49 | return { 50 | "message": "REST API for querying Weaviate database of 130k wine reviews from the Wine Enthusiast magazine" 51 | } 52 | 53 | 54 | # Attach routes 55 | app.include_router(rest.router, prefix="/wine", tags=["wine"]) 56 | -------------------------------------------------------------------------------- /dbs/weaviate/api/routers/rest.py: -------------------------------------------------------------------------------- 1 | from api.schemas.rest import CountByCountry, SimilaritySearch 2 | from fastapi import APIRouter, HTTPException, Query, Request 3 | 4 | router = APIRouter() 5 | 6 | 7 | # --- Routes --- 8 | 9 | 10 | @router.get( 11 | "/search", 12 | response_model=list[SimilaritySearch], 13 | response_description="Search for wines via semantically similar terms", 14 | ) 15 | def search_by_similarity( 16 | request: Request, 17 | terms: str = Query( 18 | description="Specify terms to search for in the variety, title and description" 19 | ), 20 | ) -> list[SimilaritySearch] | None: 21 | CLASS_NAME = "Wine" 22 | result = _search_by_similarity(request, CLASS_NAME, terms) 23 | if not result: 24 | raise HTTPException( 25 | status_code=404, 26 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 27 | ) 28 | return result 29 | 30 | 31 | @router.get( 32 | "/search_by_country", 33 | response_model=list[SimilaritySearch], 34 | response_description="Search for wines via semantically similar terms from a particular country", 35 | ) 36 | def search_by_similarity_and_country( 37 | request: Request, 38 | terms: str = Query( 39 | description="Specify terms to search for in the variety, title and description" 40 | ), 41 | country: str = Query(description="Country name to search for wines from"), 42 | ) -> list[SimilaritySearch] | None: 43 | CLASS_NAME = "Wine" 44 | result = _search_by_similarity_and_country(request, CLASS_NAME, terms, country) 45 | if not result: 46 | raise HTTPException( 47 | status_code=404, 48 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 49 | ) 50 | return result 51 | 52 | 53 | @router.get( 54 | "/search_by_filters", 55 | response_model=list[SimilaritySearch], 56 | response_description="Search for wines via semantically similar terms with added filters", 57 | ) 58 | def search_by_similarity_and_filters( 59 | request: Request, 60 | terms: str = Query( 61 | description="Specify terms to search for in the variety, title and description" 62 | ), 63 | country: str = Query(description="Country name to search for wines from"), 64 | points: int = Query(default=85, description="Minimum number of points for a wine"), 65 | price: float = Query(default=100.0, description="Maximum price for a wine"), 66 | ) -> list[SimilaritySearch] | None: 67 | CLASS_NAME = "Wine" 68 | result = _search_by_similarity_and_filters(request, CLASS_NAME, terms, country, points, price) 69 | if not result: 70 | raise HTTPException( 71 | status_code=404, 72 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 73 | ) 74 | return result 75 | 76 | 77 | @router.get( 78 | "/count_by_country", 79 | response_model=CountByCountry, 80 | response_description="Get counts of wine for a particular country", 81 | ) 82 | def count_by_country( 83 | request: Request, 84 | country: str = Query(description="Country name to get counts for"), 85 | ) -> CountByCountry | None: 86 | CLASS_NAME = "Wine" 87 | result = _count_by_country(request, CLASS_NAME, country) 88 | if not result: 89 | raise HTTPException( 90 | status_code=404, 91 | detail=f"No wine with the provided country '{country}' found in database - please try again", 92 | ) 93 | return result 94 | 95 | 96 | @router.get( 97 | "/count_by_filters", 98 | response_model=CountByCountry, 99 | response_description="Get counts of wine for a particular country, filtered by points and price", 100 | ) 101 | def count_by_filters( 102 | request: Request, 103 | country: str = Query(description="Country name to get counts for"), 104 | points: int = Query(default=85, description="Minimum number of points for a wine"), 105 | price: float = Query(default=100.0, description="Maximum price for a wine"), 106 | ) -> CountByCountry | None: 107 | CLASS_NAME = "Wine" 108 | result = _count_by_filters(request, CLASS_NAME, country, points, price) 109 | if not result: 110 | raise HTTPException( 111 | status_code=404, 112 | detail=f"No wine with the provided country '{country}' found in database - please try again", 113 | ) 114 | return result 115 | 116 | 117 | # --- Helper functions --- 118 | 119 | 120 | def _search_by_similarity( 121 | request: Request, class_name: str, terms: str 122 | ) -> list[SimilaritySearch] | None: 123 | # Convert input text query into a vector for lookup in the db 124 | vector = request.app.model.encode(terms, show_progress_bar=False, batch_size=128).tolist() 125 | near_vec = {"vector": vector} 126 | response = ( 127 | request.app.client.query.get( 128 | class_name, 129 | [ 130 | "wineID", 131 | "title", 132 | "description", 133 | "country", 134 | "province", 135 | "points", 136 | "price", 137 | "variety", 138 | "winery", 139 | "_additional {certainty}", 140 | ], 141 | ) 142 | .with_near_vector(near_vec) 143 | .with_limit(5) 144 | .do() 145 | ) 146 | try: 147 | payload = response["data"]["Get"][class_name] 148 | return payload 149 | except Exception as e: 150 | print(f"Error {e}: Did not obtain appropriate response from Weaviate") 151 | return None 152 | 153 | 154 | def _search_by_similarity_and_country( 155 | request: Request, 156 | class_name: str, 157 | terms: str, 158 | country: str, 159 | ) -> list[SimilaritySearch] | None: 160 | # Convert input text query into a vector for lookup in the db 161 | vector = request.app.model.encode(terms, show_progress_bar=False, batch_size=128).tolist() 162 | near_vec = {"vector": vector} 163 | where_filter = { 164 | "path": "country", 165 | "operator": "Equal", 166 | "valueText": country, 167 | } 168 | response = ( 169 | request.app.client.query.get( 170 | class_name, 171 | [ 172 | "wineID", 173 | "title", 174 | "description", 175 | "country", 176 | "province", 177 | "points", 178 | "price", 179 | "variety", 180 | "winery", 181 | "_additional {certainty}", 182 | ], 183 | ) 184 | .with_near_vector(near_vec) 185 | .with_where(where_filter) 186 | .with_limit(5) 187 | .do() 188 | ) 189 | try: 190 | payload = response["data"]["Get"][class_name] 191 | return payload 192 | except Exception as e: 193 | print(f"Error {e}: Did not obtain appropriate response from Weaviate") 194 | return None 195 | 196 | 197 | def _search_by_similarity_and_filters( 198 | request: Request, 199 | class_name: str, 200 | terms: str, 201 | country: str, 202 | points: int, 203 | price: float, 204 | ) -> list[SimilaritySearch] | None: 205 | # Convert input text query into a vector for lookup in the db 206 | vector = request.app.model.encode(terms, show_progress_bar=False, batch_size=128).tolist() 207 | near_vec = {"vector": vector} 208 | where_filter = { 209 | "operator": "And", 210 | "operands": [ 211 | { 212 | "path": "country", 213 | "operator": "Equal", 214 | "valueText": country, 215 | }, 216 | { 217 | "path": "price", 218 | "operator": "LessThan", 219 | "valueNumber": price, 220 | }, 221 | { 222 | "path": "points", 223 | "operator": "GreaterThan", 224 | "valueInt": points, 225 | }, 226 | ], 227 | } 228 | response = ( 229 | request.app.client.query.get( 230 | class_name, 231 | [ 232 | "wineID", 233 | "title", 234 | "description", 235 | "country", 236 | "province", 237 | "points", 238 | "price", 239 | "variety", 240 | "winery", 241 | "_additional {certainty}", 242 | ], 243 | ) 244 | .with_near_vector(near_vec) 245 | .with_where(where_filter) 246 | .with_limit(5) 247 | .do() 248 | ) 249 | try: 250 | payload = response["data"]["Get"][class_name] 251 | return payload 252 | except Exception as e: 253 | print(f"Error {e}: Did not obtain appropriate response from Weaviate") 254 | return None 255 | 256 | 257 | def _count_by_country( 258 | request: Request, 259 | class_name: str, 260 | country: str, 261 | ) -> CountByCountry | None: 262 | where_filter = { 263 | "operator": "And", 264 | "operands": [ 265 | { 266 | "path": "country", 267 | "operator": "Equal", 268 | "valueText": country, 269 | } 270 | ], 271 | } 272 | response = ( 273 | request.app.client.query.aggregate(class_name) 274 | .with_where(where_filter) 275 | .with_fields("meta {count}") 276 | .do() 277 | ) 278 | try: 279 | payload = response["data"]["Aggregate"][class_name] 280 | count = payload[0]["meta"] 281 | return count 282 | except Exception as e: 283 | print(f"Error {e}: Did not obtain appropriate response from Weaviate") 284 | return None 285 | 286 | 287 | def _count_by_filters( 288 | request: Request, 289 | class_name: str, 290 | country: str, 291 | points: int, 292 | price: float, 293 | ) -> CountByCountry | None: 294 | where_filter = { 295 | "operator": "And", 296 | "operands": [ 297 | { 298 | "path": "country", 299 | "operator": "Equal", 300 | "valueText": country, 301 | }, 302 | { 303 | "path": "price", 304 | "operator": "LessThan", 305 | "valueNumber": price, 306 | }, 307 | { 308 | "path": "points", 309 | "operator": "GreaterThan", 310 | "valueInt": points, 311 | }, 312 | ], 313 | } 314 | response = ( 315 | request.app.client.query.aggregate(class_name) 316 | .with_where(where_filter) 317 | .with_fields("meta {count}") 318 | .do() 319 | ) 320 | try: 321 | payload = response["data"]["Aggregate"][class_name] 322 | count = payload[0]["meta"] 323 | return count 324 | except Exception as e: 325 | print(f"Error {e}: Did not obtain appropriate response from Weaviate") 326 | return None 327 | -------------------------------------------------------------------------------- /dbs/weaviate/api/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/weaviate/api/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/weaviate/api/schemas/rest.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | 3 | 4 | class SimilaritySearch(BaseModel): 5 | model_config = ConfigDict( 6 | extra="ignore", 7 | json_schema_extra={ 8 | "example": { 9 | "wineID": 3845, 10 | "country": "Italy", 11 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 12 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 13 | "points": 93, 14 | "price": 16, 15 | "variety": "Red Blend", 16 | "winery": "Castellinuzza e Piuca", 17 | } 18 | }, 19 | ) 20 | 21 | wineID: int 22 | country: str 23 | province: str | None 24 | title: str 25 | description: str | None 26 | points: int 27 | price: float | str | None 28 | variety: str | None 29 | winery: str | None 30 | 31 | 32 | class CountByCountry(BaseModel): 33 | count: int | None 34 | -------------------------------------------------------------------------------- /dbs/weaviate/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | weaviate: 5 | image: semitechnologies/weaviate:${WEAVIATE_VERSION} 6 | ports: 7 | - ${WEAVIATE_PORT}:8080 8 | restart: on-failure:0 9 | environment: 10 | QUERY_DEFAULTS_LIMIT: 25 11 | AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' 12 | PERSISTENCE_DATA_PATH: '/var/lib/weaviate' 13 | DEFAULT_VECTORIZER_MODULE: 'none' 14 | CLUSTER_HOSTNAME: 'node1' 15 | volumes: 16 | - weaviate_data:/var/lib/weaviate 17 | networks: 18 | - wine 19 | 20 | fastapi: 21 | image: weaviate_wine_fastapi:${TAG} 22 | build: 23 | context: . 24 | dockerfile: Dockerfile 25 | restart: unless-stopped 26 | env_file: 27 | - .env 28 | ports: 29 | - ${API_PORT}:8000 30 | depends_on: 31 | - weaviate 32 | volumes: 33 | - ./:/wine 34 | networks: 35 | - wine 36 | command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload 37 | 38 | volumes: 39 | weaviate_data: 40 | 41 | networks: 42 | wine: 43 | driver: bridge -------------------------------------------------------------------------------- /dbs/weaviate/requirements.txt: -------------------------------------------------------------------------------- 1 | weaviate-client~=3.22.0 2 | transformers~=4.28.0 3 | sentence-transformers~=2.2.0 4 | pydantic~=2.0.0 5 | pydantic-settings~=2.0.0 6 | python-dotenv>=1.0.0 7 | fastapi~=0.100.0 8 | httpx>=0.24.0 9 | aiohttp>=3.8.4 10 | uvloop>=0.17.0 11 | uvicorn>=0.21.0, <1.0.0 12 | srsly>=2.4.6 -------------------------------------------------------------------------------- /dbs/weaviate/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/weaviate/schemas/__init__.py -------------------------------------------------------------------------------- /dbs/weaviate/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field, model_validator 2 | 3 | 4 | class Wine(BaseModel): 5 | model_config = ConfigDict( 6 | populate_by_name=True, 7 | validate_assignment=True, 8 | extra="allow", 9 | str_strip_whitespace=True, 10 | json_schema_extra={ 11 | "example": { 12 | "id": 45100, 13 | "points": 85, 14 | "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", 15 | "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", 16 | "price": 10.0, 17 | "variety": "Merlot", 18 | "winery": "Balduzzi", 19 | "vineyard": "Reserva", 20 | "country": "Chile", 21 | "province": "Maule Valley", 22 | "region_1": "null", 23 | "region_2": "null", 24 | "taster_name": "Michael Schachner", 25 | "taster_twitter_handle": "@wineschach", 26 | } 27 | }, 28 | ) 29 | 30 | id: int 31 | points: int 32 | title: str 33 | description: str | None 34 | price: float | None 35 | variety: str | None 36 | winery: str | None 37 | vineyard: str | None = Field(..., alias="designation") 38 | country: str | None 39 | province: str | None 40 | region_1: str | None 41 | region_2: str | None 42 | taster_name: str | None 43 | taster_twitter_handle: str | None 44 | 45 | @model_validator(mode="before") 46 | def _fill_country_unknowns(cls, values): 47 | "Fill in missing country values with 'Unknown', as we always want this field to be queryable" 48 | country = values.get("country") 49 | if not country: 50 | values["country"] = "Unknown" 51 | return values 52 | 53 | @model_validator(mode="before") 54 | def _add_to_vectorize_fields(cls, values): 55 | "Add a field to_vectorize that will be used to create sentence embeddings" 56 | variety = values.get("variety", "") 57 | country = values.get("country", "Unknown") 58 | province = values.get("province", "") 59 | title = values.get("title", "") 60 | description = values.get("description", "") 61 | to_vectorize = list(filter(None, [variety, country, province, title, description])) 62 | values["to_vectorize"] = " ".join(to_vectorize).strip() 63 | return values 64 | -------------------------------------------------------------------------------- /dbs/weaviate/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/db-hub-fastapi/532732625b5f0cbb0d1f42044346c0ad57ee4045/dbs/weaviate/scripts/__init__.py -------------------------------------------------------------------------------- /dbs/weaviate/scripts/bulk_index_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | from concurrent.futures import ProcessPoolExecutor 6 | from functools import lru_cache 7 | from pathlib import Path 8 | from typing import Any, Iterator 9 | 10 | import srsly 11 | import weaviate 12 | from dotenv import load_dotenv 13 | from optimum.onnxruntime import ORTModelForCustomTasks 14 | from optimum.pipelines import pipeline 15 | from tqdm import tqdm 16 | from transformers import AutoTokenizer 17 | from weaviate.client import Client 18 | 19 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 20 | from api.config import Settings 21 | from schemas.wine import Wine 22 | 23 | load_dotenv() 24 | # Custom types 25 | JsonBlob = dict[str, Any] 26 | 27 | 28 | class FileNotFoundError(Exception): 29 | pass 30 | 31 | 32 | @lru_cache() 33 | def get_settings(): 34 | # Use lru_cache to avoid loading .env file for every request 35 | return Settings() 36 | 37 | 38 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: 39 | """ 40 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 41 | """ 42 | for i in range(0, len(item_list), chunksize): 43 | yield tuple(item_list[i : i + chunksize]) 44 | 45 | 46 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 47 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 48 | file_path = data_dir / filename 49 | if not file_path.is_file(): 50 | # File may not have been uncompressed yet so try to do that first 51 | data = srsly.read_gzip_jsonl(file_path) 52 | # This time if it isn't there it really doesn't exist 53 | if not file_path.is_file(): 54 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 55 | else: 56 | data = srsly.read_gzip_jsonl(file_path) 57 | return data 58 | 59 | 60 | def validate( 61 | data: list[JsonBlob], 62 | exclude_none: bool = False, 63 | ) -> list[JsonBlob]: 64 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 65 | return validated_data 66 | 67 | 68 | def get_embedding_pipeline(onnx_path, model_filename: str) -> pipeline: 69 | """ 70 | Create a sentence embedding pipeline using the optimized ONNX model 71 | """ 72 | # Reload tokenizer 73 | tokenizer = AutoTokenizer.from_pretrained(onnx_path) 74 | optimized_model = ORTModelForCustomTasks.from_pretrained(onnx_path, file_name=model_filename) 75 | embedding_pipeline = pipeline("feature-extraction", model=optimized_model, tokenizer=tokenizer) 76 | return embedding_pipeline 77 | 78 | 79 | def create_or_update_schema(client: Client) -> None: 80 | # Create a schema with no vectorizer (we will be adding our own vectors) 81 | with open("settings/schema.json", "r") as f: 82 | schema = json.load(f) 83 | class_names = [class_["class"] for class_ in schema["classes"]] 84 | assert class_names, "No classes found in schema, please check schema definition and try again" 85 | if not client.schema.get()["classes"]: 86 | print(f"Creating schema with classes: {', '.join(class_names)}") 87 | client.schema.create(schema) 88 | else: 89 | print(f"Existing schema found, deleting it & creating it again...") 90 | client.schema.delete_all() 91 | client.schema.create(schema) 92 | 93 | 94 | def add_vectors_to_index(data_chunk: tuple[JsonBlob, ...]) -> None: 95 | settings = get_settings() 96 | CLASS_NAME = "Wine" 97 | HOST = settings.weaviate_host 98 | PORT = settings.weaviate_port 99 | client = weaviate.Client(f"http://{HOST}:{PORT}") 100 | data = validate(data_chunk, exclude_none=True) 101 | 102 | # Preload optimized, quantized ONNX sentence transformers model 103 | # NOTE: This requires that the script ../onnx_model/onnx_optimizer.py has been run beforehand 104 | pipeline = get_embedding_pipeline(ONNX_PATH, model_filename="model_optimized_quantized.onnx") 105 | 106 | ids = [item.pop("id") for item in data] 107 | # Rename "id" (Weaviate reserves the "id" key for its own uuid assignment, so we can't use it) 108 | data = [{"wineID": id, **fields} for id, fields in zip(ids, data)] 109 | to_vectorize = [text.pop("to_vectorize") for text in data] 110 | sentence_embeddings = [pipeline(text.lower(), truncate=True)[0][0] for text in to_vectorize] 111 | print(f"Finished vectorizing data in the ID range {min(ids)}-{max(ids)}") 112 | try: 113 | # Use a context manager to manage batch flushing 114 | with client.batch as batch: 115 | batch.batch_size = 64 116 | batch.dynamic = True 117 | for i, item in enumerate(data): 118 | batch.add_data_object( 119 | item, 120 | CLASS_NAME, 121 | vector=sentence_embeddings[i], 122 | ) 123 | print(f"Indexed ID range {min(ids)}-{max(ids)} to db") 124 | except Exception as e: 125 | print(f"{e}: Failed to index items in the ID range {min(ids)}-{max(ids)} to db") 126 | 127 | 128 | def main(chunked_data: Iterator[tuple[JsonBlob, ...]]) -> None: 129 | settings = get_settings() 130 | CLASS_NAME = "Wine" 131 | HOST = settings.weaviate_host 132 | PORT = settings.weaviate_port 133 | client = weaviate.Client(f"http://{HOST}:{PORT}") 134 | # Add schema 135 | create_or_update_schema(client) 136 | 137 | print("Processing chunks") 138 | with ProcessPoolExecutor(max_workers=WORKERS) as executor: 139 | chunked_data = chunk_iterable(data, CHUNKSIZE) 140 | for _ in executor.map(add_vectors_to_index, chunked_data): 141 | pass 142 | 143 | 144 | if __name__ == "__main__": 145 | # fmt: off 146 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 147 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 148 | parser.add_argument("--chunksize", type=int, default=512, help="Size of each chunk to break the dataset into before processing") 149 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 150 | parser.add_argument("--workers", type=int, default=3, help="Number of workers to use for vectorization") 151 | args = vars(parser.parse_args()) 152 | # fmt: on 153 | 154 | LIMIT = args["limit"] 155 | DATA_DIR = Path(__file__).parents[3] / "data" 156 | FILENAME = args["filename"] 157 | CHUNKSIZE = args["chunksize"] 158 | WORKERS = args["workers"] 159 | ONNX_PATH = Path(__file__).parents[1] / "onnx_model" / "onnx" 160 | 161 | data = list(get_json_data(DATA_DIR, FILENAME)) 162 | 163 | if data: 164 | data = data[:LIMIT] if LIMIT > 0 else data 165 | main(data) 166 | print("Finished execution!") 167 | -------------------------------------------------------------------------------- /dbs/weaviate/scripts/bulk_index_sbert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from concurrent.futures import ProcessPoolExecutor 5 | from functools import lru_cache 6 | from pathlib import Path 7 | from typing import Any, Iterator 8 | 9 | import srsly 10 | import weaviate 11 | from dotenv import load_dotenv 12 | from sentence_transformers import SentenceTransformer 13 | from weaviate.client import Client 14 | 15 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 16 | from api.config import Settings 17 | from schemas.wine import Wine 18 | 19 | load_dotenv() 20 | # Custom types 21 | JsonBlob = dict[str, Any] 22 | 23 | 24 | class FileNotFoundError(Exception): 25 | pass 26 | 27 | 28 | @lru_cache() 29 | def get_settings(): 30 | # Use lru_cache to avoid loading .env file for every request 31 | return Settings() 32 | 33 | 34 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: 35 | """ 36 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 37 | """ 38 | for i in range(0, len(item_list), chunksize): 39 | yield tuple(item_list[i : i + chunksize]) 40 | 41 | 42 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 43 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 44 | file_path = data_dir / filename 45 | if not file_path.is_file(): 46 | # File may not have been uncompressed yet so try to do that first 47 | data = srsly.read_gzip_jsonl(file_path) 48 | # This time if it isn't there it really doesn't exist 49 | if not file_path.is_file(): 50 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 51 | else: 52 | data = srsly.read_gzip_jsonl(file_path) 53 | return data 54 | 55 | 56 | def validate( 57 | data: list[JsonBlob], 58 | exclude_none: bool = False, 59 | ) -> list[JsonBlob]: 60 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 61 | return validated_data 62 | 63 | 64 | def create_or_update_schema(client: Client) -> None: 65 | # Create a schema with no vectorizer (we will be adding our own vectors) 66 | schema = dict(srsly.read_json("settings/schema.json")) 67 | assert schema, "No schema found, please check schema definition and try again" 68 | class_names = [class_["class"] for class_ in schema["classes"]] 69 | assert class_names, "No classes found in schema, please check schema definition and try again" 70 | if not client.schema.get()["classes"]: 71 | print(f"Creating schema with classes: {', '.join(class_names)}") 72 | client.schema.create(schema) 73 | else: 74 | print(f"Existing schema found, deleting it & creating it again...") 75 | client.schema.delete_all() 76 | client.schema.create(schema) 77 | 78 | 79 | def add_vectors_to_index(data_chunk: tuple[JsonBlob, ...]) -> None: 80 | settings = get_settings() 81 | CLASS_NAME = "Wine" 82 | HOST = settings.weaviate_host 83 | PORT = settings.weaviate_port 84 | client = weaviate.Client(f"http://{HOST}:{PORT}") 85 | data = validate(data_chunk, exclude_none=True) 86 | 87 | # Load a sentence transformer model for semantic similarity from a specified checkpoint 88 | model_id = get_settings().embedding_model_checkpoint 89 | MODEL = SentenceTransformer(model_id) 90 | 91 | ids = [item.pop("id") for item in data] 92 | # Rename "id" (Weaviate reserves the "id" key for its own uuid assignment, so we can't use it) 93 | data = [{"wineID": id, **fields} for id, fields in zip(ids, data)] 94 | to_vectorize = [text.pop("to_vectorize") for text in data] 95 | sentence_embeddings = [MODEL.encode(text.lower(), batch_size=64) for text in to_vectorize] 96 | print(f"Finished vectorizing data in the ID range {min(ids)}-{max(ids)}") 97 | try: 98 | # Use a context manager to manage batch flushing 99 | with client.batch as batch: 100 | batch.dynamic = True 101 | for i, item in enumerate(data): 102 | batch.add_data_object( 103 | item, 104 | CLASS_NAME, 105 | vector=sentence_embeddings[i], 106 | ) 107 | print(f"Indexed ID range {min(ids)}-{max(ids)} to db") 108 | except Exception as e: 109 | print(f"{e}: Failed to index items in the ID range {min(ids)}-{max(ids)} to db") 110 | 111 | 112 | def main(chunked_data: Iterator[tuple[JsonBlob, ...]]) -> None: 113 | settings = get_settings() 114 | HOST = settings.weaviate_host 115 | PORT = settings.weaviate_port 116 | client = weaviate.Client(f"http://{HOST}:{PORT}") 117 | # Add schema 118 | create_or_update_schema(client) 119 | 120 | print("Processing chunks") 121 | with ProcessPoolExecutor(max_workers=4) as executor: 122 | chunked_data = chunk_iterable(data, CHUNKSIZE) 123 | for _ in executor.map(add_vectors_to_index, chunked_data): 124 | pass 125 | 126 | 127 | if __name__ == "__main__": 128 | # fmt: off 129 | parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") 130 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 131 | parser.add_argument("--chunksize", type=int, default=512, help="Size of each chunk to break the dataset into before processing") 132 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 133 | parser.add_argument("--workers", type=int, default=4, help="Number of workers to use for vectorization") 134 | args = vars(parser.parse_args()) 135 | # fmt: on 136 | 137 | LIMIT = args["limit"] 138 | DATA_DIR = Path(__file__).parents[3] / "data" 139 | FILENAME = args["filename"] 140 | CHUNKSIZE = args["chunksize"] 141 | WORKERS = args["workers"] 142 | 143 | data = list(get_json_data(DATA_DIR, FILENAME)) 144 | 145 | if data: 146 | data = data[:LIMIT] if LIMIT > 0 else data 147 | main(data) 148 | print("Finished execution!") 149 | -------------------------------------------------------------------------------- /dbs/weaviate/scripts/settings/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "classes": [ 3 | { 4 | "class": "Wine", 5 | "vectorizer": "none", 6 | "properties": [ 7 | { 8 | "name": "wineID", 9 | "dataType": ["int"] 10 | }, 11 | { 12 | "name": "points", 13 | "dataType": ["int"] 14 | }, 15 | { 16 | "name": "variety", 17 | "dataType": ["text"] 18 | }, 19 | { 20 | "name": "title", 21 | "dataType": ["text"] 22 | }, 23 | { 24 | "name": "description", 25 | "dataType": ["text"] 26 | }, 27 | { 28 | "name": "price", 29 | "dataType": ["number"] 30 | }, 31 | { 32 | "name": "country", 33 | "dataType": ["text"] 34 | }, 35 | { 36 | "name": "province", 37 | "dataType": ["text"] 38 | }, 39 | { 40 | "name": "taster_name", 41 | "dataType": ["text"] 42 | }, 43 | { 44 | "name": "taster_twitter_handle", 45 | "dataType": ["text"] 46 | } 47 | ] 48 | } 49 | ] 50 | } --------------------------------------------------------------------------------