├── .flake8 ├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE.MD ├── README.md ├── docs ├── common │ ├── badges.md │ ├── dependencies.md │ ├── dependencies.ru.md │ ├── install001.md │ └── install002.md ├── index.md ├── index.ru.md ├── managers.md ├── setup.md └── setup.ru.md ├── docs_src ├── managers │ ├── passwords001.py │ ├── passwords002.py │ ├── tokens001.py │ └── tokens002.py └── setup001.py ├── fastapi_mongodb ├── __init__.py ├── config.py ├── db.py ├── dependencies.py ├── exceptions.py ├── helpers.py ├── logging.py ├── managers.py ├── middlewares.py ├── models.py ├── repositories.py ├── schemas.py └── types.py ├── mkdocs.yml ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── test_db.py ├── test_helpers.py ├── test_logging.py ├── test_managers.py ├── test_models.py ├── test_repositories.py └── test_types.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | extend-ignore = E203, W503 4 | exclude = __init__.py,.tox,dist,docs_src -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.9] 16 | mongodb-version: [5.0] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install flake8 pytest poetry 29 | poetry install --no-interaction --no-ansi 30 | 31 | - name: MongoDB in GitHub Actions 32 | uses: supercharge/mongodb-github-action@1.6.0 33 | with: 34 | mongodb-version: ${{ matrix.mongodb-version}} 35 | mongodb-replica-set: rs-test 36 | 37 | - name: Lint with flake8 38 | run: | 39 | poetry run flake8 40 | 41 | - name: Test with pytest 42 | run: | 43 | poetry run pytest 44 | 45 | - name: Generate Report 46 | run: | 47 | poetry run pytest --cov=./ --cov-report=xml 48 | 49 | - name: Upload Coverage to Codecov 50 | uses: codecov/codecov-action@v2 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .env 3 | .venv 4 | env/ 5 | venv/ 6 | ENV/ 7 | env.bak/ 8 | venv.bak/ 9 | .pytest_cache 10 | .vscode 11 | .coverage 12 | coverate_html_report 13 | dist 14 | poetry.lock 15 | .tox 16 | .mypy_cache 17 | /site 18 | */__pycache__/* 19 | *.py[cod] 20 | build/ 21 | dist/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.mo 26 | *.pot 27 | -------------------------------------------------------------------------------- /LICENSE.MD: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kostiantyn Salnykov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI❤MongoDB 2 | ![GitHub](https://img.shields.io/github/license/Kostiantyn-Salnykov/fastapi-mongodb) 3 | ![GitHub Workflow Status (branch)](https://img.shields.io/github/workflow/status/Kostiantyn-Salnykov/fastapi-mongodb/Python%20package/master) 4 | ![GitHub last commit (branch)](https://img.shields.io/github/last-commit/kostiantyn-salnykov/fastapi-mongodb/master) 5 | [![codecov](https://codecov.io/gh/Kostiantyn-Salnykov/fastapi-mongodb/branch/master/graph/badge.svg?token=77Z4DQVIU5)](https://codecov.io/gh/Kostiantyn-Salnykov/fastapi-mongodb) 6 | 7 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/fastapi-mongodb) 8 | ![PyPI](https://img.shields.io/pypi/v/fastapi-mongodb) 9 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/fastapi-mongodb) 10 | 11 | [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 12 | [![](https://img.shields.io/badge/code%20style-black-000000?style=flat)](https://github.com/psf/black) 13 | 14 | # Installation 15 | 16 | ```shell 17 | pip install fastapi-mongodb 18 | ``` 19 | 20 | # FastAPI❤MongoDB documentation: 21 | 22 | ## [Documentation](https://Kostiantyn-Salnykov.github.io/fastapi-mongodb/) 23 | -------------------------------------------------------------------------------- /docs/common/badges.md: -------------------------------------------------------------------------------- 1 | ![GitHub](https://img.shields.io/github/license/Kostiantyn-Salnykov/fastapi-mongodb) 2 | ![GitHub Workflow Status (branch)](https://img.shields.io/github/workflow/status/Kostiantyn-Salnykov/fastapi-mongodb/Python%20package/master) 3 | ![GitHub last commit (branch)](https://img.shields.io/github/last-commit/kostiantyn-salnykov/fastapi-mongodb/master) 4 | [![codecov](https://codecov.io/gh/Kostiantyn-Salnykov/fastapi-mongodb/branch/master/graph/badge.svg?token=77Z4DQVIU5)](https://codecov.io/gh/Kostiantyn-Salnykov/fastapi-mongodb) 5 | 6 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/fastapi-mongodb) 7 | ![PyPI](https://img.shields.io/pypi/v/fastapi-mongodb) 8 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/fastapi-mongodb) 9 | 10 | [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 11 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000?style=flat)](https://github.com/psf/black) 12 | -------------------------------------------------------------------------------- /docs/common/dependencies.md: -------------------------------------------------------------------------------- 1 | - [Python 3.9+](https://www.python.org/downloads/) 2 | - [fastapi](https://fastapi.tiangolo.com/) 3 | - [motor](https://motor.readthedocs.io/en/stable/index.html) 4 | - [pymongo](https://pymongo.readthedocs.io/en/stable/) 5 | - [pydantic](https://pydantic-docs.helpmanual.io/) 6 | - [orjson](https://github.com/ijl/orjson) (optional, need for BaseConfiguration and Pydantic's **Config** options: * 7 | json_dumps*, *json_loads*) 8 | - [pyjwt](https://pyjwt.readthedocs.io/en/stable/) -------------------------------------------------------------------------------- /docs/common/dependencies.ru.md: -------------------------------------------------------------------------------- 1 | - [Python 3.9+](https://www.python.org/downloads/) 2 | - [fastapi](https://fastapi.tiangolo.com/) 3 | - [motor](https://motor.readthedocs.io/en/stable/index.html) 4 | - [pymongo](https://pymongo.readthedocs.io/en/stable/) 5 | - [pydantic](https://pydantic-docs.helpmanual.io/) 6 | - [orjson](https://github.com/ijl/orjson) (не обязательная к установке, используется для BaseConfiguration и Pydantic ** 7 | Config** параметров: *json_dumps*, *json_loads*) 8 | - [pyjwt](https://pyjwt.readthedocs.io/en/stable/) -------------------------------------------------------------------------------- /docs/common/install001.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | pip install fastapi-mongodb 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/common/install002.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | pip install fastapi-mongodb[orjson] 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # FastAPI:heart:MongoDB documentation 2 | 3 | --8<-- "docs/common/badges.md" 4 | 5 | ## Description 6 | 7 | A library that simplifies work and integration with MongoDB for a FastAPI project. 8 | 9 | ???+ danger 10 | The project is under development, so there is a possibility of breaking changes. 11 | 12 | ??? info 13 | When the key points in development are completed, I will take care of the documentation. 14 | 15 | ## Installation 16 | 17 | ### Basic installation 18 | 19 | --8<-- "docs/common/install001.md" 20 | 21 | ### Full installation 22 | 23 | --8<-- "docs/common/install002.md" 24 | 25 | ??? info "Dependencies" 26 | --8<-- "docs/common/dependencies.md" 27 | -------------------------------------------------------------------------------- /docs/index.ru.md: -------------------------------------------------------------------------------- 1 | # FastAPI:heart:MongoDB документация 2 | 3 | --8<-- "docs/common/badges.md" 4 | 5 | ## Описание 6 | 7 | Библиотека, которая упрощает работу и интеграцию с MongoDB для проекта на FastAPI. 8 | 9 | ???+ danger 10 | Проект находится в активной стадии разработки, поэтому возможны критические изменения. 11 | 12 | ??? info 13 | Когда ключевые моменты в разработке будут выполнены, я займусь документацией. 14 | 15 | ## Установка 16 | 17 | ### Базовая установка 18 | 19 | --8<-- "docs/common/install001.md" 20 | 21 | ### Полная установка 22 | 23 | --8<-- "docs/common/install002.md" 24 | 25 | ??? info "Зависимости" 26 | --8<-- "docs/common/dependencies.ru.md" 27 | -------------------------------------------------------------------------------- /docs/managers.md: -------------------------------------------------------------------------------- 1 | ## PasswordsManager 2 | ```python 3 | --8<-- "docs_src/managers/passwords001.py" 4 | ``` 5 | 6 | ### PASSWORD_ALGORITHMS 7 | ```python 8 | --8<-- "docs_src/managers/passwords002.py" 9 | ``` 10 | 11 | ## TokensManager 12 | ```python 13 | --8<-- "docs_src/managers/tokens001.py" 14 | ``` 15 | 16 | ### TOKENS_ALGORITHMS 17 | ```python 18 | --8<-- "docs_src/managers/tokens002.py" 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/setup.md: -------------------------------------------------------------------------------- 1 | ## Setup 2 | 3 | ```python hl_lines="5 10 11" 4 | --8<-- "docs_src/setup001.py" 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/setup.ru.md: -------------------------------------------------------------------------------- 1 | ## Настройка 2 | 3 | ```python hl_lines="5 10 11" 4 | --8<-- "docs_src/setup001.py" 5 | ``` 6 | -------------------------------------------------------------------------------- /docs_src/managers/passwords001.py: -------------------------------------------------------------------------------- 1 | import fastapi_mongodb 2 | 3 | passwords_manager = fastapi_mongodb.PasswordsManager( 4 | algorithm=fastapi_mongodb.PASSWORD_ALGORITHMS.SHA512, iterations=524288 5 | ) 6 | 7 | raw_password = "SUPER SECURE PASSWORD!" 8 | password_hash = passwords_manager.make_password(password=raw_password) 9 | if passwords_manager.check_password( 10 | password=raw_password, password_hash=password_hash 11 | ): 12 | print("""ALLOW ACCESS!""") 13 | else: 14 | print("""ACCESS DENIED!""") 15 | -------------------------------------------------------------------------------- /docs_src/managers/passwords002.py: -------------------------------------------------------------------------------- 1 | class PASSWORD_ALGORITHMS(str, enum.Enum): 2 | SHA256 = "sha256" 3 | SHA384 = "sha384" 4 | SHA512 = "sha512" 5 | -------------------------------------------------------------------------------- /docs_src/managers/tokens001.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import fastapi_mongodb 4 | 5 | SECRET_KEY = "SUPER SECURE TOKEN FROM ENVIRONMENT VARIABLES!!!" 6 | 7 | tokens_manager = fastapi_mongodb.TokensManager( 8 | secret_key=SECRET_KEY, 9 | algorithm=fastapi_mongodb.TOKEN_ALGORITHMS.HS256, 10 | default_token_lifetime=datetime.timedelta(minutes=30), 11 | ) 12 | 13 | token_data = { 14 | "anyKey": "Not Secure", 15 | "anotherKey": "All data available to Front-end and JWT clients!!!", 16 | } 17 | 18 | jwt_token = tokens_manager.create_code( 19 | data=token_data, aud="test_audience", iss="test_issuer" 20 | ) 21 | 22 | parsed_token = tokens_manager.read_code( 23 | code=jwt_token, aud="test_audience", iss="test_issuer" 24 | ) 25 | -------------------------------------------------------------------------------- /docs_src/managers/tokens002.py: -------------------------------------------------------------------------------- 1 | class TOKEN_ALGORITHMS(str, enum.Enum): 2 | HS256 = "HS256" 3 | HS384 = "HS384" 4 | HS512 = "HS512" 5 | -------------------------------------------------------------------------------- /docs_src/setup001.py: -------------------------------------------------------------------------------- 1 | import fastapi 2 | 3 | import fastapi_mongodb 4 | 5 | db_manager = fastapi_mongodb.BaseDBManager( 6 | db_url="mongodb://0.0.0.0:27017/", default_db_name="test_db" 7 | ) 8 | 9 | app = fastapi.FastAPI( 10 | on_startup=[db_manager.create_client], 11 | on_shutdown=[db_manager.delete_client], 12 | ) 13 | 14 | app.state.db_manager = db_manager 15 | app.state.users_repository = fastapi_mongodb.BaseRepository( 16 | db_manager=db_manager, db_name="test_db", col_name="users" 17 | ) 18 | -------------------------------------------------------------------------------- /fastapi_mongodb/__init__.py: -------------------------------------------------------------------------------- 1 | """fastapi_mongodb library entrypoint.""" 2 | from .config import * 3 | from .db import * 4 | from .dependencies import * 5 | from .exceptions import * 6 | from .helpers import * 7 | from .logging import * 8 | from .managers import * 9 | from .middlewares import * 10 | from .models import * 11 | from .repositories import * 12 | from .schemas import * 13 | from .types import * 14 | -------------------------------------------------------------------------------- /fastapi_mongodb/config.py: -------------------------------------------------------------------------------- 1 | """Customized configuration from pydantic to work with models and schemas.""" 2 | import datetime 3 | 4 | import bson 5 | import pydantic 6 | import pydantic.json 7 | 8 | try: 9 | import orjson 10 | 11 | json_dumps = orjson.dumps 12 | json_loads = orjson.loads 13 | except ImportError: 14 | import json 15 | 16 | json_dumps = json.dumps 17 | json_loads = json.loads 18 | 19 | __all__ = ["BaseConfiguration"] 20 | 21 | 22 | class BaseConfiguration(pydantic.BaseConfig): 23 | """base configuration class for Models and Schemas.""" 24 | 25 | allow_population_by_field_name = True 26 | use_enum_values = True 27 | json_encoders = { # pragma: no cover 28 | datetime.datetime: lambda date_time: pydantic.json.isoformat(date_time), 29 | datetime.timedelta: lambda time_delta: pydantic.json.timedelta_isoformat(time_delta), 30 | bson.ObjectId: str, 31 | } 32 | json_dumps = json_dumps 33 | json_loads = json_loads 34 | -------------------------------------------------------------------------------- /fastapi_mongodb/db.py: -------------------------------------------------------------------------------- 1 | """MongoDB base logic.""" 2 | import collections.abc 3 | import datetime 4 | import decimal 5 | import re 6 | import typing 7 | 8 | import bson 9 | import motor.motor_asyncio 10 | import pydantic.json 11 | import pymongo 12 | import pymongo.client_session 13 | import pymongo.database 14 | import pymongo.errors 15 | import pymongo.monitoring 16 | import pymongo.read_concern 17 | from bson import UuidRepresentation 18 | 19 | import fastapi_mongodb.helpers 20 | from fastapi_mongodb.logging import simple_logger as logger 21 | 22 | __all__ = [ 23 | "CommandLogger", 24 | "ConnectionPoolLogger", 25 | "BaseDBManager", 26 | "HeartbeatLogger", 27 | "ServerLogger", 28 | "TopologyLogger", 29 | "BaseDocument", 30 | "DECIMAL_CODEC", 31 | "TIMEDELTA_CODEC", 32 | "CODEC_OPTIONS", 33 | ] 34 | 35 | 36 | class BaseDocument(collections.abc.MutableMapping): 37 | def __init__(self, data: dict = None): 38 | self._data = data or {} 39 | 40 | def __len__(self): 41 | """Get length of BaseDocument.""" 42 | return len(self._data) 43 | 44 | def __getitem__(self, item): 45 | """Retrieve key in BaseDocument.""" 46 | return self._data.__getitem__(item) 47 | 48 | def __setitem__(self, key, value): 49 | """Set key in BaseDocument.""" 50 | return self._data.__setitem__(key, value) 51 | 52 | def __delitem__(self, key): 53 | """Delete key in BaseDocument.""" 54 | return self._data.__delitem__(key) 55 | 56 | def __iter__(self): 57 | """Return iterable from BaseDocument.""" 58 | return iter(self._data) 59 | 60 | def __repr__(self): 61 | """Representation of BaseDocument.""" 62 | data = f"oid={doc_id}" if (doc_id := self.id) else "NO DOCUMENT id" 63 | return f"{self.__class__.__name__}({data})" 64 | 65 | def __eq__(self, other): 66 | """Check equality between BaseDocument and dict.""" 67 | if isinstance(other, self.__class__): 68 | return self.data == other.data 69 | elif isinstance(other, dict): 70 | return self.data == other 71 | else: 72 | return False 73 | 74 | @property 75 | def data(self) -> dict[str, typing.Any]: 76 | return self._data 77 | 78 | @property 79 | def oid(self) -> bson.ObjectId: 80 | return self._data.get("_id", None) 81 | 82 | @property 83 | def id(self) -> str: 84 | return str(self.oid) if self.oid else None 85 | 86 | @property 87 | def generated_at(self): 88 | return self.oid.generation_time.astimezone(tz=fastapi_mongodb.helpers.get_utc_timezone()) 89 | 90 | 91 | class DecimalCodec(bson.codec_options.TypeCodec): 92 | python_type = decimal.Decimal 93 | bson_type = bson.Decimal128 94 | 95 | def transform_python(self, value): 96 | return bson.Decimal128(value=value) 97 | 98 | def transform_bson(self, value): 99 | return value.to_decimal() 100 | 101 | 102 | class TimeDeltaCodec(bson.codec_options.TypeCodec): 103 | python_type = datetime.timedelta 104 | bson_type = bson.string_type 105 | 106 | def transform_python(self, value): 107 | return pydantic.json.timedelta_isoformat(value) 108 | 109 | def transform_bson(self, value): 110 | def get_iso_split(s, split): 111 | if split in s: 112 | n, s = s.split(split) 113 | else: 114 | n = 0 115 | return n, s 116 | 117 | def get_int(s): 118 | try: 119 | return int(s) 120 | except ValueError: 121 | return 0 122 | 123 | def parse_iso_duration(string: str) -> typing.Union[str, datetime.timedelta]: 124 | valid_duration = re.match( 125 | pattern=r"^P(?!$)((\d+Y)|(\d+\.\d+Y$))?((\d+M)|(\d+\.\d+M$))?((\d+W)|(\d+\.\d+W$))?((\d+D)|" 126 | r"(\d+\.\d+D$))?(T(?=\d)((\d+H)|(\d+\.\d+H$))?((\d+M)|(\d+\.\d+M$))?(\d+(\.\d+)?S)?)??$", 127 | string=string, 128 | ) 129 | if not valid_duration: 130 | return string 131 | # Remove prefix 132 | s = string.removeprefix("P") 133 | # Step through letter dividers 134 | weeks, s = get_iso_split(s, "W") 135 | days, s = get_iso_split(s, "D") 136 | _, s = get_iso_split(s, "T") 137 | hours, s = get_iso_split(s, "H") 138 | minutes, s = get_iso_split(s, "M") 139 | full_seconds, s = get_iso_split(s, "S") 140 | secs, _, micros = full_seconds.partition(".") 141 | seconds = get_int(secs) 142 | microseconds = get_int(micros) 143 | return datetime.timedelta( 144 | weeks=int(weeks), 145 | days=int(days), 146 | hours=int(hours), 147 | minutes=int(minutes), 148 | seconds=seconds, 149 | microseconds=microseconds, 150 | ) 151 | 152 | return parse_iso_duration(string=value) 153 | 154 | 155 | DECIMAL_CODEC = DecimalCodec() 156 | TIMEDELTA_CODEC = TimeDeltaCodec() 157 | 158 | CODEC_OPTIONS = bson.codec_options.CodecOptions( 159 | document_class=BaseDocument, 160 | uuid_representation=UuidRepresentation.STANDARD, 161 | tz_aware=True, 162 | tzinfo=fastapi_mongodb.helpers.get_utc_timezone(), 163 | type_registry=bson.codec_options.TypeRegistry(type_codecs=[DECIMAL_CODEC, TIMEDELTA_CODEC]), 164 | ) 165 | 166 | 167 | # TODO: Make loggers customizable 168 | class CommandLogger(pymongo.monitoring.CommandListener): 169 | def started(self, event): 170 | logger.debug( 171 | f"Command '{event.command_name}' with request id {event.request_id} started on server {event.connection_id}" 172 | ) 173 | 174 | def succeeded(self, event): 175 | logger.debug( 176 | f"Command '{event.command_name}' with request id {event.request_id} on server {event.connection_id} " 177 | f"succeeded in {event.duration_micros} microseconds" 178 | ) 179 | 180 | def failed(self, event): 181 | logger.debug( 182 | f"Command {event.command_name} with request id {event.request_id} on server {event.connection_id} failed " 183 | f"in {event.duration_micros} microseconds" 184 | ) 185 | 186 | 187 | class ConnectionPoolLogger(pymongo.monitoring.ConnectionPoolListener): 188 | def pool_created(self, event): 189 | logger.debug(f"[pool {event.address}] pool created") 190 | 191 | def pool_cleared(self, event): 192 | logger.debug(f"[pool {event.address}] pool cleared") 193 | 194 | def pool_closed(self, event): 195 | logger.debug(f"[pool {event.address}] pool closed") 196 | 197 | def connection_created(self, event): 198 | logger.debug(f"[pool {event.address}][conn #{event.connection_id}] connection created") 199 | 200 | def connection_ready(self, event): 201 | logger.debug(f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded") 202 | 203 | def connection_closed(self, event): 204 | logger.debug(f"[pool {event.address}][conn #{event.connection_id}] connection closed, reason: {event.reason}") 205 | 206 | def connection_check_out_started(self, event): 207 | logger.debug(f"[pool {event.address}] connection check out started") 208 | 209 | def connection_check_out_failed(self, event): 210 | logger.debug(f"[pool {event.address}] connection check out failed, reason: {event.reason}") 211 | 212 | def connection_checked_out(self, event): 213 | logger.debug(f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool") 214 | 215 | def connection_checked_in(self, event): 216 | logger.debug(f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool") 217 | 218 | 219 | class ServerLogger(pymongo.monitoring.ServerListener): 220 | def opened(self, event): 221 | logger.debug(f"Server {event.server_address} added to topology {event.topology_id}") 222 | 223 | def description_changed(self, event): 224 | previous_server_type = event.previous_description.server_type 225 | new_server_type = event.new_description.server_type 226 | if new_server_type != previous_server_type: 227 | logger.debug( 228 | f"Server {event.server_address} changed type from {event.previous_description.server_type_name} to " 229 | f"{event.new_description.server_type_name}" 230 | ) 231 | 232 | def closed(self, event): 233 | logger.debug(f"Server {event.server_address} removed from topology {event.topology_id}") 234 | 235 | 236 | class HeartbeatLogger(pymongo.monitoring.ServerHeartbeatListener): 237 | def started(self, event): 238 | logger.debug(f"Heartbeat sent to server {event.connection_id}") 239 | 240 | def succeeded(self, event): 241 | logger.debug(f"Heartbeat to server {event.connection_id} succeeded with reply {event.reply.document}") 242 | 243 | def failed(self, event): 244 | logger.debug(f"Heartbeat to server {event.connection_id} failed with error {event.reply}") 245 | 246 | 247 | class TopologyLogger(pymongo.monitoring.TopologyListener): 248 | def opened(self, event): 249 | logger.debug(f"Topology with id {event.topology_id} opened") 250 | 251 | def description_changed(self, event): 252 | logger.debug(f"Topology description updated for topology id {event.topology_id}") 253 | previous_topology_type = event.previous_description.topology_type 254 | new_topology_type = event.new_description.topology_type 255 | if new_topology_type != previous_topology_type: 256 | logger.debug( 257 | f"Topology {event.topology_id} changed type from {event.previous_description.topology_type_name} to " 258 | f"{event.new_description.topology_type_name}" 259 | ) 260 | if not event.new_description.has_writable_server(): 261 | logger.debug("No writable servers available.") 262 | if not event.new_description.has_readable_server(): 263 | logger.debug("No readable servers available.") 264 | 265 | def closed(self, event): 266 | logger.debug(f"Topology with id {event.topology_id} closed") 267 | 268 | 269 | class BaseDBManager: 270 | """Class hold MongoDB client connection.""" 271 | 272 | client: pymongo.MongoClient = None 273 | 274 | def __init__( 275 | self, db_url: str, default_db_name: str = "main", code_options: bson.codec_options.CodecOptions = CODEC_OPTIONS 276 | ): 277 | self.db_url = db_url 278 | self.default_db_name = default_db_name 279 | self.codec_options = code_options 280 | 281 | def retrieve_client(self) -> pymongo.MongoClient: 282 | """Retrieve existing MongoDB client or create it (at first call).""" 283 | if self.__class__.client is None: # pragma: no cover 284 | logger.debug(msg="Initialization of MongoDB") 285 | self.create_client() 286 | return self.__class__.client 287 | 288 | def create_client(self): 289 | """Create MongoDB client.""" 290 | logger.debug(msg="Creating MongoDB client") 291 | self.__class__.client = motor.motor_asyncio.AsyncIOMotorClient(self.db_url) 292 | 293 | def delete_client(self): 294 | """Close MongoDB client.""" 295 | logger.debug(msg="Disconnecting from MongoDB") 296 | self.__class__.client.close() 297 | self.__class__.client = None # noqa 298 | 299 | async def get_server_info(self, *, session: pymongo.client_session.ClientSession = None) -> dict: 300 | client = self.retrieve_client() 301 | return await client.server_info(session=session) 302 | 303 | def retrieve_database( 304 | self, 305 | *, 306 | name: str = None, 307 | code_options: bson.codec_options.CodecOptions = None, 308 | read_preference: pymongo.ReadPreference = pymongo.ReadPreference.SECONDARY_PREFERRED, # default None 309 | write_concern: pymongo.write_concern.WriteConcern = pymongo.write_concern.WriteConcern( 310 | w="majority", j=True 311 | ), # default None 312 | read_concern: pymongo.read_concern.ReadConcern = pymongo.read_concern.ReadConcern( 313 | level="majority" 314 | ), # default None 315 | ) -> pymongo.database.Database: 316 | """Retrieve Database by name.""" 317 | client = self.retrieve_client() 318 | if name is None: # pragma: no cover 319 | name = self.default_db_name 320 | if code_options is None: # pragma: no cover 321 | code_options = self.codec_options 322 | return client.get_database( 323 | name=name, 324 | codec_options=code_options, 325 | read_preference=read_preference, 326 | write_concern=write_concern, 327 | read_concern=read_concern, 328 | ) 329 | 330 | async def list_databases( 331 | self, 332 | *, 333 | only_names: bool = False, 334 | session: pymongo.client_session.ClientSession = None, 335 | ) -> typing.Union[list[str], list[dict[str, typing.Any]]]: 336 | client = self.retrieve_client() 337 | if only_names: 338 | return [database for database in await client.list_database_names(session=session)] 339 | return [database async for database in await client.list_databases(session=session)] 340 | 341 | async def delete_database(self, *, name: str, session: pymongo.client_session.ClientSession = None): 342 | client = self.retrieve_client() 343 | await client.drop_database(name_or_database=name, session=session) 344 | 345 | async def get_profiling_info(self, *, db_name: str, session: pymongo.client_session.ClientSession = None) -> list: 346 | database = self.retrieve_database(name=db_name) 347 | return [item async for item in database["system.profile"].find(session=session)] 348 | 349 | async def get_profiling_level(self, *, db_name: str, session: pymongo.client_session.ClientSession = None) -> int: 350 | database = self.retrieve_database(name=db_name) 351 | result = await database.command("profile", -1, session=session) 352 | return result["was"] 353 | 354 | async def set_profiling_level( 355 | self, *, db_name: str, level, slow_ms: int = None, session: pymongo.client_session.ClientSession = None 356 | ): 357 | database = self.retrieve_database(name=db_name) 358 | return await database.command("profile", level, slowms=slow_ms, session=session) 359 | 360 | async def create_collection( 361 | self, *, name: str, db_name: str = None, safe: bool = True, session: pymongo.client_session.ClientSession = None 362 | ) -> pymongo.collection.Collection: 363 | """Create collection by name.""" 364 | database = self.retrieve_database(name=db_name) 365 | try: 366 | return await database.create_collection(name=name, session=session) 367 | except pymongo.errors.CollectionInvalid as error: 368 | if not safe: 369 | raise error 370 | 371 | async def delete_collection( 372 | self, *, name: str, db_name: str = None, session: pymongo.client_session.ClientSession = None 373 | ): 374 | """Delete collection by name.""" 375 | database = self.retrieve_database(name=db_name) 376 | await database.drop_collection(name_or_collection=name, session=session) 377 | 378 | async def list_collections( 379 | self, 380 | *, 381 | db_name: str = None, 382 | only_names: bool = False, 383 | session: pymongo.client_session.ClientSession = None, 384 | ) -> typing.Union[list[str], list[dict[str, typing.Any]]]: 385 | database = self.retrieve_database(name=db_name) 386 | if only_names: 387 | return [collection for collection in await database.list_collection_names(session=session)] 388 | return [collection for collection in await database.list_collections(session=session)] 389 | 390 | async def create_index( 391 | self, 392 | *, 393 | col_name: str, 394 | db_name: str = None, 395 | index: list[tuple[str, int]], 396 | name: str, 397 | background: bool = True, 398 | unique: bool = False, 399 | sparse: bool = False, 400 | session: pymongo.client_session.ClientSession = None, 401 | **kwargs, 402 | ) -> str: 403 | """Create an index for collection.""" 404 | database = self.retrieve_database(name=db_name) 405 | return await database[col_name].create_index( 406 | keys=index, name=name, background=background, unique=unique, sparse=sparse, session=session, **kwargs 407 | ) 408 | 409 | async def create_indexes( 410 | self, 411 | *, 412 | col_name: str, 413 | indexes: list[pymongo.IndexModel], 414 | db_name: str = None, 415 | session: pymongo.client_session.ClientSession = None, 416 | ): 417 | """Create indexes for collection by list of IndexModel.""" 418 | database: pymongo.database.Database = self.retrieve_database(name=db_name) 419 | return await database[col_name].create_indexes(indexes=indexes, session=session) 420 | 421 | async def delete_index( 422 | self, 423 | *, 424 | name: str, 425 | col_name: str, 426 | db_name: str = None, 427 | safe: bool = True, 428 | session: pymongo.client_session.ClientSession = None, 429 | ): 430 | """Remove index for specific collection.""" 431 | database = self.retrieve_database(name=db_name) 432 | try: 433 | await database[col_name].drop_index(index_or_name=name, session=session) 434 | except pymongo.errors.OperationFailure as error: 435 | if not safe: 436 | raise error 437 | 438 | async def list_indexes( 439 | self, 440 | *, 441 | col_name: str, 442 | db_name: str = None, 443 | only_names: bool = False, 444 | session: pymongo.client_session.ClientSession = None, 445 | ): 446 | database = self.retrieve_database(name=db_name) 447 | if only_names: 448 | return [index["name"] async for index in database[col_name].list_indexes(session=session)] 449 | return [index async for index in database[col_name].list_indexes(session=session)] 450 | -------------------------------------------------------------------------------- /fastapi_mongodb/dependencies.py: -------------------------------------------------------------------------------- 1 | """Library code for working with FastAPI dependencies.""" 2 | import fastapi 3 | import pymongo.client_session 4 | 5 | import fastapi_mongodb.db 6 | 7 | __all__ = ["DBSession"] 8 | 9 | 10 | class DBSession: 11 | def __init__(self, db_manager: fastapi_mongodb.db.BaseDBManager = None, state_attr_name: str = "db_manager"): 12 | self.db_manager = db_manager 13 | self.state_attr_name = state_attr_name 14 | 15 | async def __call__(self, request: fastapi.Request) -> pymongo.client_session.ClientSession: 16 | db_manager = self.db_manager or getattr(request.app.state, self.state_attr_name, None) 17 | if not db_manager: 18 | raise NotImplementedError( 19 | "Provide 'db_manager' parameter to dependency initializer OR set '.state.db_manager' " 20 | "to global app state" 21 | ) 22 | db_client = db_manager.retrieve_client() 23 | async with await db_client.start_session() as session: 24 | yield session 25 | -------------------------------------------------------------------------------- /fastapi_mongodb/exceptions.py: -------------------------------------------------------------------------------- 1 | """Common apps exceptions.""" 2 | 3 | __all__ = ["ManagerException", "NotFoundManagerException"] 4 | 5 | 6 | class ManagerException(Exception): 7 | """Exception that raises in managers.""" 8 | 9 | 10 | class NotFoundManagerException(ManagerException): 11 | """Exception that raises in managers if required object wasn't found.""" 12 | -------------------------------------------------------------------------------- /fastapi_mongodb/helpers.py: -------------------------------------------------------------------------------- 1 | """Classes and functions to provide common timezones, profiling etc.""" 2 | import datetime 3 | import functools 4 | import time 5 | import tracemalloc 6 | import typing 7 | import zoneinfo 8 | 9 | from fastapi_mongodb.logging import simple_logger 10 | 11 | __all__ = ["get_utc_timezone", "utc_now", "as_utc", "BaseProfiler"] 12 | 13 | 14 | @functools.lru_cache() 15 | def get_utc_timezone() -> zoneinfo.ZoneInfo: 16 | """Return UTC zone info.""" 17 | return zoneinfo.ZoneInfo(key="UTC") 18 | 19 | 20 | def utc_now() -> datetime.datetime: 21 | """Return current datetime with UTC zone info.""" 22 | return datetime.datetime.now(tz=get_utc_timezone()) 23 | 24 | 25 | def as_utc(date_time: datetime.datetime) -> datetime.datetime: 26 | """Get datetime object and convert it to datetime with UTC zone info.""" 27 | return date_time.astimezone(tz=get_utc_timezone()) 28 | 29 | 30 | class BaseProfiler: 31 | def __init__( 32 | self, 33 | number_frames: int = 10, 34 | include_files: list[str] = None, 35 | exclude_files: list[str] = None, 36 | show_timing: bool = True, 37 | show_memory: bool = True, 38 | ): 39 | self.number_frames = number_frames 40 | self.include_files = include_files or [] 41 | self.exclude_files = exclude_files or [] 42 | self.show_timing = show_timing 43 | self.show_memory = show_memory 44 | self._start_time = None 45 | self._end_time = None 46 | 47 | def __call__(self, func): 48 | """Call function to work with profiler as decorator.""" 49 | 50 | @functools.wraps(func) 51 | def decorated(*args, **kwargs): 52 | self._start_trace_malloc() 53 | self._set_start_time() 54 | result = func(*args, **kwargs) 55 | self._set_end_time() 56 | self._print_timing(name=func.__name__) 57 | self._end_trace_malloc() 58 | return result 59 | 60 | return decorated 61 | 62 | def _set_start_time(self): 63 | """Set time start point.""" 64 | self._start_time = time.time() 65 | 66 | def _set_end_time(self): 67 | """Set time end point.""" 68 | self._end_time = time.time() 69 | 70 | def __enter__(self): 71 | """Start profiling.""" 72 | self._start_trace_malloc() 73 | self._set_start_time() 74 | 75 | def __exit__(self, exc_type, exc_val, exc_tb): 76 | """End profiling.""" 77 | self._set_end_time() 78 | self._print_timing(name="CODE BLOCK") 79 | self._end_trace_malloc() 80 | 81 | def _start_trace_malloc(self): 82 | if not self.show_memory: 83 | return 84 | 85 | tracemalloc.start(self.number_frames) if self.number_frames else tracemalloc.start() 86 | 87 | def _end_trace_malloc(self): 88 | if not self.show_memory: 89 | return 90 | 91 | simple_logger.debug(msg="=== START SNAPSHOT ===") 92 | snapshot = tracemalloc.take_snapshot() 93 | snapshot = snapshot.filter_traces(filters=self._get_trace_malloc_filters()) 94 | for stat in snapshot.statistics(key_type="lineno", cumulative=True): 95 | simple_logger.debug(msg=f"{stat}") 96 | size, peak = tracemalloc.get_traced_memory() 97 | snapshot_size = tracemalloc.get_tracemalloc_memory() 98 | simple_logger.debug( 99 | msg=f"❕size={self._bytes_to_megabytes(size=size)}, " 100 | f"❗peak={self._bytes_to_megabytes(size=peak)}, " 101 | f"💾snapshot_size={self._bytes_to_megabytes(size=snapshot_size)}" 102 | ) 103 | tracemalloc.clear_traces() 104 | simple_logger.debug(msg="=== END SNAPSHOT ===") 105 | 106 | def _get_trace_malloc_filters( 107 | self, 108 | ) -> list[typing.Union[tracemalloc.Filter, tracemalloc.DomainFilter]]: 109 | filters = [tracemalloc.Filter(inclusive=True, filename_pattern=file_name) for file_name in self.include_files] 110 | 111 | for file_name in self.exclude_files: 112 | filters.append(tracemalloc.Filter(inclusive=False, filename_pattern=file_name)) 113 | return filters 114 | 115 | @staticmethod 116 | def _bytes_to_megabytes(size: int, precision: int = 3): 117 | return f"{size / 1024.0 / 1024.0:.{precision}f} MB" 118 | 119 | def _print_timing(self, name: str, precision: int = 6): 120 | if not self.show_timing: 121 | return 122 | simple_logger.debug( 123 | f"📊Execution timing of: '{name}' ⏱: {self._end_time - self._start_time:.{precision}f} seconds" 124 | ) 125 | -------------------------------------------------------------------------------- /fastapi_mongodb/logging.py: -------------------------------------------------------------------------------- 1 | """Classes to provide setup logging and colors.""" 2 | import copy 3 | import enum 4 | import functools 5 | import logging 6 | import typing 7 | from typing import Union 8 | 9 | import click 10 | 11 | __all__ = ["logger", "simple_logger", "setup_logging"] 12 | 13 | TRACE = 5 14 | SUCCESS = 25 15 | 16 | 17 | class Colors(str, enum.Enum): 18 | RED = "red" 19 | GREEN = "green" 20 | ORANGE = "yellow" 21 | BLUE = "blue" 22 | MAGENTA = "magenta" 23 | CYAN = "cyan" 24 | GRAY = "black" 25 | LIGHT_GRAY = "white" 26 | BRIGHT_BLACK = "bright_black" 27 | BRIGHT_RED = "bright_red" 28 | BRIGHT_GREEN = "bright_green" 29 | BRIGHT_YELLOW = "bright_yellow" 30 | BRIGHT_BLUE = "bright_blue" 31 | BRIGHT_MAGENTA = "bright_magenta" 32 | BRIGHT_CYAN = "bright_cyan" 33 | WHITE = "bright_white" 34 | CLEAR = "reset" 35 | 36 | 37 | class Red(tuple[int], enum.Enum): 38 | RED = (244, 67, 54) 39 | RED_1 = (255, 235, 238) 40 | RED_2 = (254, 205, 210) 41 | RED_3 = (239, 154, 154) 42 | RED_4 = (229, 115, 115) 43 | RED_5 = (239, 83, 80) 44 | RED_6 = (244, 67, 54) 45 | RED_7 = (229, 57, 52) 46 | RED_8 = (211, 47, 47) 47 | RED_9 = (198, 40, 40) 48 | RED_10 = (183, 28, 28) 49 | 50 | 51 | class Pink(tuple[int], enum.Enum): 52 | PINK = (233, 31, 99) 53 | PINK_1 = (252, 228, 236) 54 | PINK_2 = (248, 187, 208) 55 | PINK_3 = (244, 143, 177) 56 | PINK_4 = (239, 98, 146) 57 | PINK_5 = (236, 64, 121) 58 | PINK_6 = (233, 31, 99) 59 | PINK_7 = (216, 28, 96) 60 | PINK_8 = (194, 25, 92) 61 | PINK_9 = (173, 21, 87) 62 | PINK_10 = (136, 15, 79) 63 | 64 | 65 | class Purple(tuple[int], enum.Enum): 66 | PURPLE = (156, 39, 176) 67 | PURPLE_1 = (243, 228, 245) 68 | PURPLE_2 = (225, 190, 231) 69 | PURPLE_3 = (206, 147, 216) 70 | PURPLE_4 = (186, 104, 200) 71 | PURPLE_5 = (171, 73, 188) 72 | PURPLE_6 = (156, 39, 176) 73 | PURPLE_7 = (142, 36, 170) 74 | PURPLE_8 = (123, 31, 163) 75 | PURPLE_9 = (106, 27, 154) 76 | PURPLE_10 = (74, 20, 140) 77 | 78 | 79 | class DeepPurple(tuple[int], enum.Enum): 80 | DEEP_PURPLE = (103, 58, 183) 81 | DEEP_PURPLE_1 = (237, 231, 246) 82 | DEEP_PURPLE_2 = (209, 196, 233) 83 | DEEP_PURPLE_3 = (178, 157, 219) 84 | DEEP_PURPLE_4 = (149, 117, 205) 85 | DEEP_PURPLE_5 = (126, 87, 194) 86 | DEEP_PURPLE_6 = (103, 58, 183) 87 | DEEP_PURPLE_7 = (94, 53, 177) 88 | DEEP_PURPLE_8 = (81, 45, 167) 89 | DEEP_PURPLE_9 = (69, 39, 160) 90 | DEEP_PURPLE_10 = (49, 27, 146) 91 | 92 | 93 | class Indigo(tuple[int], enum.Enum): 94 | INDIGO = (63, 81, 181) 95 | INDIGO_1 = (232, 234, 246) 96 | INDIGO_2 = (197, 202, 233) 97 | INDIGO_3 = (159, 168, 218) 98 | INDIGO_4 = (121, 134, 203) 99 | INDIGO_5 = (92, 107, 192) 100 | INDIGO_6 = (63, 81, 181) 101 | INDIGO_7 = (57, 73, 171) 102 | INDIGO_8 = (48, 63, 159) 103 | INDIGO_9 = (39, 53, 147) 104 | INDIGO_10 = (26, 35, 126) 105 | 106 | 107 | class Blue(tuple[int], enum.Enum): 108 | BLUE = (33, 150, 243) 109 | BLUE_1 = (227, 242, 253) 110 | BLUE_2 = (187, 222, 251) 111 | BLUE_3 = (143, 202, 249) 112 | BLUE_4 = (99, 181, 246) 113 | BLUE_5 = (66, 165, 245) 114 | BLUE_6 = (33, 150, 243) 115 | BLUE_7 = (30, 136, 229) 116 | BLUE_8 = (25, 118, 210) 117 | BLUE_9 = (21, 101, 192) 118 | BLUE_10 = (13, 71, 161) 119 | 120 | 121 | class LightBlue(tuple[int], enum.Enum): 122 | LIGHT_BLUE = (7, 169, 244) 123 | LIGHT_BLUE_1 = (225, 245, 254) 124 | LIGHT_BLUE_2 = (179, 229, 252) 125 | LIGHT_BLUE_3 = (129, 212, 250) 126 | LIGHT_BLUE_4 = (79, 195, 247) 127 | LIGHT_BLUE_5 = (41, 182, 246) 128 | LIGHT_BLUE_6 = (7, 169, 244) 129 | LIGHT_BLUE_7 = (4, 155, 229) 130 | LIGHT_BLUE_8 = (1, 136, 209) 131 | LIGHT_BLUE_9 = (2, 119, 189) 132 | LIGHT_BLUE_10 = (1, 87, 155) 133 | 134 | 135 | class Cyan(tuple[int], enum.Enum): 136 | CYAN = (25, 188, 212) 137 | CYAN_1 = (223, 247, 250) 138 | CYAN_2 = (178, 235, 242) 139 | CYAN_3 = (128, 222, 234) 140 | CYAN_4 = (77, 208, 225) 141 | CYAN_5 = (38, 198, 218) 142 | CYAN_6 = (25, 188, 212) 143 | CYAN_7 = (22, 172, 193) 144 | CYAN_8 = (18, 151, 166) 145 | CYAN_9 = (15, 131, 143) 146 | CYAN_10 = (9, 96, 100) 147 | 148 | 149 | class Gray(tuple[int], enum.Enum): 150 | GRAY = (158, 158, 158) 151 | GRAY_1 = (250, 250, 250) 152 | GRAY_2 = (245, 245, 245) 153 | GRAY_3 = (238, 238, 238) 154 | GRAY_4 = (224, 224, 224) 155 | GRAY_5 = (189, 189, 189) 156 | GRAY_6 = (158, 158, 158) 157 | GRAY_7 = (117, 117, 117) 158 | GRAY_8 = (97, 97, 97) 159 | GRAY_9 = (66, 66, 66) 160 | GRAY_10 = (33, 33, 33) 161 | 162 | 163 | class Brown(tuple[int], enum.Enum): 164 | BROWN = (121, 85, 72) 165 | BROWN_1 = (239, 235, 233) 166 | BROWN_2 = (215, 204, 200) 167 | BROWN_3 = (188, 170, 164) 168 | BROWN_4 = (161, 136, 127) 169 | BROWN_5 = (141, 110, 99) 170 | BROWN_6 = (121, 85, 72) 171 | BROWN_7 = (109, 76, 65) 172 | BROWN_8 = (94, 64, 55) 173 | BROWN_9 = (78, 52, 46) 174 | BROWN_10 = (62, 39, 35) 175 | 176 | 177 | class Palette: 178 | COLORS = Colors 179 | GRAY = Gray 180 | RED = Red 181 | PINK = Pink 182 | PURPLE = Purple 183 | DEEP_PURPLE = DeepPurple 184 | INDIGO = Indigo 185 | BLUE = Blue 186 | LIGHT_BLUE = LightBlue 187 | CYAN = Cyan 188 | BROWN = Brown 189 | 190 | 191 | class Styler: 192 | def __init__(self, override_kwargs: list[dict[typing.Union[int, str], typing.Any]] = None): 193 | self.colors_map: dict[int, functools.partial] = {} # {: , ...} 194 | _default_kwargs = [ 195 | {"level": TRACE, "fg": Palette.GRAY.GRAY_6}, 196 | {"level": logging.DEBUG, "fg": Palette.BROWN.BROWN_6}, 197 | {"level": logging.INFO, "fg": Palette.COLORS.BRIGHT_BLUE}, 198 | {"level": SUCCESS, "fg": Palette.COLORS.GREEN}, 199 | {"level": logging.WARNING, "fg": Palette.COLORS.BRIGHT_YELLOW}, 200 | {"level": logging.ERROR, "fg": Palette.COLORS.BRIGHT_RED}, 201 | {"level": logging.CRITICAL, "fg": Palette.DEEP_PURPLE.DEEP_PURPLE_5, "bold": True, "underline": True}, 202 | ] 203 | 204 | for kwargs in _default_kwargs: 205 | self.set_style(**kwargs) 206 | 207 | if override_kwargs: 208 | for kwargs in override_kwargs: 209 | self.set_style(**kwargs) 210 | 211 | def get_style(self, level: int) -> functools.partial: 212 | return self.colors_map.get(level, None) 213 | 214 | def set_style( 215 | self, 216 | level: int, 217 | fg=None, 218 | bg=None, 219 | bold: bool = None, 220 | dim: bool = None, 221 | underline: bool = None, 222 | overline: bool = None, 223 | italic: bool = None, 224 | blink: bool = None, 225 | reverse: bool = None, 226 | strikethrough: bool = None, 227 | reset: bool = True, 228 | ): 229 | self.colors_map[level] = functools.partial( 230 | click.style, 231 | fg=fg, 232 | bg=bg, 233 | bold=bold, 234 | dim=dim, 235 | underline=underline, 236 | overline=overline, 237 | italic=italic, 238 | blink=blink, 239 | reverse=reverse, 240 | strikethrough=strikethrough, 241 | reset=reset, 242 | ) 243 | 244 | 245 | class DebugFormatter(logging.Formatter): 246 | def __init__( 247 | self, 248 | fmt: str = "{levelname} {message} ({asctime})", 249 | datefmt: str = "%Y-%m-%dT%H:%M:%S%z", 250 | style: str = "{", 251 | validate: bool = True, 252 | accent_color: str = "bright_cyan", 253 | styler: typing.ClassVar[Styler] = Styler, 254 | ): 255 | self.accent_color = accent_color 256 | self.styler = styler() 257 | super().__init__(fmt=fmt, datefmt=datefmt, style=style, validate=validate) 258 | 259 | def formatMessage(self, record: logging.LogRecord) -> str: 260 | record_copy = copy.copy(record) 261 | for key in record_copy.__dict__: 262 | if key == "message": 263 | record_copy.__dict__["message"] = self.styler.get_style(level=record_copy.levelno)( 264 | text=record_copy.message 265 | ) 266 | elif key == "levelname": 267 | separator = " " * (8 - len(record_copy.levelname)) 268 | record_copy.__dict__["levelname"] = ( 269 | self.styler.get_style(level=record_copy.levelno)(text=record_copy.levelname) 270 | + click.style(text=":", fg=self.accent_color) 271 | + separator 272 | ) 273 | elif key == "levelno": 274 | continue # set it after iterations (because using in other cases) 275 | else: 276 | record_copy.__dict__[key] = click.style(text=str(record.__dict__[key]), fg=self.accent_color) 277 | 278 | record_copy.__dict__["levelno"] = click.style(text=str(record.__dict__["levelno"]), fg=self.accent_color) 279 | 280 | return super().formatMessage(record=record_copy) 281 | 282 | 283 | class PyCharmDebugLogger(logging.getLoggerClass()): 284 | def trace(self, msg, *args, **kwargs): 285 | if self.isEnabledFor(level=TRACE): 286 | self._log(level=TRACE, msg=msg, args=args, **kwargs) 287 | 288 | def success(self, msg, *args, **kwargs): 289 | if self.isEnabledFor(level=SUCCESS): 290 | self._log(level=SUCCESS, msg=msg, args=args, **kwargs) 291 | 292 | 293 | @functools.lru_cache() 294 | def setup_logging( 295 | name: str = "Debug Logger", 296 | raw_format: str = "{levelname} {message} {asctime}", 297 | file_link_formatter: bool = True, 298 | color_formatter: bool = True, 299 | date_format: str = "%Y-%m-%dT%H:%M:%S%z", 300 | accent_color: Union[tuple[int, int, int], str] = Palette.COLORS.CYAN, 301 | styler: typing.ClassVar[Styler] = Styler, 302 | ): 303 | """Class for logging formatter.""" 304 | if file_link_formatter and color_formatter: 305 | # make PyCharm available link to file where's log occurs 306 | # example of link creation: "File {pathname}, line {lineno}" 307 | file_format = click.style(text='╰───📑File "', fg="bright_white", bold=True) 308 | line_format = click.style(text='", line ', fg="bright_white", bold=True) 309 | debug_format = raw_format + f"\n{file_format}{{pathname}}{line_format}{{lineno}}" 310 | formatter = DebugFormatter(fmt=debug_format, datefmt=date_format, accent_color=accent_color, styler=styler) 311 | elif not file_link_formatter and color_formatter: 312 | # use colored formatter 313 | formatter = DebugFormatter(fmt=raw_format, datefmt=date_format, accent_color=accent_color, styler=styler) 314 | else: 315 | # use default Formatter 316 | formatter = logging.Formatter(fmt=raw_format, datefmt=date_format, style="{") 317 | 318 | logging.addLevelName(level=TRACE, levelName="TRACE") 319 | logging.addLevelName(level=SUCCESS, levelName="SUCCESS") 320 | handler = logging.StreamHandler() 321 | handler.setFormatter(fmt=formatter) 322 | handler.setLevel(level=TRACE) 323 | 324 | main_logger = PyCharmDebugLogger(name=name) 325 | main_logger.setLevel(level=TRACE) 326 | main_logger.addHandler(hdlr=handler) 327 | 328 | return main_logger 329 | 330 | 331 | logger = setup_logging() 332 | simple_logger = setup_logging(file_link_formatter=False) 333 | raw_logger = setup_logging(file_link_formatter=False, color_formatter=False) 334 | -------------------------------------------------------------------------------- /fastapi_mongodb/managers.py: -------------------------------------------------------------------------------- 1 | """Manager implementations for common tasks.""" 2 | import binascii 3 | import copy 4 | import datetime 5 | import enum 6 | import hashlib 7 | import secrets 8 | import typing 9 | 10 | import jwt 11 | 12 | import fastapi_mongodb.exceptions 13 | import fastapi_mongodb.helpers 14 | import fastapi_mongodb.schemas 15 | 16 | __all__ = ["PASSWORD_ALGORITHMS", "TOKEN_ALGORITHMS", "PasswordsManager", "TokensManager"] 17 | 18 | 19 | class PASSWORD_ALGORITHMS(str, enum.Enum): 20 | SHA256 = "sha256" 21 | SHA384 = "sha384" 22 | SHA512 = "sha512" 23 | 24 | 25 | ALGORITHMS_LENGTH_MAP = { 26 | PASSWORD_ALGORITHMS.SHA256: 64, 27 | PASSWORD_ALGORITHMS.SHA384: 96, 28 | PASSWORD_ALGORITHMS.SHA512: 128, 29 | } 30 | ALGORITHMS_METHODS_MAP = { 31 | PASSWORD_ALGORITHMS.SHA256: hashlib.sha256, 32 | PASSWORD_ALGORITHMS.SHA384: hashlib.sha384, 33 | PASSWORD_ALGORITHMS.SHA512: hashlib.sha512, 34 | } 35 | 36 | 37 | class TOKEN_ALGORITHMS(str, enum.Enum): 38 | HS256 = "HS256" 39 | HS384 = "HS384" 40 | HS512 = "HS512" 41 | 42 | 43 | class PasswordsManager: 44 | """Manager for generating and checking passwords.""" 45 | 46 | def __init__(self, *, algorithm: PASSWORD_ALGORITHMS = PASSWORD_ALGORITHMS.SHA512, iterations: int = 524288): 47 | self._algorithm = algorithm 48 | self._algorithm_length = ALGORITHMS_LENGTH_MAP[self._algorithm] 49 | self.iterations = iterations 50 | 51 | def __repr__(self): 52 | """Representation for PasswordsManager.""" 53 | return f"{self.__class__.__name__}(algorithm={self._algorithm}, iterations={self.iterations})" 54 | 55 | def __hash(self, *, salt: str, password: str, encoding: str = "UTF-8") -> str: 56 | """Make hash for password N length string (default length from algorithm).""" 57 | password_hash = hashlib.pbkdf2_hmac( 58 | hash_name=self._algorithm, 59 | password=password.encode(encoding=encoding), 60 | salt=salt.encode(encoding=encoding), 61 | iterations=self.iterations, 62 | ) 63 | password_hash = binascii.hexlify(password_hash).decode(encoding) 64 | return password_hash 65 | 66 | def __generate_salt(self) -> str: 67 | """Return N length string (default length from algorithm).""" 68 | method = ALGORITHMS_METHODS_MAP[self._algorithm] 69 | return method(string=secrets.token_bytes(nbytes=64)).hexdigest() # noqa 70 | 71 | def check_password(self, *, password: str, password_hash: str) -> bool: 72 | """Check password and hash.""" 73 | salt = password_hash[: self._algorithm_length] 74 | stored_password_hash = password_hash[self._algorithm_length :] 75 | check_password_hash = self.__hash(salt=salt, password=password) 76 | return check_password_hash == stored_password_hash 77 | 78 | def make_password(self, *, password: str) -> str: 79 | """Make hash from password.""" 80 | salt: str = self.__generate_salt() 81 | new_password_hash: str = self.__hash(salt=salt, password=password) 82 | return salt + new_password_hash 83 | 84 | 85 | class TokensManager: 86 | """Manager for generating and checking JWT tokens.""" 87 | 88 | def __init__( 89 | self, 90 | *, 91 | secret_key: str, 92 | algorithm: TOKEN_ALGORITHMS = TOKEN_ALGORITHMS.HS256, 93 | default_token_lifetime: datetime.timedelta = datetime.timedelta(minutes=30), 94 | ): 95 | self._secret_key = secret_key 96 | self._algorithm = algorithm 97 | self.default_token_lifetime = default_token_lifetime 98 | 99 | def create_code( 100 | self, 101 | *, 102 | data: dict[str, typing.Union[str, int, float, dict, list, bool]] = None, 103 | aud: str = "access", # Audience 104 | iat: datetime.datetime = None, # Issued at datetime 105 | exp: datetime.datetime = None, # Expired at datetime 106 | nbf: datetime.datetime = None, # Not before datetime 107 | iss: str = "", # Issuer 108 | ) -> str: 109 | """Method for generation of JWT token.""" 110 | if data is None: 111 | data = {} 112 | now = fastapi_mongodb.helpers.utc_now() 113 | if iat is None: 114 | iat = now 115 | if exp is None: 116 | exp = now + self.default_token_lifetime 117 | if nbf is None: 118 | nbf = now 119 | payload = copy.deepcopy(data) 120 | payload |= {"iat": iat, "aud": aud, "exp": exp, "nbf": nbf, "iss": iss} 121 | return jwt.encode(payload=payload, key=self._secret_key, algorithm=self._algorithm) 122 | 123 | def read_code( 124 | self, 125 | *, 126 | code: str, 127 | aud: str = "access", # Audience 128 | iss: str = "", # Issuer 129 | leeway: int = 0, # provide extra time in seconds to validate (iat, exp, nbf) 130 | convert_to: typing.Type[fastapi_mongodb.schemas.BaseSchema] = None, 131 | ): 132 | """Method for parse and validate JWT token.""" 133 | try: 134 | payload = jwt.decode( 135 | jwt=code, 136 | key=self._secret_key, 137 | algorithms=[self._algorithm], 138 | leeway=leeway, 139 | audience=aud, 140 | issuer=iss, 141 | ) 142 | if convert_to is not None: 143 | payload = convert_to(**payload) # noqa 144 | except jwt.exceptions.InvalidIssuerError as error: 145 | raise fastapi_mongodb.exceptions.ManagerException("Invalid JWT issuer.") from error 146 | except jwt.exceptions.InvalidAudienceError as error: 147 | raise fastapi_mongodb.exceptions.ManagerException("Invalid JWT audience.") from error 148 | except jwt.exceptions.ExpiredSignatureError as error: 149 | raise fastapi_mongodb.exceptions.ManagerException("Expired JWT token.") from error 150 | except jwt.exceptions.ImmatureSignatureError as error: 151 | raise fastapi_mongodb.exceptions.ManagerException("The token is not valid yet.") from error 152 | # base error exception from pyjwt 153 | except jwt.exceptions.PyJWTError as error: 154 | raise fastapi_mongodb.exceptions.ManagerException("Invalid JWT.") from error 155 | else: 156 | return payload 157 | -------------------------------------------------------------------------------- /fastapi_mongodb/middlewares.py: -------------------------------------------------------------------------------- 1 | """Application middleware classes.""" 2 | 3 | import fastapi 4 | from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint 5 | from starlette.types import ASGIApp 6 | 7 | from fastapi_mongodb.db import BaseDBManager 8 | 9 | __all__ = ["DBSessionMiddleware"] 10 | 11 | 12 | class DBSessionMiddleware(BaseHTTPMiddleware): 13 | """Append 'db_session' for every request.state.""" 14 | 15 | def __init__(self, app: ASGIApp, db_manager: BaseDBManager, dispatch: DispatchFunction = None) -> None: 16 | self.app = app 17 | self.db_manager = db_manager 18 | self.dispatch_func = self.dispatch if dispatch is None else dispatch 19 | 20 | async def dispatch(self, request: fastapi.Request, call_next: RequestResponseEndpoint) -> fastapi.Response: 21 | db_client = self.db_manager.retrieve_client() 22 | async with await db_client.start_session() as session: 23 | request.state.db_session = session 24 | return await call_next(request) 25 | -------------------------------------------------------------------------------- /fastapi_mongodb/models.py: -------------------------------------------------------------------------------- 1 | """Pydantic based and custom classes to interact with MongoDB database.""" 2 | import datetime 3 | import typing 4 | 5 | import bson 6 | import pydantic.typing 7 | import pymongo.client_session 8 | import pymongo.database 9 | import pymongo.results 10 | 11 | import fastapi_mongodb.db 12 | import fastapi_mongodb.helpers 13 | from fastapi_mongodb.config import BaseConfiguration 14 | from fastapi_mongodb.repositories import BaseRepository 15 | from fastapi_mongodb.types import OID 16 | 17 | __all__ = ["BaseDBModel", "BaseCreatedUpdatedModel", "BaseActiveRecord", "BaseDataMapper"] 18 | 19 | 20 | class BaseDBModel(pydantic.BaseModel): 21 | """Class for MongoDB (class data view).""" 22 | 23 | oid: typing.Optional[OID] = pydantic.Field(alias="_id") 24 | 25 | class Config(BaseConfiguration): 26 | """configuration class.""" 27 | 28 | id = property(fget=lambda self: str(self.oid)) 29 | 30 | @classmethod 31 | def from_db(cls, *, data: dict): 32 | """Convert result from MongoDB to BaseDBModel.""" 33 | if not data: 34 | return data 35 | return cls(**data) 36 | 37 | def dict( 38 | self, 39 | *, 40 | include: typing.Union["pydantic.typing.AbstractSetIntStr", "pydantic.typing.MappingIntStrAny"] = None, 41 | exclude: typing.Union["pydantic.typing.AbstractSetIntStr", "pydantic.typing.MappingIntStrAny"] = None, 42 | by_alias: bool = False, 43 | skip_defaults: bool = None, 44 | exclude_unset: bool = False, 45 | exclude_defaults: bool = False, 46 | exclude_none: bool = False, 47 | _to_db: bool = False, 48 | ) -> "pydantic.typing.DictStrAny": 49 | if _to_db: 50 | override_by_alias = True 51 | else: 52 | override_by_alias = by_alias # override value to work with BaseSchema 53 | 54 | return super().dict( 55 | include=include, 56 | exclude=exclude, 57 | by_alias=override_by_alias, 58 | skip_defaults=skip_defaults, 59 | exclude_unset=exclude_unset, 60 | exclude_defaults=exclude_defaults, 61 | exclude_none=exclude_none, 62 | ) 63 | 64 | def to_db( 65 | self, 66 | *, 67 | include: typing.Union["pydantic.typing.AbstractSetIntStr", "pydantic.typing.MappingIntStrAny"] = None, 68 | exclude: typing.Union["pydantic.typing.AbstractSetIntStr", "pydantic.typing.MappingIntStrAny"] = None, 69 | skip_defaults: bool = None, 70 | exclude_unset: bool = False, 71 | exclude_defaults: bool = False, 72 | exclude_none: bool = True, # pydantic default is False 73 | ) -> dict: 74 | """Prepare data for MongoDB.""" 75 | result: dict = self.dict( 76 | include=include, 77 | exclude=exclude, 78 | skip_defaults=skip_defaults, 79 | exclude_unset=exclude_unset, 80 | exclude_defaults=exclude_defaults, 81 | exclude_none=exclude_none, 82 | _to_db=True, 83 | ) 84 | 85 | # if no "oid" and "_id" -> creates it 86 | if "_id" not in result and "oid" not in result: 87 | result["_id"] = bson.ObjectId() 88 | 89 | return result 90 | 91 | 92 | class BaseCreatedUpdatedModel(BaseDBModel): 93 | created_at: typing.Optional[datetime.datetime] = pydantic.Field(default=None) 94 | updated_at: typing.Optional[datetime.datetime] = pydantic.Field(default=None) 95 | 96 | def to_db( 97 | self, 98 | *, 99 | include: typing.Union["pydantic.typing.AbstractSetIntStr", "pydantic.typing.MappingIntStrAny"] = None, 100 | exclude: typing.Union["pydantic.typing.AbstractSetIntStr", "pydantic.typing.MappingIntStrAny"] = None, 101 | skip_defaults: bool = None, 102 | exclude_unset: bool = False, 103 | exclude_defaults: bool = False, 104 | exclude_none: bool = True, # pydantic default is False 105 | ) -> dict: 106 | result = super(BaseCreatedUpdatedModel, self).to_db( 107 | include=include, 108 | exclude=exclude, 109 | skip_defaults=skip_defaults, 110 | exclude_unset=exclude_unset, 111 | exclude_defaults=exclude_defaults, 112 | exclude_none=exclude_none, 113 | ) 114 | result["updated_at"] = fastapi_mongodb.helpers.utc_now() 115 | result["created_at"] = result["_id"].generation_time.replace(tzinfo=fastapi_mongodb.helpers.get_utc_timezone()) 116 | return result 117 | 118 | 119 | class BaseActiveRecord: 120 | def __init__( 121 | self, 122 | document: typing.Union[dict, fastapi_mongodb.db.BaseDocument], 123 | repository: BaseRepository, 124 | ): 125 | self._document = document 126 | self._repository: BaseRepository = repository 127 | self._unset_data = {} 128 | 129 | def __repr__(self): 130 | """Representation of BaseActiveRecord.""" 131 | data = f"oid={doc_id}" if (doc_id := self.id) else "NO DOCUMENT id" 132 | return f"{self.__class__.__name__}({data})" 133 | 134 | def __setitem__(self, key, value): 135 | """Set key to BaseActiveRecord.""" 136 | self._document.__setitem__(key, value) 137 | self._unset_data.pop(key, None) 138 | 139 | def __getitem__(self, item): 140 | """Retrieve key from BaseActiveRecord.""" 141 | return self._document.__getitem__(item) 142 | 143 | def __delitem__(self, key): 144 | """Delete key from BaseActiveRecord.""" 145 | self._document.__delitem__(key) 146 | self._unset_data[key] = None 147 | 148 | def __contains__(self, item): 149 | """Check if key exists in BaseActiveRecord.""" 150 | if item in self._document.keys(): 151 | return True 152 | return False 153 | 154 | @property 155 | def oid(self): 156 | try: 157 | return self._document["_id"] 158 | except KeyError: 159 | fastapi_mongodb.simple_logger.warning(msg="Document already deleted or has not been created yet!") 160 | return None 161 | 162 | @property 163 | def id(self): 164 | return str(self.oid) if self.oid else None 165 | 166 | @property 167 | def generated_at(self) -> typing.Optional[datetime.datetime]: 168 | try: 169 | return self._document["_id"].generation_time.astimezone(tz=fastapi_mongodb.helpers.get_utc_timezone()) 170 | except KeyError: 171 | fastapi_mongodb.simple_logger.warning(msg="Document already deleted or has not been created yet!") 172 | return None 173 | 174 | def get_collection(self) -> pymongo.collection.Collection: 175 | return self._repository.col 176 | 177 | def get_db(self) -> pymongo.database.Database: 178 | return self._repository.db 179 | 180 | @classmethod 181 | async def create( 182 | cls, 183 | document: dict, 184 | repository: BaseRepository, 185 | session: pymongo.client_session.ClientSession = None, 186 | ) -> "BaseActiveRecord": 187 | result = await repository.insert_one(document=document, session=session) 188 | return await cls.read(query={"_id": result.inserted_id}, repository=repository, session=session) 189 | 190 | @classmethod 191 | async def insert( 192 | cls, 193 | document: dict, 194 | repository: BaseRepository, 195 | session: pymongo.client_session.ClientSession = None, 196 | ) -> "BaseActiveRecord": 197 | return await cls.create(document=document, repository=repository, session=session) 198 | 199 | async def update(self, session: pymongo.client_session.ClientSession = None) -> "BaseActiveRecord": 200 | updated_document = await self._repository.find_one_and_update( 201 | query={"_id": self.oid}, 202 | update={"$set": self._document, "$unset": self._unset_data}, 203 | session=session, 204 | ) 205 | self._document = updated_document 206 | self._unset_data = {} 207 | return self 208 | 209 | @classmethod 210 | async def read( 211 | cls, 212 | query: dict, 213 | repository: BaseRepository, 214 | session: pymongo.client_session.ClientSession = None, 215 | ) -> "BaseActiveRecord": 216 | result = await repository.find_one(query=query, session=session) 217 | return cls(document=result or {}, repository=repository) 218 | 219 | async def delete(self, session: pymongo.client_session.ClientSession = None) -> "BaseActiveRecord": 220 | await self._repository.delete_one(query={"_id": self.oid}, session=session) 221 | self._document = {} 222 | return self 223 | 224 | async def refresh(self, session: pymongo.client_session.ClientSession = None) -> "BaseActiveRecord": 225 | result = await self._repository.find_one(query={"_id": self.oid}, session=session) if self.oid else None 226 | self._document = result or {} 227 | return self 228 | 229 | async def reload(self, session: pymongo.client_session.ClientSession = None) -> "BaseActiveRecord": 230 | return await self.refresh(session=session) 231 | 232 | 233 | class BaseDataMapper: 234 | def __init__( 235 | self, 236 | repository: BaseRepository, 237 | model: typing.Type[BaseDBModel], 238 | db_session: pymongo.client_session.ClientSession = None, 239 | ): 240 | self._repository = repository 241 | self._model = model 242 | self._db_session = db_session 243 | 244 | @property 245 | def db_session(self): 246 | return self._db_session 247 | 248 | @db_session.setter 249 | def db_session(self, value): 250 | self._db_session = value 251 | 252 | def __repr__(self): 253 | return ( 254 | f"{self.__class__.__name__}(repository={self._repository.__class__.__name__}, model={self._model.__name__})" 255 | ) 256 | 257 | async def create( 258 | self, model: BaseDBModel, raw_result: bool = False, session: pymongo.client_session.ClientSession = None 259 | ) -> typing.Union[pymongo.results.InsertOneResult, BaseDBModel]: 260 | session = session or self._db_session 261 | insertion_result = await self._repository.insert_one(document=model.to_db(), session=session) 262 | if raw_result: 263 | return insertion_result 264 | else: 265 | return await self.retrieve(query={"_id": insertion_result.inserted_id}, session=session) 266 | 267 | async def list(self, query: dict, session: pymongo.client_session.ClientSession = None): 268 | return ( 269 | self._model.from_db(data=result) 270 | async for result in await self._repository.find(query=query, session=session or self._db_session) 271 | ) 272 | 273 | async def retrieve(self, query: dict, session: pymongo.client_session.ClientSession = None): 274 | document = await self._repository.find_one(query=query, session=session or self._db_session) 275 | return self._model.from_db(data=document) 276 | 277 | async def partial_update(self, model: BaseDBModel, session: pymongo.client_session.ClientSession = None): 278 | document = await self._repository.find_one_and_replace( 279 | query={"_id": model.id}, replacement=model.to_db(), session=session or self._db_session 280 | ) 281 | return self._model.from_db(data=document) 282 | 283 | async def update(self, model: BaseDBModel, session: pymongo.client_session.ClientSession = None): 284 | return await self._repository.replace_one( 285 | query={"_id": model.id}, replacement=model.to_db(), session=session or self._db_session 286 | ) 287 | 288 | async def delete(self, model: BaseDBModel, session: pymongo.client_session.ClientSession = None): 289 | return await self._repository.delete_one(query={"_id": model.oid}, session=session or self._db_session) 290 | -------------------------------------------------------------------------------- /fastapi_mongodb/repositories.py: -------------------------------------------------------------------------------- 1 | """Repository pattern to work with MongoDB.""" 2 | import typing 3 | from functools import cached_property 4 | 5 | import motor.motor_asyncio 6 | import pymongo.client_session 7 | import pymongo.results 8 | 9 | from fastapi_mongodb.db import BaseDBManager 10 | 11 | 12 | class BaseRepository: 13 | def __init__(self, db_manager: BaseDBManager, db_name: str, col_name: str): 14 | """Repository initializer.""" 15 | self._db_manager = db_manager 16 | self._db_name = db_name 17 | self._col_name = col_name 18 | 19 | @cached_property 20 | def db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: 21 | """Retrieve database of this repository.""" 22 | return self._db_manager.retrieve_database(name=self._db_name) 23 | 24 | @cached_property 25 | def col(self) -> motor.motor_asyncio.AsyncIOMotorCollection: 26 | """Retrieve collection of this repository.""" 27 | return self.db[self._col_name] 28 | 29 | async def insert_one( 30 | self, 31 | *, 32 | document: dict, 33 | session: pymongo.client_session.ClientSession = None, 34 | **kwargs, 35 | ) -> pymongo.results.InsertOneResult: 36 | """Insert one document to MongoDB.""" 37 | return await self.col.insert_one(document=document, session=session, **kwargs) 38 | 39 | async def insert_many( 40 | self, 41 | *, 42 | documents: list[dict], 43 | ordered: bool = False, 44 | session: pymongo.client_session.ClientSession = None, 45 | **kwargs, 46 | ) -> pymongo.results.InsertManyResult: 47 | """Insert many documents to MongoDB.""" 48 | return await self.col.insert_many(documents=documents, ordered=ordered, session=session, **kwargs) 49 | 50 | async def replace_one( 51 | self, 52 | *, 53 | query: dict, 54 | replacement: dict, 55 | upsert: bool = False, 56 | session: pymongo.client_session.ClientSession = None, 57 | **kwargs, 58 | ) -> pymongo.results.UpdateResult: 59 | """Replace one document in MongoDB.""" 60 | return await self.col.replace_one( 61 | filter=query, 62 | replacement=replacement, 63 | upsert=upsert, 64 | session=session, 65 | **kwargs, 66 | ) 67 | 68 | async def update_one( 69 | self, 70 | *, 71 | query: dict, 72 | update: dict, 73 | upsert: bool = False, 74 | session: pymongo.client_session.ClientSession = None, 75 | **kwargs, 76 | ) -> pymongo.results.UpdateResult: 77 | """Update one document to MongoDB.""" 78 | return await self.col.update_one(filter=query, update=update, upsert=upsert, session=session, **kwargs) 79 | 80 | async def update_many( 81 | self, 82 | *, 83 | query: dict, 84 | update: dict, 85 | upsert: bool = False, 86 | session: pymongo.client_session.ClientSession = None, 87 | **kwargs, 88 | ) -> pymongo.results.UpdateResult: 89 | """Update many documents to MongoDB.""" 90 | return await self.col.update_many(filter=query, update=update, upsert=upsert, session=session, **kwargs) 91 | 92 | async def delete_one( 93 | self, 94 | *, 95 | query: dict, 96 | session: pymongo.client_session.ClientSession = None, 97 | **kwargs, 98 | ) -> pymongo.results.DeleteResult: 99 | """Delete one document from MongoDB.""" 100 | return await self.col.delete_one(filter=query, session=session, **kwargs) 101 | 102 | async def delete_many( 103 | self, 104 | *, 105 | query: dict, 106 | session: pymongo.client_session.ClientSession = None, 107 | **kwargs, 108 | ) -> pymongo.results.DeleteResult: 109 | """Delete many documents from MongoDB.""" 110 | return await self.col.delete_many(filter=query, session=session, **kwargs) 111 | 112 | async def find( 113 | self, 114 | *, 115 | query: dict, 116 | sort: list[tuple[str, int]] = None, 117 | skip: int = 0, 118 | limit: int = 0, 119 | projection: typing.Union[list[str], dict[str, bool]] = None, 120 | session: pymongo.client_session.ClientSession = None, 121 | **kwargs, 122 | ) -> motor.motor_asyncio.AsyncIOMotorCursor: 123 | """Find documents from MongoDB.""" 124 | return self.col.find( 125 | filter=query, 126 | sort=sort, 127 | skip=skip, 128 | limit=limit, 129 | projection=projection, 130 | session=session, 131 | **kwargs, 132 | ) 133 | 134 | async def find_one( 135 | self, 136 | *, 137 | query: dict, 138 | sort: list[tuple[str, int]] = None, 139 | projection: typing.Union[list[str], dict[str, bool]] = None, 140 | session: pymongo.client_session.ClientSession = None, 141 | **kwargs, 142 | ): 143 | """Find one document from MongoDB.""" 144 | return await self.col.find_one(filter=query, sort=sort, projection=projection, session=session, **kwargs) 145 | 146 | async def find_one_and_delete( 147 | self, 148 | *, 149 | query: dict, 150 | sort: list[tuple[str, int]] = None, 151 | projection: typing.Union[list[str], dict[str, bool]] = None, 152 | session: pymongo.client_session.ClientSession = None, 153 | **kwargs, 154 | ): 155 | """Find one and delete a document from MongoDB.""" 156 | return await self.col.find_one_and_delete( 157 | filter=query, projection=projection, sort=sort, session=session, **kwargs 158 | ) 159 | 160 | async def find_one_and_replace( 161 | self, 162 | *, 163 | query: dict, 164 | replacement: dict, 165 | upsert: bool = False, 166 | sort: list[tuple[str, int]] = None, 167 | projection: typing.Union[list[str], dict[str, bool]] = None, 168 | return_document: pymongo.ReturnDocument = pymongo.ReturnDocument.AFTER, # default BEFORE 169 | session: pymongo.client_session.ClientSession = None, 170 | **kwargs, 171 | ): 172 | """Find one and replace a document from MongoDB.""" 173 | return await self.col.find_one_and_replace( 174 | filter=query, 175 | replacement=replacement, 176 | upsert=upsert, 177 | sort=sort, 178 | projection=projection, 179 | session=session, 180 | return_document=return_document, 181 | **kwargs, 182 | ) 183 | 184 | async def find_one_and_update( 185 | self, 186 | *, 187 | query: dict, 188 | update: dict, 189 | sort: list[tuple[str, int]] = None, 190 | projection: typing.Union[list[str], dict[str, bool]] = None, 191 | return_document: pymongo.ReturnDocument = pymongo.ReturnDocument.AFTER, # default BEFORE 192 | session: pymongo.client_session.ClientSession = None, 193 | **kwargs, 194 | ): 195 | """Find one and update document from MongoDB.""" 196 | return await self.col.find_one_and_update( 197 | filter=query, 198 | update=update, 199 | sort=sort, 200 | projection=projection, 201 | session=session, 202 | return_document=return_document, 203 | **kwargs, 204 | ) 205 | 206 | async def count_documents( 207 | self, 208 | *, 209 | query: dict, 210 | session: pymongo.client_session.ClientSession = None, 211 | **kwargs, 212 | ) -> int: 213 | """Count documents in MongoDB collection.""" 214 | return await self.col.count_documents(filter=query, session=session, **kwargs) 215 | 216 | async def estimated_document_count(self, **kwargs): 217 | """Count documents in MongoDB from collection metadata.""" 218 | return await self.col.estimated_document_count(**kwargs) 219 | 220 | async def aggregate( 221 | self, 222 | *, 223 | pipeline: list, 224 | session: pymongo.client_session.ClientSession = None, 225 | **kwargs, 226 | ) -> motor.motor_asyncio.AsyncIOMotorCommandCursor: 227 | """Run aggregation pipeline against collection.""" 228 | return self.col.aggregate(pipeline=pipeline, session=session, **kwargs) 229 | 230 | async def bulk_write( 231 | self, 232 | *, 233 | operations: list[ 234 | typing.Union[ 235 | pymongo.operations.InsertOne, 236 | pymongo.operations.UpdateOne, 237 | pymongo.operations.UpdateMany, 238 | pymongo.operations.ReplaceOne, 239 | pymongo.operations.DeleteOne, 240 | pymongo.operations.DeleteMany, 241 | ] 242 | ], 243 | ordered: bool = False, # default True 244 | session: pymongo.client_session.ClientSession = None, 245 | ) -> pymongo.results.BulkWriteResult: 246 | """Run multiple operations in one db call.""" 247 | return await self.col.bulk_write(requests=operations, ordered=ordered, session=session) 248 | 249 | async def watch( 250 | self, *, pipeline: list[dict], session: pymongo.client_session.ClientSession = None, **kwargs 251 | ) -> motor.motor_asyncio.AsyncIOMotorChangeStream: 252 | """Blocking client stream to get operations on collection.""" 253 | return self.col.watch(pipeline=pipeline, session=session, **kwargs) 254 | -------------------------------------------------------------------------------- /fastapi_mongodb/schemas.py: -------------------------------------------------------------------------------- 1 | """Base classes for schemas (aka serializers) for in/out responses.""" 2 | import datetime 3 | import typing 4 | 5 | import pydantic 6 | 7 | from fastapi_mongodb.config import BaseConfiguration 8 | from fastapi_mongodb.types import OID 9 | 10 | 11 | class BaseSchema(pydantic.BaseModel): 12 | """Class using as a base class for schemas.py.""" 13 | 14 | oid: typing.Optional[OID] = pydantic.Field(alias="oid") 15 | 16 | class Config(BaseConfiguration): 17 | """Configuration class.""" 18 | 19 | 20 | class BaseCreatedUpdatedSchema(BaseSchema): 21 | """Append datetime fields for schema.""" 22 | 23 | created_at: typing.Optional[datetime.datetime] 24 | updated_at: typing.Optional[datetime.datetime] 25 | 26 | class Config(BaseConfiguration): 27 | """Configuration class.""" 28 | -------------------------------------------------------------------------------- /fastapi_mongodb/types.py: -------------------------------------------------------------------------------- 1 | """Pydantic types to work with bson.ObjectId instance.""" 2 | import bson 3 | 4 | 5 | class OID(str): 6 | """ObjectId type for BaseMongoDBModels and BaseSchemas.""" 7 | 8 | @classmethod 9 | def __modify_schema__(cls, field_schema): 10 | """Update OpenAPI docs schema.""" 11 | field_schema.update( 12 | pattern="^[a-f0-9]{24}$", 13 | example="5f5cf6f50cde9ec07786b294", 14 | title="ObjectId", 15 | type="string", 16 | ) 17 | 18 | @classmethod 19 | def __get_validators__(cls): 20 | """Require method for Pydantic Types.""" 21 | yield cls.validate 22 | 23 | @classmethod 24 | def validate(cls, v) -> bson.ObjectId: 25 | """Require validation for Pydantic Types.""" 26 | try: 27 | value = bson.ObjectId(oid=str(v)) 28 | except bson.errors.InvalidId as error: 29 | raise ValueError(error) from error 30 | else: 31 | return value 32 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: FastAPI❤MongoDB 2 | site_description: FastAPI❤MongoDB Documentation 3 | site_author: Kostiantyn Salnykov 4 | repo_name: fastapi-mongodb 5 | repo_url: https://github.com/Kostiantyn-Salnykov/fastapi-mongodb 6 | edit_uri: "" 7 | theme: 8 | icon: 9 | repo: fontawesome/brands/git-alt 10 | name: material 11 | language: en 12 | palette: 13 | - media: "(prefers-color-scheme: light)" 14 | scheme: default 15 | primary: green 16 | accent: green 17 | toggle: 18 | icon: material/weather-night 19 | name: Switch to dark mode 20 | - media: "(prefers-color-scheme: dark)" 21 | scheme: slate 22 | primary: black 23 | accent: green 24 | toggle: 25 | icon: material/weather-sunny 26 | name: Switch to light mode 27 | features: 28 | # - navigation.expand 29 | - navigation.indexes 30 | # - navigation.sections 31 | - navigation.top 32 | - search.suggest 33 | - search.highlight 34 | - search.share 35 | # - toc.integrate 36 | - content.tabs.link 37 | # - content.code.annotate 38 | 39 | plugins: 40 | - search: 41 | lang: 42 | - ru 43 | - en 44 | - macros: 45 | include_dir: "docs" 46 | - i18n: 47 | languages: 48 | en: English 49 | ru: Русский 50 | default_language: "en" 51 | nav_translations: 52 | ru: 53 | Main: Главная 54 | Setup: Настройка 55 | Managers: Менеджеры 56 | 57 | nav: 58 | - Main: index.md 59 | - Setup: setup.md 60 | - Managers: managers.md 61 | 62 | markdown_extensions: 63 | - footnotes 64 | - pymdownx.tabbed 65 | - pymdownx.superfences 66 | - pymdownx.inlinehilite 67 | - pymdownx.emoji: 68 | emoji_generator: !!python/name:pymdownx.emoji.to_svg 69 | - pymdownx.snippets: 70 | base_path: "." 71 | check_paths: true 72 | - pymdownx.caret 73 | - pymdownx.mark 74 | - pymdownx.tilde 75 | - pymdownx.smartsymbols 76 | - admonition 77 | - pymdownx.details 78 | - abbr 79 | - toc: 80 | permalink: ⚓ 81 | toc_depth: 3 82 | - pymdownx.highlight: 83 | linenums: true 84 | linenums_style: pymdownx.inline 85 | 86 | extra: 87 | generator: false 88 | social: 89 | - icon: fontawesome/brands/github 90 | link: https://github.com/KosT-NavySky/fastapi_mongodb 91 | alternate: 92 | - link: ./ 93 | name: 🇬🇧 - English 94 | lang: en 95 | - link: ./ru/ 96 | name: 🇷🇺 - Русский 97 | lang: ru 98 | 99 | copyright: Copyright © 2020 - 2021 Kostiatyn Salnykov -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fastapi-mongodb" 3 | version = "0.0.1b4" 4 | description = "A library that simplifies work and integration with MongoDB for a FastAPI project." 5 | license = "MIT" 6 | authors = ["Kostiantyn Salnykov "] 7 | maintainers = ["Kostiantyn Salnykov "] 8 | readme = "README.md" 9 | homepage = "https://kostiantyn-salnykov.github.io/fastapi-mongodb/" 10 | repository = "https://github.com/Kostiantyn-Salnykov/fastapi-mongodb/" 11 | documentation = "https://kostiantyn-salnykov.github.io/fastapi-mongodb/" 12 | keywords = ["fastapi", "mongodb", "pydantic", "motor", "pymongo", "orjson", "pyjwt"] 13 | packages = [ 14 | { "include" = "fastapi_mongodb" }, 15 | ] 16 | classifiers = [ 17 | "Topic :: Software Development", 18 | "Topic :: Software Development :: Libraries", 19 | "Topic :: Software Development :: Libraries :: Python Modules" 20 | ] 21 | 22 | [tool.poetry.urls] 23 | "Issues" = "https://github.com/Kostiantyn-Salnykov/fastapi-mongodb/issues" 24 | "Discussions" = "https://github.com/Kostiantyn-Salnykov/fastapi-mongodb/discussions" 25 | 26 | [tool.poetry.extras] 27 | orjson = ["orjson"] 28 | 29 | [tool.poetry.dependencies] 30 | python = "^3.9" 31 | fastapi = "^0.68.1" 32 | motor = "^2.5.1" 33 | click = "^8.0.1" 34 | pymongo = { extras = ["tls", "srv"], version = "^3.12.0" } 35 | pydantic = { extras = ["email", "dotenv"], version = "^1.8.2" } 36 | pyjwt = "^2.1.0" 37 | orjson = { optional = true, version = "^3.6.3" } 38 | 39 | [tool.poetry.dev-dependencies] 40 | uvicorn = "*" 41 | httpx = "*" 42 | ipython = "*" 43 | pylint = "*" 44 | black = "*" 45 | flake8 = "*" 46 | factory-boy = "*" 47 | xenon = "*" 48 | pytest = "*" 49 | pytest-asyncio = "*" 50 | pytest-cov = "*" 51 | coverage = "*" 52 | faker = "*" 53 | pyupgrade = "^2.25.0" 54 | bandit = "^1.7.0" 55 | tox = "^3.24.3" 56 | pydocstyle = { extras = ["toml"], version = "^6.1.1" } 57 | mkdocs-material = "^7.3.0" 58 | pymdown-extensions = "^8.2" 59 | mkdocs-macros-plugin = "^0.6.0" 60 | mkdocs-i18n = "^0.4.2" 61 | mkdocs-static-i18n = "^0.19" 62 | 63 | [build-system] 64 | requires = ["poetry-core>=1.0.0"] 65 | build-backend = "poetry.core.masonry.api" 66 | 67 | # === Pytest === 68 | [tool.pytest.ini_options] 69 | #addopts = "--maxfail=1 -rf" 70 | norecursedirs = ["venv", "mongodb",] 71 | testpaths = ["tests",] 72 | python_files = ["test*.py"] 73 | python_functions = ["*_test", "test_*"] 74 | console_output_style = "progress" 75 | 76 | # === Coverage === 77 | [tool.coverage.run] 78 | branch = true 79 | source = ["fastapi_mongodb"] 80 | #omit = ["*/.local/*", "*/.idea/*", "*./venv/*", "*test*", "*__init__*", "*run.py"] 81 | 82 | [tool.coverage.report] 83 | # show_missing = true 84 | ignore_errors = true 85 | sort = "Cover" 86 | precision = 4 87 | omit = ["*/.local/*", "*/.idea/*", "*./venv/*", "*test*", "*__init__*", "*run.py"] 88 | exclude_lines = ["pragma: no cover", "def __repr__", "if __name__ = .__main__.:"] 89 | 90 | [tool.coverage.html] 91 | directory = "coverage_html_report" 92 | 93 | # === Black === 94 | [tool.black] 95 | line-length = 120 96 | target-version = ['py39'] 97 | include = '\.pyi?$' 98 | extend-exclude = ''' 99 | /( 100 | | dist 101 | | .pytest_cache 102 | | .tox 103 | | mongodb 104 | | tests 105 | | docs 106 | | docs_src 107 | )/ 108 | ''' 109 | 110 | # === Isort === 111 | [tool.isort] 112 | multi_line_output = 3 113 | include_trailing_comma = true 114 | force_grid_wrap = 0 115 | use_parentheses = true 116 | ensure_newline_before_comments = true 117 | line_length = 120 118 | 119 | # === Tox === 120 | [tool.tox] 121 | legacy_tox_ini = """ 122 | [tox] 123 | isolated_build = true 124 | envlist = py39 125 | 126 | [testenv] 127 | whitelist_externals = poetry 128 | commands = 129 | poetry install -v 130 | poetry run pytest -v 131 | """ 132 | 133 | # === Pylint === 134 | [tool.pylint.messages_control] 135 | disable = "C0330, C0326" 136 | 137 | [tool.pylint.format] 138 | max-line-length = "120" 139 | 140 | # === Bandit === 141 | [tool.bandit] 142 | targets = ["fastapi_mongodb"] 143 | 144 | # === Pydocstyle === 145 | [tool.pydocstyle] 146 | match = '(?!test_).*\.py' 147 | match_dir = "^fastapi_mongodb" 148 | ignore = "D107,D203" 149 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kostiantyn-Salnykov/fastapi-mongodb/e8d0edf59632912fe8854df951d81d63eaf3478d/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import unittest.mock 3 | 4 | import motor.motor_asyncio 5 | import pytest 6 | from _pytest.monkeypatch import MonkeyPatch 7 | from faker import Faker 8 | 9 | import fastapi_mongodb.db 10 | 11 | pytestmark = [pytest.mark.asyncio] 12 | 13 | 14 | @pytest.fixture() 15 | def faker() -> "Faker": 16 | return Faker() 17 | 18 | 19 | class Patcher: 20 | def __init__(self): 21 | self.current_patcher = None 22 | 23 | def patch_obj(self, target: str, **kwargs): 24 | self.current_patcher = unittest.mock.patch(target=target, **kwargs) 25 | return self.current_patcher.start() 26 | 27 | def patch_attr(self, target, attribute: str, **kwargs): 28 | self.current_patcher = unittest.mock.patch.object(target=target, attribute=attribute, **kwargs) 29 | return self.current_patcher.start() 30 | 31 | def patch_property(self, target, attribute: str, return_value, **kwargs): 32 | self.current_patcher = unittest.mock.patch.object( 33 | target=target, attribute=attribute, new=unittest.mock.PropertyMock(return_value=return_value), **kwargs 34 | ) 35 | return self.current_patcher.start() 36 | 37 | def stop(self): 38 | if self.current_patcher is not None: 39 | self.current_patcher.stop() 40 | 41 | 42 | @pytest.fixture(scope="function") 43 | def patcher() -> "Patcher": 44 | _patcher = Patcher() 45 | yield _patcher 46 | _patcher.current_patcher.stop() 47 | 48 | 49 | @pytest.fixture(scope="session", autouse=True) 50 | def monkeypatch_session() -> "MonkeyPatch": 51 | monkey_patch = MonkeyPatch() 52 | yield monkey_patch 53 | monkey_patch.undo() 54 | 55 | 56 | @pytest.fixture(scope="session", autouse=True) 57 | def event_loop(): 58 | loop = asyncio.get_event_loop() 59 | yield loop 60 | loop.close() 61 | 62 | 63 | @pytest.fixture(scope="session", autouse=True) 64 | async def db_manager() -> fastapi_mongodb.db.BaseDBManager: 65 | handler = fastapi_mongodb.db.BaseDBManager( 66 | db_url="mongodb://0.0.0.0:27017/", 67 | default_db_name="test_db", 68 | ) 69 | yield handler 70 | await handler.delete_database(name="test_db") 71 | handler.delete_client() 72 | 73 | 74 | @pytest.fixture(scope="session", autouse=True) 75 | def mongodb_client(monkeypatch_session, db_manager) -> motor.motor_asyncio.AsyncIOMotorClient: 76 | client = db_manager.retrieve_client() 77 | monkeypatch_session.setattr( 78 | target=fastapi_mongodb.db.BaseDBManager, name="retrieve_client", value=lambda *args, **kwargs: client 79 | ) 80 | db = db_manager.retrieve_database(name="test_db", code_options=fastapi_mongodb.db.CODEC_OPTIONS) 81 | monkeypatch_session.setattr( 82 | target=fastapi_mongodb.db.BaseDBManager, name="retrieve_database", value=lambda *args, **kwargs: db 83 | ) 84 | yield client 85 | client.close() 86 | 87 | 88 | @pytest.fixture(scope="session", autouse=True) 89 | def mongodb_db(mongodb_client) -> motor.motor_asyncio.AsyncIOMotorDatabase: 90 | yield mongodb_client["test_db"] 91 | 92 | 93 | @pytest.fixture() 94 | async def mongodb_session(mongodb_client) -> motor.motor_asyncio.AsyncIOMotorClientSession: 95 | async with await mongodb_client.start_session() as session: 96 | yield session 97 | 98 | 99 | @pytest.fixture(scope="class") 100 | def repository(db_manager): 101 | return fastapi_mongodb.repositories.BaseRepository(db_manager=db_manager, db_name="test_db", col_name="test_col") 102 | -------------------------------------------------------------------------------- /tests/test_db.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import decimal 3 | import random 4 | import typing 5 | import unittest.mock 6 | 7 | import bson 8 | import motor.motor_asyncio 9 | import pymongo.errors 10 | import pytest 11 | 12 | import fastapi_mongodb.db 13 | import fastapi_mongodb.helpers 14 | import fastapi_mongodb.logging 15 | 16 | pytestmark = [pytest.mark.asyncio] 17 | 18 | 19 | class TestBaseDocument: 20 | def setup_method(self) -> None: 21 | self.base_document = fastapi_mongodb.db.BaseDocument() 22 | 23 | def test__eq__(self, faker): 24 | data_1, data_2 = faker.pydict(), faker.pydict() 25 | fake_type = random.choice( 26 | [faker.pystr, faker.pybool, faker.pyfloat, faker.pyint, faker.pylist, faker.pyset, faker.pydecimal] 27 | )() 28 | document_1 = fastapi_mongodb.db.BaseDocument(data=data_1) 29 | document_2 = fastapi_mongodb.db.BaseDocument(data=data_2) 30 | 31 | assert data_1 == document_1 32 | assert data_2 == document_2 33 | assert data_1 != document_2 34 | assert data_2 != document_1 35 | assert document_1 != document_2 36 | assert fake_type != document_1 37 | 38 | def test_properties(self): 39 | assert self.base_document.oid is None 40 | assert self.base_document.id is None 41 | assert self.base_document.data == {} 42 | oid = bson.ObjectId() 43 | 44 | self.base_document["_id"] = oid 45 | 46 | assert self.base_document.oid == oid 47 | assert self.base_document.id == str(oid) 48 | assert self.base_document.data == {"_id": oid} 49 | assert self.base_document.generated_at == oid.generation_time.astimezone( 50 | tz=fastapi_mongodb.helpers.get_utc_timezone() 51 | ) 52 | 53 | def test_getitem_setitem_delitem(self, faker): 54 | # test getitem 55 | with pytest.raises(KeyError) as exception_context: 56 | _ = self.base_document["test"] 57 | assert str(exception_context.value) == str(KeyError("test")) 58 | assert self.base_document.get("test", None) is None 59 | fake_test = faker.pystr() 60 | 61 | # test setitem 62 | self.base_document["test"] = fake_test 63 | assert self.base_document["test"] == fake_test 64 | 65 | # test delitem 66 | del self.base_document["test"] 67 | assert self.base_document.get("test", None) is None 68 | 69 | 70 | class TestDecimalCode: 71 | def setup_method(self) -> None: 72 | self.codec = fastapi_mongodb.db.DecimalCodec() 73 | 74 | def test_transform(self, faker): 75 | value = faker.pydecimal() 76 | 77 | bson_value = self.codec.transform_python(value=value) 78 | python_value = self.codec.transform_bson(value=bson_value) 79 | 80 | assert isinstance(bson_value, bson.Decimal128) 81 | assert isinstance(python_value, decimal.Decimal) 82 | assert value == python_value 83 | 84 | 85 | class TestTimeDeltaCodec: 86 | def setup_method(self) -> None: 87 | self.codec = fastapi_mongodb.db.TimeDeltaCodec() 88 | 89 | def test_transform(self, faker): 90 | value = faker.time_delta(end_datetime=faker.future_datetime()) 91 | 92 | bson_value = self.codec.transform_python(value=value) 93 | python_value = self.codec.transform_bson(value=bson_value) 94 | 95 | assert isinstance(bson_value, str) 96 | assert isinstance(python_value, datetime.timedelta) 97 | assert value == python_value 98 | 99 | def test_transform_0_micros(self): 100 | weeks, days, hours, minutes, seconds = 1, 2, 3, 4, 5 101 | bson_value = f"P{weeks}W{days}DT{hours}H{minutes}M{seconds}S" 102 | expected_time_delta = datetime.timedelta(weeks=weeks, days=days, hours=hours, minutes=minutes, seconds=seconds) 103 | 104 | python_value = self.codec.transform_bson(value=bson_value) 105 | 106 | assert expected_time_delta == python_value 107 | 108 | 109 | @pytest.fixture() 110 | def patch_logging(patcher): 111 | yield patcher.patch_attr(target=fastapi_mongodb.logging.simple_logger, attribute="debug") 112 | 113 | 114 | @pytest.fixture() 115 | def event(): 116 | return unittest.mock.MagicMock() 117 | 118 | 119 | class TestCommandLogger: 120 | @classmethod 121 | def setup_class(cls) -> None: 122 | cls.logger = fastapi_mongodb.db.CommandLogger() 123 | 124 | def test_started(self, patch_logging, event): 125 | self.logger.started(event=event) 126 | 127 | patch_logging.assert_called_once_with( 128 | f"Command '{event.command_name}' with request id {event.request_id} started on server " 129 | f"{event.connection_id}" 130 | ) 131 | 132 | def test_succeeded(self, patch_logging, event): 133 | self.logger.succeeded(event=event) 134 | 135 | patch_logging.assert_called_once_with( 136 | f"Command '{event.command_name}' with request id {event.request_id} on server " 137 | f"{event.connection_id} succeeded in {event.duration_micros} microseconds" 138 | ) 139 | 140 | def test_failed(self, patch_logging, event): 141 | self.logger.failed(event=event) 142 | 143 | patch_logging.assert_called_once_with( 144 | f"Command {event.command_name} with request id {event.request_id} on server " 145 | f"{event.connection_id} failed in {event.duration_micros} microseconds" 146 | ) 147 | 148 | 149 | class TestConnectionPoolLogger: 150 | @classmethod 151 | def setup_class(cls) -> None: 152 | cls.logger = fastapi_mongodb.db.ConnectionPoolLogger() 153 | 154 | def test_pool_created(self, patch_logging, event): 155 | self.logger.pool_created(event=event) 156 | 157 | patch_logging.assert_called_once_with(f"[pool {event.address}] pool created") 158 | 159 | def test_pool_cleared(self, patch_logging, event): 160 | self.logger.pool_cleared(event=event) 161 | 162 | patch_logging.assert_called_once_with(f"[pool {event.address}] pool cleared") 163 | 164 | def test_pool_closed(self, patch_logging, event): 165 | self.logger.pool_closed(event=event) 166 | 167 | patch_logging.assert_called_once_with(f"[pool {event.address}] pool closed") 168 | 169 | def test_connection_created(self, patch_logging, event): 170 | self.logger.connection_created(event=event) 171 | 172 | patch_logging.assert_called_once_with(f"[pool {event.address}][conn #{event.connection_id}] connection created") 173 | 174 | def test_connection_ready(self, patch_logging, event): 175 | self.logger.connection_ready(event=event) 176 | 177 | patch_logging.assert_called_once_with( 178 | f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" 179 | ) 180 | 181 | def test_connection_closed(self, patch_logging, event): 182 | self.logger.connection_closed(event=event) 183 | 184 | patch_logging.assert_called_once_with( 185 | f"[pool {event.address}][conn #{event.connection_id}] connection closed, reason: {event.reason}" 186 | ) 187 | 188 | def test_connection_check_out_started(self, patch_logging, event): 189 | self.logger.connection_check_out_started(event=event) 190 | 191 | patch_logging.assert_called_once_with(f"[pool {event.address}] connection check out started") 192 | 193 | def test_connection_check_out_failed(self, patch_logging, event): 194 | self.logger.connection_check_out_failed(event=event) 195 | 196 | patch_logging.assert_called_once_with( 197 | f"[pool {event.address}] connection check out failed, reason: {event.reason}" 198 | ) 199 | 200 | def test_connection_checked_out(self, patch_logging, event): 201 | self.logger.connection_checked_out(event=event) 202 | 203 | patch_logging.assert_called_once_with( 204 | f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" 205 | ) 206 | 207 | def test_connection_checked_in(self, patch_logging, event): 208 | self.logger.connection_checked_in(event=event) 209 | 210 | patch_logging.assert_called_once_with( 211 | f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" 212 | ) 213 | 214 | 215 | class TestServerLogger: 216 | @classmethod 217 | def setup_class(cls) -> None: 218 | cls.logger = fastapi_mongodb.db.ServerLogger() 219 | 220 | def test_opened(self, patch_logging, event): 221 | self.logger.opened(event=event) 222 | 223 | patch_logging.assert_called_once_with(f"Server {event.server_address} added to topology {event.topology_id}") 224 | 225 | def test_description_changed_called(self, patch_logging, event): 226 | new_mock = unittest.mock.MagicMock() 227 | event.new_description.server_type = new_mock 228 | self.logger.description_changed(event=event) 229 | 230 | patch_logging.assert_called_once_with( 231 | f"Server {event.server_address} changed type from {event.previous_description.server_type_name} to " 232 | f"{event.new_description.server_type_name}" 233 | ) 234 | 235 | def test_description_changed_not_called(self, patch_logging, event): 236 | new_mock = unittest.mock.MagicMock() 237 | event.previous_description.server_type = new_mock 238 | event.new_description.server_type = new_mock 239 | self.logger.description_changed(event=event) 240 | 241 | patch_logging.assert_not_called() 242 | 243 | def test_closed(self, patch_logging, event): 244 | self.logger.closed(event=event) 245 | 246 | patch_logging.assert_called_once_with( 247 | f"Server {event.server_address} removed from topology {event.topology_id}" 248 | ) 249 | 250 | 251 | class TestHeartbeatLogger: 252 | @classmethod 253 | def setup_class(cls) -> None: 254 | cls.logger = fastapi_mongodb.db.HeartbeatLogger() 255 | 256 | def test_started(self, patch_logging, event): 257 | self.logger.started(event=event) 258 | 259 | patch_logging.assert_called_once_with(f"Heartbeat sent to server {event.connection_id}") 260 | 261 | def test_succeeded(self, patch_logging, event): 262 | self.logger.succeeded(event=event) 263 | 264 | patch_logging.assert_called_once_with( 265 | f"Heartbeat to server {event.connection_id} succeeded with reply {event.reply.document}" 266 | ) 267 | 268 | def test_failed(self, patch_logging, event): 269 | self.logger.failed(event=event) 270 | 271 | patch_logging.assert_called_once_with( 272 | f"Heartbeat to server {event.connection_id} failed with error {event.reply}" 273 | ) 274 | 275 | 276 | class TestTopologyLogger: 277 | @classmethod 278 | def setup_class(cls) -> None: 279 | cls.logger = fastapi_mongodb.db.TopologyLogger() 280 | 281 | def test_opened(self, patch_logging, event): 282 | self.logger.opened(event=event) 283 | 284 | patch_logging.assert_called_once_with(f"Topology with id {event.topology_id} opened") 285 | 286 | def test_description_changed(self, patch_logging, event): 287 | event.new_description.has_writable_server.return_value = False 288 | event.new_description.has_readable_server.return_value = False 289 | self.logger.description_changed(event=event) 290 | 291 | patch_logging.assert_has_calls( 292 | calls=[ 293 | unittest.mock.call(f"Topology description updated for topology id {event.topology_id}"), 294 | unittest.mock.call( 295 | f"Topology {event.topology_id} changed type from {event.previous_description.topology_type_name} " 296 | f"to {event.new_description.topology_type_name}" 297 | ), 298 | unittest.mock.call("No writable servers available."), 299 | unittest.mock.call("No readable servers available."), 300 | ] 301 | ) 302 | 303 | def test_description_changed_not_changed(self, patch_logging, event): 304 | mock_topology_type = unittest.mock.MagicMock() 305 | event.previous_description.topology_type = mock_topology_type 306 | event.new_description.topology_type = mock_topology_type 307 | self.logger.description_changed(event=event) 308 | 309 | patch_logging.assert_called_once_with(f"Topology description updated for topology id {event.topology_id}") 310 | 311 | def test_closed(self, patch_logging, event): 312 | self.logger.closed(event=event) 313 | 314 | patch_logging.assert_called_once_with(f"Topology with id {event.topology_id} closed") 315 | 316 | 317 | class TestDBHandler: 318 | @classmethod 319 | def setup_class(cls) -> None: 320 | cls.test_db = "test_db" 321 | 322 | @pytest.fixture() 323 | async def setup_indexes(self, db_manager, faker, mongodb_session): 324 | index_name, col_name = faker.pystr(), faker.pystr() 325 | await db_manager.create_index( 326 | col_name=col_name, 327 | db_name=self.test_db, 328 | name=index_name, 329 | index=[("test", pymongo.ASCENDING)], 330 | session=mongodb_session, 331 | ) 332 | indexes_names = await db_manager.list_indexes( 333 | col_name=col_name, db_name=self.test_db, only_names=True, session=mongodb_session 334 | ) 335 | assert index_name in indexes_names 336 | return index_name, col_name 337 | 338 | def test_create_client(self, db_manager): 339 | result = db_manager.create_client() 340 | 341 | assert result is None 342 | 343 | def test_delete_client(self, db_manager): 344 | result = db_manager.delete_client() 345 | 346 | assert result is None 347 | 348 | def test_retrieve_client(self, db_manager): 349 | result = db_manager.retrieve_client() 350 | 351 | assert result.__class__ == motor.motor_asyncio.AsyncIOMotorClient 352 | 353 | def test_retrieve_database(self, db_manager): 354 | result = db_manager.retrieve_database() 355 | 356 | assert result.__class__ == motor.motor_asyncio.AsyncIOMotorDatabase 357 | assert result.name == "test_db" 358 | 359 | async def test_get_server_info(self, db_manager, mongodb_session): 360 | result = await db_manager.get_server_info(session=mongodb_session) 361 | 362 | assert dict == result.__class__ 363 | assert 1.0 == result["ok"] 364 | 365 | async def test_list_databases(self, db_manager, mongodb_session): 366 | required_dbs = ["admin", "local"] 367 | 368 | result: list[dict[str, typing.Any()]] = await db_manager.list_databases(session=mongodb_session) 369 | result_2: list[str] = await db_manager.list_databases(only_names=True, session=mongodb_session) 370 | 371 | assert all(required_db in [db["name"] for db in result] for required_db in required_dbs) 372 | assert all(required_db in result_2 for required_db in required_dbs) 373 | 374 | async def test_delete_database(self, db_manager, faker, mongodb_session): 375 | test_db = self.test_db 376 | await db_manager.create_collection(name=faker.pystr(), db_name=faker.pystr()) 377 | db_names = await db_manager.list_databases(only_names=True, session=mongodb_session) 378 | assert test_db in db_names 379 | 380 | await db_manager.delete_database(name=test_db, session=mongodb_session) 381 | 382 | updated_db_names = await db_manager.list_databases(only_names=True, session=mongodb_session) 383 | 384 | assert test_db not in updated_db_names 385 | 386 | async def test_set_get_profiling_level(self, db_manager, mongodb_session): 387 | default_level = 0 # OFF 388 | new_level = 2 # ALL 389 | assert default_level == await db_manager.get_profiling_level(db_name=self.test_db, session=mongodb_session) 390 | 391 | result = await db_manager.set_profiling_level( 392 | db_name=self.test_db, level=new_level, slow_ms=0, session=mongodb_session 393 | ) 394 | 395 | assert default_level == result["was"] 396 | assert 1.0 == result["ok"] 397 | assert new_level == await db_manager.get_profiling_level(db_name=self.test_db, session=mongodb_session) 398 | 399 | async def test_get_profiling_info(self, db_manager, faker, mongodb_session): 400 | col_name = faker.pystr() 401 | level = 2 402 | await db_manager.set_profiling_level(db_name=self.test_db, level=level, session=mongodb_session) 403 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 404 | result = await db_manager.get_profiling_info(db_name=self.test_db, session=mongodb_session) 405 | 406 | assert list == result.__class__ 407 | assert "command" == result[0]["op"] 408 | 409 | async def test_create_collection(self, db_manager, faker, mongodb_session): 410 | col_name = faker.pystr() 411 | col_names = await db_manager.list_collections(db_name=self.test_db, only_names=True, session=mongodb_session) 412 | assert col_name not in col_names 413 | 414 | collection = await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 415 | 416 | updated_col_names = await db_manager.list_collections( 417 | db_name=self.test_db, only_names=True, session=mongodb_session 418 | ) 419 | assert col_name in updated_col_names 420 | assert motor.motor_asyncio.AsyncIOMotorCollection == collection.__class__ 421 | assert col_name == collection.name 422 | assert self.test_db == collection.database.name 423 | 424 | async def test_create_collection_not_safe(self, db_manager, faker, mongodb_session): 425 | col_name = faker.pystr() 426 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 427 | 428 | with pytest.raises(pymongo.errors.CollectionInvalid) as exception_context: 429 | await db_manager.create_collection(name=col_name, db_name=self.test_db, safe=False, session=mongodb_session) 430 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 431 | 432 | assert f"collection {col_name} already exists" == str(exception_context.value) 433 | 434 | async def test_delete_collection(self, db_manager, faker, mongodb_session): 435 | col_name = faker.pystr() 436 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 437 | col_names = await db_manager.list_collections(db_name=self.test_db, only_names=True, session=mongodb_session) 438 | assert col_name in col_names 439 | 440 | await db_manager.delete_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 441 | 442 | updated_col_names = await db_manager.list_collections( 443 | db_name=self.test_db, only_names=True, session=mongodb_session 444 | ) 445 | assert col_name not in updated_col_names 446 | 447 | async def test_list_collections(self, db_manager, faker, mongodb_session): 448 | col_name, col_type = faker.pystr(), "collection" 449 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 450 | 451 | result: list[dict[str, typing.Any]] = await db_manager.list_collections( 452 | db_name=self.test_db, session=mongodb_session 453 | ) 454 | 455 | for i, col in enumerate(result): 456 | if col["name"] == col_name: 457 | assert col_type == result[i]["type"] 458 | 459 | async def test_list_collections_only_names(self, db_manager, faker, mongodb_session): 460 | col_name = faker.pystr() 461 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 462 | 463 | result = await db_manager.list_collections(db_name=self.test_db, only_names=True, session=mongodb_session) 464 | 465 | assert col_name in result 466 | 467 | async def test_create_index(self, db_manager, faker, mongodb_session): 468 | index_name, col_name = faker.pystr(), faker.pystr() 469 | indexes_names = await db_manager.list_indexes( 470 | col_name=col_name, db_name=self.test_db, only_names=True, session=mongodb_session 471 | ) 472 | assert index_name not in indexes_names 473 | 474 | result = await db_manager.create_index( 475 | col_name=col_name, 476 | db_name=self.test_db, 477 | name=index_name, 478 | index=[("test", pymongo.ASCENDING)], 479 | session=mongodb_session, 480 | ) 481 | 482 | updated_indexes_names = await db_manager.list_indexes( 483 | col_name=col_name, db_name=self.test_db, only_names=True, session=mongodb_session 484 | ) 485 | assert index_name in updated_indexes_names 486 | assert index_name == result 487 | 488 | async def test_create_indexes(self, db_manager, faker, mongodb_session): 489 | index_name, index_name2, col_name = faker.pystr(), faker.pystr(), faker.pystr() 490 | indexes = [ 491 | pymongo.IndexModel(name=index_name, keys=[("test", pymongo.ASCENDING)]), 492 | pymongo.IndexModel(name=index_name2, keys=[("test2", pymongo.DESCENDING)]), 493 | ] 494 | 495 | result = await db_manager.create_indexes( 496 | col_name=col_name, db_name=self.test_db, indexes=indexes, session=mongodb_session 497 | ) 498 | 499 | assert [index_name, index_name2] == result 500 | 501 | async def test_delete_index(self, db_manager, setup_indexes, mongodb_session): 502 | index_name, col_name = setup_indexes 503 | await db_manager.delete_index(col_name=col_name, db_name=self.test_db, name=index_name, session=mongodb_session) 504 | 505 | updated_indexes_names = await db_manager.list_indexes( 506 | col_name=col_name, db_name=self.test_db, only_names=True, session=mongodb_session 507 | ) 508 | assert index_name not in updated_indexes_names 509 | 510 | async def test_delete_index_not_safe(self, db_manager, faker, mongodb_session): 511 | index_name, col_name = faker.pystr(), faker.pystr() 512 | expected_exception_details = { 513 | "ok": 0.0, 514 | "errmsg": f"index not found with name [{index_name}]", 515 | "code": 27, 516 | "codeName": "IndexNotFound", 517 | } 518 | await db_manager.create_collection(name=col_name, db_name=self.test_db, session=mongodb_session) 519 | 520 | with pytest.raises(pymongo.errors.OperationFailure) as exception_context: 521 | await db_manager.delete_index( 522 | col_name=col_name, db_name=self.test_db, name=index_name, safe=False, session=mongodb_session 523 | ) 524 | await db_manager.delete_index(col_name=col_name, db_name=self.test_db, name=index_name, session=mongodb_session) 525 | 526 | for key, value in expected_exception_details.items(): 527 | assert value == exception_context.value.details[key] 528 | 529 | async def test_list_indexes_names(self, db_manager, faker, setup_indexes): 530 | """Same logic as in setup_indexes""" 531 | 532 | async def test_list_indexes(self, db_manager, faker, mongodb_session): 533 | index_name, col_name = faker.pystr(), faker.pystr() 534 | index_key, index_order = "test", pymongo.ASCENDING 535 | index_keys = [(index_key, index_order)] 536 | await db_manager.create_index( 537 | col_name=col_name, db_name=self.test_db, name=index_name, index=index_keys, session=mongodb_session 538 | ) 539 | 540 | result = await db_manager.list_indexes(col_name=col_name, db_name=self.test_db, session=mongodb_session) 541 | 542 | created_index_son = result[-1] 543 | assert index_name == created_index_son["name"] 544 | assert index_order == created_index_son["key"][index_key] 545 | assert created_index_son["background"] 546 | assert not created_index_son["sparse"] 547 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import zoneinfo 4 | 5 | import fastapi_mongodb.helpers 6 | from tests.conftest import Patcher 7 | 8 | 9 | class TestHelpers: 10 | def test_get_utc_timezone(self): 11 | result = fastapi_mongodb.helpers.get_utc_timezone() 12 | 13 | assert result.key == "UTC" 14 | assert isinstance(result, zoneinfo.ZoneInfo) 15 | 16 | def test_utc_now(self, patcher: Patcher): 17 | expected_zoneinfo = zoneinfo.ZoneInfo("UTC") 18 | expected_now = datetime.datetime.now(tz=expected_zoneinfo) 19 | datetime_mock = patcher.patch_obj(target="fastapi_mongodb.helpers.datetime.datetime") 20 | datetime_mock.now.return_value = expected_now 21 | 22 | result = fastapi_mongodb.helpers.utc_now() 23 | 24 | assert result == expected_now 25 | assert result.tzinfo == expected_zoneinfo 26 | 27 | def test_as_utc(self): 28 | tz = zoneinfo.ZoneInfo(key=random.choice(list(zoneinfo.available_timezones()))) 29 | dt = datetime.datetime.now(tz=tz) 30 | 31 | result = fastapi_mongodb.helpers.as_utc(date_time=dt) 32 | 33 | assert result.tzinfo == fastapi_mongodb.helpers.get_utc_timezone() 34 | assert result.astimezone(tz=tz) == dt 35 | 36 | 37 | class TestBaseProfiler: 38 | @classmethod 39 | def setup_class(cls): 40 | cls.base_profiler = fastapi_mongodb.helpers.BaseProfiler(include_files=["test_helpers.py"]) 41 | 42 | def test_profiler_decorator(self): 43 | @self.base_profiler 44 | def something(): 45 | return True 46 | 47 | something() 48 | 49 | def test_profiler_context_manager(self): 50 | def something_new(): 51 | return False 52 | 53 | with self.base_profiler: 54 | something_new() 55 | -------------------------------------------------------------------------------- /tests/test_logging.py: -------------------------------------------------------------------------------- 1 | from fastapi_mongodb.logging import logger 2 | 3 | 4 | def test_logging(): 5 | logger.trace(msg="TRACE MESSAGE") 6 | logger.debug(msg="DEBUG MESSAGE") 7 | logger.info(msg="INFO MESSAGE") 8 | logger.success(msg="SUCCESS MESSAGE") 9 | logger.warning(msg="WARNING MESSAGE") 10 | logger.error(msg="ERROR MESSAGE") 11 | logger.critical(msg="CRITICAL MESSAGE") 12 | -------------------------------------------------------------------------------- /tests/test_managers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | 5 | import fastapi_mongodb.exceptions 6 | import fastapi_mongodb.helpers 7 | import fastapi_mongodb.managers 8 | import fastapi_mongodb.schemas 9 | 10 | 11 | class TestPasswordsHandler: 12 | def _manager_factory(self, algorithm: fastapi_mongodb.managers.PASSWORD_ALGORITHMS): 13 | return fastapi_mongodb.managers.PasswordsManager(algorithm=algorithm, iterations=1) 14 | 15 | @pytest.mark.parametrize( 16 | argnames=["algorithm", "faker"], 17 | argvalues=[(algorithm, "faker") for algorithm in fastapi_mongodb.managers.PASSWORD_ALGORITHMS], 18 | indirect=["faker"], 19 | ) 20 | def test_make_password(self, algorithm: fastapi_mongodb.managers.PASSWORD_ALGORITHMS, faker): 21 | fake_password = faker.pystr() 22 | hash_length = fastapi_mongodb.managers.ALGORITHMS_LENGTH_MAP[algorithm] * 2 # SALT + HASH 23 | manager = self._manager_factory(algorithm=algorithm) 24 | 25 | password_hash = manager.make_password(password=fake_password) 26 | 27 | assert len(password_hash) == hash_length 28 | assert isinstance(password_hash, str) 29 | 30 | @pytest.mark.parametrize( 31 | argnames=["algorithm", "faker"], 32 | argvalues=[(algorithm, "faker") for algorithm in fastapi_mongodb.managers.PASSWORD_ALGORITHMS], 33 | indirect=["faker"], 34 | ) 35 | def test_check_password(self, algorithm: fastapi_mongodb.managers.PASSWORD_ALGORITHMS, faker): 36 | fake_password = faker.pystr() 37 | fake_password_2 = faker.pystr() 38 | manager = self._manager_factory(algorithm=algorithm) 39 | fake_password_hash = manager.make_password(password=fake_password) 40 | fake_password_hash_2 = manager.make_password(password=fake_password_2) 41 | 42 | assert manager.check_password(password=fake_password, password_hash=fake_password_hash) is True 43 | assert manager.check_password(password=fake_password_2, password_hash=fake_password_hash_2) is True 44 | assert manager.check_password(password=fake_password, password_hash=fake_password_hash_2) is False 45 | assert manager.check_password(password=fake_password_2, password_hash=fake_password_hash) is False 46 | 47 | 48 | class TestTokensManager: 49 | class MockPayload(fastapi_mongodb.schemas.BaseSchema): 50 | # JWT options 51 | iss: str 52 | iat: datetime.datetime 53 | exp: datetime.datetime 54 | nbf: datetime.datetime 55 | 56 | int_key: int 57 | float_key: float 58 | bool_key: bool 59 | str_key: str 60 | list_key: list 61 | dict_key: dict 62 | 63 | def test_data(self, faker): 64 | return { 65 | "int_key": faker.pyint(), 66 | "float_key": faker.pyfloat(), 67 | "bool_key": faker.pybool(), 68 | "str_key": faker.pystr(), 69 | "list_key": faker.pylist(nb_elements=10, variable_nb_elements=True, value_types=[int, float, bool, str]), 70 | "dict_key": faker.pydict(nb_elements=10, variable_nb_elements=True, value_types=[int, float, bool, str]), 71 | } 72 | 73 | @staticmethod 74 | def generate_issuer(faker): 75 | return {"iss": faker.pystr()} 76 | 77 | @staticmethod 78 | def _manager_factory(algorithm): 79 | return fastapi_mongodb.managers.TokensManager(secret_key="TEST", algorithm=algorithm) 80 | 81 | @staticmethod 82 | def _get_default_claims(): 83 | return ["aud", "exp", "iat", "iss", "nbf"] 84 | 85 | @pytest.mark.parametrize( 86 | argnames=["algorithm"], argvalues=[(algorithm,) for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS] 87 | ) 88 | def test_create_read_code_default(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS): 89 | manager = self._manager_factory(algorithm=algorithm) 90 | code = manager.create_code() 91 | 92 | parsed = manager.read_code(code=code) 93 | 94 | assert isinstance(code, str) 95 | assert isinstance(parsed, dict) 96 | assert all(field in parsed for field in self._get_default_claims()) 97 | 98 | @pytest.mark.parametrize( 99 | argnames=["algorithm", "faker"], 100 | argvalues=[(algorithm, "faker") for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 101 | indirect=["faker"], 102 | ) 103 | def test_create_read_code_custom(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS, faker): 104 | now = fastapi_mongodb.helpers.utc_now() 105 | test_data = self.test_data(faker=faker) 106 | custom_claims = self.generate_issuer(faker=faker) 107 | audience = faker.pystr() 108 | test_data |= custom_claims 109 | manager = self._manager_factory(algorithm=algorithm) 110 | 111 | code = manager.create_code(data=test_data, iat=now, exp=now, nbf=now, aud=audience, **custom_claims) 112 | 113 | parsed = manager.read_code( 114 | code=code, aud=audience, leeway=faker.pyint(), convert_to=self.MockPayload, **custom_claims 115 | ) 116 | 117 | assert isinstance(code, str) 118 | assert isinstance(parsed, self.MockPayload) 119 | assert test_data == parsed.dict(include={key for key in test_data.keys()}) 120 | 121 | @pytest.mark.parametrize( 122 | argnames=["algorithm"], 123 | argvalues=[(algorithm,) for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 124 | ) 125 | def test_read_code_exception_exp(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS): 126 | manager = self._manager_factory(algorithm=algorithm) 127 | code = manager.create_code(exp=fastapi_mongodb.helpers.utc_now() - datetime.timedelta(seconds=5)) 128 | 129 | with pytest.raises(expected_exception=fastapi_mongodb.exceptions.ManagerException) as exception_context: 130 | manager.read_code(code=code) 131 | 132 | assert "Expired JWT token." == str(exception_context.value) 133 | 134 | @pytest.mark.parametrize( 135 | argnames=["algorithm"], 136 | argvalues=[(algorithm,) for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 137 | ) 138 | def test_read_code_exception_nbf(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS): 139 | manager = self._manager_factory(algorithm=algorithm) 140 | code = manager.create_code(nbf=fastapi_mongodb.helpers.utc_now() + datetime.timedelta(seconds=5)) 141 | 142 | with pytest.raises(expected_exception=fastapi_mongodb.exceptions.ManagerException) as exception_context: 143 | manager.read_code(code=code) 144 | 145 | assert "The token is not valid yet." == str(exception_context.value) 146 | 147 | @pytest.mark.parametrize( 148 | argnames=["algorithm"], 149 | argvalues=[(algorithm,) for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 150 | ) 151 | def test_read_code_exception_leeway(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS): 152 | manager = self._manager_factory(algorithm=algorithm) 153 | code = manager.create_code(exp=fastapi_mongodb.helpers.utc_now() - datetime.timedelta(seconds=5)) 154 | 155 | parsed = manager.read_code(code=code, leeway=5) 156 | with pytest.raises(expected_exception=fastapi_mongodb.exceptions.ManagerException) as exception_context: 157 | manager.read_code(code=code) 158 | 159 | assert "Expired JWT token." == str(exception_context.value) 160 | assert isinstance(parsed, dict) 161 | 162 | @pytest.mark.parametrize( 163 | argnames=["algorithm"], 164 | argvalues=[(algorithm,) for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 165 | ) 166 | def test_read_code_exception_aud(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS): 167 | manager = self._manager_factory(algorithm=algorithm) 168 | code = manager.create_code(aud="TEST") 169 | 170 | with pytest.raises(expected_exception=fastapi_mongodb.exceptions.ManagerException) as exception_context: 171 | manager.read_code(code=code) 172 | 173 | assert "Invalid JWT audience." == str(exception_context.value) 174 | 175 | @pytest.mark.parametrize( 176 | argnames=["algorithm"], 177 | argvalues=[(algorithm,) for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 178 | ) 179 | def test_read_code_exception_iss(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS): 180 | manager = self._manager_factory(algorithm=algorithm) 181 | code = manager.create_code(iss="TEST") 182 | 183 | with pytest.raises(expected_exception=fastapi_mongodb.exceptions.ManagerException) as exception_context: 184 | manager.read_code(code=code) 185 | 186 | assert "Invalid JWT issuer." == str(exception_context.value) 187 | 188 | @pytest.mark.parametrize( 189 | argnames=["algorithm", "faker"], 190 | argvalues=[(algorithm, "faker") for algorithm in fastapi_mongodb.managers.TOKEN_ALGORITHMS], 191 | indirect=["faker"], 192 | ) 193 | def test_read_code_exception_invalid_jwt(self, algorithm: fastapi_mongodb.managers.TOKEN_ALGORITHMS, faker): 194 | manager = self._manager_factory(algorithm=algorithm) 195 | 196 | with pytest.raises(expected_exception=fastapi_mongodb.exceptions.ManagerException) as exception_context: 197 | manager.read_code(code=faker.pystr()) 198 | 199 | assert "Invalid JWT." == str(exception_context.value) 200 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import unittest.mock 3 | 4 | import bson 5 | import motor.motor_asyncio 6 | import pytest 7 | 8 | import fastapi_mongodb 9 | 10 | pytestmark = [pytest.mark.asyncio] 11 | 12 | 13 | class TestBaseDBModel: 14 | class MyCreatedUpdatedModel(fastapi_mongodb.models.BaseCreatedUpdatedModel): 15 | test: str 16 | 17 | class MyModel(fastapi_mongodb.models.BaseDBModel): 18 | test: str 19 | 20 | @classmethod 21 | def setup_class(cls): 22 | cls.model_class = fastapi_mongodb.models.BaseDBModel 23 | 24 | def test_from_db_default(self): 25 | data = {"_id": bson.ObjectId()} 26 | oid = data["_id"] 27 | 28 | result = self.model_class.from_db(data=data) 29 | 30 | assert oid == result.oid 31 | assert str(oid) == result.id 32 | 33 | def test_from_db_none(self): 34 | data = None 35 | 36 | result = self.model_class.from_db(data=data) 37 | 38 | assert result is None 39 | 40 | def test_from_db_no_id(self, faker): 41 | data = {faker.pystr(): faker.pystr()} 42 | 43 | result = self.model_class.from_db(data=data) 44 | 45 | assert self.model_class(id=None) == result 46 | 47 | def test_to_db(self, faker, monkeypatch, patcher): 48 | expected_object_id = bson.ObjectId() 49 | monkeypatch.setattr(target=bson, name="ObjectId", value=lambda *args, **kwargs: expected_object_id) 50 | expected_utc_now = fastapi_mongodb.helpers.utc_now() 51 | patcher.patch_obj(target="fastapi_mongodb.helpers.utc_now", return_value=expected_utc_now) 52 | model = self.MyCreatedUpdatedModel(test=faker.pystr()) 53 | 54 | result = model.to_db() 55 | 56 | assert expected_object_id == result["_id"] 57 | assert ( 58 | expected_object_id.generation_time.replace(tzinfo=fastapi_mongodb.helpers.get_utc_timezone()) 59 | == result["created_at"] 60 | ) 61 | assert expected_utc_now == result["updated_at"] 62 | assert model.test, result["test"] 63 | 64 | def test_to_db_from_db_flow(self, monkeypatch: pytest.MonkeyPatch, faker): 65 | expected_object_id = bson.ObjectId() 66 | monkeypatch.setattr(target=bson, name="ObjectId", value=lambda *args, **kwargs: expected_object_id) 67 | to_db_data = {"test": faker.pystr()} 68 | expected_result = {"_id": expected_object_id} | to_db_data 69 | 70 | to_db_result = self.MyModel(**to_db_data).to_db() 71 | from_db_result = self.MyModel.from_db(data=to_db_result).to_db() 72 | 73 | assert expected_result == from_db_result 74 | 75 | def test_to_db_with_id(self, faker): 76 | oid = bson.ObjectId() 77 | 78 | to_db_result = self.MyModel(_id=oid, test=faker.pystr()).to_db() 79 | 80 | assert oid == to_db_result["_id"] 81 | 82 | 83 | # TODO: Refactor this test class 84 | class TestBaseActiveRecord: 85 | WARNING_MSG = "Document already deleted or has not been created yet!" 86 | 87 | @staticmethod 88 | def _check_empty_active_record(active_record: fastapi_mongodb.models.BaseActiveRecord): 89 | assert active_record.id is None 90 | assert active_record.oid is None 91 | assert active_record.generated_at is None 92 | assert isinstance(active_record, fastapi_mongodb.models.BaseActiveRecord) 93 | 94 | @staticmethod 95 | def _check_filled_active_record(active_record: fastapi_mongodb.models.BaseActiveRecord, oid: bson.ObjectId): 96 | assert active_record.id == str(oid) 97 | assert active_record.oid == oid 98 | assert isinstance(active_record.generated_at, datetime.datetime) 99 | assert isinstance(active_record, fastapi_mongodb.models.BaseActiveRecord) 100 | 101 | def test_properties(self, faker, repository, patcher): 102 | logger = patcher.patch_obj(target="fastapi_mongodb.simple_logger") 103 | oid = bson.ObjectId() 104 | active_record = fastapi_mongodb.models.BaseActiveRecord(document=faker.pydict(), repository=repository) 105 | 106 | self._check_empty_active_record(active_record=active_record) 107 | logger.warning.assert_has_calls( 108 | calls=[ 109 | unittest.mock.call(msg=self.WARNING_MSG), 110 | unittest.mock.call(msg=self.WARNING_MSG), 111 | unittest.mock.call(msg=self.WARNING_MSG), 112 | ] 113 | ) 114 | 115 | active_record._document["_id"] = oid 116 | 117 | self._check_filled_active_record(active_record=active_record, oid=oid) 118 | 119 | def test_get_collection(self, repository): 120 | active_record = fastapi_mongodb.models.BaseActiveRecord(document={}, repository=repository) 121 | 122 | collection = active_record.get_collection() 123 | 124 | assert isinstance(collection, motor.motor_asyncio.AsyncIOMotorCollection) 125 | assert collection.name == "test_col" 126 | 127 | def test_get_db(self, repository): 128 | active_record = fastapi_mongodb.models.BaseActiveRecord(document={}, repository=repository) 129 | 130 | db = active_record.get_db() 131 | 132 | assert isinstance(db, motor.motor_asyncio.AsyncIOMotorDatabase) 133 | assert db.name == "test_db" 134 | 135 | async def test_insert_create(self, faker, repository, mongodb_session): 136 | oid = bson.ObjectId() 137 | document = {"_id": oid} | faker.pydict(10, True, [str, int, float, bool]) 138 | 139 | active_record = await fastapi_mongodb.models.BaseActiveRecord.insert( 140 | document=document, repository=repository, session=mongodb_session 141 | ) 142 | 143 | self._check_filled_active_record(active_record=active_record, oid=oid) 144 | assert active_record._document == document 145 | 146 | async def test_update(self, faker, repository, mongodb_session): 147 | test_key, test_value = faker.pystr(), faker.pystr() 148 | new_test_value = faker.pystr() 149 | remove_key, remove_value = faker.pystr(), faker.pystr() 150 | document = {test_key: test_value, remove_key: remove_value} 151 | 152 | active_record = await fastapi_mongodb.models.BaseActiveRecord.insert( 153 | document=document, repository=repository, session=mongodb_session 154 | ) 155 | 156 | assert active_record[test_key] == test_value 157 | assert active_record[remove_key] == remove_value 158 | 159 | active_record[test_key] = new_test_value 160 | del active_record[remove_key] 161 | new_active_record = await active_record.update(session=mongodb_session) 162 | 163 | assert active_record[test_key] == new_test_value 164 | assert remove_key not in active_record 165 | assert active_record == new_active_record 166 | 167 | async def test_read_empty(self, faker, repository, mongodb_session, patcher): 168 | logger = patcher.patch_obj(target="fastapi_mongodb.simple_logger") 169 | fake_oid = bson.ObjectId() 170 | empty_active_record = await fastapi_mongodb.models.BaseActiveRecord.read( 171 | query={"_id": fake_oid}, repository=repository, session=mongodb_session 172 | ) 173 | 174 | self._check_empty_active_record(active_record=empty_active_record) 175 | logger.warning.assert_has_calls( 176 | calls=[ 177 | unittest.mock.call(msg=self.WARNING_MSG), 178 | unittest.mock.call(msg=self.WARNING_MSG), 179 | unittest.mock.call(msg=self.WARNING_MSG), 180 | ] 181 | ) 182 | 183 | async def test_read_filled(self, faker, repository, mongodb_session): 184 | oid = bson.ObjectId() 185 | document = {"_id": oid} | faker.pydict(10, True, [str, int, float, bool]) 186 | active_record = await fastapi_mongodb.models.BaseActiveRecord.insert( 187 | document=document, repository=repository, session=mongodb_session 188 | ) 189 | 190 | self._check_filled_active_record(active_record=active_record, oid=oid) 191 | 192 | async def test_delete(self, faker, repository, mongodb_session, patcher): 193 | logger = patcher.patch_obj(target="fastapi_mongodb.simple_logger") 194 | oid = bson.ObjectId() 195 | document = {"_id": oid} | faker.pydict(10, True, [str, int, float, bool]) 196 | active_record = await fastapi_mongodb.models.BaseActiveRecord.insert( 197 | document=document, repository=repository, session=mongodb_session 198 | ) 199 | 200 | deleted_active_record = await active_record.delete(session=mongodb_session) 201 | 202 | self._check_empty_active_record(active_record=deleted_active_record) 203 | logger.warning.assert_has_calls( 204 | calls=[ 205 | unittest.mock.call(msg=self.WARNING_MSG), 206 | unittest.mock.call(msg=self.WARNING_MSG), 207 | unittest.mock.call(msg=self.WARNING_MSG), 208 | ] 209 | ) 210 | 211 | async def test_reload_refresh_empty(self, faker, repository, mongodb_session, patcher): 212 | logger = patcher.patch_obj(target="fastapi_mongodb.simple_logger") 213 | empty_active_record = fastapi_mongodb.models.BaseActiveRecord(document={}, repository=repository) 214 | 215 | self._check_empty_active_record(active_record=empty_active_record) 216 | assert logger.warning.call_count == 3 # 3 logs from _check method 217 | logger.warning.assert_has_calls( 218 | calls=[ 219 | unittest.mock.call(msg=self.WARNING_MSG), 220 | unittest.mock.call(msg=self.WARNING_MSG), 221 | unittest.mock.call(msg=self.WARNING_MSG), 222 | ] 223 | ) 224 | 225 | new_empty_active_record = await empty_active_record.reload(session=mongodb_session) 226 | 227 | self._check_empty_active_record(active_record=new_empty_active_record) 228 | assert logger.warning.call_count == 7 # 3 earlier + 1 from reload and another 3 from _check method 229 | -------------------------------------------------------------------------------- /tests/test_repositories.py: -------------------------------------------------------------------------------- 1 | import bson 2 | import motor.motor_asyncio 3 | import pymongo.results 4 | import pytest 5 | 6 | pytestmark = [pytest.mark.asyncio] 7 | 8 | 9 | class TestBaseRepository: 10 | def test_properties(self, repository, db_manager): 11 | 12 | assert repository.db == db_manager.retrieve_database(name="test_db") 13 | assert isinstance(repository.db, motor.motor_asyncio.AsyncIOMotorDatabase) 14 | 15 | assert repository.col == repository.db["test_col"] 16 | assert isinstance(repository.col, motor.motor_asyncio.AsyncIOMotorCollection) 17 | 18 | @pytest.fixture() 19 | def document(self, faker): 20 | return {"_id": bson.ObjectId(), faker.pystr(): faker.pystr()} 21 | 22 | @pytest.fixture() 23 | def documents(self, faker): 24 | count = faker.pyint(min_value=2, max_value=6) 25 | return [{"_id": bson.ObjectId(), faker.pystr(): faker.pystr()} for _ in range(count)] 26 | 27 | def _check_update_one_result(self, update_result: pymongo.results.UpdateResult): 28 | assert isinstance(update_result, pymongo.results.UpdateResult) 29 | assert 1 == update_result.matched_count 30 | assert 1 == update_result.modified_count 31 | assert update_result.upserted_id is None 32 | assert update_result.acknowledged is True 33 | 34 | async def test_insert_one(self, repository, document, mongodb_session): 35 | insert_one_result = await repository.insert_one(document=document, session=mongodb_session) 36 | 37 | assert isinstance(insert_one_result, pymongo.results.InsertOneResult) 38 | assert isinstance(insert_one_result.inserted_id, bson.ObjectId) 39 | assert insert_one_result.acknowledged is True 40 | 41 | inserted_document = await repository.find_one(query={"_id": document["_id"]}, session=mongodb_session) 42 | 43 | assert document == inserted_document 44 | 45 | async def test_insert_many(self, repository, documents, mongodb_session): 46 | ids_list = [document["_id"] for document in documents] 47 | insert_many_result = await repository.insert_many(documents=documents, session=mongodb_session) 48 | 49 | assert isinstance(insert_many_result, pymongo.results.InsertManyResult) 50 | assert ids_list == insert_many_result.inserted_ids 51 | assert insert_many_result.acknowledged is True 52 | 53 | cursor = await repository.find(query={"_id": {"$in": ids_list}}, session=mongodb_session) 54 | 55 | assert documents == [document async for document in cursor] 56 | 57 | async def test_replace_one(self, repository, faker, mongodb_session): 58 | old_data_key, new_data_key = faker.pystr(), faker.pystr() 59 | old_document = {"_id": bson.ObjectId(), old_data_key: faker.pystr()} 60 | await repository.insert_one(document=old_document, session=mongodb_session) 61 | 62 | new_document = {new_data_key: faker.pystr()} 63 | 64 | replace_one_result = await repository.replace_one( 65 | query={"_id": old_document["_id"]}, replacement=new_document, session=mongodb_session 66 | ) 67 | 68 | self._check_update_one_result(update_result=replace_one_result) 69 | 70 | updated_old_document = await repository.find_one(query={"_id": old_document["_id"]}, session=mongodb_session) 71 | 72 | assert old_document["_id"] == updated_old_document["_id"] # check _id field not changed 73 | assert new_document[new_data_key] == updated_old_document[new_data_key] 74 | 75 | async def test_update_one(self, repository, faker, mongodb_session): 76 | old_data_key, new_data_value = faker.pystr(), faker.pystr() 77 | 78 | old_document = {"_id": bson.ObjectId(), old_data_key: faker.pystr()} 79 | await repository.insert_one(document=old_document, session=mongodb_session) 80 | 81 | update_one_result = await repository.update_one( 82 | query={"_id": old_document["_id"]}, update={"$set": {old_data_key: new_data_value}}, session=mongodb_session 83 | ) 84 | 85 | self._check_update_one_result(update_result=update_one_result) 86 | 87 | updated_old_document = await repository.find_one(query={"_id": old_document["_id"]}, session=mongodb_session) 88 | 89 | assert old_document["_id"] == updated_old_document["_id"] 90 | assert new_data_value == updated_old_document[old_data_key] 91 | 92 | async def test_update_many(self, repository, faker, mongodb_session): 93 | old_key, old_value, new_value = (faker.pystr(), faker.pystr(), faker.pystr()) 94 | documents = [ 95 | {"_id": bson.ObjectId(), old_key: old_value}, 96 | {"_id": bson.ObjectId(), old_key: old_value}, 97 | ] 98 | ids_list = [document["_id"] for document in documents] 99 | expected_documents = [ 100 | {"_id": documents[0]["_id"], old_key: new_value}, 101 | {"_id": documents[1]["_id"], old_key: new_value}, 102 | ] 103 | await repository.insert_many(documents=documents, session=mongodb_session) 104 | 105 | update_many_result = await repository.update_many( 106 | query={"_id": {"$in": ids_list}}, update={"$set": {old_key: new_value}}, session=mongodb_session 107 | ) 108 | 109 | assert isinstance(update_many_result, pymongo.results.UpdateResult) 110 | assert 2 == update_many_result.matched_count 111 | assert 2 == update_many_result.modified_count 112 | assert update_many_result.upserted_id is None 113 | assert update_many_result.acknowledged is True 114 | 115 | find_result = await repository.find(query={"_id": {"$in": ids_list}}, session=mongodb_session) 116 | 117 | assert expected_documents == [document async for document in find_result] 118 | 119 | async def test_delete_one(self, document, repository, mongodb_session): 120 | await repository.insert_one(document=document, session=mongodb_session) 121 | 122 | delete_result = await repository.delete_one(query={"_id": document["_id"]}, session=mongodb_session) 123 | 124 | assert isinstance(delete_result, pymongo.results.DeleteResult) 125 | assert 1 == delete_result.deleted_count 126 | assert delete_result.acknowledged is True 127 | 128 | async def test_delete_many(self, documents, repository, mongodb_session): 129 | ids_list = [document["_id"] for document in documents] 130 | await repository.insert_many(documents=documents, session=mongodb_session) 131 | 132 | delete_result = await repository.delete_many(query={"_id": {"$in": ids_list}}, session=mongodb_session) 133 | 134 | assert isinstance(delete_result, pymongo.results.DeleteResult) 135 | assert delete_result.deleted_count == len(documents) 136 | assert delete_result.acknowledged is True 137 | 138 | async def test_find(self, documents, repository, mongodb_session): 139 | ids_list = [document["_id"] for document in documents] 140 | await repository.insert_many(documents=documents, session=mongodb_session) 141 | 142 | find_result = await repository.find(query={"_id": {"$in": ids_list}}, session=mongodb_session) 143 | 144 | assert documents == [document async for document in find_result] 145 | 146 | async def test_find_one(self, document, repository, mongodb_session): 147 | await repository.insert_one(document=document, session=mongodb_session) 148 | 149 | find_one_result = await repository.find_one(query={"_id": document["_id"]}, session=mongodb_session) 150 | 151 | assert document == find_one_result 152 | 153 | async def test_find_one_and_delete(self, document, repository, mongodb_session): 154 | await repository.insert_one(document=document, session=mongodb_session) 155 | 156 | deleted_document = await repository.find_one_and_delete(query={"_id": document["_id"]}, session=mongodb_session) 157 | 158 | assert document == deleted_document 159 | 160 | async def test_find_one_and_replace(self, repository, faker, mongodb_session): 161 | new_key = faker.pystr() 162 | document = {"_id": bson.ObjectId(), faker.pystr(): faker.pystr()} 163 | replacement = {new_key: faker.pystr()} 164 | replaced_document = {"_id": document["_id"], new_key: replacement[new_key]} 165 | await repository.insert_one(document=document) 166 | query = {"_id": document["_id"]} 167 | 168 | result = await repository.find_one_and_replace(query=query, replacement=replacement) 169 | 170 | assert result == replaced_document 171 | 172 | async def test_find_one_and_update(self, repository, faker, mongodb_session): 173 | old_data_key, new_data_value = faker.pystr(), faker.pystr() 174 | 175 | old_document = {"_id": bson.ObjectId(), old_data_key: faker.pystr()} 176 | updated_document = {"_id": old_document["_id"], old_data_key: new_data_value} 177 | await repository.insert_one(document=old_document) 178 | 179 | find_one_and_update_result = await repository.find_one_and_update( 180 | query={"_id": old_document["_id"]}, 181 | update={"$set": {old_data_key: new_data_value}}, 182 | ) 183 | 184 | assert updated_document == find_one_and_update_result 185 | 186 | async def test_count_documents(self, documents, repository, mongodb_session): 187 | await repository.insert_many(documents=documents, session=mongodb_session) 188 | ids_list = [document["_id"] for document in documents] 189 | 190 | result = await repository.count_documents(query={"_id": {"$in": ids_list}}, session=mongodb_session) 191 | 192 | assert len(documents) == result 193 | 194 | async def test_estimated_document_count(self, documents, repository, mongodb_session): 195 | await repository.delete_many(query={}, session=mongodb_session) 196 | await repository.insert_many(documents=documents, session=mongodb_session) 197 | 198 | result = await repository.estimated_document_count() 199 | 200 | assert len(documents) == result 201 | 202 | async def test_aggregate(self, documents, repository, mongodb_session): 203 | await repository.insert_many(documents=documents, session=mongodb_session) 204 | ids_list = [document["_id"] for document in documents] 205 | match_stage = {"$match": {"_id": {"$in": ids_list}}} 206 | 207 | cursor = await repository.aggregate(pipeline=[match_stage], session=mongodb_session) 208 | 209 | assert documents == [document async for document in cursor] 210 | 211 | async def test_bulk_write(self, repository, faker, mongodb_session): 212 | _id, field_name, expected_count = bson.ObjectId(), faker.pystr(), 1 213 | _check_attrs_list = [ 214 | "deleted_count", 215 | "inserted_count", 216 | "matched_count", 217 | "modified_count", 218 | "upserted_count", 219 | ] 220 | operations = [ 221 | pymongo.operations.UpdateOne( 222 | filter={"_id": _id}, 223 | update={"$set": {field_name: faker.pystr()}}, 224 | upsert=True, 225 | ), 226 | pymongo.operations.InsertOne(document={"_id": bson.ObjectId(), field_name: faker.pystr()}), 227 | pymongo.operations.UpdateOne(filter={"_id": _id}, update={"$set": {field_name: faker.pystr()}}), 228 | pymongo.operations.DeleteOne(filter={"_id": _id}), 229 | ] 230 | 231 | result = await repository.bulk_write(operations=operations, session=mongodb_session) 232 | 233 | assert isinstance(result, pymongo.results.BulkWriteResult) 234 | assert result.acknowledged is True 235 | assert _id in result.upserted_ids.values() 236 | for attr in _check_attrs_list: 237 | assert expected_count == getattr(result, attr) 238 | 239 | async def test_watch(self, document, repository, mongodb_session): 240 | pipeline = [{"$match": {"operationType": "insert"}}] 241 | change_stream_1 = await repository.watch(pipeline=pipeline, session=mongodb_session) 242 | nothing_yet = await change_stream_1.try_next() 243 | assert nothing_yet is None 244 | await repository.insert_one(document=document, session=mongodb_session) 245 | 246 | change_stream_2 = await repository.watch( 247 | pipeline=pipeline, resume_after=change_stream_1.resume_token, session=mongodb_session 248 | ) 249 | streamed_document = await change_stream_2.try_next() 250 | 251 | assert document == streamed_document["fullDocument"] 252 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | 3 | import pytest 4 | from bson import ObjectId 5 | 6 | from fastapi_mongodb.types import OID 7 | 8 | 9 | class TestOID: 10 | @classmethod 11 | def setup_class(cls) -> None: 12 | cls.type_class = OID 13 | 14 | def test__modify_schema__(self): 15 | field_schema_mock = unittest.mock.MagicMock() 16 | 17 | self.type_class.__modify_schema__(field_schema=field_schema_mock) 18 | 19 | field_schema_mock.update.assert_called_once_with( 20 | pattern="^[a-f0-9]{24}$", 21 | example="5f5cf6f50cde9ec07786b294", 22 | title="ObjectId", 23 | type="string", 24 | ) 25 | 26 | def test_validate_error(self, faker): 27 | value = faker.pystr() 28 | 29 | with pytest.raises(ValueError) as exception_context: 30 | self.type_class.validate(v=value) 31 | 32 | assert f"'{value}' is not a valid ObjectId, it must be a 12-byte input or a 24-character hex string" == str( 33 | exception_context.value 34 | ) 35 | 36 | def test_validate_success(self, faker): 37 | object_id = ObjectId() 38 | 39 | result = self.type_class.validate(v=object_id) 40 | 41 | assert result == object_id 42 | --------------------------------------------------------------------------------