├── .env.example ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── data ├── README.md └── convert.py ├── docker-compose.yml ├── requirements.txt └── src ├── __init__.py ├── api ├── __init__.py ├── main.py ├── routers │ └── query.py └── test_main.py ├── config └── settings.py ├── ingest ├── README.md ├── bulk_ingest_async.py └── bulk_ingest_sync.py ├── schemas ├── __init__.py ├── response.py └── wine.py └── tests ├── test_crud.py └── test_schemas.py /.env.example: -------------------------------------------------------------------------------- 1 | NEO4J_USER = "neo4j" 2 | NEO4J_PASSWORD = "" 3 | NEO4J_URL = "localhost" 4 | NEO4J_SERVICE = "neo4j" 5 | 6 | TAG = 0.1.0 7 | API_PORT = 8000 8 | 9 | NEO4J_VERSION = "5.7.0" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # Mac 131 | .DS_Store 132 | 133 | # data 134 | data/*.json 135 | data/*.jsonl 136 | data/*.jsonl.gz -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM python:3.11-slim-bullseye 3 | 4 | WORKDIR /wine 5 | 6 | COPY ./requirements.txt /wine/requirements.txt 7 | 8 | RUN pip install --no-cache-dir --upgrade -r /wine/requirements.txt 9 | 10 | COPY ./src/api /wine/api 11 | COPY ./src/config /wine/config 12 | COPY ./src/schemas /wine/schemas 13 | 14 | EXPOSE 8000 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Prashanth Rao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neo4j for Pythonistas 2 | 3 | This repo contains code for the methods described in this series of blog posts: 4 | 5 | 1. [Neo4j for Pythonistas: Part 1](https://thedataquarry.com/posts/neo4j-python-1/) 6 | * Using Pydantic and async Python to build a graph in Neo4j 7 | 2. [Neo4j for Pythonistas: Part 2](https://thedataquarry.com/posts/neo4j-python-2/) 8 | * Build a RESTful API on top of a Neo4j graph 9 | 10 | ## Goals 11 | 12 | The aim of this code is to build a Neo4j graph via its [officially maintained Python client](https://github.com/neo4j/neo4j-python-driver), using either the sync or async database drivers. The async driver offers support for using Python's `asyncio` coroutine-based asynchronous, concurrent workflows, which can be beneficial in certain scenarios. Both sync and async code is provided in `src/ingest` as a starter template to bulk-ingest large amounts of data into Neo4j in batches, so as to be as efficient as possible. Code will be added in three parts: 13 | 14 | 15 | 1. Bulk data ingestion into Neo4j using Pydantic and async Python 16 | 2. Building a RESTful API on top of the Neo4j graph via [FastAPI](https://fastapi.tiangolo.com/) 17 | 3. Building a GraphQL API on top of the Neo4j graph via FastAPI and [Strawberry](https://strawberry.rocks/) 18 | 19 | There are lots of clever ways one can write an API on top of Neo4j, but the main focus of this repo is to keep code readable, and the logic simple and easy enough to extend for future use cases as they arise. 20 | 21 | 22 | ## Requirements 23 | 24 | ### Install Python dependencies 25 | 26 | Install Python dependencies in a virtual environment using `requirements.txt` as follows. All code in this repo has been tested on Python 3.11. 27 | 28 | ``` 29 | # Setup the environment for the first time 30 | python -m venv neo4j_venv 31 | 32 | # Activate the environment (for subsequent runs) 33 | source neoj_venv/bin/activate 34 | 35 | python -m pip install -r requirements.txt 36 | ``` 37 | 38 | 39 | ### Install and run Docker 40 | 41 | * [Download Docker](https://docs.docker.com/get-docker/) and run the Docker daemon 42 | * Use the provided `docker-compose.yml` to set up and run the database in a container 43 | * This ensures reproducibility and ease of setup, regardless of the platform used. 44 | * Copy the file `.env.example` and rename it to `.env`. 45 | * Fill in the `NEO4J_PASSWORD` field in `.env` to a non-null value -- this will be the password used to log into the Neo4j database running on `localhost`. 46 | 47 | To start the database service, run Docker in detached mode via the compose file. 48 | 49 | ```sh 50 | docker compose up -d 51 | ``` 52 | 53 | This command starts a persistent-volume Neo4j database so that any data that's ingested persists on the local system even after Docker is shut down. 54 | 55 | Tear down the database process and containers at any time using the following command: 56 | 57 | ``` 58 | docker compose down 59 | ``` 60 | 61 | ## Dataset 62 | 63 | The [wine reviews dataset](./data/) provided in this repo is a newline-delimited JSON-formatted version of the version obtained from Kaggle datasets. 64 | 65 | 66 | ## Run tests 67 | 68 | Once the data is ingested into Neo4j, the APIs and schemas can be tested via `pytest` to ensure that endpoints behave as expected. 69 | 70 | > 💡 **Note:** Run the tests **inside the Docker container** as FastAPI communicates with the Neo4j service via its own network inside the container. 71 | 72 | To enter the Docker container, in the following example, the name of the running container obtained via `docker ps` is `neo4j-python-fastapi-fastapi-1`. 73 | 74 | ``` 75 | docker exec -it neo4j-python-fastapi-fastapi-1 bash 76 | pytest -v 77 | ``` 78 | 79 | The first line runs an interactive bash shell inside the container, and the second runs the tests in verbose mode. Assuming that the data has been ingested into the database, the tests should pass and return something like this. 80 | 81 | ``` 82 | ======================== test session starts ======================== 83 | platform linux -- Python 3.11.3, pytest-7.3.1, pluggy-1.0.0 -- /usr/local/bin/python 84 | cachedir: .pytest_cache 85 | rootdir: /wine 86 | plugins: asyncio-0.21.0, anyio-3.6.2 87 | asyncio: mode=Mode.STRICT 88 | collected 7 items 89 | 90 | src/api/test_main.py::test_search PASSED [ 14%] 91 | src/api/test_main.py::test_top_by_country PASSED [ 28%] 92 | src/api/test_main.py::test_top_by_province PASSED [ 42%] 93 | src/api/test_main.py::test_most_by_variety PASSED [ 57%] 94 | src/tests/test_crud.py::test_sync_transactions PASSED [ 71%] 95 | src/tests/test_crud.py::test_async_transactions PASSED [ 85%] 96 | src/tests/test_schemas.py::test_wine_schema PASSED [100%] 97 | 98 | ========================= 7 passed in 0.45s ========================= 99 | ``` -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Dataset: 130k wine reviews 2 | 3 | The dataset used is that of [130k wine reviews](https://www.kaggle.com/datasets/zynicide/wine-reviews) from the [WineEnthusiast Magazine](https://www.winemag.com/?s=&drink_type=wine) in 2017, obtained via Kaggle. Full credit is due to [the original author](https://www.kaggle.com/zynicide) and Kaggle for curating/hosting this dataset. 4 | 5 | For ease of use, the original data in JSON is re-formatted to newline-delimited JSON (`.jsonl`) format and compressed as a GZIP archive. There is no need to run the `convert.py` file to use the data in the `jsonl.gz` file; the code is provided here purely for reference. 6 | 7 | A sample wine review JSON object is shown below. 8 | 9 | ```json 10 | { 11 | "points": "90", 12 | "title": "Castello San Donato in Perano 2009 Riserva (Chianti Classico)", 13 | "description": "Made from a blend of 85% Sangiovese and 15% Merlot, this ripe wine delivers soft plum, black currants, clove and cracked pepper sensations accented with coffee and espresso notes. A backbone of firm tannins give structure. Drink now through 2019.", 14 | "taster_name": "Kerin O'Keefe", 15 | "taster_twitter_handle": "@kerinokeefe", 16 | "price": 30, 17 | "designation": "Riserva", 18 | "variety": "Red Blend", 19 | "region_1": "Chianti Classico", 20 | "region_2": null, 21 | "province": "Tuscany", 22 | "country": "Italy", 23 | "winery": "Castello San Donato in Perano", 24 | "id": 40825 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /data/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run `pip install srsly` to use this script 3 | 4 | This script converts the JSON data file from https://www.kaggle.com/datasets/zynicide/wine-reviews 5 | to a .gzip line-delimited (.jsonl) file for use downstream with the databases in question. 6 | 7 | Full credit to the original author, @zynicide, on Kaggle, for the data. 8 | """ 9 | from pathlib import Path 10 | from typing import Any 11 | 12 | import srsly 13 | 14 | JsonBlob = dict[str, Any] 15 | 16 | 17 | def convert_to_jsonl(filename: str) -> None: 18 | data = srsly.read_json(filename) 19 | # Add an `id` field to the start of each dict item so we have a primary key for indexing 20 | new_data = [{"id": idx, **item} for idx, item in enumerate(data, 1)] 21 | srsly.write_gzip_jsonl(f"{Path(filename).stem}.jsonl.gz", new_data) 22 | 23 | 24 | if __name__ == "__main__": 25 | # Download the JSON data file from https://www.kaggle.com/datasets/zynicide/wine-reviews' 26 | convert_to_jsonl("winemag-data-130k-v2.json") 27 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | 3 | services: 4 | neo4j: 5 | image: neo4j:${NEO4J_VERSION} 6 | restart: unless-stopped 7 | environment: 8 | - NEO4J_AUTH=neo4j/${NEO4J_PASSWORD} 9 | - NEO4J_PLUGINS=["graph-data-science", "apoc"] 10 | # DB and server 11 | - NEO4J_server_memory_pagecache_size=1G 12 | - NEO4J_server_memory_heap_initial__size=1G 13 | - NEO4J_server_memory_heap_max__size=2G 14 | - NEO4J_dbms_security_procedures_unrestricted=gds.*,apoc.* 15 | ports: 16 | - 7687:7687 17 | volumes: 18 | - logs:/logs 19 | - data:/data 20 | - plugins:/plugins 21 | - import:/import 22 | networks: 23 | - wine 24 | 25 | fastapi: 26 | image: neo4j_wine_fastapi:${TAG} 27 | build: . 28 | restart: unless-stopped 29 | env_file: 30 | - .env 31 | ports: 32 | - ${API_PORT}:8000 33 | depends_on: 34 | - neo4j 35 | volumes: 36 | - ./:/wine 37 | networks: 38 | - wine 39 | command: uvicorn src.api.main:app --host 0.0.0.0 --port 8000 --reload 40 | 41 | 42 | volumes: 43 | logs: 44 | data: 45 | plugins: 46 | import: 47 | 48 | networks: 49 | wine: 50 | driver: bridge -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | codetiming>=1.4.0 2 | neo4j~=5.8.0 3 | pydantic[dotenv]~=2.0.0 4 | pydantic-settings~=2.0.0 5 | srsly>=2.4.6 6 | uvloop>=0.17.0 7 | fastapi~=0.100.0 8 | httpx>=0.24.0 9 | uvicorn>=0.22.0 10 | pytest>=7.3.1 11 | pytest-asyncio>=0.21.0 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/neo4j-python-fastapi/36a18f76385f227c854e5db9ee6180ae8883497f/src/__init__.py -------------------------------------------------------------------------------- /src/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/neo4j-python-fastapi/36a18f76385f227c854e5db9ee6180ae8883497f/src/api/__init__.py -------------------------------------------------------------------------------- /src/api/main.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from contextlib import asynccontextmanager 3 | from functools import lru_cache 4 | 5 | from fastapi import FastAPI 6 | from neo4j import AsyncGraphDatabase 7 | 8 | from src.config import settings 9 | from src.api.routers import query 10 | 11 | 12 | @lru_cache() 13 | def get_settings() -> settings.Settings: 14 | # Use lru_cache to avoid loading .env file for every request 15 | config = settings.Settings() 16 | return config 17 | 18 | 19 | @asynccontextmanager 20 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 21 | """Async context manager for MongoDB connection.""" 22 | settings = get_settings() 23 | URI = f"bolt://{settings.neo4j_service}:7687" 24 | AUTH = (settings.neo4j_user, settings.neo4j_password) 25 | async with AsyncGraphDatabase.driver(URI, auth=AUTH) as driver: 26 | async with driver.session(database="neo4j") as session: 27 | app.session = session 28 | print("Successfully connected to wine reviews Neo4j DB") 29 | yield 30 | print("Successfully closed wine reviews Neo4j connection") 31 | 32 | 33 | app = FastAPI( 34 | title="REST API for wine reviews on Neo4j", 35 | description=( 36 | "Query from a Neo4j database of 130k wine reviews from the Wine Enthusiast magazine" 37 | ), 38 | version=get_settings().tag, 39 | lifespan=lifespan, 40 | ) 41 | 42 | 43 | @app.get("/", include_in_schema=False) 44 | async def root(): 45 | return { 46 | "message": "REST API for querying Neo4j database of 130k wine reviews from the Wine Enthusiast magazine" 47 | } 48 | 49 | 50 | # Attach routes 51 | app.include_router(query.router, prefix="/v1/query", tags=["query"]) 52 | -------------------------------------------------------------------------------- /src/api/routers/query.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Query, Request 2 | from neo4j import AsyncManagedTransaction 3 | from src.schemas.response import ( 4 | FullTextSearch, 5 | MostWinesByVariety, 6 | TopWinesByCountry, 7 | TopWinesByProvince, 8 | ) 9 | 10 | router = APIRouter() 11 | 12 | # --- Routes --- 13 | 14 | @router.get( 15 | "/search", 16 | response_model=list[FullTextSearch], 17 | response_description="Search wines by title and description", 18 | ) 19 | async def search_by_keywords( 20 | request: Request, 21 | terms: str = Query(description="Search wine by keywords in title or description"), 22 | max_price: float = Query( 23 | default=100.0, description="Specify the maximum price for the wine (e.g., 30)" 24 | ), 25 | ) -> list[FullTextSearch] | None: 26 | session = request.app.session 27 | result = await session.execute_read(_search_by_keywords, terms, max_price) 28 | if not result: 29 | raise HTTPException( 30 | status_code=404, 31 | detail=f"No wine with the provided terms '{terms}' found in database - please try again", 32 | ) 33 | return result 34 | 35 | 36 | @router.get( 37 | "/top_by_country", 38 | response_model=list[TopWinesByCountry], 39 | response_description="Get top-rated wines by country", 40 | ) 41 | async def top_by_country( 42 | request: Request, 43 | country: str = Query( 44 | description="Get top-rated wines by country name specified (must be exact name)" 45 | ), 46 | ) -> list[TopWinesByCountry] | None: 47 | session = request.app.session 48 | result = await session.execute_read(_top_by_country, country) 49 | if not result: 50 | raise HTTPException( 51 | status_code=404, 52 | detail=f"No wine from the provided country '{country}' found in database - please enter exact country name", 53 | ) 54 | return result 55 | 56 | 57 | @router.get( 58 | "/top_by_province", 59 | response_model=list[TopWinesByProvince], 60 | response_description="Get top-rated wines by province", 61 | ) 62 | async def top_by_province( 63 | request: Request, 64 | province: str = Query( 65 | description="Get top-rated wines by province name specified (must be exact name)" 66 | ), 67 | ) -> list[TopWinesByProvince] | None: 68 | session = request.app.session 69 | result = await session.execute_read(_top_by_province, province) 70 | if not result: 71 | raise HTTPException( 72 | status_code=404, 73 | detail=f"No wine from the provided province '{province}' found in database - please enter exact province name", 74 | ) 75 | return result 76 | 77 | 78 | @router.get( 79 | "/most_by_variety", 80 | response_model=list[MostWinesByVariety], 81 | response_description="Get the countries with the most wines above a points-rating of a specified variety (blended or otherwise)", 82 | ) 83 | async def most_by_variety( 84 | request: Request, 85 | variety: str = Query( 86 | description="Specify the variety of wine to search for (e.g., 'Pinot Noir' or 'Red Blend')" 87 | ), 88 | points: int = Query( 89 | default=85, 90 | description="Specify the minimum points-rating for the wine (e.g., 85)", 91 | ), 92 | ) -> list[MostWinesByVariety] | None: 93 | session = request.app.session 94 | result = await session.execute_read(_most_by_variety, variety, points) 95 | if not result: 96 | raise HTTPException( 97 | status_code=404, 98 | detail=f"No wine of the specified variety '{variety}' found in database - please try a different variety", 99 | ) 100 | return result 101 | 102 | 103 | # --- Neo4j query funcs --- 104 | 105 | 106 | async def _search_by_keywords( 107 | tx: AsyncManagedTransaction, 108 | terms: str, 109 | price: float, 110 | ) -> list[FullTextSearch] | None: 111 | query = """ 112 | CALL db.index.fulltext.queryNodes("searchText", $terms) YIELD node AS wine, score 113 | WITH DISTINCT wine, score 114 | MATCH (wine)-[:IS_FROM_COUNTRY]->(c:Country) 115 | WHERE wine.price <= $price 116 | RETURN 117 | c.countryName AS country, 118 | wine.wineID AS wineID, 119 | wine.points AS points, 120 | wine.title AS title, 121 | wine.description AS description, 122 | coalesce(wine.price, "Not available") AS price, 123 | wine.variety AS variety, 124 | wine.winery AS winery 125 | ORDER BY score DESC, points DESC LIMIT 5 126 | """ 127 | response = await tx.run(query, terms=terms, price=price) 128 | result = await response.data() 129 | if result: 130 | return [FullTextSearch(**r) for r in result] 131 | return None 132 | 133 | 134 | async def _top_by_country( 135 | tx: AsyncManagedTransaction, 136 | country: str, 137 | ) -> list[TopWinesByCountry] | None: 138 | query = """ 139 | MATCH (wine:Wine)-[:IS_FROM_COUNTRY]->(c:Country) 140 | WHERE tolower(c.countryName) = tolower($country) 141 | RETURN 142 | wine.wineID AS wineID, 143 | wine.points AS points, 144 | wine.title AS title, 145 | wine.description AS description, 146 | c.countryName AS country, 147 | coalesce(wine.price, "Not available") AS price, 148 | wine.variety AS variety, 149 | wine.winery AS winery 150 | ORDER BY points DESC LIMIT 5 151 | """ 152 | response = await tx.run(query, country=country) 153 | result = await response.data() 154 | if result: 155 | return [TopWinesByCountry(**r) for r in result] 156 | return None 157 | 158 | 159 | async def _top_by_province( 160 | tx: AsyncManagedTransaction, 161 | province: str, 162 | ) -> list[TopWinesByProvince] | None: 163 | query = """ 164 | MATCH (wine:Wine)-[:IS_FROM_PROVINCE]->(p:Province)-[:IS_LOCATED_IN]->(c:Country) 165 | WHERE tolower(p.provinceName) = tolower($province) 166 | RETURN 167 | wine.wineID AS wineID, 168 | wine.points AS points, 169 | wine.title AS title, 170 | wine.description AS description, 171 | c.countryName AS country, 172 | p.provinceName AS province, 173 | coalesce(wine.price, "Not available") AS price, 174 | wine.variety AS variety, 175 | wine.winery AS winery 176 | ORDER BY points DESC LIMIT 5 177 | """ 178 | response = await tx.run(query, province=province) 179 | result = await response.data() 180 | if result: 181 | return [TopWinesByProvince(**r) for r in result] 182 | return None 183 | 184 | 185 | async def _most_by_variety( 186 | tx: AsyncManagedTransaction, 187 | variety: str, 188 | points: int, 189 | ) -> list[MostWinesByVariety] | None: 190 | query = """ 191 | CALL db.index.fulltext.queryNodes("searchText", $variety) YIELD node AS wine, score 192 | WITH wine 193 | MATCH (wine)-[:IS_FROM_COUNTRY]->(c:Country) 194 | WHERE wine.points >= $points 195 | RETURN 196 | c.countryName AS country, 197 | count(wine) as wineCount 198 | ORDER BY wineCount DESC LIMIT 5 199 | """ 200 | response = await tx.run(query, variety=variety, points=points) 201 | result = await response.data() 202 | if result: 203 | return [MostWinesByVariety(**r) for r in result] 204 | return None -------------------------------------------------------------------------------- /src/api/test_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from src.api.main import app 5 | from fastapi.testclient import TestClient 6 | 7 | 8 | client = TestClient(app) 9 | 10 | 11 | def test_search(): 12 | with TestClient(app) as client: 13 | response = client.get("/v1/rest/search?terms=chardonnay&max_price=100") 14 | assert response.status_code == 200 15 | sample = response.json()[0] 16 | assert isinstance(sample["wineID"], int) 17 | assert isinstance(sample["price"], float) 18 | assert isinstance(sample["points"], int) 19 | assert "country" in sample 20 | assert sample["price"] <= 100.0 21 | 22 | 23 | def test_top_by_country(): 24 | with TestClient(app) as client: 25 | response = client.get("/v1/rest/top_by_country?country=new%20zealand") 26 | assert response.status_code == 200 27 | first_sample = response.json()[0] 28 | last_sample = response.json()[-1] 29 | assert isinstance(first_sample["wineID"], int) 30 | assert isinstance(first_sample["price"], float) 31 | assert isinstance(first_sample["points"], int) 32 | assert first_sample["country"] == "New Zealand" 33 | # Test sorting 34 | assert first_sample["points"] >= last_sample["points"] 35 | 36 | 37 | def test_top_by_province(): 38 | with TestClient(app) as client: 39 | response = client.get("/v1/rest/top_by_province?province=oregon") 40 | assert response.status_code == 200 41 | first_sample = response.json()[0] 42 | last_sample = response.json()[-1] 43 | assert isinstance(first_sample["wineID"], int) 44 | assert isinstance(first_sample["price"], float) 45 | assert isinstance(first_sample["points"], int) 46 | assert first_sample["province"] == "Oregon" 47 | assert first_sample["country"] == "US" 48 | # Test sorting 49 | assert first_sample["points"] >= last_sample["points"] 50 | 51 | 52 | def test_most_by_variety(): 53 | with TestClient(app) as client: 54 | response = client.get("/v1/rest/most_by_variety?variety=pinot%20noir&points=85") 55 | assert response.status_code == 200 56 | assert len(response.json()) > 0 57 | first_sample = response.json()[0] 58 | last_sample = response.json()[-1] 59 | # Test sorting 60 | assert first_sample["wineCount"] >= last_sample["wineCount"] 61 | -------------------------------------------------------------------------------- /src/config/settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict( 6 | env_file=".env", 7 | extra="allow", 8 | case_sensitive=False, 9 | ) 10 | 11 | neo4j_url: str 12 | neo4j_service: str 13 | neo4j_user: str 14 | neo4j_password: str 15 | tag: str 16 | -------------------------------------------------------------------------------- /src/ingest/README.md: -------------------------------------------------------------------------------- 1 | # Data ingestion 2 | 3 | This section shows how to effectively bulk-ingest large amounts of data into Neo4j using Python. 4 | 5 | ## Best practices 6 | 7 | ### Data validation via Pydantic 8 | 9 | Because Neo4j doesn't enforce data types prior to runtime, data validation is done via [Pydantic](https://docs.pydantic.dev/latest/). It is highly recommended to perform validation this way prior to ingesting the data into Neo4j so that queries perform as expected, and there are no errors or unexpected issues in production. 10 | 11 | - Number types (like integers and floats) are coerced prior to ingestion 12 | - Fields are renamed when required (`designation` is renamed to `vineyard` for clarity) 13 | - Default values are set using validators (If `country` field doesn't exist, a default value of "Unknown" is assigned) 14 | 15 | 16 | ### Create indexes and constraints 17 | 18 | To efficiently ingest large amounts of data into Neo4j, as the graph keeps getting larger and larger, it helps greatly to set up indexes and constraints beforehand. Typically, constraints are set on items that we expect to be unique, such as an ID, or a country name. 19 | 20 | 21 | ### Batch transactions with `UNWIND` 22 | 23 | Each transaction with the database is expensive due to network overhead, so, to speed up performance, batched transactions are performed. 24 | 25 | - The validated data is divided up into batches, or "chunks" 26 | - Each chunk is roughly 10k-20k records (where each "record" is a type-validated dict, coming from Pydantic) 27 | - The chunks (lists of dicts) can be fed into Neo4j via the sync or async driver, and easily `UNWIND`ed in Cypher 28 | - `UNWIND`, as the name suggests, expands a list of dicts into rows, each of which provides all the information that Neo4j needs to build the nodes and edges in the graph 29 | 30 | 💡 **Tip**: Submitting a large list of dicts, 10k-20k in length, to `UNWIND` in Cypher, can improve performance massively compared to submitting each record one at a time (due to bolt network and Python overheads) 31 | 32 | A more detailed blog post on these best practices will be published soon! 33 | -------------------------------------------------------------------------------- /src/ingest/bulk_ingest_async.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import os 4 | import sys 5 | from functools import lru_cache 6 | from pathlib import Path 7 | from typing import Any, Iterator 8 | 9 | import srsly 10 | from codetiming import Timer 11 | from dotenv import load_dotenv 12 | from neo4j import AsyncGraphDatabase, AsyncManagedTransaction, AsyncSession 13 | 14 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 15 | from config.settings import Settings 16 | from schemas.wine import Wine 17 | 18 | 19 | # Custom types 20 | JsonBlob = dict[str, Any] 21 | 22 | 23 | class FileNotFoundError(Exception): 24 | pass 25 | 26 | 27 | # --- Blocking functions --- 28 | 29 | @lru_cache() 30 | def get_settings(): 31 | load_dotenv() 32 | # Use lru_cache to avoid loading .env file for every request 33 | return Settings() 34 | 35 | 36 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[list[JsonBlob]]: 37 | """ 38 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 39 | """ 40 | for i in range(0, len(item_list), chunksize): 41 | yield item_list[i : i + chunksize] 42 | 43 | 44 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 45 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 46 | file_path = data_dir / filename 47 | if not file_path.is_file(): 48 | # File may not have been uncompressed yet so try to do that first 49 | data = srsly.read_gzip_jsonl(file_path) 50 | # Thi.dicts time if it isn't there it really doesn't exist 51 | if not file_path.is_file(): 52 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 53 | else: 54 | data = srsly.read_gzip_jsonl(file_path) 55 | return data 56 | 57 | 58 | @Timer(name="pydantic validator") 59 | def validate( 60 | data: list[JsonBlob], 61 | exclude_none: bool = False, 62 | ) -> list[JsonBlob]: 63 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 64 | return validated_data 65 | 66 | 67 | # --- Async functions --- 68 | 69 | 70 | async def create_indexes_and_constraints(session: AsyncSession) -> None: 71 | queries = [ 72 | # constraints 73 | "CREATE CONSTRAINT countryName IF NOT EXISTS FOR (c:Country) REQUIRE c.countryName IS UNIQUE ", 74 | "CREATE CONSTRAINT wineID IF NOT EXISTS FOR (w:Wine) REQUIRE w.wineID IS UNIQUE ", 75 | # indexes 76 | "CREATE INDEX provinceName IF NOT EXISTS FOR (p:Province) ON (p.provinceName) ", 77 | "CREATE INDEX tasterName IF NOT EXISTS FOR (p:Person) ON (p.tasterName) ", 78 | "CREATE FULLTEXT INDEX searchText IF NOT EXISTS FOR (w:Wine) ON EACH [w.title, w.description, w.variety] ", 79 | ] 80 | for query in queries: 81 | await session.run(query) 82 | 83 | 84 | async def build_query(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: 85 | query = """ 86 | UNWIND $data AS record 87 | MERGE (wine:Wine {wineID: record.id}) 88 | SET wine += { 89 | points: record.points, 90 | title: record.title, 91 | description: record.description, 92 | price: record.price, 93 | variety: record.variety, 94 | winery: record.winery, 95 | vineyard: record.vineyard, 96 | region_1: record.region_1, 97 | region_2: record.region_2 98 | } 99 | WITH record, wine 100 | WHERE record.taster_name IS NOT NULL 101 | MERGE (taster:Person {tasterName: record.taster_name}) 102 | SET taster += {tasterTwitterHandle: record.taster_twitter_handle} 103 | MERGE (wine)-[:TASTED_BY]->(taster) 104 | WITH record, wine 105 | MERGE (country:Country {countryName: record.country}) 106 | MERGE (wine)-[:IS_FROM_COUNTRY]->(country) 107 | WITH record, wine, country 108 | WHERE record.province IS NOT NULL 109 | MERGE (province:Province {provinceName: record.province}) 110 | MERGE (wine)-[:IS_FROM_PROVINCE]->(province) 111 | WITH record, wine, country, province 112 | WHERE record.province IS NOT NULL AND record.country IS NOT NULL 113 | MERGE (province)-[:IS_LOCATED_IN]->(country) 114 | """ 115 | await tx.run(query, data=data) 116 | 117 | 118 | async def main(data: list[JsonBlob]) -> None: 119 | async with AsyncGraphDatabase.driver(URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) as driver: 120 | async with driver.session(database="neo4j") as session: 121 | # Create indexes and constraints 122 | await create_indexes_and_constraints(session) 123 | # Ingest the data into Neo4j 124 | print("Validating data...") 125 | validated_data = validate(data, exclude_none=True) 126 | # Break the data into chunks 127 | chunked_data = chunk_iterable(validated_data, CHUNKSIZE) 128 | print("Ingesting data...") 129 | with Timer(name="ingest"): 130 | for chunk in chunked_data: 131 | # Awaiting each chunk in a loop isn't ideal, but it's easiest this way when working with graphs! 132 | # Merging edges on top of nodes concurrently can lead to race conditions. Neo4j doesn't allow this, 133 | # and prevents the user from merging relationships on nodes that might not exist yet, for good reason. 134 | ids = [item["id"] for item in chunk] 135 | try: 136 | await session.execute_write(build_query, chunk) 137 | print(f"Processed ids in range {min(ids)}-{max(ids)}") 138 | except Exception as e: 139 | print(f"{e}: Failed to ingest IDs in range {min(ids)}-{max(ids)}") 140 | 141 | 142 | if __name__ == "__main__": 143 | # fmt: off 144 | parser = argparse.ArgumentParser("Build a graph from the wine reviews JSONL data") 145 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 146 | parser.add_argument("--chunksize", type=int, default=10_000, help="Size of each chunk to break the dataset into before processing") 147 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 148 | args = vars(parser.parse_args()) 149 | # fmt: on 150 | 151 | LIMIT = args["limit"] 152 | DATA_DIR = Path(__file__).parents[2] / "data" 153 | FILENAME = args["filename"] 154 | CHUNKSIZE = args["chunksize"] 155 | 156 | # # Neo4j 157 | settings = get_settings() 158 | URI = f"bolt://{settings.neo4j_url}:7687" 159 | NEO4J_USER = settings.neo4j_user 160 | NEO4J_PASSWORD = settings.neo4j_password 161 | 162 | data = list(get_json_data(DATA_DIR, FILENAME)) 163 | if LIMIT > 0: 164 | data = data[:LIMIT] 165 | 166 | if data: 167 | # Neo4j async uses uvloop under the hood, so we can gain marginal performance improvement by using it too 168 | import uvloop 169 | 170 | uvloop.install() 171 | asyncio.run(main(data)) -------------------------------------------------------------------------------- /src/ingest/bulk_ingest_sync.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from functools import lru_cache 5 | from pathlib import Path 6 | from typing import Any, Iterator 7 | 8 | import srsly 9 | from codetiming import Timer 10 | from dotenv import load_dotenv 11 | from neo4j import GraphDatabase, ManagedTransaction, Session 12 | 13 | sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) 14 | from config.settings import Settings 15 | from schemas.wine import Wine 16 | 17 | 18 | # Custom types 19 | JsonBlob = dict[str, Any] 20 | 21 | 22 | class FileNotFoundError(Exception): 23 | pass 24 | 25 | 26 | @lru_cache() 27 | def get_settings(): 28 | load_dotenv() 29 | # Use lru_cache to avoid loading .env file for every request 30 | return Settings() 31 | 32 | 33 | def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[list[JsonBlob]]: 34 | """ 35 | Break a large iterable into an iterable of smaller iterables of size `chunksize` 36 | """ 37 | for i in range(0, len(item_list), chunksize): 38 | yield item_list[i : i + chunksize] 39 | 40 | 41 | def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: 42 | """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" 43 | file_path = data_dir / filename 44 | if not file_path.is_file(): 45 | # File may not have been uncompressed yet so try to do that first 46 | data = srsly.read_gzip_jsonl(file_path) 47 | # This time if it isn't there it really doesn't exist 48 | if not file_path.is_file(): 49 | raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") 50 | else: 51 | data = srsly.read_gzip_jsonl(file_path) 52 | return data 53 | 54 | 55 | @Timer(name="pydantic validator") 56 | def validate( 57 | data: list[JsonBlob], 58 | exclude_none: bool = False, 59 | ) -> list[JsonBlob]: 60 | validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data] 61 | return validated_data 62 | 63 | 64 | def create_indexes_and_constraints(session: Session) -> None: 65 | queries = [ 66 | # constraints 67 | "CREATE CONSTRAINT countryName IF NOT EXISTS FOR (c:Country) REQUIRE c.countryName IS UNIQUE ", 68 | "CREATE CONSTRAINT wineID IF NOT EXISTS FOR (w:Wine) REQUIRE w.wineID IS UNIQUE ", 69 | # indexes 70 | "CREATE INDEX provinceName IF NOT EXISTS FOR (p:Province) ON (p.provinceName) ", 71 | "CREATE INDEX tasterName IF NOT EXISTS FOR (p:Person) ON (p.tasterName) ", 72 | "CREATE FULLTEXT INDEX searchText IF NOT EXISTS FOR (w:Wine) ON EACH [w.title, w.description, w.variety] ", 73 | ] 74 | for query in queries: 75 | session.run(query) 76 | 77 | 78 | def build_query(tx: ManagedTransaction, data: list[JsonBlob]) -> None: 79 | query = """ 80 | UNWIND $data AS record 81 | MERGE (wine:Wine {wineID: record.id}) 82 | SET wine += { 83 | points: record.points, 84 | title: record.title, 85 | description: record.description, 86 | price: record.price, 87 | variety: record.variety, 88 | winery: record.winery, 89 | vineyard: record.vineyard, 90 | region_1: record.region_1, 91 | region_2: record.region_2 92 | } 93 | WITH record, wine 94 | WHERE record.taster_name IS NOT NULL 95 | MERGE (taster:Person {tasterName: record.taster_name}) 96 | SET taster += {tasterTwitterHandle: record.taster_twitter_handle} 97 | MERGE (wine)-[:TASTED_BY]->(taster) 98 | WITH record, wine 99 | MERGE (country:Country {countryName: record.country}) 100 | MERGE (wine)-[:IS_FROM_COUNTRY]->(country) 101 | WITH record, wine, country 102 | WHERE record.province IS NOT NULL 103 | MERGE (province:Province {provinceName: record.province}) 104 | MERGE (wine)-[:IS_FROM_PROVINCE]->(province) 105 | WITH record, wine, country, province 106 | WHERE record.province IS NOT NULL AND record.country IS NOT NULL 107 | MERGE (province)-[:IS_LOCATED_IN]->(country) 108 | """ 109 | tx.run(query, data=data) 110 | 111 | 112 | def ingest_data(session: Session, validated_data: list[JsonBlob]) -> None: 113 | for data in validated_data: 114 | ids = [item["id"] for item in data] 115 | try: 116 | session.execute_write(build_query, data) 117 | print(f"Processed ids in range {min(ids)}-{max(ids)}") 118 | except Exception as e: 119 | print(f"{e}: Failed to ingest IDs in range {min(ids)}-{max(ids)}") 120 | 121 | 122 | def main(data: list[JsonBlob]) -> None: 123 | with GraphDatabase.driver(URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) as driver: 124 | with driver.session(database="neo4j") as session: 125 | # Create indexes and constraints 126 | create_indexes_and_constraints(session) 127 | # Ingest the data into Neo4j 128 | print("Validating data...") 129 | validated_data = validate(data, exclude_none=True) 130 | # Break the data into chunks 131 | chunked_data = chunk_iterable(validated_data, CHUNKSIZE) 132 | print("Ingesting data...") 133 | with Timer(name="ingest"): 134 | for chunk in chunked_data: 135 | ids = [item["id"] for item in chunk] 136 | try: 137 | session.execute_write(build_query, chunk) 138 | print(f"Processed ids in range {min(ids)}-{max(ids)}") 139 | except Exception as e: 140 | print(f"{e}: Failed to ingest IDs in range {min(ids)}-{max(ids)}") 141 | 142 | 143 | if __name__ == "__main__": 144 | # fmt: off 145 | parser = argparse.ArgumentParser("Build a graph from the wine reviews JSONL data") 146 | parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") 147 | parser.add_argument("--chunksize", type=int, default=10_000, help="Size of each chunk to break the dataset into before processing") 148 | parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") 149 | args = vars(parser.parse_args()) 150 | # fmt: on 151 | 152 | LIMIT = args["limit"] 153 | DATA_DIR = Path(__file__).parents[2] / "data" 154 | FILENAME = args["filename"] 155 | CHUNKSIZE = args["chunksize"] 156 | 157 | # # Neo4j 158 | settings = get_settings() 159 | URI = f"bolt://{settings.neo4j_url}:7687" 160 | NEO4J_USER = settings.neo4j_user 161 | NEO4J_PASSWORD = settings.neo4j_password 162 | 163 | data = list(get_json_data(DATA_DIR, FILENAME)) 164 | if LIMIT > 0: 165 | data = data[:LIMIT] 166 | 167 | if data: 168 | main(data) 169 | print("Finished execution!") -------------------------------------------------------------------------------- /src/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prrao87/neo4j-python-fastapi/36a18f76385f227c854e5db9ee6180ae8883497f/src/schemas/__init__.py -------------------------------------------------------------------------------- /src/schemas/response.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class FullTextSearch(BaseModel): 5 | wineID: int 6 | country: str 7 | title: str 8 | description: str | None 9 | points: int 10 | price: float 11 | variety: str | None 12 | winery: str | None 13 | 14 | class Config: 15 | json_schema_extra = { 16 | "example": { 17 | "wineID": 3845, 18 | "country": "Italy", 19 | "title": "Castellinuzza e Piuca 2010 Chianti Classico", 20 | "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", 21 | "points": 93, 22 | "price": 16, 23 | "variety": "Red Blend", 24 | "winery": "Castellinuzza e Piuca", 25 | } 26 | } 27 | 28 | 29 | class TopWinesByCountry(BaseModel): 30 | wineID: int 31 | country: str 32 | title: str 33 | description: str | None 34 | points: int 35 | price: float | str 36 | variety: str | None 37 | winery: str | None 38 | 39 | 40 | class TopWinesByProvince(BaseModel): 41 | wineID: int 42 | country: str 43 | province: str 44 | title: str 45 | description: str | None 46 | points: int 47 | price: float | str 48 | variety: str | None 49 | winery: str | None 50 | 51 | 52 | class MostWinesByVariety(BaseModel): 53 | country: str 54 | wineCount: int -------------------------------------------------------------------------------- /src/schemas/wine.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, field_validator 2 | 3 | 4 | class Wine(BaseModel): 5 | id: int 6 | points: int 7 | title: str 8 | description: str | None 9 | price: float | None 10 | variety: str | None 11 | winery: str | None 12 | country: str | None 13 | province: str | None 14 | region_1: str | None 15 | region_2: str | None 16 | vineyard: str | None = Field(alias="designation") 17 | taster_name: str | None 18 | taster_twitter_handle: str | None 19 | 20 | @field_validator("country") 21 | def validate_country(cls, value: str | None) -> str: 22 | if value is None: 23 | return "Unknown" 24 | return value -------------------------------------------------------------------------------- /src/tests/test_crud.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_asyncio 3 | from neo4j import AsyncGraphDatabase, GraphDatabase 4 | 5 | from src.config import settings 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def config(): 10 | from dotenv import load_dotenv 11 | 12 | load_dotenv() 13 | return settings.Settings() 14 | 15 | 16 | # Test a sync Neo4j database connection 17 | @pytest.fixture(scope="session") 18 | def sync_connection(config): 19 | URI = f"bolt://{config.neo4j_service}:7687" 20 | AUTH = (config.neo4j_user, config.neo4j_password) 21 | with GraphDatabase.driver(URI, auth=AUTH) as driver: 22 | with driver.session(database="neo4j") as session: 23 | yield session 24 | 25 | 26 | # Test an async Neo4j database connection 27 | @pytest_asyncio.fixture 28 | async def async_connection(config): 29 | URI = f"bolt://{config.neo4j_service}:7687" 30 | AUTH = (config.neo4j_user, config.neo4j_password) 31 | async with AsyncGraphDatabase.driver(URI, auth=AUTH) as driver: 32 | async with driver.session(database="neo4j") as session: 33 | yield session 34 | 35 | 36 | def test_sync_transactions(sync_connection): 37 | # Merge dummy node 38 | merge_query = "MERGE (t:Test) RETURN count(t) AS res" 39 | response = sync_connection.run(merge_query) 40 | result = response.single() 41 | assert result.get("res") == 1 42 | # Delete dummy node 43 | delete_query = "MATCH (t:Test) DELETE t" 44 | response = sync_connection.run(delete_query) 45 | # Check if dummy node was deleted 46 | match_query = "MATCH (t:Test) RETURN count(t) AS res" 47 | response = sync_connection.run(match_query) 48 | result = response.single() 49 | assert result.get("res") == 0 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_async_transactions(async_connection): 54 | # Merge dummy node 55 | merge_query = "MERGE (t:Test) RETURN count(t) AS res" 56 | response = await async_connection.run(merge_query) 57 | result = await response.single() 58 | assert result.get("res") == 1 59 | # Delete dummy node 60 | delete_query = "MATCH (t:Test) DELETE t" 61 | response = await async_connection.run(delete_query) 62 | # Check if dummy node was deleted 63 | match_query = "MATCH (t:Test) RETURN count(t) AS res" 64 | response = await async_connection.run(match_query) 65 | result = await response.single() 66 | assert result.get("res") == 0 67 | -------------------------------------------------------------------------------- /src/tests/test_schemas.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.schemas.wine import Wine 3 | 4 | 5 | @pytest.fixture(scope="session") 6 | def data(): 7 | example = { 8 | "points": 90, 9 | "title": "Castello San Donato in Perano 2009 Riserva (Chianti Classico)", 10 | "description": "Made from a blend of 85% Sangiovese and 15% Merlot, this ripe wine delivers soft plum, black currants, clove and cracked pepper sensations accented with coffee and espresso notes. A backbone of firm tannins give structure. Drink now through 2019.", 11 | "taster_name": "Kerin O'Keefe", 12 | "taster_twitter_handle": "@kerinokeefe", 13 | "price": 30.0, 14 | "designation": "Riserva", 15 | "variety": "Red Blend", 16 | "region_1": "Chianti Classico", 17 | "region_2": "Monti del Chianti", 18 | "province": "Tuscany", 19 | "country": "Italy", 20 | "winery": "Castello San Donato in Perano", 21 | "id": 40825 22 | } 23 | return example 24 | 25 | 26 | # Test pydantic schema validation for wine items 27 | def test_wine_schema(data): 28 | wine = Wine(**data) 29 | assert isinstance(wine.country, str) 30 | assert isinstance(wine.points, int) 31 | assert -1 < wine.points < 101 32 | assert isinstance(wine.price, float) 33 | assert wine.price > 0 34 | assert isinstance(wine.id, int) 35 | assert all( 36 | isinstance(item, str) 37 | for item in [ 38 | wine.country, 39 | wine.province, 40 | wine.description, 41 | wine.vineyard, 42 | wine.region_1, 43 | wine.region_2, 44 | wine.taster_name, 45 | wine.taster_twitter_handle, 46 | wine.title, 47 | wine.variety, 48 | ] 49 | ) --------------------------------------------------------------------------------