├── tests ├── __init__.py ├── data │ ├── __init__.py │ ├── models.py │ └── people.py ├── endpoints │ ├── __init__.py │ ├── test_logout.py │ ├── test_confirm.py │ ├── test_login.py │ └── test_register.py ├── models │ ├── __init__.py │ ├── test_mixins.py │ ├── test_events.py │ ├── test_types.py │ └── test_models.py ├── conftest.py ├── test_db_registry.py ├── test_utils.py ├── test_middleware.py ├── test_tz.py ├── test_types.py ├── test_crud.py └── test_auth.py ├── fastapi_sqlalchemy ├── __init__.py ├── endpoints │ ├── templates │ │ ├── confirmation_email.txt │ │ ├── confirmation_email.html │ │ ├── failed_email_confirmation.html │ │ ├── send_confirmation.html │ │ ├── login.html │ │ └── register.html │ ├── __init__.py │ ├── logout.py │ ├── confirm.py │ ├── login.py │ └── register.py ├── models │ ├── __init__.py │ ├── groups.py │ ├── permissions.py │ ├── base.py │ ├── users.py │ ├── mixins.py │ ├── types.py │ ├── associations.py │ └── events.py ├── types.py ├── db_registry.py ├── tz.py ├── utils.py ├── middleware.py ├── auth.py └── crud.py ├── pytest.ini ├── assets ├── logo-128x128.png └── logo.svg ├── .gitignore ├── .flake8 ├── Makefile ├── setup.py ├── README.md └── .pylintrc /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | env = 3 | D:DATABASE_URL=sqlite:///sqlite.db?check_same_thread=False 4 | -------------------------------------------------------------------------------- /assets/logo-128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuarbase/fastapi-sqlalchemy/HEAD/assets/logo-128x128.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.db 2 | *.egg-info 3 | *.pyc 4 | __pycache__ 5 | /build 6 | /dist 7 | /pyenv 8 | /venv 9 | /.idea 10 | /.cache 11 | /.coverage 12 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/templates/confirmation_email.txt: -------------------------------------------------------------------------------- 1 | Welcome! Thanks for signing up. Please open the link to activate your account: 2 | 3 | ${confirm_url} 4 | 5 | Thank You! -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | """ All built-in endpoints """ 2 | from .login import LoginEndpoint 3 | from .logout import LogoutEndpoint 4 | from .register import RegisterEndpoint 5 | from .confirm import ConfirmEndpoint 6 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/templates/confirmation_email.html: -------------------------------------------------------------------------------- 1 |

Welcome! Thanks for signing up. Please follow this link to activate your account:

2 |

${confirm_url}

3 |
4 |

Thank you!

-------------------------------------------------------------------------------- /tests/data/models.py: -------------------------------------------------------------------------------- 1 | from fastapi_sqlalchemy import models 2 | 3 | 4 | class User(models.User): 5 | pass 6 | 7 | 8 | class Group(models.Group): 9 | pass 10 | 11 | 12 | class Permission(models.Permission): 13 | pass 14 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 80 3 | inline-quotes = " 4 | exclude = .git,node_modules,__init__.py,DEBIAN 5 | ignore = 6 | # too many leading '#' for block comment 7 | E266, 8 | # comparison to None should be 'if cond is not None:' 9 | E711 10 | -------------------------------------------------------------------------------- /tests/endpoints/test_logout.py: -------------------------------------------------------------------------------- 1 | from fastapi_sqlalchemy import endpoints 2 | 3 | 4 | def test_logout_post(engine, session, app, client): 5 | endpoint = endpoints.LogoutEndpoint() 6 | 7 | @app.post("/logout") 8 | async def _post(): 9 | return await endpoint.on_post() 10 | 11 | res = client.post("/logout") 12 | assert res.status_code == 303 13 | assert res.headers.get("location") == "/login" 14 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ The SQLAlchemy model """ 2 | from .base import BASE, Session 3 | 4 | from .types import GUID, JSONEncodedDict, JSON_TYPE 5 | from .mixins import GuidMixin, TimestampMixin, DictMixin 6 | 7 | from .users import User 8 | from .groups import Group 9 | from .permissions import Permission 10 | 11 | from .associations import ( 12 | create_group_membership_table, 13 | create_user_permissions_table, 14 | create_group_permissions_table 15 | ) 16 | 17 | from . import events 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: flake8 pylint test 2 | 3 | flake8: flake8_pkg flake8_tests 4 | .PHONY: flake8 5 | 6 | flake8_pkg: 7 | flake8 fastapi_sqlalchemy 8 | .PHONY: flake8_pkg 9 | 10 | flake8_tests: 11 | flake8 tests 12 | .PHONY: flake8_tests 13 | 14 | pylint: pylint_pkg pylint_tests 15 | .PHONY: pylint 16 | 17 | pylint_pkg: 18 | pylint fastapi_sqlalchemy 19 | .PHONY: pylint_pkg 20 | 21 | pylint_tests: 22 | pylint tests --disable=missing-docstring,unused-argument,too-many-ancestors,unexpected-keyword-arg 23 | .PHONY: pylint_tests 24 | 25 | test: 26 | pytest -xvv tests 27 | .PHONY: test 28 | 29 | coverage: 30 | pytest --cov=fastapi_sqlalchemy --cov-report=term-missing --cov-fail-under=100 tests/ 31 | .PHONY: coverage 32 | 33 | pyenv: 34 | virtualenv -p python3 pyenv 35 | pyenv/bin/pip install -e .[dev,prod] 36 | .PHONY: pyenv 37 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/groups.py: -------------------------------------------------------------------------------- 1 | """ Group model """ 2 | import sqlalchemy 3 | from sqlalchemy import event 4 | 5 | from .base import BASE, Session, MODEL_MAPPING 6 | from . import mixins 7 | 8 | 9 | class Group(BASE, mixins.GuidMixin, mixins.TimestampMixin): 10 | """ The groups table """ 11 | __tablename__ = "groups" 12 | __abstract__ = True 13 | 14 | name = sqlalchemy.Column( 15 | sqlalchemy.String(255), 16 | nullable=False, 17 | ) 18 | 19 | @classmethod 20 | def get_by_name( 21 | cls, 22 | session: Session, 23 | name: str 24 | ): 25 | """ Lookup a group by name 26 | """ 27 | return session.query(cls).filter(cls.name == name).first() 28 | 29 | 30 | @event.listens_for(Group, "mapper_configured", propagate=True) 31 | def _mapper_configured(_mapper, cls): 32 | MODEL_MAPPING["Group"] = cls 33 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/permissions.py: -------------------------------------------------------------------------------- 1 | """ Permission model """ 2 | import sqlalchemy 3 | from sqlalchemy import event 4 | 5 | from .base import BASE, Session, MODEL_MAPPING 6 | from . import mixins 7 | 8 | 9 | class Permission(BASE, mixins.GuidMixin, mixins.TimestampMixin): 10 | """ The permissions table """ 11 | __tablename__ = "permissions" 12 | __abstract__ = True 13 | 14 | name = sqlalchemy.Column( 15 | sqlalchemy.String(255), 16 | nullable=False, 17 | ) 18 | 19 | @classmethod 20 | def get_by_name( 21 | cls, 22 | session: Session, 23 | name: str 24 | ): 25 | """ Lookup a group by name 26 | """ 27 | return session.query(cls).filter(cls.name == name).first() 28 | 29 | 30 | @event.listens_for(Permission, "mapper_configured", propagate=True) 31 | def _mapper_configured(_mapper, cls): 32 | MODEL_MAPPING["Permission"] = cls 33 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/templates/failed_email_confirmation.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ${title} 8 | 9 | 10 | 11 |
12 |
13 |

Whoops! An Error Occurred

14 |

${error}

15 |

Please fill out the registration form again.

16 |
17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/logout.py: -------------------------------------------------------------------------------- 1 | """ Logout functionality """ 2 | import logging 3 | 4 | from starlette.responses import JSONResponse 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class LogoutEndpoint: 10 | """ Class-based endpoint for logout """ 11 | 12 | def __init__( 13 | self, 14 | location: str = "/login", 15 | cookie_name: str = "jwt", 16 | ): 17 | self.cookie_name = cookie_name 18 | self.location = location 19 | 20 | # NOTE: no GET handler 21 | async def on_post( 22 | self 23 | ) -> JSONResponse: 24 | """ POST /logout """ 25 | # 303 tells the browser to switch from POST to GET 26 | headers = {"location": self.location} 27 | 28 | # NOTE: use JSONResponse instead of RedirectResponse to 29 | # avoid confusing OpenAPI 30 | response = JSONResponse(status_code=303, headers=headers) 31 | response.delete_cookie(self.cookie_name) 32 | return response 33 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/templates/send_confirmation.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ${title} 8 | 9 | 10 | 11 |
12 |
13 |

Thank you for registering!

14 |

An email confirmation was sent to ${email}.

15 |

Please check your spam folder if you don't receive an email in the next several minutes.

16 |
17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /assets/logo.svg: -------------------------------------------------------------------------------- 1 | fastapi_sqlalchemy -------------------------------------------------------------------------------- /fastapi_sqlalchemy/types.py: -------------------------------------------------------------------------------- 1 | """ Extensions to pydantic.types """ 2 | from fastapi.params import Query 3 | from pydantic import ConstrainedInt 4 | 5 | 6 | class NonNegativeInt(ConstrainedInt): 7 | """ Integer >= 0 """ 8 | ge = 0 9 | 10 | 11 | class LimitQuery(Query): 12 | """ The 'limit' query parameter """ 13 | def __init__( 14 | self, 15 | limit: int = 100, 16 | alias: str = None 17 | ): 18 | assert limit > 0 19 | super().__init__( 20 | limit, 21 | title="Maximum number of entries to return", 22 | description="Maximum number of entries to return.", 23 | alias=alias 24 | ) 25 | 26 | 27 | class OffsetQuery(Query): 28 | """ The 'offset' query parameter """ 29 | def __init__( 30 | self, 31 | offset: int = 0, 32 | alias: str = None 33 | ): 34 | assert offset >= 0 35 | super().__init__( 36 | offset, 37 | title="Index of the first entry to return", 38 | description="Index of the first entry to return.", 39 | alias=alias 40 | ) 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | PACKAGE = "fastapi-sqlalchemy" 4 | VERSION = "0.9.0" 5 | 6 | setup( 7 | name=PACKAGE, 8 | version=VERSION, 9 | author="Matthew Laue", 10 | author_email="matt@zuar.com", 11 | url="https://github.com/zuarbase/fastapi-sqlalchemy", 12 | packages=find_packages(exclude=["tests"]), 13 | include_package_data=True, 14 | install_requires=[ 15 | "email-validator", 16 | "fastapi >= 0.52.0, < 0.53.0", 17 | "pydantic >= 1.4, < 1.5", 18 | "passlib", 19 | "python-dateutil", 20 | "python-multipart", 21 | "pyjwt", 22 | "sqlalchemy", 23 | "sqlalchemy-filters", 24 | "tzlocal", 25 | "itsdangerous", 26 | ], 27 | extras_require={ 28 | "dev": [ 29 | "coverage", 30 | "pylint", 31 | "pytest", 32 | "pytest-cov", 33 | "pytest-env", 34 | "pytest-mock", 35 | "requests", 36 | "flake8", 37 | "flake8-quotes", 38 | ], 39 | "prod": [ 40 | "uvicorn", 41 | "gunicorn", 42 | ] 43 | } 44 | ) 45 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/db_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database engines registry. 3 | 4 | The registry is a single source of shared sqlalchemy engines. 5 | 6 | NOTE: there are both thread-safe and non thread-safe functions. 7 | """ 8 | import typing 9 | import threading 10 | 11 | from sqlalchemy.engine import Connectable, Engine, create_engine 12 | 13 | __LOCK = threading.Lock() 14 | 15 | __ENGINE_REGISTRY: typing.Dict[str, Engine] = {} 16 | 17 | 18 | def register( 19 | bind: typing.Union[str, Connectable], 20 | pool_pre_ping=True, 21 | **engine_kwargs 22 | ) -> Engine: 23 | """Register an engine or create a new one (non thread-safe).""" 24 | if isinstance(bind, str): 25 | engine = create_engine( 26 | bind, pool_pre_ping=pool_pre_ping, **engine_kwargs) 27 | bind = engine 28 | else: 29 | engine = bind.engine 30 | 31 | __ENGINE_REGISTRY[str(engine.url)] = engine 32 | return bind 33 | 34 | 35 | def get_or_create( 36 | url: str, 37 | **engine_kwargs 38 | ) -> Engine: 39 | """Get an engine from the registry or create it if does not exist.""" 40 | with __LOCK: 41 | return __ENGINE_REGISTRY.get(url) or register(url, **engine_kwargs) 42 | -------------------------------------------------------------------------------- /tests/models/test_mixins.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy.ext.declarative 2 | 3 | from fastapi_sqlalchemy import models 4 | from tests.data.models import User 5 | 6 | 7 | def test_timestamp_mixin(session): 8 | assert models.TimestampMixin in User.__mro__ 9 | 10 | user = User(username="test_timestamp_mixin") 11 | session.add(user) 12 | session.commit() 13 | 14 | updated_at = user.updated_at 15 | assert user.created_at.replace(microsecond=0) == \ 16 | updated_at.replace(microsecond=0) 17 | 18 | user.username = "test_timestamp_mixin__updated" 19 | session.add(user) 20 | session.commit() 21 | assert user.updated_at != updated_at 22 | assert user.created_at < user.updated_at 23 | 24 | 25 | def test_dict_mixin(mocker): 26 | base_cls = sqlalchemy.ext.declarative.declarative_base() 27 | 28 | class TestModel(base_cls, models.DictMixin): 29 | __tablename__ = "test_dict_mixin" 30 | 31 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 32 | 33 | model = TestModel(id=1) 34 | 35 | mock_model_as_dict = mocker.patch( 36 | "fastapi_sqlalchemy.models.mixins.model_as_dict") 37 | result = model.as_dict() 38 | assert mock_model_as_dict.call_args == mocker.call(model) 39 | assert result is mock_model_as_dict.return_value 40 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | 4 | import sqlalchemy 5 | import pytest 6 | 7 | from fastapi import FastAPI 8 | from fastapi.testclient import TestClient 9 | 10 | from fastapi_sqlalchemy import models 11 | 12 | DATABASE_URL = os.environ["DATABASE_URL"] 13 | 14 | 15 | @pytest.fixture(scope="function", name="loop") 16 | def loop_fixture(): 17 | return asyncio.new_event_loop() 18 | 19 | 20 | @pytest.fixture(scope="session", name="engine") 21 | def engine_fixture() -> sqlalchemy.engine.Engine: 22 | engine = sqlalchemy.create_engine(DATABASE_URL) 23 | return engine 24 | 25 | 26 | @pytest.fixture(scope="function", name="session") 27 | def session_fixture(engine): 28 | 29 | def _drop_all(): 30 | meta = sqlalchemy.MetaData() 31 | meta.reflect(bind=engine) 32 | meta.drop_all(bind=engine) 33 | 34 | _drop_all() 35 | models.BASE.metadata.create_all(engine) 36 | 37 | models.Session.configure(bind=engine) 38 | session = models.Session() 39 | 40 | yield session 41 | session.close() 42 | 43 | 44 | @pytest.fixture(scope="function", name="app") 45 | def app_fixture(engine): 46 | app = FastAPI( 47 | title="fastapi_sqlalchemy", 48 | version="0.0.0" 49 | ) 50 | return app 51 | 52 | 53 | @pytest.fixture(scope="function", name="client") 54 | def client_fixture(app): 55 | client = TestClient(app) 56 | return client 57 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/templates/login.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ${title} 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 |
16 | 17 |
18 | 19 |

${error}

20 |
21 | 22 |
23 |
24 | 25 |
26 | 27 |
28 |
29 |
30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /tests/data/people.py: -------------------------------------------------------------------------------- 1 | """ Generic test data """ 2 | from datetime import datetime 3 | from uuid import UUID 4 | 5 | import sqlalchemy 6 | from sqlalchemy.orm import Session 7 | from sqlalchemy.types import CHAR 8 | from pydantic import BaseModel, PositiveInt, constr 9 | 10 | from fastapi_sqlalchemy import models 11 | 12 | PEOPLE_DATA = [ 13 | {"name": "alice", "order": 1, "gender": "F", "age": 32}, 14 | {"name": "bob", "order": 2, "gender": "M", "age": 22}, 15 | {"name": "charlie", "order": 3, "gender": "M", "age": 60}, 16 | {"name": "david", "order": 4, "gender": "M", "age": 32}, 17 | ] 18 | 19 | 20 | class Person(models.BASE, models.GuidMixin, models.TimestampMixin): 21 | __tablename__ = "people" 22 | 23 | name = sqlalchemy.Column( 24 | sqlalchemy.String(255), 25 | nullable=False, 26 | unique=True 27 | ) 28 | 29 | order = sqlalchemy.Column( 30 | sqlalchemy.Integer, 31 | nullable=False, 32 | unique=True 33 | ) 34 | 35 | gender = sqlalchemy.Column( 36 | CHAR(1), 37 | nullable=False 38 | ) 39 | 40 | age = sqlalchemy.Column( 41 | sqlalchemy.Integer, 42 | nullable=False 43 | ) 44 | 45 | 46 | class PersonRequestModel(BaseModel): 47 | name: constr(max_length=255) 48 | order: int 49 | gender: constr(min_length=1, max_length=1) 50 | age: PositiveInt 51 | 52 | 53 | class PersonResponseModel(PersonRequestModel): 54 | id: UUID 55 | created_at: datetime 56 | updated_at: datetime 57 | 58 | 59 | def load_people(session: Session): 60 | people = [] 61 | for data in PEOPLE_DATA: 62 | person = Person(**data) 63 | people.append(person) 64 | 65 | session.add_all(people) 66 | session.commit() 67 | 68 | assert len(people) == len(PEOPLE_DATA) 69 | return people 70 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/tz.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a helper module to deal with datetime and timezones. 3 | 4 | For convenience datetime related modules (e.g. dateutil.parser, pytz, etc) are 5 | imported here so they can be accessed using this module. 6 | 7 | Usage: 8 | >>> from fastapi_sqlalchemy import tz 9 | >>> dt1 = tz.utcnow() 10 | >>> dt2 = tz.utcdatetime(2015, 1, 2, 1, 2, 3) 11 | >>> dt2 = dt2.astimezone(tz.LOCAL) 12 | >>> dt2 = dt2.astimezone(tz.UTC) 13 | >>> dt2 += tz.timedelta(minutes=12) 14 | >>> dt3 = tz.parse("2017-01-02 02:22") 15 | >>> date1 = tz.date(2015, 1, 2) 16 | """ 17 | import typing 18 | 19 | # pylint: disable=unused-import 20 | from datetime import date, datetime, timedelta # noqa 21 | # pylint: enable=unused-import 22 | 23 | 24 | import dateutil.parser 25 | import pytz 26 | import tzlocal 27 | 28 | LOCAL = tzlocal.get_localzone() 29 | UTC = pytz.utc 30 | 31 | 32 | def as_datetime(value: typing.Union[str, date, datetime]) -> datetime: 33 | """Convert a string value to a datetime object.""" 34 | if not isinstance(value, datetime): 35 | value = parse(value) 36 | 37 | if not value.tzinfo: 38 | value = LOCAL.localize(value) 39 | 40 | return as_utc(value) 41 | 42 | 43 | def parse(*args, **kwargs): 44 | """Shortcut for dateutil.parser.parse with timezone""" 45 | value = dateutil.parser.parse(*args, **kwargs) 46 | return as_datetime(value) 47 | 48 | 49 | def utcnow() -> datetime: 50 | """return UTC datetime with UTC timezone""" 51 | return datetime.utcnow().replace(tzinfo=UTC) 52 | 53 | 54 | def utcdatetime(*args, **kwargs) -> datetime: 55 | """Same as datetime.datetime() but create UTC aware datetime""" 56 | return datetime(*args, **kwargs, tzinfo=UTC) 57 | 58 | 59 | def as_utc(value: typing.Union[datetime, date]): 60 | """Convert given date or datetime to UTC""" 61 | return value.astimezone(UTC) 62 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/base.py: -------------------------------------------------------------------------------- 1 | """ SQLAlchemy helper function """ 2 | import uuid 3 | import enum 4 | from collections import Mapping 5 | 6 | import sqlalchemy 7 | 8 | from sqlalchemy.ext.declarative import declarative_base 9 | from sqlalchemy.orm import sessionmaker 10 | 11 | from fastapi_sqlalchemy import tz 12 | 13 | 14 | class Base: 15 | """ Custom declarative base """ 16 | 17 | def as_dict(self): 18 | """ Convert object to dictionary """ 19 | return model_as_dict(self) 20 | 21 | 22 | BASE = declarative_base(cls=Base) 23 | Session = sessionmaker() 24 | 25 | 26 | class ModelMapping(dict): 27 | """ Class to hold model information """ 28 | 29 | def __setitem__(self, key, value): 30 | if key in self: 31 | raise RuntimeError( 32 | f"Duplicate '{key}' model found. " 33 | "There may only be one non-abstract sub-class." 34 | ) 35 | super().__setitem__(key, value) 36 | 37 | def update(self, other=None, **kwargs): 38 | if other is not None: 39 | for key, value in other.items() \ 40 | if isinstance(other, Mapping) else other: 41 | self[key] = value 42 | for key, value in kwargs.items(): 43 | self[key] = value 44 | 45 | 46 | MODEL_MAPPING = ModelMapping() 47 | 48 | 49 | def model_as_dict(model) -> dict: 50 | """Convert given sqlalchemy model to dict (relationships not included).""" 51 | result = {} 52 | for attr in sqlalchemy.inspect(model).mapper.column_attrs: 53 | value = getattr(model, attr.key) 54 | if isinstance(value, (tz.datetime, tz.date)): 55 | value = value.isoformat() 56 | elif isinstance(value, uuid.UUID): 57 | value = str(value) 58 | elif isinstance(value, enum.Enum): 59 | value = value.name 60 | result[attr.key] = value 61 | return result 62 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions """ 2 | import uuid 3 | 4 | from string import Template 5 | from typing import Union 6 | 7 | import jwt 8 | from starlette.requests import Request 9 | 10 | try: 11 | from ordered_uuid import OrderedUUID 12 | except ImportError: 13 | OrderedUUID = None 14 | 15 | 16 | def ordered_uuid(value=None) -> OrderedUUID: 17 | """ Generate a rearranged uuid1 that is ordered by time. 18 | This is a more efficient for use as a primary key, see: 19 | https://www.percona.com/blog/2014/12/19/store-uuid-optimized-way/ 20 | """ 21 | if OrderedUUID is None: 22 | raise RuntimeError("ordered_uuid package: not found") 23 | if not value: 24 | value = str(uuid.uuid1()) 25 | return OrderedUUID(value) 26 | 27 | 28 | def render( 29 | path_or_template: Union[str, Template], 30 | **kwargs, 31 | ) -> str: 32 | """ Render the specified template - either a file or the actual template """ 33 | if isinstance(path_or_template, Template): 34 | template = path_or_template 35 | elif path_or_template.startswith("<"): 36 | template = Template(path_or_template) 37 | else: 38 | with open(path_or_template, "r") as filp: 39 | contents = filp.read() 40 | template = Template(contents) 41 | return template.safe_substitute(**kwargs) 42 | 43 | 44 | def get_session(request: Request): 45 | """Get `request.state.session` 46 | 47 | Usage: 48 | >>> from fastapi import Depends 49 | >>> session = Depends(get_session) 50 | """ 51 | return request.state.session 52 | 53 | 54 | def jwt_encode(payload: dict, secret: str, algorithm: str = "HS256") -> str: 55 | """ Encode the given payload as a JWT """ 56 | assert "exp" in payload 57 | return jwt.encode( 58 | payload, 59 | str(secret), 60 | algorithm=algorithm, 61 | ).decode("utf-8") 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI-SQLAlchemy 2 | Full-stack, asynchronous Python3 framework. 3 | 4 | ## Design goals 5 | * Fast, full-service framework 6 | * Modular approach that does not force any design decisions 7 | 8 | ## Getting started 9 | 10 | ```python 11 | from fastapi import FastAPI, Request 12 | 13 | from fastapi_sqlalchemy import crud, db_registry 14 | from fastapi_sqlalchemy.middleware import SessionMiddleware 15 | from fastapi_sqlalchemy.models import BASE, Session, User 16 | 17 | DATABASE_URL = "sqlite:///sqlite.db?check_same_thread=False" 18 | 19 | 20 | # Define our model 21 | class MyUser(User): 22 | pass 23 | 24 | 25 | # Instantiate the application 26 | app = FastAPI() 27 | app.add_middleware(SessionMiddleware, bind=DATABASE_URL) 28 | 29 | # Create all tables 30 | bind = db_registry.get_or_create(DATABASE_URL) 31 | BASE.metadata.create_all(bind=bind) 32 | 33 | # Load some data 34 | session = Session() 35 | for name in ["alice", "bob", "charlie", "david"]: 36 | user = MyUser.get_by_username(session, name) 37 | if user is None: 38 | user = MyUser(username=name) 39 | session.add(user) 40 | session.commit() 41 | 42 | # Add an endpoint 43 | @app.get("/users") 44 | async def list_users( 45 | request: Request 46 | ): 47 | return await crud.list_instances(MyUser, request.state.session) 48 | ``` 49 | 50 | Assuming the above code is stored in the file `main.py`, then run it via: 51 | ```bash 52 | uvicorn main:app --reload 53 | ``` 54 | 55 | Call the endpoint: 56 | ```bash 57 | curl -s localhost:8000/users | jq . 58 | ``` 59 | 60 | The output should contain a list of 4 users, 61 | each with the attributes `id`, `username`, `updated_at` and `created_at`. 62 | 63 | 67 | -------------------------------------------------------------------------------- /tests/test_db_registry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fastapi_sqlalchemy import db_registry 4 | 5 | 6 | @pytest.fixture(name="mock_lock", autouse=True) 7 | def fixture_mock_lock(mocker): 8 | return mocker.patch("threading.Lock") 9 | 10 | 11 | @pytest.fixture(name="mock_create_engine") 12 | def fixture_mock_create_engine(mocker): 13 | return mocker.patch( 14 | "fastapi_sqlalchemy.db_registry.create_engine", 15 | side_effect=lambda url, **kwargs: mocker.Mock(url=url, **kwargs)) 16 | 17 | 18 | def test_register_existing_engine(mocker, mock_create_engine): 19 | url = "/fake/url" 20 | engine = mocker.Mock(url=url) 21 | engine.engine = engine # Follow Connectable interface 22 | 23 | db_registry.register(engine) 24 | 25 | registered_engine = db_registry.get_or_create(url) 26 | assert registered_engine is engine, "New engine instance was created." 27 | 28 | 29 | def test_register_url(mocker, mock_create_engine): 30 | url = "/fake/url" 31 | 32 | kwargs = {"key1": "value1", "echo": True} 33 | created_engine = db_registry.register(url, **kwargs) 34 | assert created_engine 35 | 36 | registered_engine = db_registry.get_or_create(url) 37 | assert registered_engine is created_engine 38 | 39 | assert mock_create_engine.call_args_list == [ 40 | mocker.call(url, pool_pre_ping=True, **kwargs) 41 | ] 42 | 43 | 44 | def test_get_or_create_by_url(mock_create_engine): 45 | url_1 = "/fake/url/1" 46 | url_2 = "/fake/url/2" 47 | 48 | created_engine_1 = db_registry.get_or_create(url_1) 49 | assert created_engine_1 50 | 51 | created_engine_2 = db_registry.get_or_create(url_2) 52 | assert created_engine_2 53 | 54 | registered_engine_1 = db_registry.get_or_create(url_1) 55 | assert registered_engine_1 is created_engine_1 56 | 57 | registered_engine_2 = db_registry.get_or_create(url_2) 58 | assert registered_engine_2 is created_engine_2 59 | 60 | assert mock_create_engine.call_count == 2 61 | -------------------------------------------------------------------------------- /tests/models/test_events.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fastapi_sqlalchemy.models import base, events 4 | 5 | 6 | @pytest.fixture(name="mapper") 7 | def fixture_model_mapping(mocker) -> dict: 8 | return mocker.patch("fastapi_sqlalchemy.models.events.MODEL_MAPPING", {}) 9 | 10 | 11 | @pytest.mark.parametrize("table,relation", ( 12 | ["group_membership", ("User", "Group")], 13 | ["user_permissions", ("User", "Permission")], 14 | ["group_permissions", ("Group", "Permission")], 15 | )) 16 | def test_permissions_configuration_errors( 17 | table, relation, mocker, mapper 18 | ): 19 | mocker.patch.object(base.BASE.metadata, "tables", { 20 | table: mocker.Mock(__association__=table) 21 | }) 22 | 23 | # Test no first relation 24 | with pytest.raises(RuntimeError) as exc_info: 25 | events._after_configured() # pylint: disable=protected-access 26 | 27 | error_msg = f"'{table}' association table found, " \ 28 | f"but no {relation[0]} table defined." 29 | assert str(exc_info.value) == error_msg 30 | 31 | # Test no second relation 32 | mapper[relation[0]] = mocker.Mock() 33 | with pytest.raises(RuntimeError) as exc_info: 34 | events._after_configured() # pylint: disable=protected-access 35 | 36 | error_msg = f"'{table}' association table found, " \ 37 | f"but no {relation[1]} table defined." 38 | assert str(exc_info.value) == error_msg 39 | 40 | 41 | def test_mapper_configuration_duplicate_key(mocker): 42 | duplicate_association = "duplicate-table" 43 | 44 | mocker.patch.object(base.BASE.metadata, "tables", { 45 | "table-1": mocker.Mock(__association__=duplicate_association), 46 | "table-2": mocker.Mock(__association__=duplicate_association) 47 | }) 48 | 49 | with pytest.raises(RuntimeError) as exc_info: 50 | events._after_configured() # pylint: disable=protected-access 51 | 52 | error_msg = f"Multiple '{duplicate_association}' associations found." \ 53 | "Only a single table may have a specific __association__ value" 54 | assert str(exc_info.value) == error_msg 55 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/templates/register.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ${title} 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 | 16 |
17 | 18 |
19 | 20 |

${error}

21 |
22 | 23 |
24 |
25 | 26 |
27 |
28 | 29 |
30 |
31 | 32 |
33 | 34 |
35 |
36 |
37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from string import Template 3 | 4 | import pytest 5 | 6 | from fastapi_sqlalchemy import utils 7 | 8 | 9 | def test_ordered_uuid(mocker): 10 | ordered_uuid = mocker.patch("fastapi_sqlalchemy.utils.OrderedUUID") 11 | 12 | value = uuid.uuid4() 13 | result = utils.ordered_uuid(value) 14 | assert ordered_uuid.call_args == mocker.call(value) 15 | assert result is ordered_uuid.return_value 16 | 17 | 18 | def test_ordered_uuid_defaults(mocker): 19 | ordered_uuid = mocker.patch("fastapi_sqlalchemy.utils.OrderedUUID") 20 | 21 | expected_result = "mock-uuid-1" 22 | mocker.patch("uuid.uuid1", return_value=expected_result) 23 | 24 | result = utils.ordered_uuid() 25 | assert ordered_uuid.call_args == mocker.call(expected_result) 26 | assert result is ordered_uuid.return_value 27 | 28 | 29 | def test_ordered_uuid_pkg_not_found_error(): 30 | with pytest.raises(RuntimeError) as exc_info: 31 | utils.ordered_uuid() 32 | 33 | assert str(exc_info.value) == "ordered_uuid package: not found" 34 | 35 | 36 | def test_render_template(): 37 | template = Template("Say $word.") 38 | result = utils.render(template, word="hi") 39 | assert result == "Say hi." 40 | 41 | 42 | def test_render_template_pass_string(): 43 | template = " str: 44 | # pylint: disable=no-self-use 45 | return value.lower() 46 | 47 | def verify( 48 | self, 49 | secret: str 50 | ) -> bool: 51 | """ Verify a provided secret against the stored hash 52 | """ 53 | if not self.hashed_password: 54 | return False 55 | return pbkdf2_sha512.verify(secret, self.hashed_password) 56 | 57 | @property 58 | def password( 59 | self 60 | ) -> None: 61 | """ password getter: throws RuntimeError 62 | """ 63 | raise RuntimeError("Invalid access: get password not allowed") 64 | 65 | @password.setter 66 | def password( 67 | self, secret 68 | ) -> None: 69 | """ password setter: update the hash using the given secret 70 | """ 71 | self.hashed_password = pbkdf2_sha512.hash(secret) 72 | 73 | @classmethod 74 | def get_by_username( 75 | cls, 76 | session: Session, 77 | name: str, 78 | ): 79 | """ Lookup a User by name 80 | """ 81 | return session.query(cls).filter(cls.username == name.lower()).first() 82 | 83 | @property 84 | def identity(self) -> str: 85 | return str(self.id) 86 | 87 | 88 | @event.listens_for(User, "mapper_configured", propagate=True) 89 | def _mapper_configured(_mapper, cls): 90 | if getattr(cls, "__model_mapping__", True): 91 | MODEL_MAPPING["User"] = cls 92 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fastapi_sqlalchemy import types 4 | 5 | 6 | def test_query_types_functional(app, client): 7 | @app.get("/types") 8 | def _get( 9 | limit: int = types.LimitQuery(50, alias="lim"), 10 | offset: int = types.OffsetQuery(20, alias="off") 11 | ): 12 | return { 13 | "limit": limit, 14 | "offset": offset, 15 | } 16 | 17 | response = client.get("/types", params={ 18 | "lim": 10, 19 | "off": 100 20 | }) 21 | assert response.status_code == 200 22 | assert response.json() == { 23 | "limit": 10, 24 | "offset": 100, 25 | } 26 | 27 | response = client.get("/types") 28 | assert response.status_code == 200 29 | assert response.json() == { 30 | "limit": 50, 31 | "offset": 20, 32 | } 33 | 34 | 35 | def test_limit_query(): 36 | limit = types.LimitQuery() 37 | assert limit.default == 100 38 | assert limit.alias is None 39 | assert limit.title == "Maximum number of entries to return" 40 | assert limit.description == "Maximum number of entries to return." 41 | 42 | # Verify that negative default value cannot be set 43 | with pytest.raises(AssertionError): 44 | types.LimitQuery(-100) 45 | 46 | 47 | def test_offset_query(): 48 | limit = types.OffsetQuery() 49 | assert limit.default == 0 50 | assert limit.alias is None 51 | assert limit.title == "Index of the first entry to return" 52 | assert limit.description == "Index of the first entry to return." 53 | 54 | # Verify that negative default value cannot be set 55 | with pytest.raises(AssertionError): 56 | types.OffsetQuery(-100) 57 | 58 | 59 | def test_pydantic_types(app, client): 60 | assert types.NonNegativeInt.ge == 0 61 | 62 | @app.get("/types") 63 | def _get( 64 | non_negative: types.NonNegativeInt = 0 65 | ): 66 | return { 67 | "non_negative": non_negative 68 | } 69 | 70 | response = client.get("/types", params={ 71 | "non_negative": 1 72 | }) 73 | assert response.status_code == 200 74 | assert response.json() == { 75 | "non_negative": 1, 76 | } 77 | 78 | response = client.get("/types") 79 | assert response.status_code == 200 80 | assert response.json() == { 81 | "non_negative": 0, 82 | } 83 | 84 | response = client.get("/types", params={ 85 | "non_negative": -1 86 | }) 87 | assert response.status_code == 422 88 | detail = response.json()["detail"] 89 | assert detail[0]["loc"] == ["query", "non_negative"] 90 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/mixins.py: -------------------------------------------------------------------------------- 1 | """ Model mixins """ 2 | import uuid 3 | 4 | import sqlalchemy 5 | from sqlalchemy.ext.declarative import declared_attr 6 | 7 | from fastapi_sqlalchemy import tz 8 | from .types import GUID 9 | from .base import Session, model_as_dict 10 | 11 | 12 | class GuidMixin: 13 | """ Mixin that add a UUID id column """ 14 | id = sqlalchemy.Column( 15 | GUID, 16 | primary_key=True, 17 | default=uuid.uuid4, 18 | ) 19 | 20 | 21 | class TimestampMixin: 22 | """ Mixin to add update_at and created_at columns 23 | 24 | The columns are added at the *end* of the table 25 | """ 26 | @declared_attr 27 | def updated_at(self): 28 | """ Last update timestamp """ 29 | column = sqlalchemy.Column( 30 | sqlalchemy.DateTime(timezone=True), 31 | default=tz.utcnow, 32 | onupdate=tz.utcnow, 33 | nullable=False, 34 | ) 35 | # pylint: disable=protected-access 36 | column._creation_order = 9800 37 | return column 38 | 39 | @declared_attr 40 | def created_at(self): 41 | """ Creation timestamp """ 42 | column = sqlalchemy.Column( 43 | sqlalchemy.DateTime(timezone=True), 44 | default=tz.utcnow, 45 | nullable=False, 46 | ) 47 | # pylint: disable=protected-access 48 | column._creation_order = 9900 49 | return column 50 | 51 | 52 | class DictMixin: 53 | """ Mixin to add as_dict() """ 54 | 55 | def as_dict(self) -> dict: 56 | """ Convert object to dictionary """ 57 | return model_as_dict(self) 58 | 59 | 60 | class ConfirmationMixin: 61 | """ Mixin to support confirmation for Users """ 62 | 63 | email = sqlalchemy.Column( 64 | sqlalchemy.String(255), 65 | nullable=False, 66 | unique=True 67 | ) 68 | 69 | @classmethod 70 | def get_by_email( 71 | cls, 72 | session: Session, 73 | email: str, 74 | ): 75 | """ Lookup a User by name 76 | """ 77 | return session.query(cls).filter(cls.email == email).first() 78 | 79 | @declared_attr 80 | def confirmed_at(self): 81 | """ Email confirmation timestamp """ 82 | column = sqlalchemy.Column( 83 | sqlalchemy.DateTime(timezone=True), 84 | nullable=True, 85 | ) 86 | # pylint: disable=protected-access 87 | column._creation_order = 9700 88 | return column 89 | 90 | @property 91 | def confirmed(self): 92 | """ Whether or not the email has been confirmed """ 93 | return self.confirmed_at is not None 94 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/types.py: -------------------------------------------------------------------------------- 1 | """ SQLAlchemy types - particularly for columns """ 2 | import uuid 3 | import json 4 | 5 | from sqlalchemy.ext.mutable import MutableDict 6 | 7 | from sqlalchemy.sql import operators 8 | from sqlalchemy.types import String, TypeDecorator, CHAR, TEXT 9 | from sqlalchemy.dialects.postgresql import UUID 10 | 11 | 12 | class GUID(TypeDecorator): 13 | """Platform-independent GUID type. 14 | 15 | Uses PostgreSQL's UUID type, otherwise uses 16 | CHAR(32), storing as stringified hex values. 17 | 18 | https://docs.sqlalchemy.org/en/latest/core/custom_types.html 19 | Backend-agnostic GUID Type 20 | """ 21 | impl = CHAR 22 | 23 | def load_dialect_impl(self, dialect): 24 | if dialect.name == "postgresql": 25 | return dialect.type_descriptor(UUID()) 26 | return dialect.type_descriptor(CHAR(32)) 27 | 28 | def process_bind_param(self, value, dialect): 29 | if value is None: 30 | return value 31 | if dialect.name == "postgresql": 32 | return str(value) 33 | if not isinstance(value, uuid.UUID): 34 | return "%.32x" % uuid.UUID(value).int 35 | # hexstring 36 | return "%.32x" % value.int 37 | 38 | def process_result_value(self, value, dialect): 39 | if value is None: 40 | return value 41 | if not isinstance(value, uuid.UUID): 42 | value = uuid.UUID(value) 43 | return value 44 | 45 | def process_literal_param(self, value, dialect): 46 | raise NotImplementedError() 47 | 48 | @property 49 | def python_type(self): 50 | raise NotImplementedError() 51 | 52 | 53 | class JSONEncodedDict(TypeDecorator): 54 | """Represents an immutable structure as a json-encoded string. 55 | """ 56 | 57 | impl = TEXT 58 | 59 | _OPERATORS_FOR_STR = ( 60 | operators.like_op, 61 | operators.notlike_op, 62 | ) 63 | 64 | def coerce_compared_value(self, op, value): 65 | if op in self._OPERATORS_FOR_STR: 66 | return String() 67 | return self 68 | 69 | def process_bind_param(self, value, dialect): 70 | if value is not None: 71 | value = json.dumps(value) 72 | 73 | return value 74 | 75 | def process_result_value(self, value, dialect): 76 | if value is not None: 77 | value = json.loads(value) 78 | return value 79 | 80 | def process_literal_param(self, value, dialect): 81 | raise NotImplementedError() 82 | 83 | @property 84 | def python_type(self): 85 | raise NotImplementedError() 86 | 87 | 88 | JSON_TYPE = MutableDict.as_mutable(JSONEncodedDict) 89 | -------------------------------------------------------------------------------- /tests/endpoints/test_confirm.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from itsdangerous import URLSafeTimedSerializer 4 | 5 | from fastapi_sqlalchemy import endpoints, models 6 | 7 | 8 | class User( 9 | models.User, models.mixins.ConfirmationMixin, models.mixins.DictMixin 10 | ): 11 | __tablename__ = "test_confirm_users" 12 | __model_mapping__ = False 13 | 14 | 15 | def test_confirm(session, app, client): 16 | endpoint = endpoints.ConfirmEndpoint( 17 | User, secret="s0secret", 18 | ) 19 | 20 | user = User(username="testuser", email="recipient@example.org") 21 | session.add(user) 22 | session.commit() 23 | 24 | token_ = URLSafeTimedSerializer("s0secret").dumps(user.email, salt=None) 25 | 26 | @app.get("/confirm/{token}") 27 | async def _get( 28 | token: str 29 | ): 30 | return await endpoint.on_get(session, token) 31 | 32 | res = client.get(f"/confirm/{token_}", allow_redirects=False) 33 | assert res.status_code == 303 34 | assert res.headers.get("location") == "/login" 35 | 36 | 37 | def test_confirm_user_already_confirmed(session, app, client): 38 | endpoint = endpoints.ConfirmEndpoint( 39 | User, secret="s0secret", 40 | ) 41 | 42 | user = User( 43 | username="testuser", 44 | email="recipient@example.org", 45 | confirmed_at=datetime.now() 46 | ) 47 | session.add(user) 48 | session.commit() 49 | 50 | token_ = URLSafeTimedSerializer("s0secret").dumps(user.email, salt=None) 51 | 52 | @app.get("/confirm/{token}") 53 | async def _get( 54 | token: str 55 | ): 56 | return await endpoint.on_get(session, token) 57 | 58 | res = client.get(f"/confirm/{token_}", allow_redirects=False) 59 | assert res.status_code == 303 60 | assert res.headers.get("location") == "/login" 61 | 62 | 63 | def test_confirm_user_not_found(session, app, client): 64 | endpoint = endpoints.ConfirmEndpoint( 65 | User, secret="s0secret", template="<${error}" 66 | ) 67 | 68 | nonexistent_email = "nonexistent@local.example.com" 69 | token_ = URLSafeTimedSerializer( 70 | "s0secret").dumps(nonexistent_email, salt=None) 71 | 72 | @app.get("/confirm/{token}") 73 | async def _get( 74 | token: str 75 | ): 76 | return await endpoint.on_get(session, token) 77 | 78 | res = client.get(f"/confirm/{token_}", allow_redirects=False) 79 | assert res.status_code == 400 80 | assert res.text == f" str: 43 | """ Render the template using the passed parameters """ 44 | kwargs.setdefault("error", "") 45 | kwargs.setdefault("title", "FastAPI-SQLAlchemy") 46 | kwargs.setdefault("email", "") 47 | 48 | return utils.render(self.template, **kwargs) 49 | 50 | async def on_get( 51 | self, 52 | session: models.Session, 53 | token, 54 | ) -> Union[HTMLResponse, JSONResponse]: 55 | """ Handle GET requests """ 56 | 57 | def _confirm() -> Union[dict, str]: 58 | try: 59 | serializer = URLSafeTimedSerializer(self.secret) 60 | email = serializer.loads( 61 | token, 62 | salt=self.salt, 63 | max_age=self.max_age, 64 | ) 65 | except BadData as ex: 66 | logger.warning("%s: %s", token, str(ex)) 67 | return self.render( 68 | error="The confirmation link is invalid or expired." 69 | ) 70 | 71 | user = self.user_cls.get_by_email(session, email) 72 | if not user: 73 | logger.info("User not found: %s", email) 74 | return self.render( 75 | error=f"User not found: {email}" 76 | ) 77 | 78 | if user.confirmed: 79 | logger.info( 80 | "User %s <%s> already confirmed", user.username, user.email 81 | ) 82 | else: 83 | user.confirmed_at = tz.utcnow() 84 | session.commit() 85 | logger.info( 86 | "User %s <%s> confirmed", user.username, user.email 87 | ) 88 | return user.as_dict() 89 | 90 | result = await run_in_threadpool(_confirm) 91 | if isinstance(result, str): 92 | # Error condition 93 | return HTMLResponse(status_code=400, content=result) 94 | 95 | headers = {"location": self.location} 96 | return JSONResponse( 97 | content=result, 98 | status_code=303, 99 | headers=headers 100 | ) 101 | -------------------------------------------------------------------------------- /tests/endpoints/test_login.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | import fastapi 4 | import jwt 5 | 6 | from fastapi_sqlalchemy import endpoints 7 | 8 | from tests.data.models import User 9 | 10 | 11 | def test_login_get(session, app, client): 12 | template = "${title}" 13 | endpoint = endpoints.LoginEndpoint( 14 | User, secret="s0secret", template=template 15 | ) 16 | 17 | @app.get("/login") 18 | async def _get(): 19 | return await endpoint.on_get() 20 | 21 | res = client.get("/login") 22 | assert res.status_code == 200 23 | assert res.text == "FastAPI-SQLAlchemy" 24 | 25 | 26 | def test_login_post(mocker, engine, session, app, client): 27 | now_dt = datetime.utcnow() 28 | mocker.patch("fastapi_sqlalchemy.tz.utcnow", return_value=now_dt) 29 | 30 | user = User(username="alice") 31 | user.password = "test123" 32 | session.add(user) 33 | session.commit() 34 | user = session.merge(user) 35 | 36 | endpoint = endpoints.LoginEndpoint(User, secret="s0secret") 37 | 38 | @app.post("/login") 39 | async def _post( 40 | username: str = fastapi.Form(None), 41 | password: str = fastapi.Form(None) 42 | ): 43 | return await endpoint.on_post(session, username, password) 44 | 45 | expiry = now_dt + timedelta(seconds=endpoint.token_expiry) 46 | 47 | expected_data = { 48 | **user.as_dict(), 49 | "exp": expiry, 50 | } 51 | expected_token = jwt.encode( 52 | expected_data, 53 | endpoint.secret, 54 | algorithm=endpoint.jwt_algorithm 55 | ).decode("utf-8") 56 | 57 | res = client.post( 58 | "/login", 59 | data={ 60 | "username": "alice", 61 | "password": "test123", 62 | } 63 | ) 64 | assert res.status_code == 303 65 | assert res.headers.get("location") == "/" 66 | assert res.cookies.items() == [ 67 | ("jwt", expected_token) 68 | ] 69 | assert res.json() == { 70 | **expected_data, 71 | "exp": expiry.isoformat(), 72 | "token": expected_token 73 | } 74 | 75 | 76 | def test_login_post_username_not_found(engine, session, app, client): 77 | endpoint = endpoints.LoginEndpoint( 78 | User, 79 | secret="s0secret", 80 | template="<${error}" 81 | ) 82 | 83 | @app.post("/login") 84 | async def _post( 85 | username: str = fastapi.Form(None), 86 | password: str = fastapi.Form(None) 87 | ): 88 | return await endpoint.on_post(session, username, password) 89 | 90 | res = client.post("/login", data={ 91 | "username": "alice", 92 | "password": "test123", 93 | }) 94 | assert res.status_code == 401 95 | assert res.text == " sqlalchemy.Table: 17 | """ Generate a user <-> group association table """ 18 | table = sqlalchemy.Table( 19 | table_name, 20 | BASE.metadata, 21 | sqlalchemy.Column( 22 | "id", 23 | GUID, 24 | primary_key=True, 25 | default=uuid.uuid4, 26 | ), 27 | sqlalchemy.Column( 28 | "group_id", 29 | GUID, 30 | sqlalchemy.ForeignKey(group_table_name + ".id") 31 | ), 32 | sqlalchemy.Column( 33 | "user_id", 34 | GUID, 35 | sqlalchemy.ForeignKey(user_table_name + ".id") 36 | ), 37 | sqlalchemy.Column( 38 | "updated_at", 39 | sqlalchemy.DateTime(timezone=True), 40 | default=tz.utcnow, 41 | onupdate=tz.utcnow, 42 | nullable=False, 43 | ), 44 | sqlalchemy.Column( 45 | "created_at", 46 | sqlalchemy.DateTime(timezone=True), 47 | default=tz.utcnow, 48 | onupdate=tz.utcnow, 49 | nullable=False, 50 | ) 51 | ) 52 | table.__association__ = "group_membership" 53 | return table 54 | 55 | 56 | def create_user_permissions_table( 57 | table_name: str = "user_permissions", 58 | user_table_name: str = "users", 59 | permission_table_name: str = "permissions", 60 | ) -> sqlalchemy.Table: 61 | """ Generate a user <-> permission association table """ 62 | table = sqlalchemy.Table( 63 | table_name, 64 | BASE.metadata, 65 | sqlalchemy.Column( 66 | "id", 67 | GUID, 68 | primary_key=True, 69 | default=uuid.uuid4, 70 | ), 71 | sqlalchemy.Column( 72 | "user_id", 73 | GUID, 74 | sqlalchemy.ForeignKey(user_table_name + ".id") 75 | ), 76 | sqlalchemy.Column( 77 | "permission_id", 78 | GUID, 79 | sqlalchemy.ForeignKey(permission_table_name + ".id") 80 | ), 81 | sqlalchemy.Column( 82 | "updated_at", 83 | sqlalchemy.DateTime(timezone=True), 84 | default=tz.utcnow, 85 | onupdate=tz.utcnow, 86 | nullable=False, 87 | ), 88 | sqlalchemy.Column( 89 | "created_at", 90 | sqlalchemy.DateTime(timezone=True), 91 | default=tz.utcnow, 92 | onupdate=tz.utcnow, 93 | nullable=False, 94 | ) 95 | ) 96 | table.__association__ = "user_permissions" 97 | return table 98 | 99 | 100 | def create_group_permissions_table( 101 | table_name: str = "group_permissions", 102 | group_table_name: str = "groups", 103 | permission_table_name: str = "permissions", 104 | ) -> sqlalchemy.Table: 105 | """ Generate a group <-> permission association table """ 106 | table = sqlalchemy.Table( 107 | table_name, 108 | BASE.metadata, 109 | sqlalchemy.Column( 110 | "id", 111 | GUID, 112 | primary_key=True, 113 | default=uuid.uuid4, 114 | ), 115 | sqlalchemy.Column( 116 | "group_id", 117 | GUID, 118 | sqlalchemy.ForeignKey(group_table_name + ".id") 119 | ), 120 | sqlalchemy.Column( 121 | "permission_id", 122 | GUID, 123 | sqlalchemy.ForeignKey(permission_table_name + ".id") 124 | ), 125 | sqlalchemy.Column( 126 | "updated_at", 127 | sqlalchemy.DateTime(timezone=True), 128 | default=tz.utcnow, 129 | onupdate=tz.utcnow, 130 | nullable=False, 131 | ), 132 | sqlalchemy.Column( 133 | "created_at", 134 | sqlalchemy.DateTime(timezone=True), 135 | default=tz.utcnow, 136 | onupdate=tz.utcnow, 137 | nullable=False, 138 | ) 139 | ) 140 | table.__association__ = "group_permissions" 141 | return table 142 | -------------------------------------------------------------------------------- /tests/models/test_types.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | 5 | from sqlalchemy import Column, Integer, MetaData, Table 6 | from sqlalchemy.dialects import postgresql 7 | from sqlalchemy.ext.declarative import declarative_base 8 | 9 | from fastapi_sqlalchemy.models import types 10 | 11 | 12 | def test_guid_type_functional(session): 13 | meta = MetaData() 14 | table = Table( 15 | "test_guid_type", 16 | meta, 17 | Column("id", types.GUID, primary_key=True, default=uuid.uuid4), 18 | Column("other_id", types.GUID) 19 | ) 20 | table.create(session.bind) 21 | 22 | other_id_uuid = uuid.uuid4() 23 | other_id_str = str(uuid.uuid4()) 24 | 25 | # pylint: disable=no-value-for-parameter 26 | session.execute(table.insert(), {"other_id": other_id_uuid}) 27 | session.execute(table.insert(), {"other_id": other_id_str}) 28 | session.execute(table.insert(), {"other_id": None}) 29 | # pylint: enable=no-value-for-parameter 30 | session.commit() 31 | 32 | rows = session.execute(table.select()).fetchall() 33 | rows = list(map(dict, rows)) 34 | assert len(rows) == 3 35 | 36 | assert isinstance(rows[0]["id"], uuid.UUID) 37 | assert isinstance(rows[0]["other_id"], uuid.UUID) 38 | assert rows[0]["other_id"] == other_id_uuid 39 | 40 | assert isinstance(rows[1]["id"], uuid.UUID) 41 | assert isinstance(rows[1]["other_id"], uuid.UUID) 42 | assert str(rows[1]["other_id"]) == other_id_str 43 | 44 | assert isinstance(rows[2]["id"], uuid.UUID) 45 | assert rows[2]["other_id"] is None 46 | 47 | 48 | def test_guid_type(mocker): 49 | guid = types.GUID() 50 | 51 | dialect_postgresql = mocker.Mock() 52 | dialect_postgresql.name = "postgresql" 53 | dialect_postgresql.type_descriptor = lambda _val: _val 54 | 55 | dialect_impl = guid.load_dialect_impl(dialect_postgresql) 56 | assert isinstance(dialect_impl, postgresql.UUID) 57 | 58 | value = uuid.uuid4() 59 | bind_param = guid.process_bind_param(value, dialect_postgresql) 60 | assert isinstance(bind_param, str) 61 | assert bind_param == str(value) 62 | 63 | with pytest.raises(NotImplementedError): 64 | guid.process_literal_param(value=mocker.Mock(), dialect=mocker.Mock()) 65 | 66 | with pytest.raises(NotImplementedError): 67 | assert not guid.python_type 68 | 69 | 70 | def test_json_encoded_dict_type_functional(session): 71 | base_cls = declarative_base() 72 | 73 | class TestModel(base_cls): 74 | __tablename__ = "test_json_encoded_dict_type" 75 | 76 | id = Column(Integer, primary_key=True) 77 | 78 | mutable_data = Column(types.JSON_TYPE) 79 | non_mutable_data = Column(types.JSONEncodedDict) 80 | 81 | base_cls.metadata.create_all(session.bind) 82 | 83 | mutable_data = {"key": "value"} 84 | non_mutable_data = {"fixed-key": "fixed-value"} 85 | model = TestModel( 86 | mutable_data=mutable_data, 87 | non_mutable_data=non_mutable_data 88 | ) 89 | session.add(model) 90 | session.commit() 91 | model = session.merge(model) 92 | 93 | assert model.mutable_data == mutable_data 94 | assert model.non_mutable_data == non_mutable_data 95 | 96 | model.mutable_data["key"] = "updated-value" 97 | model.mutable_data["new-key"] = "new-value" 98 | model.non_mutable_data["fixed-key"] = "updated-fixed-value" 99 | model.non_mutable_data["new-fixed-key"] = "new-fixed-value" 100 | session.commit() 101 | 102 | assert model.mutable_data == { 103 | "key": "updated-value", 104 | "new-key": "new-value" 105 | } 106 | # All changes should be reset after commit() 107 | assert model.non_mutable_data == non_mutable_data 108 | 109 | row = session.query(TestModel).filter( 110 | TestModel.mutable_data.like("%updated-value%"), 111 | TestModel.non_mutable_data.notlike("%new-fixed-value%") 112 | ).one() 113 | assert row is model 114 | 115 | row = session.query(TestModel).filter( 116 | TestModel.non_mutable_data == non_mutable_data 117 | ).one() 118 | assert row is model 119 | 120 | 121 | def test_json_encoded_dict_type(mocker): 122 | json_type = types.JSONEncodedDict() 123 | 124 | with pytest.raises(NotImplementedError): 125 | json_type.process_literal_param( 126 | value=mocker.Mock(), dialect=mocker.Mock()) 127 | 128 | with pytest.raises(NotImplementedError): 129 | assert not json_type.python_type 130 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/middleware.py: -------------------------------------------------------------------------------- 1 | """ Generic middleware """ 2 | import logging 3 | from typing import Union, Sequence 4 | 5 | import jwt 6 | from sqlalchemy.engine import Connectable 7 | from starlette.requests import Request 8 | from starlette.responses import Response 9 | from starlette.middleware.base import BaseHTTPMiddleware 10 | from starlette.types import ASGIApp 11 | 12 | from fastapi_sqlalchemy import db_registry 13 | from fastapi_sqlalchemy.models import Session 14 | 15 | 16 | PAYLOAD_HEADER_PREFIX = "x-payload-" 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class SessionMiddleware(BaseHTTPMiddleware): 22 | """Add a `models.Session` instance to `request.state.session`. 23 | 24 | Close the session after each request, thus rolling back any 25 | not committed transactions. 26 | 27 | Given bind will be added to `fastapi_sqlalchemy.db_registry` and so can be 28 | accessed from there. 29 | """ 30 | def __init__( 31 | self, 32 | app: ASGIApp, 33 | bind: Union[str, Connectable], 34 | **engine_kwargs 35 | ): 36 | super().__init__(app) 37 | bind = db_registry.register(bind, **engine_kwargs) 38 | Session.configure(bind=bind) 39 | 40 | async def dispatch( 41 | self, 42 | request: Request, 43 | call_next 44 | ) -> Response: 45 | added = False 46 | try: 47 | if not hasattr(request.state, "session"): 48 | request.state.session = Session() 49 | added = True 50 | response = await call_next(request) 51 | finally: 52 | if added: 53 | # Only close a session if we added it, useful for testing 54 | request.state.session.close() 55 | return response 56 | 57 | 58 | class UpstreamPayloadMiddleware(BaseHTTPMiddleware): 59 | """Parse upstream request headers and set the result to 60 | `request.state.payload`. 61 | 62 | NOTE: there must be an upstream service (like an API Gateway) to 63 | ensure these headers are trusted. Otherwise the client could set 64 | any desired permissions. 65 | """ 66 | 67 | PAYLOAD_HEADER_PREFIX = PAYLOAD_HEADER_PREFIX 68 | 69 | def __init__( 70 | self, 71 | app: ASGIApp, 72 | header_prefix: str = PAYLOAD_HEADER_PREFIX, 73 | ): 74 | super().__init__(app=app) 75 | self.header_prefix = header_prefix 76 | 77 | async def dispatch( 78 | self, 79 | request: Request, 80 | call_next 81 | ) -> Response: 82 | payload = {} 83 | for header_name in request.headers: 84 | if header_name.startswith(self.header_prefix): 85 | name = header_name[len(self.header_prefix):] 86 | value = request.headers.getlist(header_name) 87 | if len(value) == 1: 88 | payload[name] = value[0] 89 | else: # pragma: nocover 90 | payload[name] = value 91 | request.state.payload = payload 92 | return await call_next(request) 93 | 94 | 95 | class JwtMiddleware(BaseHTTPMiddleware): 96 | """Decode a JWT token from cookies (if present) and add the results to 97 | `request.state.payload`.""" 98 | 99 | def __init__( 100 | self, 101 | app: ASGIApp, 102 | secret: str, 103 | cookie_name: str = "jwt", 104 | algorithms: Sequence[str] = ("HS256", "HS512"), 105 | **kwargs, 106 | ): 107 | super().__init__(app=app) 108 | self.secret = secret 109 | self.cookie_name = cookie_name 110 | self.algorithms = algorithms 111 | self.kwargs = kwargs 112 | 113 | async def dispatch( 114 | self, request: Request, call_next 115 | ) -> Response: 116 | token = request.cookies.get(self.cookie_name) 117 | if token: 118 | try: 119 | payload = jwt.decode( 120 | token, 121 | key=self.secret, 122 | algorithms=self.algorithms, 123 | **self.kwargs 124 | ) 125 | request.state.payload = payload 126 | except jwt.exceptions.InvalidTokenError as ex: 127 | logger.info("JWT decode error: %s", str(ex)) 128 | else: 129 | logging.debug("%s: No JWT", str(request.url)) 130 | 131 | return await call_next(request) 132 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/auth.py: -------------------------------------------------------------------------------- 1 | """ authentication and authorization """ 2 | import logging 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | from fastapi import Request 6 | from fastapi.security import SecurityScopes 7 | from fastapi import status 8 | from starlette.authentication import ( 9 | AuthenticationBackend, AuthCredentials, SimpleUser 10 | ) 11 | from starlette.concurrency import run_in_threadpool 12 | from starlette.exceptions import HTTPException 13 | from starlette.requests import HTTPConnection 14 | 15 | ADMIN_SCOPE = "*" 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class PayloadAuthBackend(AuthenticationBackend): 21 | """ Get auth information from the request payload """ 22 | 23 | def __init__( 24 | self, 25 | user_cls: type = None, 26 | admin_scope: str = ADMIN_SCOPE, 27 | ): 28 | super().__init__() 29 | self.user_cls = user_cls 30 | self.admin_scope = admin_scope 31 | 32 | async def scopes(self, payload: Dict[str, str]) -> List[str]: 33 | """ Return the list of scopes """ 34 | if "scopes" in payload: 35 | scopes = payload["scopes"] 36 | elif "permissions" in payload: 37 | scopes = payload["permissions"] 38 | else: 39 | return [] 40 | 41 | if isinstance(scopes, str): 42 | scopes = [scopes] 43 | 44 | if self.admin_scope and self.admin_scope in scopes: 45 | return [self.admin_scope] 46 | 47 | result = [] 48 | for scope in scopes: 49 | result.extend([token.strip() for token in scope.split(",")]) 50 | 51 | return result 52 | 53 | async def authenticate( 54 | self, conn: HTTPConnection 55 | ) -> Optional[Tuple["AuthCredentials", "BaseUser"]]: 56 | try: 57 | payload = conn.state.payload 58 | except AttributeError: 59 | raise RuntimeError( 60 | "Missing 'request.state.payload': " 61 | "try adding 'middleware.UpstreamPayloadMiddleware'" 62 | ) 63 | 64 | username = payload.get("username") 65 | if not username: 66 | return 67 | 68 | if self.user_cls: 69 | try: 70 | session = conn.state.session 71 | except AttributeError: 72 | raise RuntimeError( 73 | "Missing 'request.state.session': " 74 | "try adding 'middleware.SessionMiddleware'" 75 | ) 76 | 77 | user = await run_in_threadpool( 78 | self.user_cls.get_by_username, session, username 79 | ) 80 | if not user: 81 | logger.warning("User not found: %s", username) 82 | return 83 | else: 84 | user = SimpleUser(username=username) 85 | 86 | scopes = await self.scopes(payload) 87 | return AuthCredentials(scopes), user 88 | 89 | 90 | def validate_authenticated(request: Request): 91 | """Validate that 'request.user' is authenticated. 92 | 93 | Usage: 94 | >>> from fastapi import Depends 95 | >>> @app.get("/my_name", dependencies=[ 96 | ... Depends(validate_authenticated) 97 | ... ]) 98 | """ 99 | user: SimpleUser = getattr(request, "user", None) 100 | if user is not None and not user.is_authenticated: 101 | raise HTTPException(status.HTTP_401_UNAUTHORIZED) 102 | 103 | 104 | def validate_all_scopes(request: Request, scopes: SecurityScopes): 105 | """Validate that all defined scopes exist in 'request.auth.scopes'. 106 | 107 | Usage: 108 | >>> from fastapi import Security 109 | >>> @app.get("/my_name", dependencies=[ 110 | ... Security(validate_all_scopes, scopes=["read", "write"]) 111 | ... ]) 112 | """ 113 | req_scopes = request.auth.scopes 114 | if not all(scope in req_scopes for scope in scopes.scopes): 115 | raise HTTPException(status.HTTP_403_FORBIDDEN) 116 | 117 | 118 | def validate_any_scope(request: Request, scopes: SecurityScopes): 119 | """Validate that at least one defined scope exists in 'request.auth.scopes'. 120 | 121 | Usage: 122 | >>> from fastapi import Security 123 | >>> @app.get("/my_name", dependencies=[ 124 | ... Security(validate_any_scope, scopes=["read", "write"]) 125 | ... ]) 126 | """ 127 | req_scopes = request.auth.scopes 128 | if not any(scope in req_scopes for scope in scopes.scopes): 129 | raise HTTPException(status.HTTP_403_FORBIDDEN) 130 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/models/events.py: -------------------------------------------------------------------------------- 1 | """ Bookkeeping """ 2 | from sqlalchemy import event, inspect 3 | from sqlalchemy.orm import mapper, relationship 4 | 5 | from .base import BASE, MODEL_MAPPING 6 | 7 | 8 | @event.listens_for(mapper, "after_configured") 9 | def _after_configured(): 10 | # pylint: disable=too-many-branches 11 | user_cls = MODEL_MAPPING.get("User") 12 | group_cls = MODEL_MAPPING.get("Group") 13 | permission_cls = MODEL_MAPPING.get("Permission") 14 | 15 | associations = {} 16 | for table in BASE.metadata.tables.values(): 17 | association = getattr(table, "__association__", None) 18 | if association is None: 19 | continue 20 | if association in associations: 21 | raise RuntimeError( 22 | f"Multiple '{association}' associations found." 23 | "Only a single table may have a specific __association__ value" 24 | ) 25 | associations[association] = table 26 | 27 | group_membership_table = associations.get("group_membership") 28 | if group_membership_table is not None: 29 | if not user_cls: 30 | raise RuntimeError( 31 | "'group_membership' association table found, " 32 | "but no User table defined." 33 | ) 34 | if not group_cls: 35 | raise RuntimeError( 36 | "'group_membership' association table found, " 37 | "but no Group table defined." 38 | ) 39 | 40 | if not hasattr(user_cls, "groups"): 41 | user_cls.groups = relationship( 42 | group_cls, 43 | secondary=group_membership_table 44 | ) 45 | if not hasattr(group_cls, "users"): 46 | group_cls.users = relationship( 47 | user_cls, 48 | secondary=group_membership_table 49 | ) 50 | 51 | user_permissions_table = associations.get("user_permissions") 52 | if user_permissions_table is not None: 53 | if not user_cls: 54 | raise RuntimeError( 55 | "'user_permissions' association table found, " 56 | "but no User table defined." 57 | ) 58 | if not permission_cls: 59 | raise RuntimeError( 60 | "'user_permissions' association table found, " 61 | "but no Permission table defined." 62 | ) 63 | 64 | if not hasattr(user_cls, "user_permissions"): 65 | user_cls.user_permissions = relationship( 66 | permission_cls, 67 | secondary=user_permissions_table 68 | ) 69 | if not hasattr(permission_cls, "users"): 70 | permission_cls.users = relationship( 71 | user_cls, 72 | secondary=user_permissions_table 73 | ) 74 | 75 | group_permissions_table = associations.get("group_permissions") 76 | if group_permissions_table is not None: 77 | if not group_cls: 78 | raise RuntimeError( 79 | "'group_permissions' association table found, " 80 | "but no Group table defined." 81 | ) 82 | if not permission_cls: 83 | raise RuntimeError( 84 | "'group_permissions' association table found, " 85 | "but no Permission table defined." 86 | ) 87 | 88 | if not hasattr(group_cls, "permissions"): 89 | group_cls.permissions = relationship( 90 | permission_cls, 91 | secondary=group_permissions_table 92 | ) 93 | if not hasattr(permission_cls, "groups"): 94 | permission_cls.groups = relationship( 95 | group_cls, 96 | secondary=group_permissions_table 97 | ) 98 | 99 | def _permissions(user): 100 | session = inspect(user).session 101 | return session.query(permission_cls) \ 102 | .join(user_permissions_table, user_cls) \ 103 | .filter(user_cls.id == user.id) \ 104 | .union( 105 | session.query(permission_cls) 106 | .join(group_permissions_table) 107 | .join(group_cls) 108 | .join(group_membership_table) 109 | .join(user_cls) 110 | .filter(user_cls.id == user.id)) 111 | 112 | if user_cls and not hasattr(user_cls, "permissions") and \ 113 | user_permissions_table is not None and \ 114 | group_permissions_table is not None and \ 115 | group_membership_table is not None: 116 | user_cls.permissions = property(_permissions) 117 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/crud.py: -------------------------------------------------------------------------------- 1 | """ Generic CRUD operations """ 2 | from uuid import UUID 3 | from typing import List, Dict, Any 4 | 5 | import sqlalchemy.exc 6 | from pydantic import BaseModel, PositiveInt 7 | from starlette.exceptions import HTTPException 8 | from starlette.concurrency import run_in_threadpool 9 | 10 | from sqlalchemy_filters import apply_filters, apply_sort 11 | 12 | from . import models, types 13 | 14 | # NOTE: always use the session of the caller 15 | # i.e. don't us models.Session in the thread pool synchronous functions 16 | # This is necessary in sqlite3 (at least) to ensure consistency. 17 | 18 | 19 | async def list_instances( 20 | cls: models.BASE, 21 | session: models.Session, 22 | filter_spec: List[Dict[str, Any]] = None, 23 | sort_spec: List[Dict[str, str]] = None, 24 | offset: types.NonNegativeInt = 0, 25 | limit: PositiveInt = None, 26 | options: Any = None 27 | ) -> List[dict]: 28 | """ Return all instances of cls """ 29 | query = session.query(cls) 30 | if filter_spec: 31 | query = apply_filters(query, filter_spec) 32 | if sort_spec: 33 | query = apply_sort(query, sort_spec) 34 | 35 | if options: 36 | query = query.options(options) 37 | 38 | if limit: 39 | query = query.limit(limit) 40 | query = query.offset(offset) 41 | 42 | def _list(): 43 | return [instance.as_dict() for instance in query.all()] 44 | 45 | return await run_in_threadpool(_list) 46 | 47 | 48 | async def count_instances( 49 | cls: models.BASE, 50 | session: models.Session, 51 | filter_spec: List[Dict[str, Any]] = None, 52 | sort_spec: List[Dict[str, Any]] = None, 53 | ) -> int: 54 | """ Total count of instances matching the given criteria """ 55 | query = session.query(cls) 56 | if filter_spec: 57 | query = apply_filters(query, filter_spec) 58 | if sort_spec: 59 | query = apply_sort(query, sort_spec) 60 | 61 | def _count(): 62 | return query.count() 63 | 64 | return await run_in_threadpool(_count) 65 | 66 | 67 | async def create_instance( 68 | cls: models.BASE, 69 | session: models.Session, 70 | data: BaseModel 71 | ) -> dict: 72 | """ Create an instances of cls with the provided data """ 73 | instance = cls(**data.dict()) 74 | 75 | def _create(): 76 | session.add(instance) 77 | session.commit() 78 | return session.merge(instance).as_dict() 79 | 80 | try: 81 | return await run_in_threadpool(_create) 82 | except sqlalchemy.exc.IntegrityError as ex: 83 | raise HTTPException(status_code=409, detail=str(ex.orig)) 84 | 85 | 86 | async def retrieve_instance( 87 | cls: models.BASE, 88 | session: models.Session, 89 | instance_id: UUID, 90 | options: Any = None 91 | ) -> dict: 92 | """ Get an instance of cls by UUID """ 93 | query = session.query(cls) 94 | 95 | if options: 96 | query = query.options(options) 97 | 98 | def _retrieve(): 99 | instance = query.get(instance_id) 100 | if instance: 101 | return instance.as_dict() 102 | return None 103 | 104 | data = await run_in_threadpool(_retrieve) 105 | if data is None: 106 | raise HTTPException(status_code=404) 107 | return data 108 | 109 | 110 | async def update_instance( 111 | cls: models.BASE, 112 | session: models.Session, 113 | instance_id: UUID, 114 | data: BaseModel) -> dict: 115 | """ Fully update an instances using the provided data """ 116 | 117 | def _update(): 118 | instance = session.query(cls).get(instance_id) 119 | if not instance: 120 | return None 121 | for key, value in data.dict().items(): 122 | setattr(instance, key, value) 123 | session.commit() 124 | return session.merge(instance).as_dict() 125 | 126 | data = await run_in_threadpool(_update) 127 | if data is None: 128 | raise HTTPException(status_code=404) 129 | return data 130 | 131 | 132 | async def delete_instance( 133 | cls: models.BASE, 134 | session: models.Session, 135 | instance_id: UUID 136 | ) -> dict: 137 | """ Delete an instance by UUID """ 138 | 139 | def _delete(): 140 | instance = session.query(cls).get(instance_id) 141 | if not instance: 142 | return None 143 | result = instance.as_dict() 144 | session.delete(instance) 145 | session.commit() 146 | return result 147 | 148 | data = await run_in_threadpool(_delete) 149 | if data is None: 150 | raise HTTPException(status_code=404) 151 | return data 152 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/login.py: -------------------------------------------------------------------------------- 1 | """ Login functionality """ 2 | import os 3 | import logging 4 | import inspect 5 | from typing import Any, Optional 6 | 7 | from starlette.responses import HTMLResponse, JSONResponse 8 | from starlette.concurrency import run_in_threadpool 9 | 10 | from fastapi_sqlalchemy import tz, models, utils 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class LoginEndpoint: 16 | """ Class-based endpoint for login """ 17 | 18 | DEFAULT_TEMPLATE = os.path.join( 19 | os.path.dirname(__file__), "templates", "login.html" 20 | ) 21 | 22 | def __init__( 23 | self, 24 | user_cls, 25 | secret, 26 | *, 27 | template: str = DEFAULT_TEMPLATE, 28 | error_status_code: int = 401, 29 | location: str = "/", 30 | token_expiry: int = 86400, # 24 hours 31 | secure: bool = True, 32 | cookie_name: str = "jwt", 33 | jwt_algorithm: str = "HS256", 34 | form_action: str = "/login", 35 | require_confirmation: bool = False, 36 | ): 37 | assert inspect.isclass(user_cls) 38 | self.secret = secret 39 | self.user_cls = user_cls 40 | self.template = template.strip() 41 | self.error_status_code = error_status_code 42 | self.location = location 43 | self.token_expiry = token_expiry 44 | self.secure = secure 45 | self.cookie_name = cookie_name 46 | self.jwt_algorithm = jwt_algorithm 47 | self.form_action = form_action 48 | self.require_confirmation = require_confirmation 49 | 50 | def render(self, **kwargs) -> str: 51 | """ Render the template using the passed parameters """ 52 | kwargs.setdefault("username", "") 53 | kwargs.setdefault("error", "") 54 | kwargs.setdefault("form_action", self.form_action) 55 | kwargs.setdefault("modal_title", "Login to your Account") 56 | kwargs.setdefault("title", "FastAPI-SQLAlchemy") 57 | 58 | return utils.render(self.template, **kwargs) 59 | 60 | async def jwt_encode(self, payload): 61 | """ Build the JWT """ 62 | assert "exp" in payload 63 | return utils.jwt_encode( 64 | payload, 65 | self.secret, 66 | algorithm=self.jwt_algorithm, 67 | ) 68 | 69 | async def payload(self, user_data): 70 | """ Determine the JWT contents (keep for sub-classes """ 71 | user_data.pop("password", None) 72 | return user_data 73 | 74 | async def authenticate( 75 | self, 76 | session: models.Session, 77 | username: str, 78 | password: str, 79 | **kwargs # pylint: disable=unused-argument 80 | ) -> Optional[dict]: 81 | """ Perform authentication against database """ 82 | 83 | def _get_by_username(): 84 | return self.user_cls.get_by_username(session, username) 85 | 86 | user = await run_in_threadpool(_get_by_username) 87 | if not user: 88 | logger.info("Invalid user '%s'", username) 89 | return None 90 | 91 | if not user.verify(password): 92 | logger.info("Invalid password for user '%s'", user.username) 93 | return None 94 | 95 | logger.info("Authenticated user '%s'", user.username) 96 | return user.as_dict() 97 | 98 | async def on_get(self) -> HTMLResponse: 99 | """ Handle GET requests """ 100 | html = await run_in_threadpool(self.render) 101 | return HTMLResponse(content=html, status_code=200) 102 | 103 | async def on_post( 104 | self, 105 | session: models.Session, 106 | username: str, 107 | password: str, 108 | location: str = None, 109 | **kwargs 110 | ) -> Any: 111 | """ Handle POST requests """ 112 | user_data = await self.authenticate( 113 | session, 114 | username=username, 115 | password=password, 116 | **kwargs 117 | ) 118 | if not user_data: 119 | # ref: OWASP 120 | error = "Login failed; Invalid userID or password" 121 | html = await run_in_threadpool( 122 | self.render, username=username, error=error 123 | ) 124 | return HTMLResponse( 125 | content=html, 126 | status_code=self.error_status_code 127 | ) 128 | 129 | result = await self.payload(user_data) 130 | 131 | expiry = tz.utcnow() + tz.timedelta(seconds=self.token_expiry) 132 | 133 | # jwt_encode will convert this to an epoch inside the token 134 | result["exp"] = expiry 135 | 136 | token = await self.jwt_encode(result) 137 | result["token"] = token 138 | result["exp"] = expiry.isoformat() 139 | 140 | headers = {"location": location or self.location} 141 | response = JSONResponse( 142 | content=result, 143 | status_code=303, 144 | headers=headers 145 | ) 146 | response.set_cookie( 147 | self.cookie_name, token, 148 | path="/", expires=int(expiry.timestamp()), secure=self.secure 149 | ) 150 | return response 151 | -------------------------------------------------------------------------------- /tests/test_crud.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | import sqlalchemy.exc 5 | 6 | from pydantic import PositiveInt 7 | from starlette.exceptions import HTTPException 8 | 9 | from fastapi_sqlalchemy import crud 10 | from fastapi_sqlalchemy.types import NonNegativeInt 11 | 12 | from tests.data.people import ( 13 | load_people, Person, PersonRequestModel, PEOPLE_DATA 14 | ) 15 | 16 | 17 | @pytest.fixture(name="mock_sqlalchemy_filters") 18 | def fixture_mock_sqlalchemy_filters(mocker): 19 | def __query(query, *_args, **__kwargs): 20 | return query 21 | 22 | apply_filters = mocker.patch( 23 | "fastapi_sqlalchemy.crud.apply_filters", side_effect=__query) 24 | apply_sort = mocker.patch( 25 | "fastapi_sqlalchemy.crud.apply_sort", side_effect=__query) 26 | return apply_filters, apply_sort 27 | 28 | 29 | def test_crud_list(session, loop): 30 | expected = [person.as_dict() for person in load_people(session)] 31 | actual = loop.run_until_complete( 32 | crud.list_instances(Person, session) 33 | ) 34 | assert expected == actual 35 | 36 | 37 | def test_crud_list_query(mocker, loop, mock_sqlalchemy_filters): 38 | def __query(*_args, **__kwargs): 39 | return mock_query 40 | 41 | apply_filters, apply_sort = mock_sqlalchemy_filters 42 | 43 | mock_query = mocker.Mock() 44 | mock_query.options = mocker.Mock(side_effect=__query) 45 | mock_query.offset = mocker.Mock(side_effect=__query) 46 | mock_query.limit = mocker.Mock(side_effect=__query) 47 | mock_query.all = mocker.Mock(return_value=[]) 48 | 49 | session = mocker.Mock() 50 | session.query = mocker.Mock(return_value=mock_query) 51 | 52 | filter_spec = [{"filter1": "value1"}] 53 | sort_spec = [{"sort1": "value1"}] 54 | options = ("option1", "option2") 55 | offset = NonNegativeInt(10) 56 | limit = PositiveInt(50) 57 | 58 | loop.run_until_complete(crud.list_instances( 59 | Person, session, filter_spec, sort_spec, offset, limit, options 60 | )) 61 | assert apply_filters.call_args == mocker.call(mock_query, filter_spec) 62 | assert apply_sort.call_args == mocker.call(mock_query, sort_spec) 63 | assert mock_query.options.call_args == mocker.call(options) 64 | assert mock_query.offset.call_args == mocker.call(offset) 65 | assert mock_query.limit.call_args == mocker.call(limit) 66 | 67 | 68 | def test_crud_count(session, loop): 69 | data = load_people(session) 70 | actual = loop.run_until_complete( 71 | crud.count_instances(Person, session) 72 | ) 73 | assert len(data) == actual 74 | 75 | 76 | def test_crud_count_query(mocker, loop, mock_sqlalchemy_filters): 77 | apply_filters, apply_sort = mock_sqlalchemy_filters 78 | 79 | mock_query = mocker.Mock() 80 | session = mocker.Mock(query=mocker.Mock(return_value=mock_query)) 81 | 82 | filter_spec = [{"filter1": "value1"}] 83 | sort_spec = [{"sort1": "value1"}] 84 | 85 | loop.run_until_complete( 86 | crud.count_instances(Person, session, filter_spec, sort_spec) 87 | ) 88 | assert apply_filters.call_args == mocker.call(mock_query, filter_spec) 89 | assert apply_sort.call_args == mocker.call(mock_query, sort_spec) 90 | 91 | 92 | def test_crud_create(session, loop): 93 | result = loop.run_until_complete( 94 | crud.create_instance( 95 | Person, session, PersonRequestModel(**PEOPLE_DATA[0]) 96 | ) 97 | ) 98 | for key in ("id", "updated_at", "created_at"): 99 | assert result.pop(key) 100 | assert result == PEOPLE_DATA[0] 101 | 102 | 103 | def test_crud_create_409(mocker, loop): 104 | exc = sqlalchemy.exc.IntegrityError( 105 | statement="fake statement", 106 | params={}, 107 | orig=Person 108 | ) 109 | 110 | session = mocker.Mock() 111 | session.commit = mocker.Mock(side_effect=exc) 112 | 113 | with pytest.raises(HTTPException) as exc_info: 114 | loop.run_until_complete( 115 | crud.create_instance( 116 | Person, session, PersonRequestModel(**PEOPLE_DATA[0]) 117 | ) 118 | ) 119 | 120 | assert exc_info.value.status_code == 409 121 | assert exc_info.value.detail == str(exc.orig) 122 | 123 | 124 | def test_crud_retrieve(session, loop): 125 | person = Person(**PEOPLE_DATA[0]) 126 | session.add(person) 127 | session.commit() 128 | person = session.merge(person) 129 | 130 | result = loop.run_until_complete( 131 | crud.retrieve_instance(Person, session, person.id) 132 | ) 133 | assert result == person.as_dict() 134 | 135 | 136 | def test_crud_retrieve_404(session, loop): 137 | with pytest.raises(HTTPException) as exc_info: 138 | loop.run_until_complete( 139 | crud.retrieve_instance(Person, session, uuid.uuid4()) 140 | ) 141 | 142 | assert exc_info.value.status_code == 404 143 | 144 | 145 | def test_crud_retrieve_query(mocker, loop): 146 | def __query(*_args, **__kwargs): 147 | return mock_query 148 | 149 | mock_query = mocker.Mock() 150 | mock_query.options = mocker.Mock(side_effect=__query) 151 | 152 | session = mocker.Mock() 153 | session.query = mocker.Mock(return_value=mock_query) 154 | 155 | options = ("option1", "option2") 156 | 157 | loop.run_until_complete( 158 | crud.retrieve_instance(Person, session, uuid.uuid4(), options) 159 | ) 160 | assert mock_query.options.call_args == mocker.call(options) 161 | 162 | 163 | def test_crud_update(session, loop): 164 | person = Person(**PEOPLE_DATA[0]) 165 | session.add(person) 166 | session.commit() 167 | person = session.merge(person) 168 | 169 | assert person.name == "alice" 170 | 171 | data = person.as_dict() 172 | data["name"] = "edith" 173 | 174 | result = loop.run_until_complete( 175 | crud.update_instance( 176 | Person, session, person.id, PersonRequestModel(**data) 177 | ) 178 | ) 179 | assert result["name"] == "edith" 180 | 181 | session.refresh(person) 182 | assert person.name == "edith" 183 | 184 | 185 | def test_crud_update_404(session, loop): 186 | with pytest.raises(HTTPException) as exc_info: 187 | loop.run_until_complete( 188 | crud.update_instance( 189 | Person, 190 | session, 191 | uuid.uuid4(), 192 | PersonRequestModel(**PEOPLE_DATA[0]) 193 | ) 194 | ) 195 | 196 | assert exc_info.value.status_code == 404 197 | 198 | 199 | def test_crud_delete(session, loop): 200 | person = Person(**PEOPLE_DATA[0]) 201 | session.add(person) 202 | session.commit() 203 | person = session.merge(person) 204 | 205 | result = loop.run_until_complete( 206 | crud.delete_instance(Person, session, person.id) 207 | ) 208 | assert person.as_dict() == result 209 | 210 | assert session.query(Person).get(person.id) is None 211 | 212 | 213 | def test_crud_delete_404(session, loop): 214 | with pytest.raises(HTTPException) as exc_info: 215 | loop.run_until_complete( 216 | crud.delete_instance(Person, session, uuid.uuid4()) 217 | ) 218 | assert exc_info.value.status_code == 404 219 | -------------------------------------------------------------------------------- /tests/models/test_models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime, date 3 | from enum import Enum 4 | 5 | import pytest 6 | import sqlalchemy.orm 7 | import sqlalchemy.ext.declarative 8 | 9 | from fastapi_sqlalchemy import models 10 | from fastapi_sqlalchemy.models import base 11 | from fastapi_sqlalchemy.models.base import MODEL_MAPPING 12 | 13 | from tests.data.models import User, Group, Permission 14 | 15 | models.create_group_membership_table() 16 | models.create_user_permissions_table() 17 | models.create_group_permissions_table() 18 | 19 | 20 | def _create_all(session): 21 | 22 | alice = User(username="alice") 23 | assert MODEL_MAPPING["User"] is User 24 | session.add(alice) 25 | 26 | users_group = Group(name="users") 27 | assert MODEL_MAPPING["Group"] is Group 28 | session.add(users_group) 29 | 30 | admins_group = Group(name="admins") 31 | session.add(admins_group) 32 | 33 | read_permission = Permission(name="READ") 34 | assert MODEL_MAPPING["Permission"] is Permission 35 | session.add(read_permission) 36 | 37 | write_permission = Permission(name="WRITE") 38 | session.add(write_permission) 39 | 40 | session.commit() 41 | 42 | 43 | def test_groups(engine, session): 44 | _create_all(session) 45 | 46 | alice = User.get_by_username(session, "alice") 47 | alice.groups.append(Group.get_by_name(session, "users")) 48 | alice.groups.append(Group.get_by_name(session, "admins")) 49 | session.commit() 50 | 51 | alice = User.get_by_username(session, "alice") 52 | assert ["admins", "users"] == sorted( 53 | group.name for group in alice.groups 54 | ) 55 | 56 | 57 | def test_user_case(engine, session): 58 | _create_all(session) 59 | 60 | user = User(username="Bob") 61 | session.add(user) 62 | session.commit() 63 | 64 | assert User.get_by_username(session, "BOB") is user 65 | assert user.identity == str(user.id) 66 | 67 | 68 | def test_user_permissions(engine, session): 69 | _create_all(session) 70 | 71 | alice = User.get_by_username(session, "alice") 72 | alice.user_permissions.append(Permission.get_by_name(session, "READ")) 73 | alice.user_permissions.append(Permission.get_by_name(session, "WRITE")) 74 | session.commit() 75 | 76 | alice = User.get_by_username(session, "alice") 77 | assert ["READ", "WRITE"] == sorted( 78 | permission.name for permission in alice.user_permissions 79 | ) 80 | 81 | 82 | def test_group_permissions(engine, session): 83 | _create_all(session) 84 | 85 | admins = Group.get_by_name(session, "admins") 86 | admins.permissions.append(Permission.get_by_name(session, "READ")) 87 | admins.permissions.append(Permission.get_by_name(session, "WRITE")) 88 | session.commit() 89 | 90 | admins = Group.get_by_name(session, "admins") 91 | assert ["READ", "WRITE"] == sorted( 92 | permission.name for permission in admins.permissions 93 | ) 94 | 95 | 96 | def test_permissions(engine, session): 97 | _create_all(session) 98 | 99 | alice = User.get_by_username(session, "alice") 100 | 101 | admins = Group.get_by_name(session, "admins") 102 | admins.users.append(alice) 103 | 104 | write_permission = Permission.get_by_name(session, "WRITE") 105 | write_permission.groups.append(admins) 106 | 107 | alice.user_permissions.append(Permission.get_by_name(session, "READ")) 108 | session.commit() 109 | 110 | alice = User.get_by_username(session, "alice") 111 | assert ["READ", "WRITE"] == sorted( 112 | permission.name for permission in alice.permissions 113 | ) 114 | 115 | 116 | def test_permission_duplicate(engine, session): 117 | _create_all(session) 118 | 119 | alice = User.get_by_username(session, "alice") 120 | 121 | admins = Group.get_by_name(session, "admins") 122 | admins.users.append(alice) 123 | 124 | read_permission = Permission.get_by_name(session, "READ") 125 | 126 | admins.permissions.append(Permission.get_by_name(session, "WRITE")) 127 | admins.permissions.append(read_permission) 128 | 129 | alice.user_permissions.append(read_permission) 130 | session.commit() 131 | 132 | alice = User.get_by_username(session, "alice") 133 | assert ["READ", "WRITE"] == sorted( 134 | permission.name for permission in alice.permissions 135 | ) 136 | 137 | 138 | def test_model_mapping_update(): 139 | mapping = base.ModelMapping() 140 | mapping.update({ 141 | "key1": "value1" 142 | }, key2="value2") 143 | mapping.update([ 144 | ("key3", "value3") 145 | ]) 146 | assert mapping == { 147 | "key1": "value1", 148 | "key2": "value2", 149 | "key3": "value3" 150 | } 151 | 152 | 153 | def test_model_mapping_duplicate_key_error(): 154 | key = "key" 155 | mapping = base.ModelMapping() 156 | mapping[key] = "value-1" 157 | 158 | with pytest.raises(RuntimeError) as exc_info: 159 | mapping[key] = "value-2" 160 | 161 | expected_msg = f"Duplicate '{key}' model found. " \ 162 | f"There may only be one non-abstract sub-class." 163 | assert str(exc_info.value) == expected_msg 164 | 165 | 166 | def test_model_to_dict(): 167 | base_cls = sqlalchemy.ext.declarative.declarative_base() 168 | 169 | class FoodEnum(str, Enum): 170 | pizza = "pizza" 171 | pasta = "pasta" 172 | 173 | class TestRelation(base_cls): 174 | __tablename__ = "test_relation" 175 | 176 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 177 | dict_id = sqlalchemy.Column( 178 | sqlalchemy.Integer, 179 | sqlalchemy.ForeignKey("test_model.id") 180 | ) 181 | 182 | class TestModel(base_cls): 183 | __tablename__ = "test_model" 184 | 185 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 186 | dt = sqlalchemy.Column(sqlalchemy.DateTime) 187 | date = sqlalchemy.Column(sqlalchemy.Date) 188 | uuid = sqlalchemy.Column(models.GUID) 189 | enum = sqlalchemy.Column(sqlalchemy.Enum(FoodEnum)) 190 | 191 | items = sqlalchemy.orm.relationship(TestRelation) 192 | 193 | model = TestModel( 194 | id=1, 195 | dt=datetime.now(), 196 | date=date.today(), 197 | uuid=uuid.uuid4(), 198 | enum=FoodEnum.pasta, 199 | items=[ 200 | TestRelation(id=1), 201 | TestRelation(id=2), 202 | ] 203 | ) 204 | assert base.model_as_dict(model) == { 205 | "id": model.id, 206 | "dt": model.dt.isoformat(), 207 | "date": model.date.isoformat(), 208 | "uuid": str(model.uuid), 209 | "enum": model.enum.name, 210 | # Relations like "items" are not included 211 | } 212 | 213 | 214 | def test_base_as_dict(mocker): 215 | base_cls = sqlalchemy.ext.declarative.declarative_base(cls=base.Base) 216 | 217 | class TestModel(base_cls): 218 | __tablename__ = "test_base_as_dict" 219 | 220 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 221 | 222 | model = TestModel(id=1) 223 | 224 | mock_model_as_dict = mocker.patch( 225 | "fastapi_sqlalchemy.models.base.model_as_dict") 226 | result = model.as_dict() 227 | assert mock_model_as_dict.call_args == mocker.call(model) 228 | assert result is mock_model_as_dict.return_value 229 | 230 | 231 | def test_user_password(): 232 | password = "my_secret" 233 | 234 | user = User(username="user1") 235 | user.password = password 236 | assert user.verify(password) is True 237 | 238 | user.hashed_password = None 239 | assert user.verify(password) is False 240 | 241 | with pytest.raises(RuntimeError) as exc_info: 242 | assert not user.password 243 | 244 | assert str(exc_info.value) == "Invalid access: get password not allowed" 245 | -------------------------------------------------------------------------------- /tests/test_auth.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi import Depends, Security 3 | from starlette.authentication import SimpleUser 4 | from starlette.datastructures import State 5 | from starlette.requests import Request 6 | from starlette.middleware.authentication import AuthenticationMiddleware 7 | 8 | from fastapi_sqlalchemy import auth, middleware, models 9 | 10 | 11 | class User( 12 | models.User, models.mixins.DictMixin 13 | ): 14 | __tablename__ = "test_middleware_users" 15 | __model_mapping__ = False 16 | 17 | 18 | def test_auth(session, app, client): 19 | user = User(username="testuser") 20 | session.add(user) 21 | session.commit() 22 | 23 | @app.get("/ping") 24 | def _ping(): 25 | return "pong" 26 | 27 | @app.get( 28 | "/me", 29 | dependencies=[ 30 | Depends(auth.validate_authenticated), 31 | ] 32 | ) 33 | async def _get(request: Request): 34 | return { 35 | "user": request.user.as_dict(), 36 | "scopes": request.auth.scopes 37 | } 38 | 39 | @app.post( 40 | "/me/scopes", 41 | dependencies=[ 42 | Depends(auth.validate_authenticated), 43 | Security(auth.validate_all_scopes, scopes=["test-scope", "write"]) 44 | ] 45 | ) 46 | def _post(request: Request, ): 47 | return { 48 | "user": request.user.as_dict(), 49 | "scopes": request.auth.scopes 50 | } 51 | 52 | app.add_middleware( 53 | AuthenticationMiddleware, backend=auth.PayloadAuthBackend(user_cls=User) 54 | ) 55 | app.add_middleware(middleware.UpstreamPayloadMiddleware) 56 | app.add_middleware(middleware.SessionMiddleware, bind=session.bind) 57 | 58 | # Test /ping - no auth & authz required 59 | res = client.get("/ping") 60 | assert res.status_code == 200, res.text 61 | assert res.text == '"pong"' 62 | 63 | # Test /me - auth required, no authz 64 | res = client.get("/me", headers={ 65 | "X-Payload-username": user.username 66 | }) 67 | assert res.status_code == 200, res.text 68 | assert res.json() == { 69 | "user": user.as_dict(), 70 | "scopes": [] 71 | } 72 | 73 | # Test /me/scopes - auth & authz required 74 | res = client.post("/me/scopes", headers={ 75 | "X-Payload-username": user.username, 76 | "X-Payload-permissions": "test-scope,write" 77 | }) 78 | assert res.status_code == 200, res.text 79 | assert res.json() == { 80 | "user": user.as_dict(), 81 | "scopes": ["test-scope", "write"] 82 | } 83 | 84 | 85 | def test_auth_not_authenticated(session, app, client): 86 | 87 | @app.get( 88 | "/me", 89 | dependencies=[Depends(auth.validate_authenticated)] 90 | ) 91 | async def _get(): 92 | return {} 93 | 94 | app.add_middleware( 95 | AuthenticationMiddleware, backend=auth.PayloadAuthBackend(user_cls=User) 96 | ) 97 | app.add_middleware(middleware.UpstreamPayloadMiddleware) 98 | app.add_middleware(middleware.SessionMiddleware, bind=session.bind) 99 | 100 | payload_prefix = middleware.UpstreamPayloadMiddleware.PAYLOAD_HEADER_PREFIX 101 | 102 | res = client.get("/me", headers={ 103 | f"{payload_prefix}username": "nonexistent_user" 104 | }) 105 | assert res.status_code == 401 106 | 107 | 108 | def test_auth_not_authorized(session, app, client): 109 | user = User(username="testuser") 110 | session.add(user) 111 | session.commit() 112 | 113 | @app.get( 114 | "/me", 115 | dependencies=[ 116 | Depends(auth.validate_authenticated), 117 | Security(auth.validate_all_scopes, scopes=["me"]) 118 | ] 119 | ) 120 | async def _get(): 121 | return {} 122 | 123 | app.add_middleware( 124 | AuthenticationMiddleware, backend=auth.PayloadAuthBackend(user_cls=User) 125 | ) 126 | app.add_middleware(middleware.UpstreamPayloadMiddleware) 127 | app.add_middleware(middleware.SessionMiddleware, bind=session.bind) 128 | 129 | payload_prefix = middleware.UpstreamPayloadMiddleware.PAYLOAD_HEADER_PREFIX 130 | 131 | res = client.get("/me", headers={ 132 | f"{payload_prefix}username": user.username, 133 | f"{payload_prefix}permissions": "read,write" # no "me" scope 134 | }) 135 | assert res.status_code == 403 136 | 137 | 138 | def test_auth_any_scope(session, app, client): 139 | user = User(username="testuser") 140 | session.add(user) 141 | session.commit() 142 | 143 | @app.get( 144 | "/me", 145 | dependencies=[ 146 | Depends(auth.validate_authenticated), 147 | Security(auth.validate_any_scope, scopes=["admin", "read"]) 148 | ] 149 | ) 150 | async def _get(): 151 | return {} 152 | 153 | app.add_middleware( 154 | AuthenticationMiddleware, backend=auth.PayloadAuthBackend(user_cls=User) 155 | ) 156 | app.add_middleware(middleware.UpstreamPayloadMiddleware) 157 | app.add_middleware(middleware.SessionMiddleware, bind=session.bind) 158 | 159 | payload_prefix = middleware.UpstreamPayloadMiddleware.PAYLOAD_HEADER_PREFIX 160 | 161 | # Test request with only "admin" scope 162 | res = client.get("/me", headers={ 163 | f"{payload_prefix}username": user.username, 164 | f"{payload_prefix}permissions": "admin" 165 | }) 166 | assert res.status_code == 200 167 | 168 | # Test request with only "read" scope 169 | res = client.get("/me", headers={ 170 | f"{payload_prefix}username": user.username, 171 | f"{payload_prefix}permissions": "read" 172 | }) 173 | assert res.status_code == 200 174 | 175 | # Test request without any scope 176 | res = client.get("/me", headers={ 177 | f"{payload_prefix}username": user.username, 178 | f"{payload_prefix}permissions": "" 179 | }) 180 | assert res.status_code == 403 181 | 182 | 183 | def test_payload_auth_backend(): 184 | backend = auth.PayloadAuthBackend(user_cls=User, admin_scope="admin") 185 | assert backend.user_cls is User 186 | assert backend.admin_scope == "admin" 187 | 188 | 189 | def test_payload_auth_no_user_cls(mocker, loop): 190 | expected_username = "user1" 191 | 192 | backend = auth.PayloadAuthBackend() 193 | 194 | mock_conn = mocker.Mock(state=State({ 195 | "payload": {"username": expected_username} 196 | })) 197 | result = loop.run_until_complete(backend.authenticate(mock_conn)) 198 | assert result is not None 199 | assert len(result) == 2 200 | assert isinstance(result[1], SimpleUser) 201 | assert result[1].username == expected_username 202 | 203 | 204 | def test_payload_auth_backend_scopes(loop): 205 | backend = auth.PayloadAuthBackend(admin_scope="admin") 206 | 207 | result = loop.run_until_complete(backend.scopes({ 208 | "scopes": "read, write" 209 | })) 210 | assert result == ["read", "write"] 211 | 212 | result = loop.run_until_complete(backend.scopes({ 213 | "permissions": "read, write" 214 | })) 215 | assert result == ["read", "write"] 216 | 217 | result = loop.run_until_complete(backend.scopes({})) 218 | assert result == [] 219 | 220 | result = loop.run_until_complete(backend.scopes({ 221 | "permissions": "admin" 222 | })) 223 | assert result == ["admin"] 224 | 225 | 226 | def test_payload_auth_backend_missing_payload_error(mocker, loop): 227 | backend = auth.PayloadAuthBackend() 228 | 229 | mock_conn = mocker.Mock(state=State()) 230 | with pytest.raises(RuntimeError) as exc_info: 231 | loop.run_until_complete(backend.authenticate(mock_conn)) 232 | 233 | error_msg = "Missing 'request.state.payload': " \ 234 | "try adding 'middleware.UpstreamPayloadMiddleware'" 235 | assert str(exc_info.value) == error_msg 236 | 237 | 238 | def test_payload_auth_backend_missing_session_error(mocker, loop): 239 | backend = auth.PayloadAuthBackend(user_cls=User) 240 | 241 | mock_conn = mocker.Mock(state=State({ 242 | "payload": {"username": "user1"} 243 | })) 244 | with pytest.raises(RuntimeError) as exc_info: 245 | loop.run_until_complete(backend.authenticate(mock_conn)) 246 | 247 | error_msg = "Missing 'request.state.session': " \ 248 | "try adding 'middleware.SessionMiddleware'" 249 | assert str(exc_info.value) == error_msg 250 | -------------------------------------------------------------------------------- /fastapi_sqlalchemy/endpoints/register.py: -------------------------------------------------------------------------------- 1 | """ Registration functionality """ 2 | import os 3 | import logging 4 | import inspect 5 | 6 | import ssl 7 | import smtplib 8 | from email.message import EmailMessage 9 | 10 | import sqlalchemy.exc 11 | from itsdangerous import URLSafeTimedSerializer 12 | from pydantic import EmailStr 13 | 14 | from starlette.responses import HTMLResponse 15 | from starlette.concurrency import run_in_threadpool 16 | 17 | from fastapi_sqlalchemy import models, utils 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class RegisterEndpoint: 23 | """ Class-based endpoint for registration with confirmation """ 24 | 25 | FORM_TEMPLATE = os.path.join( 26 | os.path.dirname(__file__), "templates", "register.html" 27 | ) 28 | 29 | CONFIRMATION_HTML_TEMPLATE = os.path.join( 30 | os.path.dirname(__file__), "templates", "confirmation_email.html" 31 | ) 32 | 33 | CONFIRMATION_TEXT_TEMPLATE = os.path.join( 34 | os.path.dirname(__file__), "templates", "confirmation_email.txt" 35 | ) 36 | 37 | SENT_TEMPLATE = os.path.join( 38 | os.path.dirname(__file__), "templates", "send_confirmation.html" 39 | ) 40 | 41 | def __init__( 42 | self, 43 | user_cls, 44 | secret, 45 | sender, 46 | *, 47 | form_template: str = FORM_TEMPLATE, 48 | confirmation_html_template: str = CONFIRMATION_HTML_TEMPLATE, 49 | confirmation_text_template: str = CONFIRMATION_TEXT_TEMPLATE, 50 | sent_template: str = SENT_TEMPLATE, 51 | form_action: str = "/register", 52 | salt: str = None, 53 | email_subject: str = "Email confirmation", 54 | email_server: str = "localhost", # local smtp server 55 | email_port: str = 0, # use default 56 | email_use_ssl: bool = False, 57 | email_use_tls: bool = False, 58 | email_login: str = None, 59 | email_password: str = None, 60 | confirm_url: str = "/confirm" 61 | ): 62 | # pylint: disable=too-many-locals 63 | assert inspect.isclass(user_cls) 64 | self.user_cls = user_cls 65 | self.secret = secret 66 | self.sender = sender 67 | 68 | self.form_template = form_template 69 | self.confirmation_html_template = confirmation_html_template 70 | self.confirmation_text_template = confirmation_text_template 71 | self.sent_template = sent_template 72 | 73 | self.form_action = form_action 74 | self.salt = salt 75 | 76 | self.email_subject = email_subject 77 | self.email_server = email_server 78 | self.email_port = email_port 79 | self.email_use_ssl = email_use_ssl 80 | self.email_use_tls = email_use_tls 81 | self.email_login = email_login 82 | self.email_password = email_password 83 | 84 | if not confirm_url.endswith("/"): 85 | confirm_url += "/" 86 | self.confirm_url = confirm_url 87 | 88 | @staticmethod 89 | def render( 90 | path_or_template: str, 91 | **kwargs 92 | ) -> str: 93 | """ Render the template using the passed parameters """ 94 | kwargs.setdefault("error", "") 95 | kwargs.setdefault("title", "FastAPI-SQLAlchemy") 96 | kwargs.setdefault("modal_title", "Register") 97 | kwargs.setdefault("username", "") 98 | kwargs.setdefault("email", "") 99 | 100 | return utils.render(path_or_template, **kwargs) 101 | 102 | def render_form( 103 | self, 104 | **kwargs 105 | ) -> str: 106 | """ Render the registration form """ 107 | kwargs["form_action"] = self.form_action 108 | return self.render(self.form_template, **kwargs) 109 | 110 | async def on_get(self) -> HTMLResponse: 111 | """ Handle GET requests """ 112 | html = await run_in_threadpool(self.render_form) 113 | return HTMLResponse(content=html, status_code=200) 114 | 115 | def _confirmation_token(self, email): 116 | serializer = URLSafeTimedSerializer(self.secret) 117 | return serializer.dumps(email, salt=self.salt) 118 | 119 | @staticmethod 120 | def validate_password(password): 121 | """ Validate the password format is acceptable """ 122 | if password and len(password) >= 7: 123 | return None 124 | raise ValueError( 125 | "Invalid password - the password must be at least 7 characters." 126 | ) 127 | 128 | def send_message( 129 | self, 130 | msg: EmailMessage 131 | ) -> None: 132 | """ Delivery the email message """ 133 | if self.email_use_ssl: 134 | smtp = smtplib.SMTP_SSL(self.email_server, self.email_port) 135 | else: 136 | smtp = smtplib.SMTP(self.email_server, self.email_port) 137 | 138 | if self.email_use_tls: 139 | context = ssl.create_default_context() 140 | smtp.starttls(context=context) 141 | 142 | if self.email_login: 143 | smtp.login(self.email_login, self.email_password) 144 | 145 | smtp.send_message(msg) 146 | smtp.close() 147 | 148 | def send_email_confirmation(self, base_url, email, **kwargs): 149 | """ Send the email with a confirmation link """ 150 | logger.info("Sending email to %s from %s", email, self.sender) 151 | 152 | msg = EmailMessage() 153 | msg["Subject"] = self.email_subject 154 | msg["From"] = self.sender 155 | msg["To"] = [email] 156 | 157 | confirm_url = self.confirm_url + self._confirmation_token(email) 158 | if not confirm_url.startswith("http"): 159 | # Assume relative URL 160 | confirm_url = base_url + confirm_url 161 | 162 | data = {**{ 163 | "email": email, 164 | "sender": self.sender, 165 | "subject": self.email_subject, 166 | "base_url": base_url, 167 | "confirm_url": confirm_url, 168 | }, **kwargs} 169 | 170 | msg.set_content( 171 | self.render(self.confirmation_text_template, **data) 172 | ) 173 | msg.add_alternative( 174 | self.render(self.confirmation_html_template, **data), 175 | subtype="html" 176 | ) 177 | 178 | try: 179 | self.send_message(msg) 180 | logger.info( 181 | "Email sent to %s with confirm URL: %s", email, confirm_url 182 | ) 183 | # pylint: disable=bare-except 184 | except: # noqa 185 | logging.exception("EMAIL FAILED TO SEND: %s", email) 186 | # pylint: enable=bare-except 187 | 188 | async def on_post( 189 | self, 190 | base_url: str, 191 | session: models.Session, 192 | username: str, 193 | email: str, 194 | password: str, 195 | confirm_password: str = None, 196 | **kwargs 197 | ) -> HTMLResponse: 198 | """ Handle POST requests """ 199 | 200 | email = EmailStr.validate(email) 201 | 202 | def _register() -> (int, str): 203 | 204 | try: 205 | self.validate_password(password) 206 | except ValueError as ex: 207 | return 400, self.render_form(error=str(ex), **kwargs) 208 | 209 | if confirm_password is not None: 210 | # The only way for confirm_password to be None is if a 211 | # standard form doesn't include it, otherwise the value is "" 212 | if password != confirm_password: 213 | return 400, self.render_form( 214 | error="The specified passwords do not match.", 215 | **kwargs 216 | ) 217 | 218 | user = self.user_cls.get_by_email(session, email) 219 | if user: 220 | if user.username != username: 221 | logger.info("Email '%s' already exists.", email) 222 | # Is this an information leak? 223 | return 409, self.render_form( 224 | error="That email address already exists.", **kwargs 225 | ) 226 | if user.confirmed: 227 | logger.info("User '%s' already confirmed.", username) 228 | else: 229 | # Unconfirmed, re-registration - update password 230 | user.password = password 231 | else: 232 | user = self.user_cls( 233 | username=username, 234 | password=password, 235 | email=email, 236 | ) 237 | session.add(user) 238 | 239 | try: 240 | session.commit() 241 | except sqlalchemy.exc.IntegrityError: 242 | # pragma: no-cover 243 | return 409, self.render_form( 244 | error="That username already exists.", **kwargs 245 | ) 246 | 247 | self.send_email_confirmation( 248 | base_url, email, username=username, **kwargs 249 | ) 250 | 251 | return 200, self.render( 252 | self.sent_template, 253 | username=username, email=email, **kwargs 254 | ) 255 | 256 | status_code, content = await run_in_threadpool(_register) 257 | return HTMLResponse(status_code=status_code, content=content) 258 | -------------------------------------------------------------------------------- /tests/endpoints/test_register.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import fastapi 4 | from itsdangerous import URLSafeTimedSerializer 5 | 6 | from fastapi_sqlalchemy import endpoints, models 7 | 8 | SENDER = "user@example.com" 9 | BASE_URL = "http://localhost" 10 | 11 | 12 | class User( 13 | models.User, models.mixins.ConfirmationMixin, models.mixins.DictMixin 14 | ): 15 | __tablename__ = "test_register_users" 16 | __model_mapping__ = False 17 | 18 | 19 | def test_register_get(session, app, client): 20 | endpoint = endpoints.RegisterEndpoint( 21 | User, secret="s0secret", sender=SENDER, 22 | form_template="${title}", 23 | ) 24 | 25 | @app.get("/register") 26 | async def _get(): 27 | return await endpoint.on_get() 28 | 29 | res = client.get("/register") 30 | assert res.status_code == 200 31 | assert res.text == "FastAPI-SQLAlchemy" 32 | 33 | 34 | def test_register_post(session, app, client, mocker): 35 | endpoint = endpoints.RegisterEndpoint( 36 | User, secret="s0secret", sender=SENDER, 37 | sent_template="${email}", 38 | ) 39 | mocker.patch.object(endpoint, "send_email_confirmation") 40 | 41 | @app.post("/register") 42 | async def _post( 43 | username: str = fastapi.Form(None), 44 | password: str = fastapi.Form(None), 45 | email: str = fastapi.Form(None), 46 | ): 47 | return await endpoint.on_post( 48 | BASE_URL, session, 49 | username=username, password=password, email=email 50 | ) 51 | 52 | recipient = "recipient@example.com" 53 | res = client.post("/register", data={ 54 | "username": "testuser", 55 | "password": "passw0rd", 56 | "email": recipient 57 | }) 58 | assert res.status_code == 200 59 | assert res.text == f"{recipient}" 60 | 61 | 62 | def test_register_post_no_confirmation(session, app, client, mocker): 63 | endpoint = endpoints.RegisterEndpoint( 64 | User, secret="s0secret", sender=SENDER, 65 | form_template="${error}" 66 | ) 67 | mocker.patch.object(endpoint, "send_email_confirmation") 68 | 69 | @app.post("/register") 70 | async def _post( 71 | username: str = fastapi.Form(None), 72 | password: str = fastapi.Form(None), 73 | confirm_password: str = fastapi.Form(None), 74 | email: str = fastapi.Form(None), 75 | ): 76 | return await endpoint.on_post( 77 | BASE_URL, session, 78 | username=username, email=email, 79 | password=password, confirm_password=confirm_password 80 | ) 81 | 82 | res = client.post("/register", data={ 83 | "username": "testuser", 84 | "password": "passw0rd", 85 | "confirm_password": "other", 86 | "email": "recipient@example.com", 87 | }) 88 | assert res.status_code == 400 89 | assert res.text == "The specified passwords do not match." 90 | 91 | 92 | def test_register_post_password_too_small(session, app, client, mocker): 93 | endpoint = endpoints.RegisterEndpoint( 94 | User, secret="s0secret", sender=SENDER, 95 | form_template="${error}" 96 | ) 97 | mocker.patch.object(endpoint, "send_email_confirmation") 98 | 99 | @app.post("/register") 100 | async def _post( 101 | username: str = fastapi.Form(None), 102 | password: str = fastapi.Form(None), 103 | confirm_password: str = fastapi.Form(None), 104 | email: str = fastapi.Form(None), 105 | ): 106 | return await endpoint.on_post( 107 | BASE_URL, session, 108 | username=username, email=email, 109 | password=password, confirm_password=confirm_password 110 | ) 111 | 112 | res = client.post("/register", data={ 113 | "username": "testuser", 114 | "password": "6" * 6, 115 | "email": "recipient@example.com", 116 | }) 117 | assert res.status_code == 400 118 | msg = "Invalid password - the password must be at least 7 characters." 119 | assert res.text == f"{msg}" 120 | 121 | 122 | def test_register_post_redo_unconfirmed(session, app, client, mocker): 123 | user = User(username="testuser", email="recipient@example.com") 124 | session.add(user) 125 | session.commit() 126 | 127 | endpoint = endpoints.RegisterEndpoint( 128 | User, secret="s0secret", sender=SENDER, 129 | sent_template="${username} <${email}>" 130 | ) 131 | mocker.patch.object(endpoint, "send_email_confirmation") 132 | 133 | @app.post("/register") 134 | async def _post( 135 | username: str = fastapi.Form(None), 136 | password: str = fastapi.Form(None), 137 | confirm_password: str = fastapi.Form(None), 138 | email: str = fastapi.Form(None), 139 | ): 140 | return await endpoint.on_post( 141 | BASE_URL, session, 142 | username=username, email=email, 143 | password=password, confirm_password=confirm_password 144 | ) 145 | 146 | res = client.post("/register", data={ 147 | "username": "testuser", 148 | "password": "passw0rd", 149 | "email": "recipient@example.com", 150 | }) 151 | assert res.status_code == 200 152 | assert res.text == "testuser " 153 | 154 | assert user.verify("passw0rd") 155 | 156 | 157 | def test_register_post_user_already_confirmed(session, app, client, mocker): 158 | user = User( 159 | username="testuser", 160 | email="recipient@example.com", 161 | confirmed_at=datetime.now() 162 | ) 163 | session.add(user) 164 | session.commit() 165 | 166 | endpoint = endpoints.RegisterEndpoint( 167 | User, secret="s0secret", sender=SENDER, 168 | sent_template="${username} <${email}>" 169 | ) 170 | mocker.patch.object(endpoint, "send_email_confirmation") 171 | 172 | @app.post("/register") 173 | async def _post( 174 | username: str = fastapi.Form(None), 175 | password: str = fastapi.Form(None), 176 | confirm_password: str = fastapi.Form(None), 177 | email: str = fastapi.Form(None), 178 | ): 179 | return await endpoint.on_post( 180 | BASE_URL, session, 181 | username=username, email=email, 182 | password=password, confirm_password=confirm_password 183 | ) 184 | 185 | res = client.post("/register", data={ 186 | "username": "testuser", 187 | "password": "passw0rd", 188 | "email": "recipient@example.com", 189 | }) 190 | assert res.status_code == 200 191 | assert res.text == "testuser " 192 | 193 | 194 | def test_register_post_duplicate_email(session, app, client, mocker): 195 | user = User(username="testuser", email="recipient@example.com") 196 | session.add(user) 197 | session.commit() 198 | 199 | endpoint = endpoints.RegisterEndpoint( 200 | User, secret="s0secret", sender=SENDER, 201 | form_template="${error}" 202 | ) 203 | mocker.patch.object(endpoint, "send_email_confirmation") 204 | 205 | @app.post("/register") 206 | async def _post( 207 | username: str = fastapi.Form(None), 208 | password: str = fastapi.Form(None), 209 | confirm_password: str = fastapi.Form(None), 210 | email: str = fastapi.Form(None), 211 | ): 212 | return await endpoint.on_post( 213 | BASE_URL, session, 214 | username=username, email=email, 215 | password=password, confirm_password=confirm_password 216 | ) 217 | 218 | res = client.post("/register", data={ 219 | "username": "testuser2", 220 | "password": "passw0rd", 221 | "email": "recipient@example.com", 222 | }) 223 | assert res.status_code == 409 224 | assert res.text == "That email address already exists." 225 | 226 | 227 | def test_register_post_duplicate_username(session, app, client, mocker): 228 | user = User(username="testuser", email="recipient@example.com") 229 | session.add(user) 230 | session.commit() 231 | 232 | endpoint = endpoints.RegisterEndpoint( 233 | User, secret="s0secret", sender=SENDER, 234 | form_template="${error}" 235 | ) 236 | mocker.patch.object(endpoint, "send_email_confirmation") 237 | 238 | @app.post("/register") 239 | async def _post( 240 | username: str = fastapi.Form(None), 241 | password: str = fastapi.Form(None), 242 | confirm_password: str = fastapi.Form(None), 243 | email: str = fastapi.Form(None), 244 | ): 245 | return await endpoint.on_post( 246 | BASE_URL, session, 247 | username=username, email=email, 248 | password=password, confirm_password=confirm_password 249 | ) 250 | 251 | res = client.post("/register", data={ 252 | "username": "testuser", 253 | "password": "passw0rd", 254 | "email": "recipient2@example.com", 255 | }) 256 | assert res.status_code == 409 257 | assert res.text == "That username already exists." 258 | 259 | 260 | def test_send_email_confirmation(mocker): 261 | endpoint = endpoints.RegisterEndpoint( 262 | User, secret="s0secret", sender=SENDER, 263 | confirmation_text_template="url: {confirmation_url}\n" 288 | 289 | 290 | def test_send_email_confirmation_fail(mocker): 291 | endpoint = endpoints.RegisterEndpoint( 292 | User, secret="s0secret", sender=SENDER, 293 | ) 294 | 295 | send_message = mocker.patch.object( 296 | endpoint, "send_message", side_effect=Exception) 297 | 298 | endpoint.send_email_confirmation( 299 | base_url="http://local.example.com/", 300 | email="recipient2@example.com" 301 | ) 302 | assert send_message.called 303 | 304 | 305 | def test_send_message_ssl(mocker): 306 | endpoint = endpoints.RegisterEndpoint( 307 | User, secret="s0secret", sender=SENDER, 308 | email_use_ssl=True, email_use_tls=True, 309 | email_login="user1", email_password="passwd1" 310 | ) 311 | 312 | mock_smtp = mocker.Mock() 313 | create_default_context = mocker.patch("ssl.create_default_context") 314 | smtp_ssl = mocker.patch("smtplib.SMTP_SSL", return_value=mock_smtp) 315 | 316 | msg = mocker.Mock() 317 | endpoint.send_message(msg) 318 | 319 | assert smtp_ssl.call_args == mocker.call( 320 | endpoint.email_server, endpoint.email_port) 321 | assert mock_smtp.mock_calls == [ 322 | mocker.call.starttls(context=create_default_context.return_value), 323 | mocker.call.login(endpoint.email_login, endpoint.email_password), 324 | mocker.call.send_message(msg), 325 | mocker.call.close(), 326 | ] 327 | 328 | 329 | def test_send_message_no_ssl(mocker): 330 | endpoint = endpoints.RegisterEndpoint( 331 | User, secret="s0secret", sender=SENDER, 332 | email_use_ssl=False, email_use_tls=True, 333 | email_login="user1", email_password="passwd1" 334 | ) 335 | 336 | mock_smtp = mocker.Mock() 337 | create_default_context = mocker.patch("ssl.create_default_context") 338 | smtp = mocker.patch("smtplib.SMTP", return_value=mock_smtp) 339 | 340 | msg = mocker.Mock() 341 | endpoint.send_message(msg) 342 | 343 | assert smtp.call_args == mocker.call( 344 | endpoint.email_server, endpoint.email_port) 345 | assert mock_smtp.mock_calls == [ 346 | mocker.call.starttls(context=create_default_context.return_value), 347 | mocker.call.login(endpoint.email_login, endpoint.email_password), 348 | mocker.call.send_message(msg), 349 | mocker.call.close(), 350 | ] 351 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code. 6 | extension-pkg-whitelist=pydantic 7 | 8 | # Add files or directories to the blacklist. They should be base names, not 9 | # paths. 10 | ignore=CVS 11 | 12 | # Add files or directories matching the regex patterns to the blacklist. The 13 | # regex matches against base names, not paths. 14 | ignore-patterns= 15 | 16 | # Python code to execute, usually for sys.path manipulation such as 17 | # pygtk.require(). 18 | #init-hook= 19 | 20 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 21 | # number of processors available to use. 22 | jobs=1 23 | 24 | # Control the amount of potential inferred values when inferring a single 25 | # object. This can help the performance when dealing with large functions or 26 | # complex, nested conditions. 27 | limit-inference-results=100 28 | 29 | # List of plugins (as comma separated values of python modules names) to load, 30 | # usually to register additional checkers. 31 | load-plugins= 32 | 33 | # Pickle collected data for later comparisons. 34 | persistent=yes 35 | 36 | # Specify a configuration file. 37 | #rcfile= 38 | 39 | # When enabled, pylint would attempt to guess common misconfiguration and emit 40 | # user-friendly hints instead of false-positive error messages. 41 | suggestion-mode=yes 42 | 43 | # Allow loading of arbitrary C extensions. Extensions are imported into the 44 | # active Python interpreter and may run arbitrary code. 45 | unsafe-load-any-extension=no 46 | 47 | 48 | [MESSAGES CONTROL] 49 | 50 | # Only show warnings with the listed confidence levels. Leave empty to show 51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 52 | confidence= 53 | 54 | # Disable the message, report, category or checker with the given id(s). You 55 | # can either give multiple identifiers separated by comma (,) or put this 56 | # option multiple times (only on the command line, not in the configuration 57 | # file where it should appear only once). You can also use "--disable=all" to 58 | # disable everything first and then reenable specific checks. For example, if 59 | # you want to run only the similarities checker, you can use "--disable=all 60 | # --enable=similarities". If you want to run only the classes checker, but have 61 | # no Warning level messages displayed, use "--disable=all --enable=classes 62 | # --disable=W". 63 | disable=print-statement, 64 | parameter-unpacking, 65 | unpacking-in-except, 66 | old-raise-syntax, 67 | backtick, 68 | long-suffix, 69 | old-ne-operator, 70 | old-octal-literal, 71 | import-star-module-level, 72 | non-ascii-bytes-literal, 73 | raw-checker-failed, 74 | bad-inline-option, 75 | locally-disabled, 76 | file-ignored, 77 | suppressed-message, 78 | useless-suppression, 79 | deprecated-pragma, 80 | use-symbolic-message-instead, 81 | apply-builtin, 82 | basestring-builtin, 83 | buffer-builtin, 84 | cmp-builtin, 85 | coerce-builtin, 86 | execfile-builtin, 87 | file-builtin, 88 | long-builtin, 89 | raw_input-builtin, 90 | reduce-builtin, 91 | standarderror-builtin, 92 | unicode-builtin, 93 | xrange-builtin, 94 | coerce-method, 95 | delslice-method, 96 | getslice-method, 97 | setslice-method, 98 | no-absolute-import, 99 | old-division, 100 | dict-iter-method, 101 | dict-view-method, 102 | next-method-called, 103 | metaclass-assignment, 104 | indexing-exception, 105 | raising-string, 106 | reload-builtin, 107 | oct-method, 108 | hex-method, 109 | nonzero-method, 110 | cmp-method, 111 | input-builtin, 112 | round-builtin, 113 | intern-builtin, 114 | unichr-builtin, 115 | map-builtin-not-iterating, 116 | zip-builtin-not-iterating, 117 | range-builtin-not-iterating, 118 | filter-builtin-not-iterating, 119 | using-cmp-argument, 120 | eq-without-hash, 121 | div-method, 122 | idiv-method, 123 | rdiv-method, 124 | exception-message-attribute, 125 | invalid-str-codec, 126 | sys-max-int, 127 | bad-python3-import, 128 | deprecated-string-function, 129 | deprecated-str-translate-call, 130 | deprecated-itertools-function, 131 | deprecated-types-field, 132 | next-method-defined, 133 | dict-items-not-iterating, 134 | dict-keys-not-iterating, 135 | dict-values-not-iterating, 136 | deprecated-operator-function, 137 | deprecated-urllib-function, 138 | xreadlines-attribute, 139 | deprecated-sys-function, 140 | exception-escape, 141 | comprehension-escape, 142 | duplicate-code, 143 | too-few-public-methods, 144 | too-many-instance-attributes, 145 | too-many-arguments, 146 | fixme 147 | 148 | # Enable the message, report, category or checker with the given id(s). You can 149 | # either give multiple identifier separated by comma (,) or put this option 150 | # multiple time (only on the command line, not in the configuration file where 151 | # it should appear only once). See also the "--disable" option for examples. 152 | enable=c-extension-no-member 153 | 154 | 155 | [REPORTS] 156 | 157 | # Python expression which should return a note less than 10 (10 is the highest 158 | # note). You have access to the variables errors warning, statement which 159 | # respectively contain the number of errors / warnings messages and the total 160 | # number of statements analyzed. This is used by the global evaluation report 161 | # (RP0004). 162 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 163 | 164 | # Template used to display messages. This is a python new-style format string 165 | # used to format the message information. See doc for all details. 166 | #msg-template= 167 | 168 | # Set the output format. Available formats are text, parseable, colorized, json 169 | # and msvs (visual studio). You can also give a reporter class, e.g. 170 | # mypackage.mymodule.MyReporterClass. 171 | output-format=text 172 | 173 | # Tells whether to display a full report or only the messages. 174 | reports=no 175 | 176 | # Activate the evaluation score. 177 | score=no 178 | 179 | 180 | [REFACTORING] 181 | 182 | # Maximum number of nested blocks for function / method body 183 | max-nested-blocks=5 184 | 185 | # Complete name of functions that never returns. When checking for 186 | # inconsistent-return-statements if a never returning function is called then 187 | # it will be considered as an explicit return statement and no message will be 188 | # printed. 189 | never-returning-functions=sys.exit 190 | 191 | 192 | [FORMAT] 193 | 194 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 195 | expected-line-ending-format= 196 | 197 | # Regexp for a line that is allowed to be longer than the limit. 198 | ignore-long-lines=^\s*(# )??$ 199 | 200 | # Number of spaces of indent required inside a hanging or continued line. 201 | indent-after-paren=4 202 | 203 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 204 | # tab). 205 | indent-string=' ' 206 | 207 | # Maximum number of characters on a single line. 208 | max-line-length=100 209 | 210 | # Maximum number of lines in a module. 211 | max-module-lines=1000 212 | 213 | # List of optional constructs for which whitespace checking is disabled. `dict- 214 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 215 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 216 | # `empty-line` allows space-only lines. 217 | no-space-check=trailing-comma, 218 | dict-separator 219 | 220 | # Allow the body of a class to be on the same line as the declaration if body 221 | # contains single statement. 222 | single-line-class-stmt=no 223 | 224 | # Allow the body of an if to be on the same line as the test if there is no 225 | # else. 226 | single-line-if-stmt=no 227 | 228 | 229 | [TYPECHECK] 230 | 231 | # List of decorators that produce context managers, such as 232 | # contextlib.contextmanager. Add to this list to register other decorators that 233 | # produce valid context managers. 234 | contextmanager-decorators=contextlib.contextmanager 235 | 236 | # List of members which are set dynamically and missed by pylint inference 237 | # system, and so shouldn't trigger E1101 when accessed. Python regular 238 | # expressions are accepted. 239 | generated-members= 240 | 241 | # Tells whether missing members accessed in mixin class should be ignored. A 242 | # mixin class is detected if its name ends with "mixin" (case insensitive). 243 | ignore-mixin-members=yes 244 | 245 | # Tells whether to warn about missing members when the owner of the attribute 246 | # is inferred to be None. 247 | ignore-none=yes 248 | 249 | # This flag controls whether pylint should warn about no-member and similar 250 | # checks whenever an opaque object is returned when inferring. The inference 251 | # can return multiple potential results while evaluating a Python object, but 252 | # some branches might not be evaluated, which results in partial inference. In 253 | # that case, it might be useful to still emit no-member and other checks for 254 | # the rest of the inferred objects. 255 | ignore-on-opaque-inference=yes 256 | 257 | # List of class names for which member attributes should not be checked (useful 258 | # for classes with dynamically set attributes). This supports the use of 259 | # qualified names. 260 | ignored-classes=optparse.Values,thread._local,_thread._local,scoped_session 261 | 262 | # List of module names for which member attributes should not be checked 263 | # (useful for modules/projects where namespaces are manipulated during runtime 264 | # and thus existing member attributes cannot be deduced by static analysis. It 265 | # supports qualified module names, as well as Unix pattern matching. 266 | ignored-modules= 267 | 268 | # Show a hint with possible names when a member name was not found. The aspect 269 | # of finding the hint is based on edit distance. 270 | missing-member-hint=yes 271 | 272 | # The minimum edit distance a name should have in order to be considered a 273 | # similar match for a missing member name. 274 | missing-member-hint-distance=1 275 | 276 | # The total number of similar names that should be taken in consideration when 277 | # showing a hint for a missing member. 278 | missing-member-max-choices=1 279 | 280 | 281 | [MISCELLANEOUS] 282 | 283 | # List of note tags to take in consideration, separated by a comma. 284 | notes=FIXME, 285 | XXX, 286 | TODO 287 | 288 | 289 | [BASIC] 290 | 291 | # Naming style matching correct argument names. 292 | argument-naming-style=snake_case 293 | 294 | # Regular expression matching correct argument names. Overrides argument- 295 | # naming-style. 296 | #argument-rgx= 297 | 298 | # Naming style matching correct attribute names. 299 | attr-naming-style=snake_case 300 | 301 | # Regular expression matching correct attribute names. Overrides attr-naming- 302 | # style. 303 | #attr-rgx= 304 | 305 | # Bad variable names which should always be refused, separated by a comma. 306 | bad-names=foo, 307 | bar, 308 | baz, 309 | toto, 310 | tutu, 311 | tata 312 | 313 | # Naming style matching correct class attribute names. 314 | class-attribute-naming-style=any 315 | 316 | # Regular expression matching correct class attribute names. Overrides class- 317 | # attribute-naming-style. 318 | #class-attribute-rgx= 319 | 320 | # Naming style matching correct class names. 321 | class-naming-style=PascalCase 322 | 323 | # Regular expression matching correct class names. Overrides class-naming- 324 | # style. 325 | #class-rgx= 326 | 327 | # Naming style matching correct constant names. 328 | const-naming-style=UPPER_CASE 329 | 330 | # Regular expression matching correct constant names. Overrides const-naming- 331 | # style. 332 | #const-rgx= 333 | 334 | # Minimum line length for functions/classes that require docstrings, shorter 335 | # ones are exempt. 336 | docstring-min-length=-1 337 | 338 | # Naming style matching correct function names. 339 | function-naming-style=snake_case 340 | 341 | # Regular expression matching correct function names. Overrides function- 342 | # naming-style. 343 | #function-rgx= 344 | 345 | # Good variable names which should always be accepted, separated by a comma. 346 | good-names=i, 347 | j, 348 | k, 349 | ex, 350 | Run, 351 | _, 352 | logger, 353 | Session, 354 | FastAPI_SQLAlchemy, 355 | 356 | # Include a hint for the correct naming format with invalid-name. 357 | include-naming-hint=no 358 | 359 | # Naming style matching correct inline iteration names. 360 | inlinevar-naming-style=any 361 | 362 | # Regular expression matching correct inline iteration names. Overrides 363 | # inlinevar-naming-style. 364 | #inlinevar-rgx= 365 | 366 | # Naming style matching correct method names. 367 | method-naming-style=snake_case 368 | 369 | # Regular expression matching correct method names. Overrides method-naming- 370 | # style. 371 | #method-rgx= 372 | 373 | # Naming style matching correct module names. 374 | module-naming-style=snake_case 375 | 376 | # Regular expression matching correct module names. Overrides module-naming- 377 | # style. 378 | #module-rgx= 379 | 380 | # Colon-delimited sets of names that determine each other's naming style when 381 | # the name regexes allow several styles. 382 | name-group= 383 | 384 | # Regular expression which should only match function or class names that do 385 | # not require a docstring. 386 | no-docstring-rgx=^_ 387 | 388 | # List of decorators that produce properties, such as abc.abstractproperty. Add 389 | # to this list to register other decorators that produce valid properties. 390 | # These decorators are taken in consideration only for invalid-name. 391 | property-classes=abc.abstractproperty 392 | 393 | # Naming style matching correct variable names. 394 | variable-naming-style=snake_case 395 | 396 | # Regular expression matching correct variable names. Overrides variable- 397 | # naming-style. 398 | #variable-rgx= 399 | 400 | 401 | [SIMILARITIES] 402 | 403 | # Ignore comments when computing similarities. 404 | ignore-comments=yes 405 | 406 | # Ignore docstrings when computing similarities. 407 | ignore-docstrings=yes 408 | 409 | # Ignore imports when computing similarities. 410 | ignore-imports=no 411 | 412 | # Minimum lines number of a similarity. 413 | min-similarity-lines=4 414 | 415 | 416 | [SPELLING] 417 | 418 | # Limits count of emitted suggestions for spelling mistakes. 419 | max-spelling-suggestions=4 420 | 421 | # Spelling dictionary name. Available dictionaries: none. To make it working 422 | # install python-enchant package.. 423 | spelling-dict= 424 | 425 | # List of comma separated words that should not be checked. 426 | spelling-ignore-words= 427 | 428 | # A path to a file that contains private dictionary; one word per line. 429 | spelling-private-dict-file= 430 | 431 | # Tells whether to store unknown words to indicated private dictionary in 432 | # --spelling-private-dict-file option instead of raising a message. 433 | spelling-store-unknown-words=no 434 | 435 | 436 | [STRING] 437 | 438 | # This flag controls whether the implicit-str-concat-in-sequence should 439 | # generate a warning on implicit string concatenation in sequences defined over 440 | # several lines. 441 | check-str-concat-over-line-jumps=no 442 | 443 | 444 | [LOGGING] 445 | 446 | # Format style used to check logging format string. `old` means using % 447 | # formatting, while `new` is for `{}` formatting. 448 | logging-format-style=old 449 | 450 | # Logging modules to check that the string format arguments are in logging 451 | # function parameter format. 452 | logging-modules=logging 453 | 454 | 455 | [VARIABLES] 456 | 457 | # List of additional names supposed to be defined in builtins. Remember that 458 | # you should avoid defining new builtins when possible. 459 | additional-builtins= 460 | 461 | # Tells whether unused global variables should be treated as a violation. 462 | allow-global-unused-variables=yes 463 | 464 | # List of strings which can identify a callback function by name. A callback 465 | # name must start or end with one of those strings. 466 | callbacks=cb_, 467 | _cb 468 | 469 | # A regular expression matching the name of dummy variables (i.e. expected to 470 | # not be used). 471 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 472 | 473 | # Argument names that match this expression will be ignored. Default to name 474 | # with leading underscore. 475 | ignored-argument-names=_.*|^ignored_|^unused_ 476 | 477 | # Tells whether we should check for unused import in __init__ files. 478 | init-import=no 479 | 480 | # List of qualified module names which can have objects that can redefine 481 | # builtins. 482 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 483 | 484 | 485 | [DESIGN] 486 | 487 | # Maximum number of arguments for function / method. 488 | max-args=5 489 | 490 | # Maximum number of attributes for a class (see R0902). 491 | max-attributes=7 492 | 493 | # Maximum number of boolean expressions in an if statement. 494 | max-bool-expr=5 495 | 496 | # Maximum number of branch for function / method body. 497 | max-branches=12 498 | 499 | # Maximum number of locals for function / method body. 500 | max-locals=15 501 | 502 | # Maximum number of parents for a class (see R0901). 503 | max-parents=7 504 | 505 | # Maximum number of public methods for a class (see R0904). 506 | max-public-methods=20 507 | 508 | # Maximum number of return / yield for function / method body. 509 | max-returns=6 510 | 511 | # Maximum number of statements in function / method body. 512 | max-statements=50 513 | 514 | # Minimum number of public methods for a class (see R0903). 515 | min-public-methods=2 516 | 517 | 518 | [CLASSES] 519 | 520 | # List of method names used to declare (i.e. assign) instance attributes. 521 | defining-attr-methods=__init__, 522 | __new__, 523 | setUp 524 | 525 | # List of member names, which should be excluded from the protected access 526 | # warning. 527 | exclude-protected=_asdict, 528 | _fields, 529 | _replace, 530 | _source, 531 | _make 532 | 533 | # List of valid names for the first argument in a class method. 534 | valid-classmethod-first-arg=cls 535 | 536 | # List of valid names for the first argument in a metaclass class method. 537 | valid-metaclass-classmethod-first-arg=cls 538 | 539 | 540 | [IMPORTS] 541 | 542 | # Allow wildcard imports from modules that define __all__. 543 | allow-wildcard-with-all=no 544 | 545 | # Analyse import fallback blocks. This can be used to support both Python 2 and 546 | # 3 compatible code, which means that the block might have code that exists 547 | # only in one or another interpreter, leading to false positives when analysed. 548 | analyse-fallback-blocks=no 549 | 550 | # Deprecated modules which should not be used, separated by a comma. 551 | deprecated-modules=optparse,tkinter.tix 552 | 553 | # Create a graph of external dependencies in the given file (report RP0402 must 554 | # not be disabled). 555 | ext-import-graph= 556 | 557 | # Create a graph of every (i.e. internal and external) dependencies in the 558 | # given file (report RP0402 must not be disabled). 559 | import-graph= 560 | 561 | # Create a graph of internal dependencies in the given file (report RP0402 must 562 | # not be disabled). 563 | int-import-graph= 564 | 565 | # Force import order to recognize a module as part of the standard 566 | # compatibility libraries. 567 | known-standard-library= 568 | 569 | # Force import order to recognize a module as part of a third party library. 570 | known-third-party=enchant 571 | 572 | 573 | [EXCEPTIONS] 574 | 575 | # Exceptions that will emit a warning when being caught. Defaults to 576 | # "BaseException, Exception". 577 | overgeneral-exceptions=BaseException, 578 | Exception 579 | --------------------------------------------------------------------------------