├── .flake8 ├── .github └── workflows │ ├── ci.yml │ ├── codeql-analysis.yml │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pyup.yml ├── LICENSE ├── README.md ├── fastapi_async_sqlalchemy ├── __init__.py ├── exceptions.py ├── middleware.py └── py.typed ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── conftest.py └── test_session.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .venv 3 | max-line-length = 100 4 | extend-ignore = E203 5 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | test: 12 | name: test 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | build: [linux_3.9, windows_3.9, mac_3.9] 17 | include: 18 | - build: linux_3.9 19 | os: ubuntu-latest 20 | python: 3.9 21 | - build: windows_3.9 22 | os: windows-latest 23 | python: 3.9 24 | - build: mac_3.9 25 | os: macos-latest 26 | python: 3.9 27 | steps: 28 | - name: Checkout repository 29 | uses: actions/checkout@v2 30 | 31 | - name: Set up Python ${{ matrix.python }} 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: ${{ matrix.python }} 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip wheel 39 | pip install -r requirements.txt 40 | 41 | # test all the builds apart from linux_3.8... 42 | - name: Test with pytest 43 | if: matrix.build != 'linux_3.9' 44 | run: pytest 45 | 46 | # only do the test coverage for linux_3.8 47 | - name: Produce coverage report 48 | if: matrix.build == 'linux_3.9' 49 | run: pytest --cov=fastapi_async_sqlalchemy --cov-report=xml 50 | 51 | - name: Upload coverage report 52 | if: matrix.build == 'linux_3.9' 53 | uses: codecov/codecov-action@v1 54 | with: 55 | file: ./coverage.xml 56 | 57 | lint: 58 | name: lint 59 | runs-on: ubuntu-latest 60 | steps: 61 | - name: Checkout repository 62 | uses: actions/checkout@v2 63 | 64 | - name: Set up Python 65 | uses: actions/setup-python@v4 66 | with: 67 | python-version: 3.9 68 | 69 | - name: Install dependencies 70 | run: pip install flake8 71 | 72 | - name: Run flake8 73 | run: flake8 --count . 74 | 75 | format: 76 | name: format 77 | runs-on: ubuntu-latest 78 | steps: 79 | - name: Checkout repository 80 | uses: actions/checkout@v2 81 | 82 | - name: Set up Python 83 | uses: actions/setup-python@v4 84 | with: 85 | python-version: 3.9 86 | 87 | - name: Install dependencies 88 | # isort needs all of the packages to be installed so it can 89 | # tell which are third party and which are first party 90 | run: pip install -r requirements.txt 91 | 92 | - name: Check formatting of imports 93 | run: isort --check-only --diff --verbose 94 | 95 | - name: Check formatting of code 96 | run: black . --check --diff 97 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '43 19 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 37 | # Learn more: 38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 39 | 40 | steps: 41 | - name: Checkout repository 42 | uses: actions/checkout@v2 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@v1 47 | with: 48 | languages: ${{ matrix.language }} 49 | # If you wish to specify custom queries, you can do so here or in a config file. 50 | # By default, queries listed here will override any specified in a config file. 51 | # Prefix the list here with "+" to use these queries and those in the config file. 52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 53 | 54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 55 | # If this step fails, then you should remove it and run the build manually (see below) 56 | - name: Autobuild 57 | uses: github/codeql-action/autobuild@v1 58 | 59 | # ℹ️ Command-line programs to run using the OS shell. 60 | # 📚 https://git.io/JvXDl 61 | 62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 63 | # and modify them (or add more) to build your code if your project 64 | # uses a compiled language 65 | 66 | #- run: | 67 | # make bootstrap 68 | # make release 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v1 72 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | build/ 3 | .vscode/ 4 | venv/ 5 | .idea/ 6 | __pycache__/ 7 | .pytest_cache/ 8 | .venv/ 9 | .env 10 | .coverage 11 | htmlcov/ 12 | test.py 13 | *.egg-info 14 | coverage.xml 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: (alembic|build|dist|docker|esign|kubernetes|migrations) 2 | 3 | default_language_version: 4 | python: python3.8 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.4.0 9 | hooks: 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | - repo: https://github.com/asottile/pyupgrade 13 | rev: v2.28.0 14 | hooks: 15 | - id: pyupgrade 16 | args: 17 | - --py37-plus 18 | - repo: https://github.com/myint/autoflake 19 | rev: v1.4 20 | hooks: 21 | - id: autoflake 22 | args: 23 | - --in-place 24 | - --remove-all-unused-imports 25 | - --expand-star-imports 26 | - --remove-duplicate-keys 27 | - --remove-unused-variables 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | - repo: https://github.com/psf/black 33 | rev: 22.3.0 34 | hooks: 35 | - id: black 36 | - repo: https://github.com/PyCQA/flake8 37 | rev: 5.0.4 38 | hooks: 39 | - id: flake8 40 | args: 41 | - --max-line-length=100 42 | - --ignore=E203, E501, W503 43 | - repo: https://github.com/pre-commit/mirrors-mypy 44 | rev: v0.982 45 | hooks: 46 | - id: mypy 47 | additional_dependencies: 48 | - pydantic 49 | - types-ujson 50 | -------------------------------------------------------------------------------- /.pyup.yml: -------------------------------------------------------------------------------- 1 | # autogenerated pyup.io config file 2 | # see https://pyup.io/docs/configuration/ for all available options 3 | 4 | schedule: '' 5 | update: false 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Michael Freeborn 4 | Copyright (c) 2021 Eugene Shershen 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SQLAlchemy FastAPI middleware 2 | 3 | [![ci](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB) 4 | [![ci](https://github.com/h0rn3t/fastapi-async-sqlalchemy/workflows/ci/badge.svg)](https://github.com/h0rn3t/fastapi-async-sqlalchemy/workflows/ci/badge.svg) 5 | [![codecov](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy/branch/main/graph/badge.svg?token=F4NJ34WKPY)](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 7 | [![pip](https://img.shields.io/pypi/v/fastapi_async_sqlalchemy?color=blue)](https://pypi.org/project/fastapi-async-sqlalchemy/) 8 | [![Downloads](https://static.pepy.tech/badge/fastapi-async-sqlalchemy)](https://pepy.tech/project/fastapi-async-sqlalchemy) 9 | [![Updates](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/shield.svg)](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/) 10 | 11 | ### Description 12 | 13 | Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine. 14 | 15 | ### Install 16 | 17 | ```bash 18 | pip install fastapi-async-sqlalchemy 19 | ``` 20 | 21 | 22 | It also works with ```sqlmodel``` 23 | 24 | 25 | ### Examples 26 | 27 | Note that the session object provided by ``db.session`` is based on the Python3.7+ ``ContextVar``. This means that 28 | each session is linked to the individual request context in which it was created. 29 | 30 | ```python 31 | 32 | from fastapi import FastAPI 33 | from fastapi_async_sqlalchemy import SQLAlchemyMiddleware 34 | from fastapi_async_sqlalchemy import db # provide access to a database session 35 | from sqlalchemy import column 36 | from sqlalchemy import table 37 | 38 | app = FastAPI() 39 | app.add_middleware( 40 | SQLAlchemyMiddleware, 41 | db_url="postgresql+asyncpg://user:user@192.168.88.200:5432/primary_db", 42 | engine_args={ # engine arguments example 43 | "echo": True, # print all SQL statements 44 | "pool_pre_ping": True, # feature will normally emit SQL equivalent to “SELECT 1” each time a connection is checked out from the pool 45 | "pool_size": 5, # number of connections to keep open at a time 46 | "max_overflow": 10, # number of connections to allow to be opened above pool_size 47 | }, 48 | ) 49 | # once the middleware is applied, any route can then access the database session 50 | # from the global ``db`` 51 | 52 | foo = table("ms_files", column("id")) 53 | 54 | # Usage inside of a route 55 | @app.get("/") 56 | async def get_files(): 57 | result = await db.session.execute(foo.select()) 58 | return result.fetchall() 59 | 60 | async def get_db_fetch(): 61 | # It uses the same ``db`` object and use it as a context manager: 62 | async with db(): 63 | result = await db.session.execute(foo.select()) 64 | return result.fetchall() 65 | 66 | # Usage inside of a route using a db context 67 | @app.get("/db_context") 68 | async def db_context(): 69 | return await get_db_fetch() 70 | 71 | # Usage outside of a route using a db context 72 | @app.on_event("startup") 73 | async def on_startup(): 74 | # We are outside of a request context, therefore we cannot rely on ``SQLAlchemyMiddleware`` 75 | # to create a database session for us. 76 | result = await get_db_fetch() 77 | 78 | 79 | if __name__ == "__main__": 80 | import uvicorn 81 | uvicorn.run(app, host="0.0.0.0", port=8002) 82 | 83 | ``` 84 | 85 | #### Usage of multiple databases 86 | 87 | databases.py 88 | 89 | ```python 90 | from fastapi import FastAPI 91 | from fastapi_async_sqlalchemy import create_middleware_and_session_proxy 92 | 93 | FirstSQLAlchemyMiddleware, first_db = create_middleware_and_session_proxy() 94 | SecondSQLAlchemyMiddleware, second_db = create_middleware_and_session_proxy() 95 | ``` 96 | 97 | main.py 98 | 99 | ```python 100 | from fastapi import FastAPI 101 | 102 | from databases import FirstSQLAlchemyMiddleware, SecondSQLAlchemyMiddleware 103 | from routes import router 104 | 105 | app = FastAPI() 106 | 107 | app.include_router(router) 108 | 109 | app.add_middleware( 110 | FirstSQLAlchemyMiddleware, 111 | db_url="postgresql+asyncpg://user:user@192.168.88.200:5432/primary_db", 112 | engine_args={ 113 | "pool_size": 5, 114 | "max_overflow": 10, 115 | }, 116 | ) 117 | app.add_middleware( 118 | SecondSQLAlchemyMiddleware, 119 | db_url="mysql+aiomysql://user:user@192.168.88.200:5432/primary_db", 120 | engine_args={ 121 | "pool_size": 5, 122 | "max_overflow": 10, 123 | }, 124 | ) 125 | ``` 126 | 127 | routes.py 128 | 129 | ```python 130 | import asyncio 131 | 132 | from fastapi import APIRouter 133 | from sqlalchemy import column, table, text 134 | 135 | from databases import first_db, second_db 136 | 137 | router = APIRouter() 138 | 139 | foo = table("ms_files", column("id")) 140 | 141 | @router.get("/first-db-files") 142 | async def get_files_from_first_db(): 143 | result = await first_db.session.execute(foo.select()) 144 | return result.fetchall() 145 | 146 | 147 | @router.get("/second-db-files") 148 | async def get_files_from_second_db(): 149 | result = await second_db.session.execute(foo.select()) 150 | return result.fetchall() 151 | 152 | 153 | @router.get("/concurrent-queries") 154 | async def parallel_select(): 155 | async with first_db(multi_sessions=True): 156 | async def execute_query(query): 157 | return await first_db.session.execute(text(query)) 158 | 159 | tasks = [ 160 | asyncio.create_task(execute_query("SELECT 1")), 161 | asyncio.create_task(execute_query("SELECT 2")), 162 | asyncio.create_task(execute_query("SELECT 3")), 163 | asyncio.create_task(execute_query("SELECT 4")), 164 | asyncio.create_task(execute_query("SELECT 5")), 165 | asyncio.create_task(execute_query("SELECT 6")), 166 | ] 167 | 168 | await asyncio.gather(*tasks) 169 | ``` 170 | -------------------------------------------------------------------------------- /fastapi_async_sqlalchemy/__init__.py: -------------------------------------------------------------------------------- 1 | from fastapi_async_sqlalchemy.middleware import ( 2 | SQLAlchemyMiddleware, 3 | db, 4 | create_middleware_and_session_proxy, 5 | ) 6 | 7 | __all__ = ["db", "SQLAlchemyMiddleware", "create_middleware_and_session_proxy"] 8 | 9 | __version__ = "0.7.0.dev4" 10 | -------------------------------------------------------------------------------- /fastapi_async_sqlalchemy/exceptions.py: -------------------------------------------------------------------------------- 1 | class MissingSessionError(Exception): 2 | """ 3 | Exception raised for when the user tries to access a database session before it is created. 4 | """ 5 | 6 | def __init__(self): 7 | msg = """ 8 | No session found! Either you are not currently in a request context, 9 | or you need to manually create a session context by using a `db` instance as 10 | a context manager e.g.: 11 | 12 | async with db(): 13 | await db.session.execute(foo.select()).fetchall() 14 | """ 15 | 16 | super().__init__(msg) 17 | 18 | 19 | class SessionNotInitialisedError(Exception): 20 | """ 21 | Exception raised when the user creates a new DB session without first initialising it. 22 | """ 23 | 24 | def __init__(self): 25 | msg = """ 26 | Session not initialised! Ensure that DBSessionMiddleware has been initialised before 27 | attempting database access. 28 | """ 29 | 30 | super().__init__(msg) 31 | -------------------------------------------------------------------------------- /fastapi_async_sqlalchemy/middleware.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from contextvars import ContextVar 3 | from typing import Dict, Optional, Union 4 | 5 | from sqlalchemy.engine import Engine 6 | from sqlalchemy.engine.url import URL 7 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine 8 | from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint 9 | from starlette.requests import Request 10 | from starlette.types import ASGIApp 11 | 12 | from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError 13 | 14 | try: 15 | from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811 16 | except ImportError: 17 | from sqlalchemy.orm import sessionmaker as async_sessionmaker 18 | 19 | 20 | def create_middleware_and_session_proxy(): 21 | _Session: Optional[async_sessionmaker] = None 22 | _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) 23 | _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) 24 | _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) 25 | # Usage of context vars inside closures is not recommended, since they are not properly 26 | # garbage collected, but in our use case context var is created on program startup and 27 | # is used throughout the whole its lifecycle. 28 | 29 | class SQLAlchemyMiddleware(BaseHTTPMiddleware): 30 | def __init__( 31 | self, 32 | app: ASGIApp, 33 | db_url: Optional[Union[str, URL]] = None, 34 | custom_engine: Optional[Engine] = None, 35 | engine_args: Dict = None, 36 | session_args: Dict = None, 37 | commit_on_exit: bool = False, 38 | ): 39 | super().__init__(app) 40 | self.commit_on_exit = commit_on_exit 41 | engine_args = engine_args or {} 42 | session_args = session_args or {} 43 | 44 | if not custom_engine and not db_url: 45 | raise ValueError("You need to pass a db_url or a custom_engine parameter.") 46 | if not custom_engine: 47 | engine = create_async_engine(db_url, **engine_args) 48 | else: 49 | engine = custom_engine 50 | 51 | nonlocal _Session 52 | _Session = async_sessionmaker( 53 | engine, class_=AsyncSession, expire_on_commit=False, **session_args 54 | ) 55 | 56 | async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): 57 | async with DBSession(commit_on_exit=self.commit_on_exit): 58 | return await call_next(request) 59 | 60 | class DBSessionMeta(type): 61 | @property 62 | def session(self) -> AsyncSession: 63 | """Return an instance of Session local to the current async context.""" 64 | if _Session is None: 65 | raise SessionNotInitialisedError 66 | 67 | multi_sessions = _multi_sessions_ctx.get() 68 | if multi_sessions: 69 | """In this case, we need to create a new session for each task. 70 | We also need to commit the session on exit if commit_on_exit is True. 71 | This is useful when we need to run multiple queries in parallel. 72 | For example, when we need to run multiple queries in parallel in a route handler. 73 | Example: 74 | ```python 75 | async with db(multi_sessions=True): 76 | async def execute_query(query): 77 | return await db.session.execute(text(query)) 78 | 79 | tasks = [ 80 | asyncio.create_task(execute_query("SELECT 1")), 81 | asyncio.create_task(execute_query("SELECT 2")), 82 | asyncio.create_task(execute_query("SELECT 3")), 83 | asyncio.create_task(execute_query("SELECT 4")), 84 | asyncio.create_task(execute_query("SELECT 5")), 85 | asyncio.create_task(execute_query("SELECT 6")), 86 | ] 87 | 88 | await asyncio.gather(*tasks) 89 | ``` 90 | """ 91 | commit_on_exit = _commit_on_exit_ctx.get() 92 | # Always create a new session for each access when multi_sessions=True 93 | session = _Session() 94 | 95 | async def cleanup(): 96 | try: 97 | if commit_on_exit: 98 | await session.commit() 99 | except Exception: 100 | await session.rollback() 101 | raise 102 | finally: 103 | await session.close() 104 | 105 | task = asyncio.current_task() 106 | if task is not None: 107 | task.add_done_callback(lambda t: asyncio.create_task(cleanup())) 108 | return session 109 | else: 110 | session = _session.get() 111 | if session is None: 112 | raise MissingSessionError 113 | return session 114 | 115 | class DBSession(metaclass=DBSessionMeta): 116 | def __init__( 117 | self, 118 | session_args: Dict = None, 119 | commit_on_exit: bool = False, 120 | multi_sessions: bool = False, 121 | ): 122 | self.token = None 123 | self.commit_on_exit_token = None 124 | self.session_args = session_args or {} 125 | self.commit_on_exit = commit_on_exit 126 | self.multi_sessions = multi_sessions 127 | 128 | async def __aenter__(self): 129 | if not isinstance(_Session, async_sessionmaker): 130 | raise SessionNotInitialisedError 131 | 132 | if self.multi_sessions: 133 | self.multi_sessions_token = _multi_sessions_ctx.set(True) 134 | self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit) 135 | else: 136 | self.token = _session.set(_Session(**self.session_args)) 137 | return type(self) 138 | 139 | async def __aexit__(self, exc_type, exc_value, traceback): 140 | if self.multi_sessions: 141 | _multi_sessions_ctx.reset(self.multi_sessions_token) 142 | _commit_on_exit_ctx.reset(self.commit_on_exit_token) 143 | else: 144 | session = _session.get() 145 | try: 146 | if exc_type is not None: 147 | await session.rollback() 148 | elif self.commit_on_exit: 149 | await session.commit() 150 | finally: 151 | await session.close() 152 | _session.reset(self.token) 153 | 154 | return SQLAlchemyMiddleware, DBSession 155 | 156 | 157 | SQLAlchemyMiddleware, db = create_middleware_and_session_proxy() 158 | -------------------------------------------------------------------------------- /fastapi_async_sqlalchemy/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h0rn3t/fastapi-async-sqlalchemy/a76b7bb67ea6e1fe90f342bbbe002d58ec1f876f/fastapi_async_sqlalchemy/py.typed -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | target-version = ['py37'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | ( 7 | | .git 8 | | .venv 9 | | build 10 | | dist 11 | ) 12 | ''' 13 | 14 | [tool.isort] 15 | multi_line_output = 3 16 | include_trailing_comma = true 17 | force_grid_wrap = 0 18 | use_parentheses = true 19 | line_length = 100 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.3 2 | atomicwrites==1.3.0 3 | attrs>=19.3.0 4 | certifi>=2023.07.22 5 | chardet==3.0.4 6 | click>=8.1.3 7 | coverage>=5.2.1 8 | entrypoints==0.3 9 | fastapi==0.90.0 # pyup: ignore 10 | flake8==3.7.9 11 | idna==3.7 12 | importlib-metadata==1.5.0 13 | isort==4.3.21 14 | mccabe==0.6.1 15 | more-itertools==7.2.0 16 | packaging>=22.0 17 | pathspec>=0.9.0 18 | pluggy==0.13.0 19 | pycodestyle==2.5.0 20 | pydantic==1.10.18 21 | pyflakes==2.1.1 22 | pyparsing==2.4.2 23 | pytest==7.2.0 24 | pytest-cov==2.11.1 25 | PyYAML>=5.4 26 | regex>=2020.2.20 27 | requests>=2.22.0 28 | httpx>=0.20.0 29 | six==1.12.0 30 | SQLAlchemy>=1.4.19 31 | asyncpg>=0.27.0 32 | aiosqlite==0.20.0 33 | sqlparse==0.5.1 34 | starlette>=0.13.6 35 | toml>=0.10.1 36 | typed-ast>=1.4.2 37 | urllib3>=1.25.9 38 | wcwidth==0.1.7 39 | zipp==3.19.1 40 | black==24.4.2 41 | pytest-asyncio==0.21.0 42 | greenlet==3.1.1 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | from setuptools import setup 5 | 6 | with open(Path("fastapi_async_sqlalchemy") / "__init__.py", encoding="utf-8") as fh: 7 | version = re.search(r'__version__ = "(.*?)"', fh.read(), re.M).group(1) # type: ignore 8 | 9 | with open("README.md", encoding="utf-8") as fh: 10 | long_description = fh.read() 11 | 12 | setup( 13 | name="fastapi-async-sqlalchemy", 14 | version=version, 15 | url="https://github.com/h0rn3t/fastapi-async-sqlalchemy.git", 16 | project_urls={ 17 | "Code": "https://github.com/h0rn3t/fastapi-async-sqlalchemy", 18 | "Issue tracker": "https://github.com/h0rn3t/fastapi-async-sqlalchemy/issues", 19 | }, 20 | license="MIT", 21 | author="Eugene Shershen", 22 | author_email="h0rn3t.null@gmail.com", 23 | description="SQLAlchemy middleware for FastAPI", 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | packages=["fastapi_async_sqlalchemy"], 27 | package_data={"fastapi_async_sqlalchemy": ["py.typed"]}, 28 | zip_safe=False, 29 | python_requires=">=3.7", 30 | install_requires=["starlette>=0.13.6", "SQLAlchemy>=1.4.19"], 31 | classifiers=[ 32 | "Development Status :: 5 - Production/Stable", 33 | "Environment :: Web Environment", 34 | "Framework :: AsyncIO", 35 | "Intended Audience :: Developers", 36 | "License :: OSI Approved :: MIT License", 37 | "Operating System :: OS Independent", 38 | "Programming Language :: Python :: 3.7", 39 | "Programming Language :: Python :: 3.8", 40 | "Programming Language :: Python :: 3.9", 41 | "Programming Language :: Python :: 3.10", 42 | "Programming Language :: Python :: 3.11", 43 | "Programming Language :: Python :: 3.12", 44 | "Programming Language :: Python :: 3 :: Only", 45 | "Programming Language :: Python :: Implementation :: CPython", 46 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers", 47 | "Topic :: Internet :: WWW/HTTP :: Dynamic Content", 48 | "Topic :: Software Development :: Libraries :: Python Modules", 49 | ], 50 | ) 51 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h0rn3t/fastapi-async-sqlalchemy/a76b7bb67ea6e1fe90f342bbbe002d58ec1f876f/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | from fastapi import FastAPI 5 | from starlette.testclient import TestClient 6 | 7 | 8 | @pytest.fixture 9 | def app(): 10 | return FastAPI() 11 | 12 | 13 | @pytest.fixture 14 | def client(app): 15 | with TestClient(app) as c: 16 | yield c 17 | 18 | 19 | @pytest.fixture 20 | def SQLAlchemyMiddleware(): 21 | from fastapi_async_sqlalchemy import SQLAlchemyMiddleware 22 | 23 | yield SQLAlchemyMiddleware 24 | 25 | 26 | @pytest.fixture 27 | def db(): 28 | from fastapi_async_sqlalchemy import db 29 | 30 | yield db 31 | 32 | # force reloading of module to clear global state 33 | 34 | try: 35 | del sys.modules["fastapi_async_sqlalchemy"] 36 | except KeyError: 37 | pass 38 | 39 | try: 40 | del sys.modules["fastapi_async_sqlalchemy.middleware"] 41 | except KeyError: 42 | pass 43 | -------------------------------------------------------------------------------- /tests/test_session.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | from sqlalchemy import text 5 | from sqlalchemy.exc import IntegrityError 6 | from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine 7 | from starlette.middleware.base import BaseHTTPMiddleware 8 | 9 | from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError 10 | 11 | db_url = "sqlite+aiosqlite://" 12 | 13 | 14 | @pytest.mark.asyncio 15 | async def test_init(app, SQLAlchemyMiddleware): 16 | mw = SQLAlchemyMiddleware(app, db_url=db_url) 17 | assert isinstance(mw, BaseHTTPMiddleware) 18 | 19 | 20 | @pytest.mark.asyncio 21 | async def test_init_required_args(app, SQLAlchemyMiddleware): 22 | with pytest.raises(ValueError) as exc_info: 23 | SQLAlchemyMiddleware(app) 24 | 25 | assert exc_info.value.args[0] == "You need to pass a db_url or a custom_engine parameter." 26 | 27 | 28 | @pytest.mark.asyncio 29 | async def test_init_required_args_custom_engine(app, db, SQLAlchemyMiddleware): 30 | custom_engine = create_async_engine(db_url) 31 | SQLAlchemyMiddleware(app, custom_engine=custom_engine) 32 | 33 | 34 | @pytest.mark.asyncio 35 | async def test_init_correct_optional_args(app, db, SQLAlchemyMiddleware): 36 | engine_args = {"echo": True} 37 | # session_args = {"expire_on_commit": False} 38 | 39 | SQLAlchemyMiddleware(app, db_url, engine_args=engine_args, session_args={}) 40 | 41 | async with db(): 42 | # assert not db.session.expire_on_commit 43 | engine = db.session.bind 44 | assert engine.echo 45 | 46 | async with db() as db_ctx: 47 | engine = db_ctx.session.bind 48 | assert engine.echo 49 | 50 | 51 | @pytest.mark.asyncio 52 | async def test_init_incorrect_optional_args(app, SQLAlchemyMiddleware): 53 | with pytest.raises(TypeError) as exc_info: 54 | SQLAlchemyMiddleware(app, db_url=db_url, invalid_args="test") 55 | 56 | assert "__init__() got an unexpected keyword argument 'invalid_args'" in exc_info.value.args[0] 57 | 58 | 59 | @pytest.mark.asyncio 60 | async def test_inside_route(app, client, db, SQLAlchemyMiddleware): 61 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 62 | 63 | @app.get("/") 64 | def test_get(): 65 | assert isinstance(db.session, AsyncSession) 66 | 67 | client.get("/") 68 | 69 | 70 | @pytest.mark.asyncio 71 | async def test_inside_route_without_middleware_fails(app, client, db): 72 | @app.get("/") 73 | def test_get(): 74 | with pytest.raises(SessionNotInitialisedError): 75 | db.session 76 | 77 | client.get("/") 78 | 79 | 80 | @pytest.mark.asyncio 81 | async def test_outside_of_route(app, db, SQLAlchemyMiddleware): 82 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 83 | 84 | async with db(): 85 | assert isinstance(db.session, AsyncSession) 86 | 87 | 88 | @pytest.mark.asyncio 89 | async def test_outside_of_route_without_middleware_fails(db): 90 | with pytest.raises(SessionNotInitialisedError): 91 | db.session 92 | 93 | with pytest.raises(SessionNotInitialisedError): 94 | async with db(): 95 | pass 96 | 97 | 98 | @pytest.mark.asyncio 99 | async def test_outside_of_route_without_context_fails(app, db, SQLAlchemyMiddleware): 100 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 101 | 102 | with pytest.raises(MissingSessionError): 103 | db.session 104 | 105 | 106 | @pytest.mark.asyncio 107 | async def test_init_session(app, db, SQLAlchemyMiddleware): 108 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 109 | 110 | async with db(): 111 | assert isinstance(db.session, AsyncSession) 112 | 113 | 114 | @pytest.mark.asyncio 115 | async def test_db_session_commit_fail(app, db, SQLAlchemyMiddleware): 116 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url, commit_on_exit=True) 117 | 118 | with pytest.raises(IntegrityError): 119 | async with db(): 120 | raise IntegrityError("test", "test", "test") 121 | db.session.close.assert_called_once() 122 | 123 | async with db(): 124 | assert db.session 125 | 126 | 127 | @pytest.mark.asyncio 128 | async def test_rollback(app, db, SQLAlchemyMiddleware): 129 | # pytest-cov shows that the line in db.__exit__() rolling back the db session 130 | # when there is an Exception is run correctly. However, it would be much better 131 | # if we could demonstrate somehow that db.session.rollback() was called e.g. once 132 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 133 | 134 | with pytest.raises(Exception): 135 | async with db(): 136 | raise Exception 137 | 138 | db.session.rollback.assert_called_once() 139 | 140 | 141 | @pytest.mark.parametrize("commit_on_exit", [True, False]) 142 | @pytest.mark.asyncio 143 | async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_exit): 144 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url, commit_on_exit=commit_on_exit) 145 | 146 | session_args = {} 147 | 148 | async with db(session_args=session_args, commit_on_exit=True): 149 | assert isinstance(db.session, AsyncSession) 150 | 151 | session_args = {"expire_on_commit": False} 152 | async with db(session_args=session_args): 153 | db.session 154 | 155 | 156 | @pytest.mark.asyncio 157 | async def test_multi_sessions(app, db, SQLAlchemyMiddleware): 158 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 159 | 160 | async with db(multi_sessions=True): 161 | 162 | async def execute_query(query): 163 | return await db.session.execute(text(query)) 164 | 165 | tasks = [ 166 | asyncio.create_task(execute_query("SELECT 1")), 167 | asyncio.create_task(execute_query("SELECT 2")), 168 | asyncio.create_task(execute_query("SELECT 3")), 169 | asyncio.create_task(execute_query("SELECT 4")), 170 | asyncio.create_task(execute_query("SELECT 5")), 171 | asyncio.create_task(execute_query("SELECT 6")), 172 | ] 173 | 174 | res = await asyncio.gather(*tasks) 175 | assert len(res) == 6 176 | 177 | 178 | @pytest.mark.asyncio 179 | async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware): 180 | app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) 181 | 182 | async with db(multi_sessions=True, commit_on_exit=True): 183 | await db.session.execute( 184 | text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)") 185 | ) 186 | 187 | async def insert_data(value): 188 | await db.session.execute( 189 | text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value} 190 | ) 191 | await db.session.flush() 192 | 193 | tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)] 194 | 195 | result_ids = await asyncio.gather(*tasks) 196 | assert len(result_ids) == 10 197 | 198 | records = await db.session.execute(text("SELECT * FROM my_model")) 199 | records = records.scalars().all() 200 | assert len(records) == 10 201 | --------------------------------------------------------------------------------