├── tests ├── __init__.py ├── conftest.py ├── test_access_token.py └── test_users.py ├── fastapi_users_db_sqlalchemy ├── py.typed ├── generics.py ├── access_token.py └── __init__.py ├── .github ├── FUNDING.yml ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── config.yml │ └── bug_report.md ├── stale.yml └── workflows │ └── build.yml ├── test_build.py ├── .editorconfig ├── Makefile ├── .vscode └── settings.json ├── LICENSE ├── .gitignore ├── README.md └── pyproject.toml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastapi_users_db_sqlalchemy/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: frankie567 2 | -------------------------------------------------------------------------------- /test_build.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import sys 3 | 4 | try: 5 | from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase 6 | except: 7 | sys.exit(1) 8 | 9 | sys.exit(0) 10 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "04:00" 8 | open-pull-requests-limit: 10 9 | reviewers: 10 | - frankie567 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | contact_links: 2 | - name: I have a question 🤔 3 | url: https://github.com/fastapi-users/fastapi-users/discussions 4 | about: If you have any question about the usage of FastAPI Users that's not clearly a bug, please open a discussion first. 5 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.yml] 14 | indent_size = 2 15 | 16 | [Makefile] 17 | indent_style = tab 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | isort: 2 | isort ./fastapi_users_db_sqlalchemy ./tests 3 | 4 | format: isort 5 | black . 6 | 7 | test: 8 | pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=term-missing --cov-fail-under=100 9 | 10 | bumpversion-major: 11 | bumpversion major 12 | 13 | bumpversion-minor: 14 | bumpversion minor 15 | 16 | bumpversion-patch: 17 | bumpversion patch 18 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 14 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | - bug 10 | - enhancement 11 | - documentation 12 | # Label to use when marking an issue as stale 13 | staleLabel: stale 14 | # Comment to post when marking an issue as stale. Set to `false` to disable 15 | markComment: > 16 | This issue has been automatically marked as stale because it has not had 17 | recent activity. It will be closed if no further activity occurs. Thank you 18 | for your contributions. 19 | # Comment to post when closing a stale issue. Set to `false` to disable 20 | closeComment: false 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Describe the bug 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | ## To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 1. Go to '...' 18 | 2. Click on '....' 19 | 3. Scroll down to '....' 20 | 4. See error 21 | 22 | ## Expected behavior 23 | 24 | A clear and concise description of what you expected to happen. 25 | 26 | ## Configuration 27 | - Python version : 28 | - FastAPI version : 29 | - FastAPI Users version : 30 | 31 | ### FastAPI Users configuration 32 | 33 | ```py 34 | # Please copy/paste your FastAPI Users configuration here. 35 | ``` 36 | 37 | ## Additional context 38 | 39 | Add any other context about the problem here. 40 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "basic", 3 | "python.analysis.autoImportCompletions": true, 4 | "python.terminal.activateEnvironment": true, 5 | "python.terminal.activateEnvInCurrentTerminal": true, 6 | "python.testing.unittestEnabled": false, 7 | "python.testing.pytestEnabled": true, 8 | "editor.rulers": [88], 9 | "python.defaultInterpreterPath": "${workspaceFolder}/.hatch/fastapi-users-db-sqlalchemy/bin/python", 10 | "python.testing.pytestPath": "${workspaceFolder}/.hatch/fastapi-users-db-sqlalchemy/bin/pytest", 11 | "python.testing.cwd": "${workspaceFolder}", 12 | "python.testing.pytestArgs": ["--no-cov"], 13 | "[python]": { 14 | "editor.formatOnSave": true, 15 | "editor.codeActionsOnSave": { 16 | "source.fixAll": "explicit", 17 | "source.organizeImports": "explicit" 18 | }, 19 | "editor.defaultFormatter": "charliermarsh.ruff" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 François Voron 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Optional 3 | 4 | import pytest 5 | from fastapi_users import schemas 6 | 7 | DATABASE_URL = os.getenv( 8 | "DATABASE_URL", "sqlite+aiosqlite:///./test-sqlalchemy-user.db" 9 | ) 10 | 11 | 12 | class User(schemas.BaseUser): 13 | first_name: Optional[str] 14 | 15 | 16 | class UserCreate(schemas.BaseUserCreate): 17 | first_name: Optional[str] 18 | 19 | 20 | class UserUpdate(schemas.BaseUserUpdate): 21 | pass 22 | 23 | 24 | class UserOAuth(User, schemas.BaseOAuthAccountMixin): 25 | pass 26 | 27 | 28 | @pytest.fixture 29 | def oauth_account1() -> dict[str, Any]: 30 | return { 31 | "oauth_name": "service1", 32 | "access_token": "TOKEN", 33 | "expires_at": 1579000751, 34 | "account_id": "user_oauth1", 35 | "account_email": "king.arthur@camelot.bt", 36 | } 37 | 38 | 39 | @pytest.fixture 40 | def oauth_account2() -> dict[str, Any]: 41 | return { 42 | "oauth_name": "service2", 43 | "access_token": "TOKEN", 44 | "expires_at": 1579000751, 45 | "account_id": "user_oauth2", 46 | "account_email": "king.arthur@camelot.bt", 47 | } 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | junit/ 50 | junit.xml 51 | test*.db* 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # OS files 108 | .DS_Store 109 | 110 | # .idea 111 | .idea/ 112 | -------------------------------------------------------------------------------- /fastapi_users_db_sqlalchemy/generics.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime, timezone 3 | from typing import Optional 4 | 5 | from pydantic import UUID4 6 | from sqlalchemy import CHAR, TIMESTAMP, TypeDecorator 7 | from sqlalchemy.dialects.postgresql import UUID 8 | 9 | 10 | class GUID(TypeDecorator): # pragma: no cover 11 | """ 12 | Platform-independent GUID type. 13 | 14 | Uses PostgreSQL's UUID type, otherwise uses 15 | CHAR(36), storing as regular strings. 16 | """ 17 | 18 | class UUIDChar(CHAR): 19 | python_type = UUID4 # type: ignore 20 | 21 | impl = UUIDChar 22 | cache_ok = True 23 | 24 | def load_dialect_impl(self, dialect): 25 | if dialect.name == "postgresql": 26 | return dialect.type_descriptor(UUID()) 27 | else: 28 | return dialect.type_descriptor(CHAR(36)) 29 | 30 | def process_bind_param(self, value, dialect): 31 | if value is None: 32 | return value 33 | elif dialect.name == "postgresql": 34 | return str(value) 35 | else: 36 | if not isinstance(value, uuid.UUID): 37 | return str(uuid.UUID(value)) 38 | else: 39 | return str(value) 40 | 41 | def process_result_value(self, value, dialect): 42 | if value is None: 43 | return value 44 | else: 45 | if not isinstance(value, uuid.UUID): 46 | value = uuid.UUID(value) 47 | return value 48 | 49 | 50 | def now_utc(): 51 | return datetime.now(timezone.utc) 52 | 53 | 54 | class TIMESTAMPAware(TypeDecorator): # pragma: no cover 55 | """ 56 | MySQL and SQLite will always return naive-Python datetimes. 57 | 58 | We store everything as UTC, but we want to have 59 | only offset-aware Python datetimes, even with MySQL and SQLite. 60 | """ 61 | 62 | impl = TIMESTAMP 63 | cache_ok = True 64 | 65 | def process_result_value(self, value: Optional[datetime], dialect): 66 | if value is not None and dialect.name != "postgresql": 67 | return value.replace(tzinfo=timezone.utc) 68 | return value 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI Users - Database adapter for SQLAlchemy ORM 2 | 3 |
4 |
5 |
8 | Ready-to-use and customizable users management for FastAPI 9 |
10 | 11 | [](https://github.com/fastapi-users/fastapi-users/actions) 12 | [](https://codecov.io/gh/fastapi-users/fastapi-users-db-sqlalchemy) 13 | [](https://badge.fury.io/py/fastapi-users-db-sqlalchemy) 14 | [](https://pepy.tech/project/fastapi-users-db-sqlalchemy) 15 | 18 | 19 | --- 20 | 21 | **Documentation**: https://fastapi-users.github.io/fastapi-users/ 22 | 23 | **Source Code**: https://github.com/fastapi-users/fastapi-users 24 | 25 | --- 26 | 27 | Add quickly a registration and authentication system to your [FastAPI](https://fastapi.tiangolo.com/) project. **FastAPI Users** is designed to be as customizable and adaptable as possible. 28 | 29 | **Sub-package for SQLAlchemy ORM support in FastAPI Users.** 30 | 31 | ## Development 32 | 33 | ### Setup environment 34 | 35 | We use [Hatch](https://hatch.pypa.io/latest/install/) to manage the development environment and production build. Ensure it's installed on your system. 36 | 37 | ### Run unit tests 38 | 39 | You can run all the tests with: 40 | 41 | ```bash 42 | hatch run test 43 | ``` 44 | 45 | ### Format the code 46 | 47 | Execute the following command to apply `isort` and `black` formatting: 48 | 49 | ```bash 50 | hatch run lint 51 | ``` 52 | 53 | ## License 54 | 55 | This project is licensed under the terms of the MIT license. 56 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.mypy] 2 | plugins = "sqlalchemy.ext.mypy.plugin" 3 | 4 | [tool.pytest.ini_options] 5 | asyncio_mode = "strict" 6 | asyncio_default_fixture_loop_scope = "function" 7 | addopts = "--ignore=test_build.py" 8 | 9 | [tool.ruff] 10 | 11 | [tool.ruff.lint] 12 | extend-select = ["I", "UP"] 13 | 14 | [tool.hatch] 15 | 16 | [tool.hatch.metadata] 17 | allow-direct-references = true 18 | 19 | [tool.hatch.version] 20 | source = "regex_commit" 21 | commit_extra_args = ["-e"] 22 | path = "fastapi_users_db_sqlalchemy/__init__.py" 23 | 24 | [tool.hatch.envs.default] 25 | installer = "uv" 26 | dependencies = [ 27 | "aiosqlite", 28 | "asyncpg", 29 | "aiomysql", 30 | "pytest", 31 | "pytest-asyncio", 32 | "black", 33 | "mypy", 34 | "pytest-cov", 35 | "pytest-mock", 36 | "asynctest", 37 | "httpx", 38 | "asgi_lifespan", 39 | "ruff", 40 | "sqlalchemy[asyncio,mypy]", 41 | ] 42 | 43 | [tool.hatch.envs.default.scripts] 44 | test = "pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=term-missing --cov-fail-under=100" 45 | test-cov-xml = "pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=xml --cov-fail-under=100" 46 | lint = [ 47 | "ruff format . ", 48 | "ruff check --fix .", 49 | "mypy fastapi_users_db_sqlalchemy/", 50 | ] 51 | lint-check = [ 52 | "ruff format --check .", 53 | "ruff check .", 54 | "mypy fastapi_users_db_sqlalchemy/", 55 | ] 56 | 57 | [tool.hatch.build.targets.sdist] 58 | support-legacy = true # Create setup.py 59 | 60 | [build-system] 61 | requires = ["hatchling", "hatch-regex-commit"] 62 | build-backend = "hatchling.build" 63 | 64 | [project] 65 | name = "fastapi-users-db-sqlalchemy" 66 | authors = [ 67 | { name = "François Voron", email = "fvoron@gmail.com" }, 68 | ] 69 | description = "FastAPI Users database adapter for SQLAlchemy" 70 | readme = "README.md" 71 | dynamic = ["version"] 72 | classifiers = [ 73 | "License :: OSI Approved :: MIT License", 74 | "Development Status :: 5 - Production/Stable", 75 | "Framework :: FastAPI", 76 | "Framework :: AsyncIO", 77 | "Intended Audience :: Developers", 78 | "Programming Language :: Python :: 3.9", 79 | "Programming Language :: Python :: 3.10", 80 | "Programming Language :: Python :: 3.11", 81 | "Programming Language :: Python :: 3.12", 82 | "Programming Language :: Python :: 3.13", 83 | "Programming Language :: Python :: 3 :: Only", 84 | "Topic :: Internet :: WWW/HTTP :: Session", 85 | ] 86 | requires-python = ">=3.9" 87 | dependencies = [ 88 | "fastapi-users >= 10.0.0", 89 | "sqlalchemy[asyncio] >=2.0.0,<2.1.0", 90 | ] 91 | 92 | [project.urls] 93 | Documentation = "https://fastapi-users.github.io/fastapi-users" 94 | Source = "https://github.com/fastapi-users/fastapi-users-db-sqlalchemy" 95 | -------------------------------------------------------------------------------- /fastapi_users_db_sqlalchemy/access_token.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime 3 | from typing import TYPE_CHECKING, Any, Generic, Optional 4 | 5 | from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase 6 | from fastapi_users.models import ID 7 | from sqlalchemy import ForeignKey, String, select 8 | from sqlalchemy.ext.asyncio import AsyncSession 9 | from sqlalchemy.orm import Mapped, declared_attr, mapped_column 10 | 11 | from fastapi_users_db_sqlalchemy.generics import GUID, TIMESTAMPAware, now_utc 12 | 13 | 14 | class SQLAlchemyBaseAccessTokenTable(Generic[ID]): 15 | """Base SQLAlchemy access token table definition.""" 16 | 17 | __tablename__ = "accesstoken" 18 | 19 | if TYPE_CHECKING: # pragma: no cover 20 | token: str 21 | created_at: datetime 22 | user_id: ID 23 | else: 24 | token: Mapped[str] = mapped_column(String(length=43), primary_key=True) 25 | created_at: Mapped[datetime] = mapped_column( 26 | TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc 27 | ) 28 | 29 | 30 | class SQLAlchemyBaseAccessTokenTableUUID(SQLAlchemyBaseAccessTokenTable[uuid.UUID]): 31 | if TYPE_CHECKING: # pragma: no cover 32 | user_id: uuid.UUID 33 | else: 34 | 35 | @declared_attr 36 | def user_id(cls) -> Mapped[GUID]: 37 | return mapped_column( 38 | GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False 39 | ) 40 | 41 | 42 | class SQLAlchemyAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): 43 | """ 44 | Access token database adapter for SQLAlchemy. 45 | 46 | :param session: SQLAlchemy session instance. 47 | :param access_token_table: SQLAlchemy access token model. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | session: AsyncSession, 53 | access_token_table: type[AP], 54 | ): 55 | self.session = session 56 | self.access_token_table = access_token_table 57 | 58 | async def get_by_token( 59 | self, token: str, max_age: Optional[datetime] = None 60 | ) -> Optional[AP]: 61 | statement = select(self.access_token_table).where( 62 | self.access_token_table.token == token # type: ignore 63 | ) 64 | if max_age is not None: 65 | statement = statement.where( 66 | self.access_token_table.created_at >= max_age # type: ignore 67 | ) 68 | 69 | results = await self.session.execute(statement) 70 | return results.scalar_one_or_none() 71 | 72 | async def create(self, create_dict: dict[str, Any]) -> AP: 73 | access_token = self.access_token_table(**create_dict) 74 | self.session.add(access_token) 75 | await self.session.commit() 76 | await self.session.refresh(access_token) 77 | return access_token 78 | 79 | async def update(self, access_token: AP, update_dict: dict[str, Any]) -> AP: 80 | for key, value in update_dict.items(): 81 | setattr(access_token, key, value) 82 | self.session.add(access_token) 83 | await self.session.commit() 84 | await self.session.refresh(access_token) 85 | return access_token 86 | 87 | async def delete(self, access_token: AP) -> None: 88 | await self.session.delete(access_token) 89 | await self.session.commit() 90 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | 9 | services: 10 | postgres: 11 | image: postgres:alpine 12 | ports: 13 | - 5432:5432 14 | env: 15 | POSTGRES_USER: fastapiusers 16 | POSTGRES_PASSWORD: fastapiuserspassword 17 | # Set health checks to wait until postgres has started 18 | options: >- 19 | --health-cmd pg_isready 20 | --health-interval 10s 21 | --health-timeout 5s 22 | --health-retries 5 23 | 24 | mariadb: 25 | image: mariadb 26 | ports: 27 | - 3306:3306 28 | env: 29 | MARIADB_ROOT_PASSWORD: fastapiuserspassword 30 | MARIADB_DATABASE: fastapiusers 31 | MARIADB_USER: fastapiusers 32 | MARIADB_PASSWORD: fastapiuserspassword 33 | 34 | strategy: 35 | fail-fast: false 36 | matrix: 37 | python_version: [3.9, '3.10', '3.11', '3.12', '3.13'] 38 | database_url: 39 | [ 40 | "sqlite+aiosqlite:///./test-fastapiusers.db", 41 | "postgresql+asyncpg://fastapiusers:fastapiuserspassword@localhost:5432/fastapiusers", 42 | "mysql+aiomysql://root:fastapiuserspassword@localhost:3306/fastapiusers", 43 | ] 44 | 45 | steps: 46 | - uses: actions/checkout@v4 47 | - name: Set up Python 48 | uses: actions/setup-python@v5 49 | with: 50 | python-version: ${{ matrix.python_version }} 51 | - name: Install dependencies 52 | run: | 53 | python -m pip install --upgrade pip 54 | pip install hatch 55 | hatch env create 56 | - name: Lint and typecheck 57 | run: | 58 | hatch run lint-check 59 | - name: Test 60 | env: 61 | DATABASE_URL: ${{ matrix.database_url }} 62 | run: | 63 | hatch run test-cov-xml 64 | - uses: codecov/codecov-action@v5 65 | with: 66 | token: ${{ secrets.CODECOV_TOKEN }} 67 | fail_ci_if_error: true 68 | verbose: true 69 | - name: Build and install it on system host 70 | run: | 71 | hatch build 72 | pip install dist/fastapi_users_db_sqlalchemy-*.whl 73 | python test_build.py 74 | 75 | release: 76 | runs-on: ubuntu-latest 77 | needs: test 78 | if: startsWith(github.ref, 'refs/tags/') 79 | 80 | steps: 81 | - uses: actions/checkout@v4 82 | - name: Set up Python 83 | uses: actions/setup-python@v5 84 | with: 85 | python-version: 3.9 86 | - name: Install dependencies 87 | run: | 88 | python -m pip install --upgrade pip 89 | pip install hatch 90 | - name: Build and publish on PyPI 91 | env: 92 | HATCH_INDEX_USER: ${{ secrets.HATCH_INDEX_USER }} 93 | HATCH_INDEX_AUTH: ${{ secrets.HATCH_INDEX_AUTH }} 94 | run: | 95 | hatch build 96 | hatch publish 97 | - name: Create release 98 | uses: ncipollo/release-action@v1 99 | with: 100 | draft: true 101 | body: ${{ github.event.head_commit.message }} 102 | artifacts: dist/*.whl,dist/*.tar.gz 103 | token: ${{ secrets.GITHUB_TOKEN }} 104 | -------------------------------------------------------------------------------- /tests/test_access_token.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from collections.abc import AsyncGenerator 3 | from datetime import datetime, timedelta, timezone 4 | 5 | import pytest 6 | import pytest_asyncio 7 | from pydantic import UUID4 8 | from sqlalchemy import exc 9 | from sqlalchemy.ext.asyncio import ( 10 | AsyncEngine, 11 | AsyncSession, 12 | async_sessionmaker, 13 | create_async_engine, 14 | ) 15 | from sqlalchemy.orm import DeclarativeBase 16 | 17 | from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID 18 | from fastapi_users_db_sqlalchemy.access_token import ( 19 | SQLAlchemyAccessTokenDatabase, 20 | SQLAlchemyBaseAccessTokenTableUUID, 21 | ) 22 | from tests.conftest import DATABASE_URL 23 | 24 | 25 | class Base(DeclarativeBase): 26 | pass 27 | 28 | 29 | class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): 30 | pass 31 | 32 | 33 | class User(SQLAlchemyBaseUserTableUUID, Base): 34 | pass 35 | 36 | 37 | def create_async_session_maker(engine: AsyncEngine): 38 | return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) 39 | 40 | 41 | @pytest.fixture 42 | def user_id() -> UUID4: 43 | return uuid.uuid4() 44 | 45 | 46 | @pytest_asyncio.fixture 47 | async def sqlalchemy_access_token_db( 48 | user_id: UUID4, 49 | ) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase[AccessToken], None]: 50 | engine = create_async_engine(DATABASE_URL) 51 | sessionmaker = create_async_session_maker(engine) 52 | 53 | async with engine.begin() as connection: 54 | await connection.run_sync(Base.metadata.create_all) 55 | 56 | async with sessionmaker() as session: 57 | user = User( 58 | id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" 59 | ) 60 | session.add(user) 61 | await session.commit() 62 | 63 | yield SQLAlchemyAccessTokenDatabase(session, AccessToken) 64 | 65 | async with engine.begin() as connection: 66 | await connection.run_sync(Base.metadata.drop_all) 67 | 68 | 69 | @pytest.mark.asyncio 70 | async def test_queries( 71 | sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], 72 | user_id: UUID4, 73 | ): 74 | access_token_create = {"token": "TOKEN", "user_id": user_id} 75 | 76 | # Create 77 | access_token = await sqlalchemy_access_token_db.create(access_token_create) 78 | assert access_token.token == "TOKEN" 79 | assert access_token.user_id == user_id 80 | 81 | # Update 82 | update_dict = {"created_at": datetime.now(timezone.utc)} 83 | updated_access_token = await sqlalchemy_access_token_db.update( 84 | access_token, update_dict 85 | ) 86 | assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ 87 | "created_at" 88 | ].replace(microsecond=0) 89 | 90 | # Get by token 91 | access_token_by_token = await sqlalchemy_access_token_db.get_by_token( 92 | access_token.token 93 | ) 94 | assert access_token_by_token is not None 95 | 96 | # Get by token expired 97 | access_token_by_token = await sqlalchemy_access_token_db.get_by_token( 98 | access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) 99 | ) 100 | assert access_token_by_token is None 101 | 102 | # Get by token not expired 103 | access_token_by_token = await sqlalchemy_access_token_db.get_by_token( 104 | access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) 105 | ) 106 | assert access_token_by_token is not None 107 | 108 | # Get by token unknown 109 | access_token_by_token = await sqlalchemy_access_token_db.get_by_token( 110 | "NOT_EXISTING_TOKEN" 111 | ) 112 | assert access_token_by_token is None 113 | 114 | # Delete token 115 | await sqlalchemy_access_token_db.delete(access_token) 116 | deleted_access_token = await sqlalchemy_access_token_db.get_by_token( 117 | access_token.token 118 | ) 119 | assert deleted_access_token is None 120 | 121 | 122 | @pytest.mark.asyncio 123 | async def test_insert_existing_token( 124 | sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], 125 | user_id: UUID4, 126 | ): 127 | access_token_create = {"token": "TOKEN", "user_id": user_id} 128 | await sqlalchemy_access_token_db.create(access_token_create) 129 | 130 | with pytest.raises(exc.IntegrityError): 131 | await sqlalchemy_access_token_db.create(access_token_create) 132 | -------------------------------------------------------------------------------- /fastapi_users_db_sqlalchemy/__init__.py: -------------------------------------------------------------------------------- 1 | """FastAPI Users database adapter for SQLAlchemy.""" 2 | 3 | import uuid 4 | from typing import TYPE_CHECKING, Any, Generic, Optional 5 | 6 | from fastapi_users.db.base import BaseUserDatabase 7 | from fastapi_users.models import ID, OAP, UP 8 | from sqlalchemy import Boolean, ForeignKey, Integer, String, func, select 9 | from sqlalchemy.ext.asyncio import AsyncSession 10 | from sqlalchemy.orm import Mapped, declared_attr, mapped_column 11 | from sqlalchemy.sql import Select 12 | 13 | from fastapi_users_db_sqlalchemy.generics import GUID 14 | 15 | __version__ = "7.0.0" 16 | 17 | UUID_ID = uuid.UUID 18 | 19 | 20 | class SQLAlchemyBaseUserTable(Generic[ID]): 21 | """Base SQLAlchemy users table definition.""" 22 | 23 | __tablename__ = "user" 24 | 25 | if TYPE_CHECKING: # pragma: no cover 26 | id: ID 27 | email: str 28 | hashed_password: str 29 | is_active: bool 30 | is_superuser: bool 31 | is_verified: bool 32 | else: 33 | email: Mapped[str] = mapped_column( 34 | String(length=320), unique=True, index=True, nullable=False 35 | ) 36 | hashed_password: Mapped[str] = mapped_column( 37 | String(length=1024), nullable=False 38 | ) 39 | is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) 40 | is_superuser: Mapped[bool] = mapped_column( 41 | Boolean, default=False, nullable=False 42 | ) 43 | is_verified: Mapped[bool] = mapped_column( 44 | Boolean, default=False, nullable=False 45 | ) 46 | 47 | 48 | class SQLAlchemyBaseUserTableUUID(SQLAlchemyBaseUserTable[UUID_ID]): 49 | if TYPE_CHECKING: # pragma: no cover 50 | id: UUID_ID 51 | else: 52 | id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) 53 | 54 | 55 | class SQLAlchemyBaseOAuthAccountTable(Generic[ID]): 56 | """Base SQLAlchemy OAuth account table definition.""" 57 | 58 | __tablename__ = "oauth_account" 59 | 60 | if TYPE_CHECKING: # pragma: no cover 61 | id: ID 62 | oauth_name: str 63 | access_token: str 64 | expires_at: Optional[int] 65 | refresh_token: Optional[str] 66 | account_id: str 67 | account_email: str 68 | else: 69 | oauth_name: Mapped[str] = mapped_column( 70 | String(length=100), index=True, nullable=False 71 | ) 72 | access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False) 73 | expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) 74 | refresh_token: Mapped[Optional[str]] = mapped_column( 75 | String(length=1024), nullable=True 76 | ) 77 | account_id: Mapped[str] = mapped_column( 78 | String(length=320), index=True, nullable=False 79 | ) 80 | account_email: Mapped[str] = mapped_column(String(length=320), nullable=False) 81 | 82 | 83 | class SQLAlchemyBaseOAuthAccountTableUUID(SQLAlchemyBaseOAuthAccountTable[UUID_ID]): 84 | if TYPE_CHECKING: # pragma: no cover 85 | id: UUID_ID 86 | user_id: UUID_ID 87 | else: 88 | id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) 89 | 90 | @declared_attr 91 | def user_id(cls) -> Mapped[GUID]: 92 | return mapped_column( 93 | GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False 94 | ) 95 | 96 | 97 | class SQLAlchemyUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): 98 | """ 99 | Database adapter for SQLAlchemy. 100 | 101 | :param session: SQLAlchemy session instance. 102 | :param user_table: SQLAlchemy user model. 103 | :param oauth_account_table: Optional SQLAlchemy OAuth accounts model. 104 | """ 105 | 106 | session: AsyncSession 107 | user_table: type[UP] 108 | oauth_account_table: Optional[type[SQLAlchemyBaseOAuthAccountTable]] 109 | 110 | def __init__( 111 | self, 112 | session: AsyncSession, 113 | user_table: type[UP], 114 | oauth_account_table: Optional[type[SQLAlchemyBaseOAuthAccountTable]] = None, 115 | ): 116 | self.session = session 117 | self.user_table = user_table 118 | self.oauth_account_table = oauth_account_table 119 | 120 | async def get(self, id: ID) -> Optional[UP]: 121 | statement = select(self.user_table).where(self.user_table.id == id) 122 | return await self._get_user(statement) 123 | 124 | async def get_by_email(self, email: str) -> Optional[UP]: 125 | statement = select(self.user_table).where( 126 | func.lower(self.user_table.email) == func.lower(email) 127 | ) 128 | return await self._get_user(statement) 129 | 130 | async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: 131 | if self.oauth_account_table is None: 132 | raise NotImplementedError() 133 | 134 | statement = ( 135 | select(self.user_table) 136 | .join(self.oauth_account_table) 137 | .where(self.oauth_account_table.oauth_name == oauth) # type: ignore 138 | .where(self.oauth_account_table.account_id == account_id) # type: ignore 139 | ) 140 | return await self._get_user(statement) 141 | 142 | async def create(self, create_dict: dict[str, Any]) -> UP: 143 | user = self.user_table(**create_dict) 144 | self.session.add(user) 145 | await self.session.commit() 146 | await self.session.refresh(user) 147 | return user 148 | 149 | async def update(self, user: UP, update_dict: dict[str, Any]) -> UP: 150 | for key, value in update_dict.items(): 151 | setattr(user, key, value) 152 | self.session.add(user) 153 | await self.session.commit() 154 | await self.session.refresh(user) 155 | return user 156 | 157 | async def delete(self, user: UP) -> None: 158 | await self.session.delete(user) 159 | await self.session.commit() 160 | 161 | async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP: 162 | if self.oauth_account_table is None: 163 | raise NotImplementedError() 164 | 165 | await self.session.refresh(user) 166 | oauth_account = self.oauth_account_table(**create_dict) 167 | self.session.add(oauth_account) 168 | user.oauth_accounts.append(oauth_account) # type: ignore 169 | self.session.add(user) 170 | 171 | await self.session.commit() 172 | 173 | return user 174 | 175 | async def update_oauth_account( 176 | self, user: UP, oauth_account: OAP, update_dict: dict[str, Any] 177 | ) -> UP: 178 | if self.oauth_account_table is None: 179 | raise NotImplementedError() 180 | 181 | for key, value in update_dict.items(): 182 | setattr(oauth_account, key, value) 183 | self.session.add(oauth_account) 184 | await self.session.commit() 185 | 186 | return user 187 | 188 | async def _get_user(self, statement: Select) -> Optional[UP]: 189 | results = await self.session.execute(statement) 190 | return results.unique().scalar_one_or_none() 191 | -------------------------------------------------------------------------------- /tests/test_users.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator 2 | from typing import Any 3 | 4 | import pytest 5 | import pytest_asyncio 6 | from sqlalchemy import String, exc 7 | from sqlalchemy.ext.asyncio import ( 8 | AsyncEngine, 9 | async_sessionmaker, 10 | create_async_engine, 11 | ) 12 | from sqlalchemy.orm import ( 13 | DeclarativeBase, 14 | Mapped, 15 | mapped_column, 16 | relationship, 17 | ) 18 | 19 | from fastapi_users_db_sqlalchemy import ( 20 | UUID_ID, 21 | SQLAlchemyBaseOAuthAccountTableUUID, 22 | SQLAlchemyBaseUserTableUUID, 23 | SQLAlchemyUserDatabase, 24 | ) 25 | from tests.conftest import DATABASE_URL 26 | 27 | 28 | def create_async_session_maker(engine: AsyncEngine): 29 | return async_sessionmaker(engine, expire_on_commit=False) 30 | 31 | 32 | class Base(DeclarativeBase): 33 | pass 34 | 35 | 36 | class User(SQLAlchemyBaseUserTableUUID, Base): 37 | first_name: Mapped[str] = mapped_column(String(255), nullable=True) 38 | 39 | 40 | class OAuthBase(DeclarativeBase): 41 | pass 42 | 43 | 44 | class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, OAuthBase): 45 | pass 46 | 47 | 48 | class UserOAuth(SQLAlchemyBaseUserTableUUID, OAuthBase): 49 | first_name: Mapped[str] = mapped_column(String(255), nullable=True) 50 | oauth_accounts: Mapped[list[OAuthAccount]] = relationship( 51 | "OAuthAccount", lazy="joined" 52 | ) 53 | 54 | 55 | @pytest_asyncio.fixture 56 | async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: 57 | engine = create_async_engine(DATABASE_URL) 58 | sessionmaker = create_async_session_maker(engine) 59 | 60 | async with engine.begin() as connection: 61 | await connection.run_sync(Base.metadata.create_all) 62 | 63 | async with sessionmaker() as session: 64 | yield SQLAlchemyUserDatabase(session, User) 65 | 66 | async with engine.begin() as connection: 67 | await connection.run_sync(Base.metadata.drop_all) 68 | 69 | 70 | @pytest_asyncio.fixture 71 | async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: 72 | engine = create_async_engine(DATABASE_URL) 73 | sessionmaker = create_async_session_maker(engine) 74 | 75 | async with engine.begin() as connection: 76 | await connection.run_sync(OAuthBase.metadata.create_all) 77 | 78 | async with sessionmaker() as session: 79 | yield SQLAlchemyUserDatabase(session, UserOAuth, OAuthAccount) 80 | 81 | async with engine.begin() as connection: 82 | await connection.run_sync(OAuthBase.metadata.drop_all) 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID]): 87 | user_create = { 88 | "email": "lancelot@camelot.bt", 89 | "hashed_password": "guinevere", 90 | } 91 | 92 | # Create 93 | user = await sqlalchemy_user_db.create(user_create) 94 | assert user.id is not None 95 | assert user.is_active is True 96 | assert user.is_superuser is False 97 | assert user.email == user_create["email"] 98 | 99 | # Update 100 | updated_user = await sqlalchemy_user_db.update(user, {"is_superuser": True}) 101 | assert updated_user.is_superuser is True 102 | 103 | # Get by id 104 | id_user = await sqlalchemy_user_db.get(user.id) 105 | assert id_user is not None 106 | assert id_user.id == user.id 107 | assert id_user.is_superuser is True 108 | 109 | # Get by email 110 | email_user = await sqlalchemy_user_db.get_by_email(str(user_create["email"])) 111 | assert email_user is not None 112 | assert email_user.id == user.id 113 | 114 | # Get by uppercased email 115 | email_user = await sqlalchemy_user_db.get_by_email("Lancelot@camelot.bt") 116 | assert email_user is not None 117 | assert email_user.id == user.id 118 | 119 | # Unknown user 120 | unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt") 121 | assert unknown_user is None 122 | 123 | # Delete user 124 | await sqlalchemy_user_db.delete(user) 125 | deleted_user = await sqlalchemy_user_db.get(user.id) 126 | assert deleted_user is None 127 | 128 | # OAuth without defined table 129 | with pytest.raises(NotImplementedError): 130 | await sqlalchemy_user_db.get_by_oauth_account("foo", "bar") 131 | with pytest.raises(NotImplementedError): 132 | await sqlalchemy_user_db.add_oauth_account(user, {}) 133 | with pytest.raises(NotImplementedError): 134 | oauth_account = OAuthAccount() 135 | await sqlalchemy_user_db.update_oauth_account(user, oauth_account, {}) 136 | 137 | 138 | @pytest.mark.asyncio 139 | async def test_insert_existing_email( 140 | sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID], 141 | ): 142 | user_create = { 143 | "email": "lancelot@camelot.bt", 144 | "hashed_password": "guinevere", 145 | } 146 | await sqlalchemy_user_db.create(user_create) 147 | 148 | with pytest.raises(exc.IntegrityError): 149 | await sqlalchemy_user_db.create(user_create) 150 | 151 | 152 | @pytest.mark.asyncio 153 | async def test_queries_custom_fields( 154 | sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID], 155 | ): 156 | """It should output custom fields in query result.""" 157 | user_create = { 158 | "email": "lancelot@camelot.bt", 159 | "hashed_password": "guinevere", 160 | "first_name": "Lancelot", 161 | } 162 | user = await sqlalchemy_user_db.create(user_create) 163 | 164 | id_user = await sqlalchemy_user_db.get(user.id) 165 | assert id_user is not None 166 | assert id_user.id == user.id 167 | assert id_user.first_name == user.first_name 168 | 169 | 170 | @pytest.mark.asyncio 171 | async def test_queries_oauth( 172 | sqlalchemy_user_db_oauth: SQLAlchemyUserDatabase[UserOAuth, UUID_ID], 173 | oauth_account1: dict[str, Any], 174 | oauth_account2: dict[str, Any], 175 | ): 176 | user_create = { 177 | "email": "lancelot@camelot.bt", 178 | "hashed_password": "guinevere", 179 | } 180 | 181 | # Create 182 | user = await sqlalchemy_user_db_oauth.create(user_create) 183 | assert user.id is not None 184 | 185 | # Add OAuth account 186 | user = await sqlalchemy_user_db_oauth.add_oauth_account(user, oauth_account1) 187 | user = await sqlalchemy_user_db_oauth.add_oauth_account(user, oauth_account2) 188 | assert len(user.oauth_accounts) == 2 189 | assert user.oauth_accounts[1].account_id == oauth_account2["account_id"] 190 | assert user.oauth_accounts[0].account_id == oauth_account1["account_id"] 191 | 192 | # Update 193 | user = await sqlalchemy_user_db_oauth.update_oauth_account( 194 | user, user.oauth_accounts[0], {"access_token": "NEW_TOKEN"} 195 | ) 196 | assert user.oauth_accounts[0].access_token == "NEW_TOKEN" 197 | 198 | # Get by id 199 | id_user = await sqlalchemy_user_db_oauth.get(user.id) 200 | assert id_user is not None 201 | assert id_user.id == user.id 202 | assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" 203 | 204 | # Get by email 205 | email_user = await sqlalchemy_user_db_oauth.get_by_email(user_create["email"]) 206 | assert email_user is not None 207 | assert email_user.id == user.id 208 | assert len(email_user.oauth_accounts) == 2 209 | 210 | # Get by OAuth account 211 | oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( 212 | oauth_account1["oauth_name"], oauth_account1["account_id"] 213 | ) 214 | assert oauth_user is not None 215 | assert oauth_user.id == user.id 216 | 217 | # Unknown OAuth account 218 | unknown_oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( 219 | "foo", "bar" 220 | ) 221 | assert unknown_oauth_user is None 222 | --------------------------------------------------------------------------------