├── beanie ├── py.typed ├── odm │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── general.py │ │ ├── self_validation.py │ │ ├── projection.py │ │ ├── dump.py │ │ ├── relations.py │ │ ├── pydantic.py │ │ ├── typing.py │ │ └── state.py │ ├── interfaces │ │ ├── __init__.py │ │ ├── clone.py │ │ ├── detector.py │ │ ├── inheritance.py │ │ ├── session.py │ │ ├── getters.py │ │ └── setters.py │ ├── queries │ │ ├── __init__.py │ │ ├── cursor.py │ │ └── delete.py │ ├── settings │ │ ├── __init__.py │ │ ├── union_doc.py │ │ ├── view.py │ │ ├── base.py │ │ ├── document.py │ │ └── timeseries.py │ ├── custom_types │ │ ├── bson │ │ │ ├── __init__.py │ │ │ └── binary.py │ │ ├── __init__.py │ │ ├── decimal.py │ │ └── re.py │ ├── operators │ │ ├── find │ │ │ ├── __init__.py │ │ │ ├── bitwise.py │ │ │ ├── element.py │ │ │ └── array.py │ │ ├── update │ │ │ ├── __init__.py │ │ │ ├── bitwise.py │ │ │ └── array.py │ │ └── __init__.py │ ├── enums.py │ ├── models.py │ ├── registry.py │ ├── cache.py │ ├── views.py │ └── union_doc.py ├── executors │ └── __init__.py ├── migrations │ ├── __init__.py │ ├── controllers │ │ ├── __init__.py │ │ ├── base.py │ │ └── free_fall.py │ ├── template.py │ ├── utils.py │ ├── database.py │ └── models.py ├── exceptions.py ├── __init__.py └── operators.py ├── tests ├── __init__.py ├── odm │ ├── __init__.py │ ├── query │ │ ├── __init__.py │ │ └── test_update_methods.py │ ├── documents │ │ ├── __init__.py │ │ ├── test_pydantic_config.py │ │ ├── test_exists.py │ │ ├── test_count.py │ │ ├── test_pydantic_extras.py │ │ ├── test_inspect.py │ │ ├── test_replace.py │ │ ├── test_distinct.py │ │ ├── test_sync.py │ │ ├── test_delete.py │ │ ├── test_aggregate.py │ │ ├── test_multi_model.py │ │ └── test_validation_on_save.py │ ├── operators │ │ ├── __init__.py │ │ ├── find │ │ │ ├── __init__.py │ │ │ ├── test_element.py │ │ │ ├── test_bitwise.py │ │ │ ├── test_array.py │ │ │ ├── test_logical.py │ │ │ ├── test_evaluation.py │ │ │ ├── test_comparison.py │ │ │ └── test_geospatial.py │ │ └── update │ │ │ ├── __init__.py │ │ │ ├── test_bitwise.py │ │ │ ├── test_array.py │ │ │ └── test_general.py │ ├── custom_types │ │ ├── __init__.py │ │ ├── test_decimal_annotation.py │ │ └── test_bson_binary.py │ ├── test_cursor.py │ ├── test_deprecated.py │ ├── test_root_models.py │ ├── views.py │ ├── test_views.py │ ├── test_concurrency.py │ ├── test_beanie_object_dumping.py │ ├── test_timesries_collection.py │ ├── test_typing_utils.py │ ├── test_id.py │ ├── test_consistency.py │ ├── test_lazy_parsing.py │ ├── test_expression_fields.py │ └── test_cache.py ├── fastapi │ ├── __init__.py │ ├── test_openapi_schema_generation.py │ ├── models.py │ ├── app.py │ ├── conftest.py │ ├── routes.py │ └── test_api.py ├── migrations │ ├── __init__.py │ ├── models.py │ ├── iterative │ │ ├── __init__.py │ │ ├── test_change_value_subfield.py │ │ ├── test_rename_field.py │ │ ├── test_change_value.py │ │ ├── test_change_subfield.py │ │ └── test_pack_unpack.py │ ├── conftest.py │ ├── migrations_for_test │ │ ├── many_migrations │ │ │ ├── 20210413170700_3_skip_backward.py │ │ │ ├── 20210413170709_3_skip_forward.py │ │ │ ├── 20210413170640_1.py │ │ │ ├── 20210413170645_2.py │ │ │ ├── 20210413170728_4_5.py │ │ │ └── 20210413170734_6_7.py │ │ ├── change_subfield_value │ │ │ └── 20210413143405_change_subfield_value.py │ │ ├── change_value │ │ │ └── 20210413115234_change_value.py │ │ ├── remove_index │ │ │ └── 20210414135045_remove_index.py │ │ ├── rename_field │ │ │ └── 20210407203225_rename_field.py │ │ ├── change_subfield │ │ │ └── 20210413152406_change_subfield.py │ │ ├── pack_unpack │ │ │ └── 20210413135927_pack_unpack.py │ │ ├── break │ │ │ └── 20210413211219_break.py │ │ └── free_fall │ │ │ └── 20210413210446_free_fall.py │ ├── test_break.py │ ├── test_free_fall.py │ └── test_remove_indexes.py ├── typing │ ├── __init__.py │ ├── models.py │ ├── aggregation.py │ ├── find.py │ └── decorators.py ├── test_beanie.py └── conftest.py ├── docs ├── CNAME ├── assets │ ├── favicon.png │ └── color_scheme.css ├── tutorial │ ├── delete.md │ ├── lazy_parse.md │ ├── on_save_validation.md │ ├── aggregate.md │ ├── cache.md │ ├── time_series.md │ ├── revision.md │ ├── init.md │ ├── insert.md │ ├── multi-model.md │ ├── views.md │ ├── indexes.md │ ├── actions.md │ └── update.md ├── publishing.md └── getting-started.md ├── .github ├── scripts │ └── handlers │ │ ├── __init__.py │ │ └── gh.py ├── workflows │ ├── github-actions-publish-docs.yml │ ├── github-actions-publish-project.yml │ ├── close_inactive_issues.yml │ └── github-actions-tests.yml └── ISSUE_TEMPLATE │ ├── config.yml │ └── bug_report.md ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── scripts ├── publish_docs.sh └── generate_changelog.py ├── .pypirc ├── .pre-commit-config.yaml └── assets └── logo └── logo.svg /beanie/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/executors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/utils/general.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | beanie-odm.dev -------------------------------------------------------------------------------- /tests/fastapi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/migrations/models.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/query/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/typing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/queries/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/settings/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/documents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/operators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/scripts/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/migrations/iterative/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/custom_types/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/operators/find/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/operators/update/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/migrations/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/odm/custom_types/bson/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanie/migrations/template.py: -------------------------------------------------------------------------------- 1 | class Forward: ... 2 | 3 | 4 | class Backward: ... 5 | -------------------------------------------------------------------------------- /docs/assets/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeanieODM/beanie/HEAD/docs/assets/favicon.png -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing 2 | ------------ 3 | 4 | Please check [this page in the documentation](https://beanie-odm.dev/development/). 5 | -------------------------------------------------------------------------------- /beanie/odm/settings/union_doc.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.settings.base import ItemSettings 2 | 3 | 4 | class UnionDocSettings(ItemSettings): ... 5 | -------------------------------------------------------------------------------- /beanie/odm/interfaces/clone.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | 4 | class CloneInterface: 5 | def clone(self): 6 | return deepcopy(self) 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | Code of Conduct 2 | --------------- 3 | 4 | Please check [this page in the documentation](https://roman-right.github.io/beanie/code-of-conduct/). 5 | -------------------------------------------------------------------------------- /beanie/odm/operators/find/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from beanie.odm.operators import BaseOperator 4 | 5 | 6 | class BaseFindOperator(BaseOperator, ABC): ... 7 | -------------------------------------------------------------------------------- /beanie/migrations/utils.py: -------------------------------------------------------------------------------- 1 | def update_dict(d, u): 2 | for k, v in u.items(): 3 | if isinstance(v, dict): 4 | d[k] = update_dict(d.get(k, {}), v) 5 | else: 6 | d[k] = v 7 | return d 8 | -------------------------------------------------------------------------------- /tests/odm/operators/update/test_bitwise.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.bitwise import Bit 2 | from tests.odm.models import Sample 3 | 4 | 5 | def test_bit(): 6 | q = Bit({Sample.integer: 2}) 7 | assert q == {"$bit": {"integer": 2}} 8 | -------------------------------------------------------------------------------- /docs/assets/color_scheme.css: -------------------------------------------------------------------------------- 1 | [data-md-color-scheme="slate"] { 2 | 3 | --md-primary-fg-color: #2e303e; /* Panel */ 4 | --md-default-bg-color: #2e303e; /* BG */ 5 | --md-typeset-a-color: #8e91aa; 6 | --md-accent-fg-color: #658eda 7 | } -------------------------------------------------------------------------------- /scripts/publish_docs.sh: -------------------------------------------------------------------------------- 1 | pydoc-markdown 2 | cd docs/build 3 | 4 | remote_repo="https://x-access-token:${GITHUB_TOKEN}@${GITHUB_DOMAIN:-"github.com"}/${GITHUB_REPOSITORY}.git" 5 | git remote rm origin 6 | git remote add origin "${remote_repo}" 7 | mkdocs gh-deploy --force -------------------------------------------------------------------------------- /tests/typing/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from beanie import Document 4 | 5 | 6 | class Test(Document): 7 | foo: str 8 | bar: str 9 | baz: str 10 | 11 | 12 | class ProjectionTest(BaseModel): 13 | foo: str 14 | bar: int 15 | -------------------------------------------------------------------------------- /beanie/odm/operators/update/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Mapping 3 | 4 | from beanie.odm.operators import BaseOperator 5 | 6 | 7 | class BaseUpdateOperator(BaseOperator): 8 | @property 9 | @abstractmethod 10 | def query(self) -> Mapping[str, Any]: ... 11 | -------------------------------------------------------------------------------- /beanie/odm/custom_types/__init__.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 2 | 3 | if IS_PYDANTIC_V2: 4 | from beanie.odm.custom_types.decimal import DecimalAnnotation 5 | else: 6 | from decimal import Decimal as DecimalAnnotation 7 | 8 | __all__ = [ 9 | "DecimalAnnotation", 10 | ] 11 | -------------------------------------------------------------------------------- /.pypirc: -------------------------------------------------------------------------------- 1 | [distutils] 2 | index-servers = 3 | pypi 4 | testpypi 5 | 6 | [pypi] 7 | repository = https://upload.pypi.org/legacy/ 8 | username = __token__ 9 | password = ${PYPI_TOKEN} 10 | 11 | [testpypi] 12 | repository = https://test.pypi.org/legacy/ 13 | username = roman-right 14 | password = =$C[wT}^]5EWvX(p#9Po -------------------------------------------------------------------------------- /beanie/odm/interfaces/detector.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ModelType(str, Enum): 5 | Document = "Document" 6 | View = "View" 7 | UnionDoc = "UnionDoc" 8 | 9 | 10 | class DetectionInterface: 11 | @classmethod 12 | def get_model_type(cls) -> ModelType: 13 | return ModelType.Document 14 | -------------------------------------------------------------------------------- /beanie/odm/custom_types/decimal.py: -------------------------------------------------------------------------------- 1 | import decimal 2 | 3 | import bson 4 | import pydantic 5 | from typing_extensions import Annotated 6 | 7 | DecimalAnnotation = Annotated[ 8 | decimal.Decimal, 9 | pydantic.BeforeValidator( 10 | lambda v: v.to_decimal() if isinstance(v, bson.Decimal128) else v 11 | ), 12 | ] 13 | -------------------------------------------------------------------------------- /tests/odm/documents/test_pydantic_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from tests.odm.models import DocumentWithPydanticConfig 5 | 6 | 7 | def test_pydantic_config(): 8 | doc = DocumentWithPydanticConfig(num_1=2) 9 | with pytest.raises(ValidationError): 10 | doc.num_1 = "wrong" 11 | -------------------------------------------------------------------------------- /beanie/odm/settings/view.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Type, Union 2 | 3 | from pydantic import Field 4 | 5 | from beanie.odm.settings.base import ItemSettings 6 | 7 | 8 | class ViewSettings(ItemSettings): 9 | source: Union[str, Type] 10 | pipeline: List[Dict[str, Any]] 11 | 12 | max_nesting_depths_per_field: dict = Field(default_factory=dict) 13 | max_nesting_depth: int = 3 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/charliermarsh/ruff-pre-commit 3 | rev: v0.11.6 4 | hooks: 5 | - id: ruff 6 | args: [ --fix ] 7 | - id: ruff-format 8 | - repo: https://github.com/pre-commit/mirrors-mypy 9 | rev: v1.15.0 10 | hooks: 11 | - id: mypy 12 | additional_dependencies: 13 | - types-click 14 | exclude: ^tests/ 15 | -------------------------------------------------------------------------------- /beanie/odm/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import pymongo 4 | 5 | 6 | class SortDirection(int, Enum): 7 | """ 8 | Sorting directions 9 | """ 10 | 11 | ASCENDING = pymongo.ASCENDING 12 | DESCENDING = pymongo.DESCENDING 13 | 14 | 15 | class InspectionStatuses(str, Enum): 16 | """ 17 | Statuses of the collection inspection 18 | """ 19 | 20 | FAIL = "FAIL" 21 | OK = "OK" 22 | -------------------------------------------------------------------------------- /tests/odm/test_cursor.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_to_list(documents): 5 | await documents(10) 6 | result = await DocumentTestModel.find_all().to_list() 7 | assert len(result) == 10 8 | 9 | 10 | async def test_async_for(documents): 11 | await documents(10) 12 | async for document in DocumentTestModel.find_all(): 13 | assert document.test_int in list(range(10)) 14 | -------------------------------------------------------------------------------- /tests/migrations/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie import init_beanie 4 | from beanie.migrations.models import MigrationLog 5 | 6 | 7 | @pytest.fixture(autouse=True) 8 | async def init(db): 9 | await init_beanie( 10 | database=db, 11 | document_models=[ 12 | MigrationLog, 13 | ], 14 | ) 15 | 16 | 17 | @pytest.fixture(autouse=True) 18 | async def remove_migrations_log(db, init): 19 | await MigrationLog.delete_all() 20 | -------------------------------------------------------------------------------- /.github/workflows/github-actions-publish-docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | publish_docs: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.10.9 14 | - name: install dependencies 15 | run: pip3 install .[doc] 16 | - name: publish docs 17 | run: bash scripts/publish_docs.sh -------------------------------------------------------------------------------- /.github/workflows/github-actions-publish-project.yml: -------------------------------------------------------------------------------- 1 | name: Publish project 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | publish_project: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: install flit 12 | run: pip3 install flit 13 | - name: publish project 14 | env: 15 | FLIT_USERNAME: __token__ 16 | FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }} 17 | run: flit publish -------------------------------------------------------------------------------- /beanie/migrations/controllers/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Type 3 | 4 | from beanie.odm.documents import Document 5 | 6 | 7 | class BaseMigrationController(ABC): 8 | def __init__(self, function): 9 | self.function = function 10 | 11 | @abstractmethod 12 | async def run(self, session): 13 | pass 14 | 15 | @property 16 | @abstractmethod 17 | def models(self) -> List[Type[Document]]: 18 | pass 19 | -------------------------------------------------------------------------------- /beanie/migrations/database.py: -------------------------------------------------------------------------------- 1 | from pymongo import AsyncMongoClient 2 | 3 | 4 | class DBHandler: 5 | @classmethod 6 | def set_db(cls, uri, db_name): 7 | cls.client = AsyncMongoClient(uri) 8 | cls.database = cls.client[db_name] 9 | 10 | @classmethod 11 | def get_cli(cls): 12 | return cls.client if hasattr(cls, "client") else None 13 | 14 | @classmethod 15 | def get_db(cls): 16 | return cls.database if hasattr(cls, "database") else None 17 | -------------------------------------------------------------------------------- /tests/odm/test_deprecated.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie import init_beanie 4 | from beanie.exceptions import Deprecation 5 | from tests.odm.models import DocWithCollectionInnerClass 6 | 7 | 8 | class TestDeprecations: 9 | async def test_doc_with_inner_collection_class_init(self, db): 10 | with pytest.raises(Deprecation): 11 | await init_beanie( 12 | database=db, 13 | document_models=[DocWithCollectionInnerClass], 14 | ) 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Question 4 | url: 'https://github.com/roman-right/beanie/discussions/new?category=question' 5 | about: Ask a question about how to use Beanie using github discussions 6 | 7 | - name: Feature Request 8 | url: 'https://github.com/roman-right/beanie/discussions/new?category=feature-request' 9 | about: > 10 | If you think we should add a new feature to Beanie, please start a discussion, once it attracts 11 | wider support, it can be migrated to an issue -------------------------------------------------------------------------------- /tests/test_beanie.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | if sys.version_info >= (3, 11): 5 | import tomllib 6 | else: 7 | import tomli as tomllib 8 | 9 | from beanie import __version__ 10 | 11 | 12 | def parse_version_from_pyproject(): 13 | pyproject = Path(__file__).parent.parent / "pyproject.toml" 14 | with pyproject.open("rb") as f: 15 | toml_data = tomllib.load(f) 16 | return toml_data["project"]["version"] 17 | 18 | 19 | def test_version(): 20 | assert __version__ == parse_version_from_pyproject() 21 | -------------------------------------------------------------------------------- /docs/tutorial/delete.md: -------------------------------------------------------------------------------- 1 | # Delete documents 2 | 3 | Beanie supports single and batch deletions: 4 | 5 | ## Single 6 | 7 | ```python 8 | 9 | await Product.find_one(Product.name == "Milka").delete() 10 | 11 | # Or 12 | bar = await Product.find_one(Product.name == "Milka") 13 | await bar.delete() 14 | ``` 15 | 16 | ## Many 17 | 18 | ```python 19 | 20 | await Product.find(Product.category.name == "Chocolate").delete() 21 | ``` 22 | 23 | ## All 24 | 25 | ```python 26 | 27 | await Product.delete_all() 28 | # Or 29 | await Product.all().delete() 30 | 31 | ``` -------------------------------------------------------------------------------- /beanie/odm/interfaces/inheritance.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | ClassVar, 3 | Dict, 4 | Optional, 5 | Type, 6 | ) 7 | 8 | 9 | class InheritanceInterface: 10 | _children: ClassVar[Dict[str, Type]] 11 | _parent: ClassVar[Optional[Type]] 12 | _inheritance_inited: ClassVar[bool] 13 | _class_id: ClassVar[Optional[str]] = None 14 | 15 | @classmethod 16 | def add_child(cls, name: str, clas: Type): 17 | cls._children[name] = clas 18 | if cls._parent is not None: 19 | cls._parent.add_child(name, clas) 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | ```python 15 | Please add a code snippet here, that reproduces the problem completely 16 | ``` 17 | 18 | **Expected behavior** 19 | A clear and concise description of what you expected to happen. 20 | 21 | **Additional context** 22 | Add any other context about the problem here. 23 | -------------------------------------------------------------------------------- /beanie/odm/interfaces/session.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pymongo.asynchronous.client_session import AsyncClientSession 4 | 5 | 6 | class SessionMethods: 7 | """ 8 | Session methods 9 | """ 10 | 11 | def set_session(self, session: Optional[AsyncClientSession] = None): 12 | """ 13 | Set session 14 | :param session: Optional[AsyncClientSession] - pymongo session 15 | :return: 16 | """ 17 | if session is not None: 18 | self.session: Optional[AsyncClientSession] = session 19 | return self 20 | -------------------------------------------------------------------------------- /beanie/odm/models.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | from beanie.odm.enums import InspectionStatuses 6 | from beanie.odm.fields import PydanticObjectId 7 | 8 | 9 | class InspectionError(BaseModel): 10 | """ 11 | Inspection error details 12 | """ 13 | 14 | document_id: PydanticObjectId 15 | error: str 16 | 17 | 18 | class InspectionResult(BaseModel): 19 | """ 20 | Collection inspection result 21 | """ 22 | 23 | status: InspectionStatuses = InspectionStatuses.OK 24 | errors: List[InspectionError] = [] 25 | -------------------------------------------------------------------------------- /tests/odm/documents/test_exists.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_count_with_filter_query(documents): 5 | await documents(4, "uno", True) 6 | await documents(2, "dos", True) 7 | await documents(1, "cuatro", True) 8 | e = await DocumentTestModel.find_many({"test_str": "dos"}).exists() 9 | assert e is True 10 | 11 | e = await DocumentTestModel.find_one({"test_str": "dos"}).exists() 12 | assert e is True 13 | 14 | e = await DocumentTestModel.find_many({"test_str": "wrong"}).exists() 15 | assert e is False 16 | -------------------------------------------------------------------------------- /tests/fastapi/test_openapi_schema_generation.py: -------------------------------------------------------------------------------- 1 | from json import dumps 2 | 3 | from fastapi.openapi.utils import get_openapi 4 | 5 | from tests.fastapi.app import app 6 | 7 | 8 | def test_openapi_schema_generation(): 9 | openapi_schema_json_str = dumps( 10 | get_openapi( 11 | title=app.title, 12 | version=app.version, 13 | openapi_version=app.openapi_version, 14 | description=app.description, 15 | routes=app.routes, 16 | ), 17 | ) 18 | 19 | assert openapi_schema_json_str is not None 20 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_element.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.element import Exists, Type 2 | from tests.odm.models import Sample 3 | 4 | 5 | async def test_exists(): 6 | q = Exists(Sample.integer, True) 7 | assert q == {"integer": {"$exists": True}} 8 | 9 | q = Exists(Sample.integer, False) 10 | assert q == {"integer": {"$exists": False}} 11 | 12 | q = Exists(Sample.integer) 13 | assert q == {"integer": {"$exists": True}} 14 | 15 | 16 | async def test_type(): 17 | q = Type(Sample.integer, "smth") 18 | assert q == {"integer": {"$type": "smth"}} 19 | -------------------------------------------------------------------------------- /beanie/odm/operators/update/bitwise.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from beanie.odm.operators.update import BaseUpdateOperator 4 | 5 | 6 | class BaseUpdateBitwiseOperator(BaseUpdateOperator, ABC): ... 7 | 8 | 9 | class Bit(BaseUpdateBitwiseOperator): 10 | """ 11 | `$bit` update query operator 12 | 13 | MongoDB doc: 14 | 15 | """ 16 | 17 | def __init__(self, expression: dict): 18 | self.expression = expression 19 | 20 | @property 21 | def query(self): 22 | return {"$bit": self.expression} 23 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/many_migrations/20210413170700_3_skip_backward.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_title(self, input_document: Note, output_document: Note): 22 | if input_document.title == "3": 23 | output_document.title = "three" 24 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/many_migrations/20210413170709_3_skip_forward.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Backward: 20 | @iterative_migration() 21 | async def change_title(self, input_document: Note, output_document: Note): 22 | if input_document.title == "three": 23 | output_document.title = "3" 24 | -------------------------------------------------------------------------------- /tests/odm/custom_types/test_decimal_annotation.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | 3 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 4 | from tests.odm.models import DocumentWithDecimalField 5 | 6 | 7 | def test_decimal_deserialize(): 8 | m = DocumentWithDecimalField(amt=Decimal("1.4")) 9 | if IS_PYDANTIC_V2: 10 | m_json = m.model_dump_json() 11 | m_from_json = DocumentWithDecimalField.model_validate_json(m_json) 12 | else: 13 | m_json = m.json() 14 | m_from_json = DocumentWithDecimalField.parse_raw(m_json) 15 | assert isinstance(m_from_json.amt, Decimal) 16 | assert m_from_json.amt == Decimal("1.4") 17 | -------------------------------------------------------------------------------- /beanie/odm/utils/self_validation.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import TYPE_CHECKING, TypeVar 3 | 4 | from typing_extensions import ParamSpec 5 | 6 | if TYPE_CHECKING: 7 | from beanie.odm.documents import AsyncDocMethod, DocType 8 | 9 | P = ParamSpec("P") 10 | R = TypeVar("R") 11 | 12 | 13 | def validate_self_before( 14 | f: "AsyncDocMethod[DocType, P, R]", 15 | ) -> "AsyncDocMethod[DocType, P, R]": 16 | @wraps(f) 17 | async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R: 18 | await self.validate_self(*args, **kwargs) 19 | return await f(self, *args, **kwargs) 20 | 21 | return wrapper 22 | -------------------------------------------------------------------------------- /beanie/odm/custom_types/bson/binary.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import bson 4 | import pydantic 5 | from typing_extensions import Annotated 6 | 7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 8 | 9 | 10 | def _to_bson_binary(value: Any) -> bson.Binary: 11 | return value if isinstance(value, bson.Binary) else bson.Binary(value) 12 | 13 | 14 | if IS_PYDANTIC_V2: 15 | BsonBinary = Annotated[ 16 | bson.Binary, pydantic.PlainValidator(_to_bson_binary) 17 | ] 18 | else: 19 | 20 | class BsonBinary(bson.Binary): # type: ignore[no-redef] 21 | @classmethod 22 | def __get_validators__(cls): 23 | yield _to_bson_binary 24 | -------------------------------------------------------------------------------- /tests/typing/aggregation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from tests.typing.models import ProjectionTest, Test 4 | 5 | 6 | async def aggregate() -> List[Dict[str, Any]]: 7 | result = await Test.aggregate([]).to_list() 8 | result_2 = await Test.find().aggregate([]).to_list() 9 | return result or result_2 10 | 11 | 12 | async def aggregate_with_projection() -> List[ProjectionTest]: 13 | result = ( 14 | await Test.find() 15 | .aggregate([], projection_model=ProjectionTest) 16 | .to_list() 17 | ) 18 | result_2 = await Test.aggregate( 19 | [], projection_model=ProjectionTest 20 | ).to_list() 21 | return result or result_2 22 | -------------------------------------------------------------------------------- /beanie/odm/custom_types/re.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import bson 4 | import pydantic 5 | from typing_extensions import Annotated 6 | 7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 8 | 9 | 10 | def _to_bson_regex(v): 11 | return v.try_compile() if isinstance(v, bson.Regex) else v 12 | 13 | 14 | if IS_PYDANTIC_V2: 15 | Pattern = Annotated[ 16 | re.Pattern, 17 | pydantic.BeforeValidator( 18 | lambda v: v.try_compile() if isinstance(v, bson.Regex) else v 19 | ), 20 | ] 21 | else: 22 | 23 | class Pattern(bson.Regex): # type: ignore[no-redef] 24 | @classmethod 25 | def __get_validators__(cls): 26 | yield _to_bson_regex 27 | -------------------------------------------------------------------------------- /tests/odm/test_root_models.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 2 | from tests.odm.models import DocumentWithRootModelAsAField 3 | 4 | if IS_PYDANTIC_V2: 5 | 6 | class TestRootModels: 7 | async def test_insert(self): 8 | doc = DocumentWithRootModelAsAField(pets=["dog", "cat", "fish"]) 9 | await doc.insert() 10 | 11 | new_doc = await DocumentWithRootModelAsAField.get(doc.id) 12 | assert new_doc.pets.root == ["dog", "cat", "fish"] 13 | 14 | collection = DocumentWithRootModelAsAField.get_pymongo_collection() 15 | raw_doc = await collection.find_one({"_id": doc.id}) 16 | assert raw_doc["pets"] == ["dog", "cat", "fish"] 17 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pymongo import AsyncMongoClient 3 | 4 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 5 | 6 | if IS_PYDANTIC_V2: 7 | from pydantic_settings import BaseSettings 8 | else: 9 | from pydantic import BaseSettings 10 | 11 | 12 | class Settings(BaseSettings): 13 | mongodb_dsn: str = "mongodb://localhost:27017/beanie_db" 14 | mongodb_db_name: str = "beanie_db" 15 | 16 | 17 | @pytest.fixture 18 | def settings(): 19 | return Settings() 20 | 21 | 22 | @pytest.fixture 23 | async def cli(settings): 24 | client = AsyncMongoClient(settings.mongodb_dsn) 25 | 26 | yield client 27 | 28 | await client.close() 29 | 30 | 31 | @pytest.fixture 32 | def db(cli, settings): 33 | return cli[settings.mongodb_db_name] 34 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/change_subfield_value/20210413143405_change_subfield_value.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_color(self, input_document: Note, output_document: Note): 22 | output_document.tag.color = "blue" 23 | 24 | 25 | class Backward: 26 | @iterative_migration() 27 | async def change_title(self, input_document: Note, output_document: Note): 28 | output_document.tag.color = "red" 29 | -------------------------------------------------------------------------------- /tests/odm/documents/test_count.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_count(documents): 5 | await documents(4, "uno", True) 6 | c = await DocumentTestModel.count() 7 | assert c == 4 8 | 9 | 10 | async def test_count_with_filter_query(documents): 11 | await documents(4, "uno", True) 12 | await documents(2, "dos", True) 13 | await documents(1, "cuatro", True) 14 | c = await DocumentTestModel.find_many({"test_str": "dos"}).count() 15 | assert c == 2 16 | 17 | 18 | async def test_count_with_limit(documents): 19 | await documents(5, "five", True) 20 | c = await DocumentTestModel.find_all().limit(1).count() 21 | assert c == 1 22 | d = await DocumentTestModel.find_all().count() 23 | assert d == 5 24 | -------------------------------------------------------------------------------- /tests/odm/views.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.fields import Link 2 | from beanie.odm.views import View 3 | from tests.odm.models import DocumentTestModel, DocumentTestModelWithLink 4 | 5 | 6 | class ViewForTest(View): 7 | number: int 8 | string: str 9 | 10 | class Settings: 11 | view_name = "test_view" 12 | source = DocumentTestModel 13 | pipeline = [ 14 | {"$match": {"$expr": {"$gt": ["$test_int", 8]}}}, 15 | {"$project": {"number": "$test_int", "string": "$test_str"}}, 16 | ] 17 | 18 | 19 | class ViewForTestWithLink(View): 20 | link: Link[DocumentTestModel] 21 | 22 | class Settings: 23 | view_name = "test_view_with_link" 24 | source = DocumentTestModelWithLink 25 | pipeline = [{"$set": {"link": "$test_link"}}] 26 | -------------------------------------------------------------------------------- /beanie/odm/interfaces/getters.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from pymongo.asynchronous.collection import AsyncCollection 4 | 5 | from beanie.odm.settings.base import ItemSettings 6 | 7 | 8 | class OtherGettersInterface: 9 | @classmethod 10 | @abstractmethod 11 | def get_settings(cls) -> ItemSettings: 12 | pass 13 | 14 | @classmethod 15 | def get_pymongo_collection(cls) -> AsyncCollection: 16 | return cls.get_settings().pymongo_collection 17 | 18 | @classmethod 19 | def get_collection_name(cls) -> str: 20 | return cls.get_settings().name # type: ignore 21 | 22 | @classmethod 23 | def get_bson_encoders(cls): 24 | return cls.get_settings().bson_encoders 25 | 26 | @classmethod 27 | def get_link_fields(cls): 28 | return None 29 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_bitwise.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.bitwise import ( 2 | BitsAllClear, 3 | BitsAllSet, 4 | BitsAnyClear, 5 | BitsAnySet, 6 | ) 7 | from tests.odm.models import Sample 8 | 9 | 10 | async def test_bits_all_clear(): 11 | q = BitsAllClear(Sample.integer, "smth") 12 | assert q == {"integer": {"$bitsAllClear": "smth"}} 13 | 14 | 15 | async def test_bits_all_set(): 16 | q = BitsAllSet(Sample.integer, "smth") 17 | assert q == {"integer": {"$bitsAllSet": "smth"}} 18 | 19 | 20 | async def test_any_clear(): 21 | q = BitsAnyClear(Sample.integer, "smth") 22 | assert q == {"integer": {"$bitsAnyClear": "smth"}} 23 | 24 | 25 | async def test_any_set(): 26 | q = BitsAnySet(Sample.integer, "smth") 27 | assert q == {"integer": {"$bitsAnySet": "smth"}} 28 | -------------------------------------------------------------------------------- /tests/odm/operators/update/test_array.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.array import ( 2 | AddToSet, 3 | Pop, 4 | Pull, 5 | PullAll, 6 | Push, 7 | ) 8 | from tests.odm.models import Sample 9 | 10 | 11 | def test_add_to_set(): 12 | q = AddToSet({Sample.integer: 2}) 13 | assert q == {"$addToSet": {"integer": 2}} 14 | 15 | 16 | def test_pop(): 17 | q = Pop({Sample.integer: 2}) 18 | assert q == {"$pop": {"integer": 2}} 19 | 20 | 21 | def test_pull(): 22 | q = Pull({Sample.integer: 2}) 23 | assert q == {"$pull": {"integer": 2}} 24 | 25 | 26 | def test_push(): 27 | q = Push({Sample.integer: 2}) 28 | assert q == {"$push": {"integer": 2}} 29 | 30 | 31 | def test_pull_all(): 32 | q = PullAll({Sample.integer: 2}) 33 | assert q == {"$pullAll": {"integer": 2}} 34 | -------------------------------------------------------------------------------- /beanie/migrations/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import Enum 3 | from typing import List, Optional 4 | 5 | from pydantic import Field 6 | from pydantic.main import BaseModel 7 | 8 | from beanie.odm.documents import Document 9 | 10 | 11 | class MigrationLog(Document): 12 | ts: datetime = Field(default_factory=datetime.now) 13 | name: str 14 | is_current: bool 15 | 16 | class Settings: 17 | name = "migrations_log" 18 | 19 | 20 | class RunningDirections(str, Enum): 21 | FORWARD = "FORWARD" 22 | BACKWARD = "BACKWARD" 23 | 24 | 25 | class RunningMode(BaseModel): 26 | direction: RunningDirections 27 | distance: int = 0 28 | 29 | 30 | class ParsedMigrations(BaseModel): 31 | path: str 32 | names: List[str] 33 | current: Optional[MigrationLog] = None 34 | -------------------------------------------------------------------------------- /beanie/odm/interfaces/setters.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, Optional 2 | 3 | from beanie.odm.settings.document import DocumentSettings 4 | 5 | 6 | class SettersInterface: 7 | _document_settings: ClassVar[Optional[DocumentSettings]] 8 | 9 | @classmethod 10 | def set_collection(cls, collection): 11 | """ 12 | Collection setter 13 | """ 14 | cls._document_settings.pymongo_collection = collection 15 | 16 | @classmethod 17 | def set_database(cls, database): 18 | """ 19 | Database setter 20 | """ 21 | cls._document_settings.pymongo_db = database 22 | 23 | @classmethod 24 | def set_collection_name(cls, name: str): 25 | """ 26 | Collection name setter 27 | """ 28 | cls._document_settings.name = name # type: ignore 29 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/many_migrations/20210413170640_1.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_title(self, input_document: Note, output_document: Note): 22 | if input_document.title == "1": 23 | output_document.title = "one" 24 | 25 | 26 | class Backward: 27 | @iterative_migration() 28 | async def change_title(self, input_document: Note, output_document: Note): 29 | if input_document.title == "one": 30 | output_document.title = "1" 31 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/many_migrations/20210413170645_2.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_title(self, input_document: Note, output_document: Note): 22 | if input_document.title == "2": 23 | output_document.title = "two" 24 | 25 | 26 | class Backward: 27 | @iterative_migration() 28 | async def change_title(self, input_document: Note, output_document: Note): 29 | if input_document.title == "two": 30 | output_document.title = "2" 31 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/change_value/20210413115234_change_value.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_title(self, input_document: Note, output_document: Note): 22 | if input_document.title == "5": 23 | output_document.title = "five" 24 | 25 | 26 | class Backward: 27 | @iterative_migration() 28 | async def change_title(self, input_document: Note, output_document: Note): 29 | if input_document.title == "five": 30 | output_document.title = "5" 31 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_array.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.array import All, ElemMatch, Size 2 | from tests.odm.models import PackageElemMatch, Sample 3 | 4 | 5 | async def test_all(): 6 | q = All(Sample.integer, [1, 2, 3]) 7 | assert q == {"integer": {"$all": [1, 2, 3]}} 8 | 9 | 10 | async def test_elem_match(): 11 | q = ElemMatch(Sample.integer, {"a": "b"}) 12 | assert q == {"integer": {"$elemMatch": {"a": "b"}}} 13 | 14 | 15 | async def test_size(): 16 | q = Size(Sample.integer, 4) 17 | assert q == {"integer": {"$size": 4}} 18 | 19 | 20 | async def test_elem_match_nested(): 21 | q = ElemMatch( 22 | PackageElemMatch.releases, major_ver=7, minor_ver=1, build_ver=0 23 | ) 24 | assert q == { 25 | "releases": { 26 | "$elemMatch": {"major_ver": 7, "minor_ver": 1, "build_ver": 0} 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /tests/fastapi/models.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import Field 4 | 5 | from beanie import Document, Indexed, Link 6 | from beanie.odm.fields import BackLink 7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 8 | 9 | 10 | class WindowAPI(Document): 11 | x: int 12 | y: int 13 | 14 | 15 | class DoorAPI(Document): 16 | t: int = 10 17 | 18 | 19 | class RoofAPI(Document): 20 | r: int = 100 21 | 22 | 23 | class HouseAPI(Document): 24 | windows: List[Link[WindowAPI]] 25 | name: Indexed(str) 26 | height: Indexed(int) = 2 27 | 28 | 29 | class House(Document): 30 | name: str 31 | owner: Link["Person"] 32 | 33 | 34 | class Person(Document): 35 | name: str 36 | house: BackLink[House] = ( 37 | Field(json_schema_extra={"original_field": "owner"}) 38 | if IS_PYDANTIC_V2 39 | else Field(original_field="owner") 40 | ) 41 | -------------------------------------------------------------------------------- /docs/tutorial/lazy_parse.md: -------------------------------------------------------------------------------- 1 | ## Using Lazy Parsing in Queries 2 | Lazy parsing allows you to skip the parsing and validation process for documents and instead call it on demand for each field separately. This can be useful for optimizing performance in certain scenarios. 3 | 4 | To use lazy parsing in your queries, you can pass the `lazy_parse=True` parameter to your find method. 5 | 6 | Here's an example of how to use lazy parsing in a find query: 7 | 8 | ```python 9 | await Sample.find(Sample.number == 10, lazy_parse=True).to_list() 10 | ``` 11 | 12 | By setting lazy_parse=True, the parsing and validation process will be skipped and be called on demand when the respective fields will be used. This can potentially improve the performance of your query by reducing the amount of processing required upfront. However, keep in mind that using lazy parsing may also introduce some additional overhead when accessing the fields later on. -------------------------------------------------------------------------------- /tests/fastapi/app.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | 3 | from fastapi import FastAPI 4 | from pymongo import AsyncMongoClient 5 | 6 | from beanie import init_beanie 7 | from tests.conftest import Settings 8 | from tests.fastapi.models import ( 9 | DoorAPI, 10 | House, 11 | HouseAPI, 12 | Person, 13 | RoofAPI, 14 | WindowAPI, 15 | ) 16 | from tests.fastapi.routes import house_router 17 | 18 | 19 | @asynccontextmanager 20 | async def live_span(_: FastAPI): 21 | # CREATE ASYNC PYMONGO CLIENT 22 | client = AsyncMongoClient(Settings().mongodb_dsn) 23 | 24 | # INIT BEANIE 25 | await init_beanie( 26 | client.beanie_db, 27 | document_models=[House, Person, HouseAPI, WindowAPI, DoorAPI, RoofAPI], 28 | ) 29 | 30 | yield 31 | 32 | 33 | app = FastAPI(lifespan=live_span) 34 | 35 | # ADD ROUTES 36 | app.include_router(house_router, prefix="/v1", tags=["house"]) 37 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/remove_index/20210414135045_remove_index.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class OldNote(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | indexes = ["title"] 18 | 19 | 20 | class Note(Document): 21 | title: str 22 | tag: Tag 23 | 24 | class Settings: 25 | name = "notes" 26 | indexes = [ 27 | "_id", 28 | ] 29 | 30 | 31 | class Forward: 32 | @iterative_migration() 33 | async def name_to_title( 34 | self, input_document: OldNote, output_document: Note 35 | ): ... 36 | 37 | 38 | class Backward: 39 | @iterative_migration() 40 | async def title_to_name( 41 | self, input_document: Note, output_document: OldNote 42 | ): ... 43 | -------------------------------------------------------------------------------- /docs/tutorial/on_save_validation.md: -------------------------------------------------------------------------------- 1 | # On save validation 2 | 3 | Pydantic has a very useful config to validate values on assignment - `validate_assignment = True`. 4 | But, unfortunately, this is an expensive operation and doesn't fit some use cases. 5 | You can validate all the values before saving the document (`insert`, `replace`, `save`, `save_changes`) 6 | with beanie config `validate_on_save` instead. 7 | 8 | This feature must be turned on in the `Settings` inner class explicitly: 9 | 10 | ```python 11 | class Sample(Document): 12 | num: int 13 | name: str 14 | 15 | class Settings: 16 | validate_on_save = True 17 | ``` 18 | 19 | If any field has a wrong value, 20 | it will raise an error on write operations (`insert`, `replace`, `save`, `save_changes`). 21 | 22 | ```python 23 | sample = Sample.find_one(Sample.name == "Test") 24 | sample.num = "wrong value type" 25 | 26 | # Next call will raise an error 27 | await sample.replace() 28 | ``` 29 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/rename_field/20210407203225_rename_field.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class OldNote(Document): 12 | name: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Note(Document): 20 | title: str 21 | tag: Tag 22 | 23 | class Settings: 24 | name = "notes" 25 | 26 | 27 | class Forward: 28 | @iterative_migration() 29 | async def name_to_title( 30 | self, input_document: OldNote, output_document: Note 31 | ): 32 | output_document.title = input_document.name 33 | 34 | 35 | class Backward: 36 | @iterative_migration() 37 | async def title_to_name( 38 | self, input_document: Note, output_document: OldNote 39 | ): 40 | output_document.name = input_document.title 41 | -------------------------------------------------------------------------------- /tests/odm/documents/test_pydantic_extras.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.odm.models import ( 4 | DocumentWithExtras, 5 | DocumentWithExtrasKw, 6 | DocumentWithPydanticConfig, 7 | ) 8 | 9 | 10 | async def test_pydantic_extras(): 11 | doc = DocumentWithExtras(num_1=2) 12 | doc.extra_value = "foo" 13 | await doc.save() 14 | 15 | loaded_doc = await DocumentWithExtras.get(doc.id) 16 | 17 | assert loaded_doc.extra_value == "foo" 18 | 19 | 20 | @pytest.mark.skip(reason="setting extra to allow via class kwargs not working") 21 | async def test_pydantic_extras_kw(): 22 | doc = DocumentWithExtrasKw(num_1=2) 23 | doc.extra_value = "foo" 24 | await doc.save() 25 | 26 | loaded_doc = await DocumentWithExtras.get(doc.id) 27 | 28 | assert loaded_doc.extra_value == "foo" 29 | 30 | 31 | async def test_fail_with_no_extras(): 32 | doc = DocumentWithPydanticConfig(num_1=2) 33 | with pytest.raises(ValueError): 34 | doc.extra_value = "foo" 35 | -------------------------------------------------------------------------------- /tests/odm/custom_types/test_bson_binary.py: -------------------------------------------------------------------------------- 1 | import bson 2 | import pytest 3 | 4 | from beanie import BsonBinary 5 | from beanie.odm.utils.pydantic import get_model_dump, parse_model 6 | from tests.odm.models import DocumentWithBsonBinaryField 7 | 8 | 9 | @pytest.mark.parametrize("binary_field", [bson.Binary(b"test"), b"test"]) 10 | async def test_bson_binary(binary_field): 11 | doc = DocumentWithBsonBinaryField(binary_field=binary_field) 12 | await doc.insert() 13 | assert doc.binary_field == BsonBinary(b"test") 14 | 15 | new_doc = await DocumentWithBsonBinaryField.get(doc.id) 16 | assert new_doc.binary_field == BsonBinary(b"test") 17 | 18 | 19 | @pytest.mark.parametrize("binary_field", [bson.Binary(b"test"), b"test"]) 20 | def test_bson_binary_roundtrip(binary_field): 21 | doc = DocumentWithBsonBinaryField(binary_field=binary_field) 22 | doc_dict = get_model_dump(doc) 23 | new_doc = parse_model(DocumentWithBsonBinaryField, doc_dict) 24 | assert new_doc == doc 25 | -------------------------------------------------------------------------------- /beanie/odm/operators/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections.abc import Mapping 3 | from copy import copy, deepcopy 4 | from typing import Any, Dict 5 | from typing import Mapping as MappingType 6 | 7 | 8 | class BaseOperator(Mapping): 9 | """ 10 | Base operator. 11 | """ 12 | 13 | @property 14 | @abstractmethod 15 | def query(self) -> MappingType[str, Any]: ... 16 | 17 | def __getitem__(self, item: str): 18 | return self.query[item] 19 | 20 | def __iter__(self): 21 | return iter(self.query) 22 | 23 | def __len__(self): 24 | return len(self.query) 25 | 26 | def __repr__(self): 27 | return repr(self.query) 28 | 29 | def __str__(self): 30 | return str(self.query) 31 | 32 | def __copy__(self): 33 | return copy(self.query) 34 | 35 | def __deepcopy__(self, memodict: Dict[str, Any] = {}): 36 | return deepcopy(self.query) 37 | 38 | def copy(self): 39 | return copy(self) 40 | -------------------------------------------------------------------------------- /beanie/odm/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, ForwardRef, Type, Union 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class DocsRegistry: 7 | _registry: Dict[str, Type[BaseModel]] = {} 8 | 9 | @classmethod 10 | def register(cls, name: str, doc_type: Type[BaseModel]): 11 | cls._registry[name] = doc_type 12 | 13 | @classmethod 14 | def get(cls, name: str) -> Type[BaseModel]: 15 | return cls._registry[name] 16 | 17 | @classmethod 18 | def evaluate_fr(cls, forward_ref: Union[ForwardRef, Type]): 19 | """ 20 | Evaluate forward ref 21 | 22 | :param forward_ref: ForwardRef - forward ref to evaluate 23 | :return: Type[BaseModel] - class of the forward ref 24 | """ 25 | if ( 26 | isinstance(forward_ref, ForwardRef) 27 | and forward_ref.__forward_arg__ in cls._registry 28 | ): 29 | return cls._registry[forward_ref.__forward_arg__] 30 | else: 31 | return forward_ref 32 | -------------------------------------------------------------------------------- /.github/workflows/close_inactive_issues.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity.' 16 | stale-pr-message: 'This PR is stale because it has been open 45 days with no activity.' 17 | close-issue-message: 'This issue was closed because it has been stalled for 14 days with no activity.' 18 | close-pr-message: 'This PR was closed because it has been stalled for 14 days with no activity.' 19 | exempt-issue-labels: 'bug,feature-request,typing bug,feature request,doc,documentation' 20 | days-before-issue-stale: 30 21 | days-before-pr-stale: 45 22 | days-before-issue-close: 14 23 | days-before-pr-close: 14 -------------------------------------------------------------------------------- /tests/fastapi/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from asgi_lifespan import LifespanManager 3 | from httpx import ASGITransport, AsyncClient 4 | 5 | from tests.fastapi.app import app 6 | from tests.fastapi.models import ( 7 | DoorAPI, 8 | House, 9 | HouseAPI, 10 | Person, 11 | RoofAPI, 12 | WindowAPI, 13 | ) 14 | 15 | 16 | @pytest.fixture(autouse=True) 17 | async def api_client(clean_db): 18 | """api client fixture.""" 19 | async with LifespanManager(app, startup_timeout=100, shutdown_timeout=100): 20 | server_name = "http://localhost" 21 | async with AsyncClient( 22 | transport=ASGITransport(app=app), base_url=server_name 23 | ) as ac: 24 | yield ac 25 | 26 | 27 | @pytest.fixture(autouse=True) 28 | async def clean_db(db): 29 | models = [House, Person, HouseAPI, WindowAPI, DoorAPI, RoofAPI] 30 | yield None 31 | 32 | for model in models: 33 | await model.get_pymongo_collection().drop() 34 | await model.get_pymongo_collection().drop_indexes() 35 | -------------------------------------------------------------------------------- /tests/typing/find.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from tests.typing.models import ProjectionTest, Test 4 | 5 | 6 | async def find_many() -> List[Test]: 7 | return await Test.find().to_list() 8 | 9 | 10 | async def find_many_with_projection() -> List[ProjectionTest]: 11 | return await Test.find().project(projection_model=ProjectionTest).to_list() 12 | 13 | 14 | async def find_many_generator() -> List[Test]: 15 | docs: List[Test] = [] 16 | async for doc in Test.find(): 17 | docs.append(doc) 18 | return docs 19 | 20 | 21 | async def find_many_generator_with_projection() -> List[ProjectionTest]: 22 | docs: List[ProjectionTest] = [] 23 | async for doc in Test.find().project(projection_model=ProjectionTest): 24 | docs.append(doc) 25 | return docs 26 | 27 | 28 | async def find_one() -> Optional[Test]: 29 | return await Test.find_one() 30 | 31 | 32 | async def find_one_with_projection() -> Optional[ProjectionTest]: 33 | return await Test.find_one().project(projection_model=ProjectionTest) 34 | -------------------------------------------------------------------------------- /tests/odm/test_views.py: -------------------------------------------------------------------------------- 1 | from tests.odm.views import ViewForTest, ViewForTestWithLink 2 | 3 | 4 | class TestViews: 5 | async def test_simple(self, documents): 6 | await documents(number=15) 7 | results = await ViewForTest.all().to_list() 8 | assert len(results) == 6 9 | 10 | async def test_aggregate(self, documents): 11 | await documents(number=15) 12 | results = await ViewForTest.aggregate( 13 | [ 14 | {"$set": {"test_field": 1}}, 15 | {"$match": {"$expr": {"$lt": ["$number", 12]}}}, 16 | ] 17 | ).to_list() 18 | assert len(results) == 3 19 | assert results[0]["test_field"] == 1 20 | 21 | async def test_link(self, documents_with_links): 22 | await documents_with_links() 23 | results = await ViewForTestWithLink.all().to_list() 24 | for document in results: 25 | await document.fetch_all_links() 26 | 27 | for i, document in enumerate(results): 28 | assert document.link.test_int == i 29 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/change_subfield/20210413152406_change_subfield.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class OldTag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Tag(BaseModel): 12 | color: str 13 | title: str 14 | 15 | 16 | class OldNote(Document): 17 | title: str 18 | tag: OldTag 19 | 20 | class Settings: 21 | name = "notes" 22 | 23 | 24 | class Note(Document): 25 | title: str 26 | tag: Tag 27 | 28 | class Settings: 29 | name = "notes" 30 | 31 | 32 | class Forward: 33 | @iterative_migration() 34 | async def change_color( 35 | self, input_document: OldNote, output_document: Note 36 | ): 37 | output_document.tag.title = input_document.tag.name 38 | 39 | 40 | class Backward: 41 | @iterative_migration() 42 | async def change_title( 43 | self, input_document: Note, output_document: OldNote 44 | ): 45 | output_document.tag.name = input_document.tag.title 46 | -------------------------------------------------------------------------------- /tests/odm/documents/test_inspect.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.models import InspectionStatuses 2 | from tests.odm.models import DocumentTestModel, DocumentTestModelFailInspection 3 | 4 | 5 | async def test_inspect_ok(documents): 6 | await documents(10, "smth") 7 | result = await DocumentTestModel.inspect_collection() 8 | assert result.status == InspectionStatuses.OK 9 | assert result.errors == [] 10 | 11 | 12 | async def test_inspect_fail(documents): 13 | await documents(10, "smth") 14 | result = await DocumentTestModelFailInspection.inspect_collection() 15 | assert result.status == InspectionStatuses.FAIL 16 | assert len(result.errors) == 10 17 | assert ( 18 | "1 validation error for DocumentTestModelFailInspection" 19 | in result.errors[0].error 20 | ) 21 | 22 | 23 | async def test_inspect_ok_with_session(documents, session): 24 | await documents(10, "smth") 25 | result = await DocumentTestModel.inspect_collection(session=session) 26 | assert result.status == InspectionStatuses.OK 27 | assert result.errors == [] 28 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/pack_unpack/20210413135927_pack_unpack.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class OldNote(Document): 12 | title: str 13 | tag_name: str 14 | tag_color: str 15 | 16 | class Settings: 17 | name = "notes" 18 | 19 | 20 | class Note(Document): 21 | title: str 22 | tag: Tag 23 | 24 | class Settings: 25 | name = "notes" 26 | 27 | 28 | class Forward: 29 | @iterative_migration() 30 | async def pack(self, input_document: OldNote, output_document: Note): 31 | output_document.tag = Tag( 32 | name=input_document.tag_name, color=input_document.tag_color 33 | ) 34 | 35 | 36 | class Backward: 37 | @iterative_migration() 38 | async def unpack(self, input_document: Note, output_document: OldNote): 39 | output_document.tag_name = input_document.tag.name 40 | output_document.tag_color = input_document.tag.color 41 | -------------------------------------------------------------------------------- /beanie/migrations/controllers/free_fall.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | from typing import Any, List, Type 3 | 4 | from beanie.migrations.controllers.base import BaseMigrationController 5 | from beanie.odm.documents import Document 6 | 7 | 8 | def free_fall_migration(document_models: List[Type[Document]]): 9 | class FreeFallMigrationController(BaseMigrationController): 10 | def __init__(self, function): 11 | self.function = function 12 | self.function_signature = signature(function) 13 | self.document_models = document_models 14 | 15 | def __call__(self, *args: Any, **kwargs: Any): 16 | pass 17 | 18 | @property 19 | def models(self) -> List[Type[Document]]: 20 | return self.document_models 21 | 22 | async def run(self, session): 23 | function_kwargs = {"session": session} 24 | if "self" in self.function_signature.parameters: 25 | function_kwargs["self"] = None 26 | await self.function(**function_kwargs) 27 | 28 | return FreeFallMigrationController 29 | -------------------------------------------------------------------------------- /docs/tutorial/aggregate.md: -------------------------------------------------------------------------------- 1 | # Aggregations 2 | 3 | You can perform aggregation queries through beanie as well. For example, to calculate the average: 4 | 5 | ```python 6 | # With a search: 7 | avg_price = await Product.find( 8 | Product.category.name == "Chocolate" 9 | ).avg(Product.price) 10 | 11 | # Over the whole collection: 12 | avg_price = await Product.avg(Product.price) 13 | ``` 14 | 15 | A full list of available methods can be found [here](../api-documentation/interfaces.md/#aggregatemethods). 16 | 17 | You can also use the native PyMongo syntax by calling the `aggregate` method. 18 | However, as Beanie will not know what output to expect, you will have to supply a projection model yourself. 19 | If you do not supply a projection model, then a dictionary will be returned. 20 | 21 | ```python 22 | class OutputItem(BaseModel): 23 | id: str = Field(None, alias="_id") 24 | total: float 25 | 26 | 27 | result = await Product.find( 28 | Product.category.name == "Chocolate").aggregate( 29 | [{"$group": {"_id": "$category.name", "total": {"$avg": "$price"}}}], 30 | projection_model=OutputItem 31 | ).to_list() 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /.github/workflows/github-actions-tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | pull_request: 4 | 5 | jobs: 6 | run-tests: 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] 11 | mongodb-version: [ "4.4", "5.0", "6.0", "7.0", "8.0" ] 12 | pydantic-version: [ "1.10.18", "2.9.2" , "2.10.4", "2.11"] 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | cache: pip 20 | cache-dependency-path: pyproject.toml 21 | - name: Start MongoDB 22 | uses: supercharge/mongodb-github-action@1.11.0 23 | with: 24 | mongodb-version: ${{ matrix.mongodb-version }} 25 | mongodb-replica-set: test-rs 26 | - name: install dependencies 27 | run: pip install .[test,ci] 28 | - name: install pydantic 29 | run: pip install pydantic==${{ matrix.pydantic-version }} 30 | - name: run tests 31 | env: 32 | PYTHON_JIT: 1 33 | run: pytest -v 34 | -------------------------------------------------------------------------------- /tests/odm/test_concurrency.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from pymongo import AsyncMongoClient 4 | 5 | from beanie import Document, init_beanie 6 | 7 | 8 | class SampleModel(Document): 9 | s: str = "TEST" 10 | i: int = 10 11 | 12 | 13 | class SampleModel2(SampleModel): ... 14 | 15 | 16 | class SampleModel3(SampleModel2): ... 17 | 18 | 19 | class TestConcurrency: 20 | async def test_without_init(self, settings): 21 | clients = [] 22 | for _ in range(10): 23 | client = AsyncMongoClient(settings.mongodb_dsn) 24 | clients.append(client) 25 | db = client[settings.mongodb_db_name] 26 | await init_beanie( 27 | db, document_models=[SampleModel3, SampleModel, SampleModel2] 28 | ) 29 | 30 | async def insert_find(): 31 | await SampleModel2().insert() 32 | docs = await SampleModel2.find(SampleModel2.i == 10).to_list() 33 | return docs 34 | 35 | await asyncio.gather(*[insert_find() for _ in range(10)]) 36 | 37 | await SampleModel2.delete_all() 38 | 39 | [await client.close() for client in clients] 40 | -------------------------------------------------------------------------------- /docs/tutorial/cache.md: -------------------------------------------------------------------------------- 1 | # Cache 2 | All query results could be locally cached. 3 | 4 | This feature must be explicitly turned on in the `Settings` inner class. 5 | 6 | ```python 7 | class Sample(Document): 8 | num: int 9 | name: str 10 | 11 | class Settings: 12 | use_cache = True 13 | ``` 14 | 15 | Beanie uses LRU cache with expiration time. 16 | You can set `capacity` (the maximum number of the cached queries) and expiration time in the `Settings` inner class. 17 | 18 | ```python 19 | class Sample(Document): 20 | num: int 21 | name: str 22 | 23 | class Settings: 24 | use_cache = True 25 | cache_expiration_time = datetime.timedelta(seconds=10) 26 | cache_capacity = 5 27 | ``` 28 | 29 | Any query will be cached for this document class. 30 | 31 | ```python 32 | # on the first call it will go to the database 33 | samples = await Sample.find(num>10).to_list() 34 | 35 | # on the second - it will use cache instead 36 | samples = await Sample.find(num>10).to_list() 37 | 38 | await asyncio.sleep(15) 39 | 40 | # if the expiration time was reached it will go to the database again 41 | samples = await Sample.find(num>10).to_list() 42 | ``` -------------------------------------------------------------------------------- /docs/tutorial/time_series.md: -------------------------------------------------------------------------------- 1 | # Time series 2 | 3 | You can set up a timeseries collection using the inner `Settings` class. 4 | 5 | **Be aware, timeseries collections a supported by MongoDB 5.0 and higher only. The fields `bucket_max_span_seconds` and `bucket_rounding_seconds` however require MongoDB 6.3 or higher** 6 | 7 | ```python 8 | from datetime import datetime 9 | 10 | from beanie import Document, TimeSeriesConfig, Granularity 11 | from pydantic import Field 12 | 13 | 14 | class Sample(Document): 15 | ts: datetime = Field(default_factory=datetime.now) 16 | meta: str 17 | 18 | class Settings: 19 | timeseries = TimeSeriesConfig( 20 | time_field="ts", # Required 21 | meta_field="meta", # Optional 22 | granularity=Granularity.hours, # Optional 23 | bucket_max_span_seconds=3600, # Optional 24 | bucket_rounding_seconds=3600, # Optional 25 | expire_after_seconds=2 # Optional 26 | ) 27 | ``` 28 | 29 | TimeSeriesConfig fields reflect the respective parameters of the MongoDB timeseries creation function. 30 | 31 | MongoDB documentation: https://docs.mongodb.com/manual/core/timeseries-collections/ -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/break/20210413211219_break.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, Indexed, PydanticObjectId, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class OldNote(Document): 12 | name: Indexed(str, unique=True) 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Note(Document): 20 | name: Indexed(str, unique=True) 21 | title: str 22 | tag: Tag 23 | 24 | class Settings: 25 | name = "notes" 26 | 27 | 28 | fixed_id = PydanticObjectId("6076f1f3e4b7f6b7a0f6e5a0") 29 | 30 | 31 | class Forward: 32 | @iterative_migration(batch_size=2) 33 | async def name_to_title( 34 | self, input_document: OldNote, output_document: Note 35 | ): 36 | output_document.title = input_document.name 37 | if output_document.title > "5": 38 | output_document.name = "5" 39 | 40 | 41 | class Backward: 42 | @iterative_migration() 43 | async def title_to_name( 44 | self, input_document: Note, output_document: OldNote 45 | ): 46 | output_document.name = input_document.title 47 | -------------------------------------------------------------------------------- /beanie/odm/settings/base.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Dict, Optional, Type 3 | 4 | from pydantic import BaseModel, Field 5 | from pymongo.asynchronous.collection import AsyncCollection 6 | from pymongo.asynchronous.database import AsyncDatabase 7 | 8 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 9 | 10 | if IS_PYDANTIC_V2: 11 | from pydantic import ConfigDict 12 | 13 | 14 | class ItemSettings(BaseModel): 15 | name: Optional[str] = None 16 | 17 | use_cache: bool = False 18 | cache_capacity: int = 32 19 | cache_expiration_time: timedelta = timedelta(minutes=10) 20 | bson_encoders: Dict[Any, Any] = Field(default_factory=dict) 21 | projection: Optional[Dict[str, Any]] = None 22 | 23 | pymongo_db: Optional[AsyncDatabase] = None 24 | pymongo_collection: Optional[AsyncCollection] = None 25 | 26 | union_doc: Optional[Type] = None 27 | union_doc_alias: Optional[str] = None 28 | class_id: str = "_class_id" 29 | 30 | is_root: bool = False 31 | 32 | if IS_PYDANTIC_V2: 33 | model_config = ConfigDict( 34 | arbitrary_types_allowed=True, 35 | ) 36 | else: 37 | 38 | class Config: 39 | arbitrary_types_allowed = True 40 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/free_fall/20210413210446_free_fall.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, free_fall_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class OldNote(Document): 12 | name: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Note(Document): 20 | title: str 21 | tag: Tag 22 | 23 | class Settings: 24 | name = "notes" 25 | 26 | 27 | class Forward: 28 | @free_fall_migration(document_models=[OldNote, Note]) 29 | async def name_to_title(self, session): 30 | async for old_note in OldNote.find_all(): 31 | new_note = Note( 32 | id=old_note.id, title=old_note.name, tag=old_note.tag 33 | ) 34 | await new_note.replace(session=session) 35 | 36 | 37 | class Backward: 38 | @free_fall_migration(document_models=[OldNote, Note]) 39 | async def title_to_name(self, session): 40 | async for old_note in Note.find_all(): 41 | new_note = OldNote( 42 | id=old_note.id, name=old_note.title, tag=old_note.tag 43 | ) 44 | await new_note.replace(session=session) 45 | -------------------------------------------------------------------------------- /tests/odm/documents/test_replace.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import Sample 2 | 3 | 4 | async def test_replace_one(preset_documents): 5 | count_1_before = await Sample.find_many(Sample.integer == 1).count() 6 | count_2_before = await Sample.find_many(Sample.integer == 2).count() 7 | 8 | a_2 = await Sample.find_one(Sample.integer == 2) 9 | await Sample.find_one(Sample.integer == 1).replace_one(a_2) 10 | 11 | count_1_after = await Sample.find_many(Sample.integer == 1).count() 12 | count_2_after = await Sample.find_many(Sample.integer == 2).count() 13 | 14 | assert count_1_after == count_1_before - 1 15 | assert count_2_after == count_2_before + 1 16 | 17 | 18 | async def test_replace_self(preset_documents): 19 | count_1_before = await Sample.find_many(Sample.integer == 1).count() 20 | count_2_before = await Sample.find_many(Sample.integer == 2).count() 21 | 22 | a_1 = await Sample.find_one(Sample.integer == 1) 23 | a_1.integer = 2 24 | await a_1.replace() 25 | 26 | count_1_after = await Sample.find_many(Sample.integer == 1).count() 27 | count_2_after = await Sample.find_many(Sample.integer == 2).count() 28 | 29 | assert count_1_after == count_1_before - 1 30 | assert count_2_after == count_2_before + 1 31 | -------------------------------------------------------------------------------- /beanie/exceptions.py: -------------------------------------------------------------------------------- 1 | class WrongDocumentUpdateStrategy(Exception): 2 | pass 3 | 4 | 5 | class DocumentNotFound(Exception): 6 | pass 7 | 8 | 9 | class DocumentAlreadyCreated(Exception): 10 | pass 11 | 12 | 13 | class DocumentWasNotSaved(Exception): 14 | pass 15 | 16 | 17 | class CollectionWasNotInitialized(Exception): 18 | pass 19 | 20 | 21 | class MigrationException(Exception): 22 | pass 23 | 24 | 25 | class ReplaceError(Exception): 26 | pass 27 | 28 | 29 | class StateManagementIsTurnedOff(Exception): 30 | pass 31 | 32 | 33 | class StateNotSaved(Exception): 34 | pass 35 | 36 | 37 | class RevisionIdWasChanged(Exception): 38 | pass 39 | 40 | 41 | class NotSupported(Exception): 42 | pass 43 | 44 | 45 | class MongoDBVersionError(Exception): 46 | pass 47 | 48 | 49 | class ViewWasNotInitialized(Exception): 50 | pass 51 | 52 | 53 | class ViewHasNoSettings(Exception): 54 | pass 55 | 56 | 57 | class UnionHasNoRegisteredDocs(Exception): 58 | pass 59 | 60 | 61 | class UnionDocNotInited(Exception): 62 | pass 63 | 64 | 65 | class DocWasNotRegisteredInUnionClass(Exception): 66 | pass 67 | 68 | 69 | class Deprecation(Exception): 70 | pass 71 | 72 | 73 | class ApplyChangesException(Exception): 74 | pass 75 | -------------------------------------------------------------------------------- /beanie/odm/settings/document.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pydantic import Field 4 | 5 | from beanie.odm.fields import IndexModelField 6 | from beanie.odm.settings.base import ItemSettings 7 | from beanie.odm.settings.timeseries import TimeSeriesConfig 8 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 9 | 10 | if IS_PYDANTIC_V2: 11 | from pydantic import ConfigDict 12 | 13 | 14 | class DocumentSettings(ItemSettings): 15 | use_state_management: bool = False 16 | state_management_replace_objects: bool = False 17 | state_management_save_previous: bool = False 18 | validate_on_save: bool = False 19 | use_revision: bool = False 20 | single_root_inheritance: bool = False 21 | 22 | indexes: List[IndexModelField] = Field(default_factory=list) 23 | merge_indexes: bool = False 24 | timeseries: Optional[TimeSeriesConfig] = None 25 | 26 | lazy_parsing: bool = False 27 | 28 | keep_nulls: bool = True 29 | 30 | max_nesting_depths_per_field: dict = Field(default_factory=dict) 31 | max_nesting_depth: int = 3 32 | 33 | if IS_PYDANTIC_V2: 34 | model_config = ConfigDict( 35 | arbitrary_types_allowed=True, 36 | ) 37 | else: 38 | 39 | class Config: 40 | arbitrary_types_allowed = True 41 | -------------------------------------------------------------------------------- /beanie/odm/utils/projection.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Type, TypeVar 2 | 3 | from pydantic import BaseModel 4 | 5 | from beanie.odm.interfaces.detector import ModelType 6 | from beanie.odm.utils.pydantic import get_config_value, get_model_fields 7 | 8 | ProjectionModelType = TypeVar("ProjectionModelType", bound=BaseModel) 9 | 10 | 11 | def get_projection( 12 | model: Type[ProjectionModelType], 13 | ) -> Optional[Dict[str, int]]: 14 | if hasattr(model, "get_model_type") and ( 15 | model.get_model_type() == ModelType.UnionDoc # type: ignore 16 | or ( # type: ignore 17 | model.get_model_type() == ModelType.Document # type: ignore 18 | and model._inheritance_inited # type: ignore 19 | ) 20 | ): # type: ignore 21 | return None 22 | 23 | if hasattr(model, "Settings"): # MyPy checks 24 | settings = getattr(model, "Settings") 25 | 26 | if hasattr(settings, "projection"): 27 | return getattr(settings, "projection") 28 | 29 | if get_config_value(model, "extra") == "allow": 30 | return None 31 | 32 | document_projection: Dict[str, int] = {} 33 | 34 | for name, field in get_model_fields(model).items(): 35 | document_projection[field.alias or name] = 1 36 | return document_projection 37 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_logical.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.odm.operators.find.logical import And, Nor, Not, Or 4 | from tests.odm.models import Sample 5 | 6 | 7 | async def test_and(): 8 | q = And(Sample.integer == 1) 9 | assert q == {"integer": 1} 10 | 11 | q = And(Sample.integer == 1, Sample.nested.integer > 3) 12 | assert q == {"$and": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} 13 | 14 | 15 | async def test_not(preset_documents): 16 | q = Not(Sample.integer == 1) 17 | assert q == {"integer": {"$not": {"$eq": 1}}} 18 | 19 | docs = await Sample.find(q).to_list() 20 | assert len(docs) == 7 21 | 22 | with pytest.raises(AttributeError): 23 | q = Not(And(Sample.integer == 1, Sample.nested.integer > 3)) 24 | await Sample.find(q).to_list() 25 | 26 | 27 | async def test_nor(): 28 | q = Nor(Sample.integer == 1) 29 | assert q == {"$nor": [{"integer": 1}]} 30 | 31 | q = Nor(Sample.integer == 1, Sample.nested.integer > 3) 32 | assert q == {"$nor": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} 33 | 34 | 35 | async def test_or(): 36 | q = Or(Sample.integer == 1) 37 | assert q == {"integer": 1} 38 | 39 | q = Or(Sample.integer == 1, Sample.nested.integer > 3) 40 | assert q == {"$or": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} 41 | -------------------------------------------------------------------------------- /docs/tutorial/revision.md: -------------------------------------------------------------------------------- 1 | # Revision 2 | 3 | This feature helps with concurrent operations. 4 | It stores `revision_id` together with the document and changes it on each document update. 5 | If the application with an older local copy of the document tries to change it, an exception will be raised. 6 | Only when the local copy is synced with the database, the application will be allowed to change the data. 7 | This helps to avoid data losses. 8 | 9 | ### Be aware 10 | revision id feature may work incorrectly with BulkWriter. 11 | 12 | ### Usage 13 | 14 | This feature must be explicitly turned on in the `Settings` inner class: 15 | 16 | ```python 17 | class Sample(Document): 18 | num: int 19 | name: str 20 | 21 | class Settings: 22 | use_revision = True 23 | ``` 24 | 25 | Any changing operation will check if the local copy of the document has the up-to-date `revision_id` value: 26 | 27 | ```python 28 | s = await Sample.find_one(Sample.name="TestName") 29 | s.num = 10 30 | 31 | # If a concurrent process already changed the doc, the next operation will raise an error 32 | await s.replace() 33 | ``` 34 | 35 | If you want to ignore revision and apply all the changes even if the local copy is outdated, 36 | you can use the `ignore_revision` parameter: 37 | 38 | ```python 39 | await s.replace(ignore_revision=True) 40 | ``` 41 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/many_migrations/20210413170728_4_5.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_title_4( 22 | self, input_document: Note, output_document: Note 23 | ): 24 | if input_document.title == "4": 25 | output_document.title = "four" 26 | 27 | @iterative_migration() 28 | async def change_title_5( 29 | self, input_document: Note, output_document: Note 30 | ): 31 | if input_document.title == "5": 32 | output_document.title = "five" 33 | 34 | 35 | class Backward: 36 | @iterative_migration() 37 | async def change_title_5( 38 | self, input_document: Note, output_document: Note 39 | ): 40 | if input_document.title == "five": 41 | output_document.title = "5" 42 | 43 | @iterative_migration() 44 | async def change_title_4( 45 | self, input_document: Note, output_document: Note 46 | ): 47 | if input_document.title == "four": 48 | output_document.title = "4" 49 | -------------------------------------------------------------------------------- /tests/migrations/migrations_for_test/many_migrations/20210413170734_6_7.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from beanie import Document, iterative_migration 4 | 5 | 6 | class Tag(BaseModel): 7 | color: str 8 | name: str 9 | 10 | 11 | class Note(Document): 12 | title: str 13 | tag: Tag 14 | 15 | class Settings: 16 | name = "notes" 17 | 18 | 19 | class Forward: 20 | @iterative_migration() 21 | async def change_title_6( 22 | self, input_document: Note, output_document: Note 23 | ): 24 | if input_document.title == "6": 25 | output_document.title = "six" 26 | 27 | @iterative_migration() 28 | async def change_title_7( 29 | self, input_document: Note, output_document: Note 30 | ): 31 | if input_document.title == "7": 32 | output_document.title = "seven" 33 | 34 | 35 | class Backward: 36 | @iterative_migration() 37 | async def change_title_7( 38 | self, input_document: Note, output_document: Note 39 | ): 40 | if input_document.title == "seven": 41 | output_document.title = "7" 42 | 43 | @iterative_migration() 44 | async def change_title_6( 45 | self, input_document: Note, output_document: Note 46 | ): 47 | if input_document.title == "six": 48 | output_document.title = "6" 49 | -------------------------------------------------------------------------------- /tests/odm/documents/test_distinct.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_distinct_unique(documents, document_not_inserted): 5 | await documents(1, "uno") 6 | await documents(2, "dos") 7 | await documents(3, "cuatro") 8 | expected_result = ["cuatro", "dos", "uno"] 9 | unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 10 | assert unique_test_strs == expected_result 11 | document_not_inserted.test_str = "uno" 12 | await document_not_inserted.insert() 13 | another_unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 14 | assert another_unique_test_strs == expected_result 15 | 16 | 17 | async def test_distinct_different_value(documents, document_not_inserted): 18 | await documents(1, "uno") 19 | await documents(2, "dos") 20 | await documents(3, "cuatro") 21 | expected_result = ["cuatro", "dos", "uno"] 22 | unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 23 | assert unique_test_strs == expected_result 24 | document_not_inserted.test_str = "diff_val" 25 | await document_not_inserted.insert() 26 | another_unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 27 | assert not another_unique_test_strs == expected_result 28 | assert another_unique_test_strs == ["cuatro", "diff_val", "dos", "uno"] 29 | -------------------------------------------------------------------------------- /tests/odm/operators/update/test_general.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.general import ( 2 | CurrentDate, 3 | Inc, 4 | Max, 5 | Min, 6 | Mul, 7 | Rename, 8 | Set, 9 | SetOnInsert, 10 | Unset, 11 | ) 12 | from tests.odm.models import Sample 13 | 14 | 15 | def test_set(): 16 | q = Set({Sample.integer: 2}) 17 | assert q == {"$set": {"integer": 2}} 18 | 19 | 20 | def test_current_date(): 21 | q = CurrentDate({Sample.integer: 2}) 22 | assert q == {"$currentDate": {"integer": 2}} 23 | 24 | 25 | def test_inc(): 26 | q = Inc({Sample.integer: 2}) 27 | assert q == {"$inc": {"integer": 2}} 28 | 29 | 30 | def test_min(): 31 | q = Min({Sample.integer: 2}) 32 | assert q == {"$min": {"integer": 2}} 33 | 34 | 35 | def test_max(): 36 | q = Max({Sample.integer: 2}) 37 | assert q == {"$max": {"integer": 2}} 38 | 39 | 40 | def test_mul(): 41 | q = Mul({Sample.integer: 2}) 42 | assert q == {"$mul": {"integer": 2}} 43 | 44 | 45 | def test_rename(): 46 | q = Rename({Sample.integer: 2}) 47 | assert q == {"$rename": {"integer": 2}} 48 | 49 | 50 | def test_set_on_insert(): 51 | q = SetOnInsert({Sample.integer: 2}) 52 | assert q == {"$setOnInsert": {"integer": 2}} 53 | 54 | 55 | def test_unset(): 56 | q = Unset({Sample.integer: 2}) 57 | assert q == {"$unset": {"integer": 2}} 58 | -------------------------------------------------------------------------------- /tests/odm/test_beanie_object_dumping.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel, Field 3 | 4 | from beanie import Link, PydanticObjectId 5 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 6 | from tests.odm.models import DocumentTestModelWithSoftDelete 7 | 8 | 9 | class TestModel(BaseModel): 10 | my_id: PydanticObjectId = Field(default_factory=PydanticObjectId) 11 | fake_doc: Link[DocumentTestModelWithSoftDelete] 12 | 13 | 14 | def data_maker(): 15 | return TestModel( 16 | my_id="5f4e3f3b7c0c9d001f7d4c8e", 17 | fake_doc=DocumentTestModelWithSoftDelete( 18 | test_int=1, test_str="test", id="5f4e3f3b7c0c9d001f7d4c8f" 19 | ), 20 | ) 21 | 22 | 23 | @pytest.mark.skipif( 24 | not IS_PYDANTIC_V2, 25 | reason="model dumping support is more complete with pydantic v2", 26 | ) 27 | def test_id_types_preserved_when_dumping_to_python(): 28 | dumped = data_maker().model_dump(mode="python") 29 | assert isinstance(dumped["my_id"], PydanticObjectId) 30 | assert isinstance(dumped["fake_doc"]["id"], PydanticObjectId) 31 | 32 | 33 | @pytest.mark.skipif( 34 | not IS_PYDANTIC_V2, 35 | reason="model dumping support is more complete with pydantic v2", 36 | ) 37 | def test_id_types_serialized_when_dumping_to_json(): 38 | dumped = data_maker().model_dump(mode="json") 39 | assert isinstance(dumped["my_id"], str) 40 | assert isinstance(dumped["fake_doc"]["id"], str) 41 | -------------------------------------------------------------------------------- /tests/odm/test_timesries_collection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie import init_beanie 4 | from beanie.exceptions import MongoDBVersionError 5 | from tests.odm.models import DocumentWithTimeseries 6 | 7 | 8 | async def test_timeseries_collection(db): 9 | build_info = await db.command({"buildInfo": 1}) 10 | mongo_version = build_info["version"] 11 | major_version = int(mongo_version.split(".")[0]) 12 | if major_version < 5: 13 | with pytest.raises(MongoDBVersionError): 14 | await init_beanie( 15 | database=db, document_models=[DocumentWithTimeseries] 16 | ) 17 | 18 | if major_version >= 5: 19 | await init_beanie( 20 | database=db, document_models=[DocumentWithTimeseries] 21 | ) 22 | info = await db.command( 23 | { 24 | "listCollections": 1, 25 | "filter": {"name": "DocumentWithTimeseries"}, 26 | } 27 | ) 28 | 29 | assert info["cursor"]["firstBatch"][0] == { 30 | "name": "DocumentWithTimeseries", 31 | "type": "timeseries", 32 | "options": { 33 | "expireAfterSeconds": 2, 34 | "timeseries": { 35 | "timeField": "ts", 36 | "granularity": "seconds", 37 | "bucketMaxSpanSeconds": 3600, 38 | }, 39 | }, 40 | "info": {"readOnly": False}, 41 | } 42 | -------------------------------------------------------------------------------- /beanie/odm/cache.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | from datetime import timedelta, timezone 4 | from typing import Any, Optional 5 | 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class CachedItem(BaseModel): 10 | timestamp: datetime.datetime = Field( 11 | default_factory=lambda: datetime.datetime.now(tz=timezone.utc) 12 | ) 13 | value: Any 14 | 15 | 16 | class LRUCache: 17 | def __init__(self, capacity: int, expiration_time: timedelta): 18 | self.capacity: int = capacity 19 | self.expiration_time: timedelta = expiration_time 20 | self.cache: collections.OrderedDict = collections.OrderedDict() 21 | 22 | def get(self, key) -> Optional[CachedItem]: 23 | try: 24 | item: CachedItem = self.cache.pop(key) 25 | if ( 26 | datetime.datetime.now(tz=timezone.utc) - item.timestamp 27 | > self.expiration_time 28 | ): 29 | return None 30 | self.cache[key] = item 31 | return item.value 32 | except KeyError: 33 | return None 34 | 35 | def set(self, key, value) -> None: 36 | try: 37 | self.cache.pop(key) 38 | except KeyError: 39 | if len(self.cache) >= self.capacity: 40 | self.cache.popitem(last=False) 41 | self.cache[key] = CachedItem(value=value) 42 | 43 | @staticmethod 44 | def create_key(*args): 45 | return str(args) # TODO think about this 46 | -------------------------------------------------------------------------------- /docs/tutorial/init.md: -------------------------------------------------------------------------------- 1 | Beanie uses Async PyMongo as an async database engine. 2 | To initialize previously created documents, you should provide an Async PyMongo database instance 3 | and a list of your document models to the `init_beanie(...)` function, as it is shown in the example: 4 | 5 | ```python 6 | from beanie import init_beanie, Document 7 | from pymongo import AsyncMongoClient 8 | 9 | class Sample(Document): 10 | name: str 11 | 12 | async def init(): 13 | # Create Async PyMongo client 14 | client = AsyncMongoClient( 15 | "mongodb://user:pass@host:27017" 16 | ) 17 | 18 | # Initialize beanie with the Sample document class and a database 19 | await init_beanie(database=client.db_name, document_models=[Sample]) 20 | ``` 21 | 22 | This creates the collection (if necessary) and sets up any indexes that are defined. 23 | 24 | 25 | `init_beanie` supports not only a list of classes as the document_models argument, 26 | but also strings with dot-separated paths: 27 | 28 | ```python 29 | await init_beanie( 30 | database=client.db_name, 31 | document_models=[ 32 | "app.models.DemoDocument", 33 | ], 34 | ) 35 | ``` 36 | 37 | ### Warning 38 | 39 | `init_beanie` supports the parameter named `allow_index_dropping` that will drop indexes from your collections. 40 | `allow_index_dropping` is by default set to `False`. If you set this to `True`, 41 | ensure that you are not managing your indexes in another manner. 42 | If you are, these will be deleted when setting `allow_index_dropping=True`. -------------------------------------------------------------------------------- /beanie/odm/utils/dump.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Set 2 | 3 | from beanie.odm.utils.encoder import Encoder 4 | 5 | if TYPE_CHECKING: 6 | from beanie.odm.documents import Document 7 | 8 | 9 | def get_dict( 10 | document: "Document", 11 | to_db: bool = False, 12 | exclude: Optional[Set[str]] = None, 13 | keep_nulls: bool = True, 14 | ): 15 | if exclude is None: 16 | exclude = set() 17 | if document.id is None: 18 | exclude.add("_id") 19 | include = set() 20 | if document.get_settings().use_revision: 21 | include.add("revision_id") 22 | encoder = Encoder( 23 | exclude=exclude, include=include, to_db=to_db, keep_nulls=keep_nulls 24 | ) 25 | return encoder.encode(document) 26 | 27 | 28 | def get_nulls( 29 | document: "Document", 30 | exclude: Optional[Set[str]] = None, 31 | ): 32 | dictionary = get_dict(document, exclude=exclude, keep_nulls=True) 33 | return filter_none(dictionary) 34 | 35 | 36 | def get_top_level_nones( 37 | document: "Document", 38 | exclude: Optional[Set[str]] = None, 39 | ): 40 | dictionary = get_dict(document, exclude=exclude, keep_nulls=True) 41 | return {k: v for k, v in dictionary.items() if v is None} 42 | 43 | 44 | def filter_none(d): 45 | result = {} 46 | for k, v in d.items(): 47 | if isinstance(v, dict): 48 | filtered = filter_none(v) 49 | if filtered: 50 | result[k] = filtered 51 | elif v is None: 52 | result[k] = v 53 | return result 54 | -------------------------------------------------------------------------------- /beanie/odm/operators/find/bitwise.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from beanie.odm.fields import ExpressionField 4 | from beanie.odm.operators.find import BaseFindOperator 5 | 6 | 7 | class BaseFindBitwiseOperator(BaseFindOperator): 8 | operator = "" 9 | 10 | def __init__(self, field: Union[str, ExpressionField], bitmask): 11 | self.field = field 12 | self.bitmask = bitmask 13 | 14 | @property 15 | def query(self): 16 | return {self.field: {self.operator: self.bitmask}} 17 | 18 | 19 | class BitsAllClear(BaseFindBitwiseOperator): 20 | """ 21 | `$bitsAllClear` query operator 22 | 23 | MongoDB doc: 24 | 25 | """ 26 | 27 | operator = "$bitsAllClear" 28 | 29 | 30 | class BitsAllSet(BaseFindBitwiseOperator): 31 | """ 32 | `$bitsAllSet` query operator 33 | 34 | MongoDB doc: 35 | https://docs.mongodb.com/manual/reference/operator/query/bitsAllSet/ 36 | """ 37 | 38 | operator = "$bitsAllSet" 39 | 40 | 41 | class BitsAnyClear(BaseFindBitwiseOperator): 42 | """ 43 | `$bitsAnyClear` query operator 44 | 45 | MongoDB doc: 46 | https://docs.mongodb.com/manual/reference/operator/query/bitsAnyClear/ 47 | """ 48 | 49 | operator = "$bitsAnyClear" 50 | 51 | 52 | class BitsAnySet(BaseFindBitwiseOperator): 53 | """ 54 | `$bitsAnySet` query operator 55 | 56 | MongoDB doc: 57 | https://docs.mongodb.com/manual/reference/operator/query/bitsAnySet/ 58 | """ 59 | 60 | operator = "$bitsAnySet" 61 | -------------------------------------------------------------------------------- /beanie/odm/utils/relations.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from typing import TYPE_CHECKING, Any, Dict 3 | from typing import Mapping as MappingType 4 | 5 | from beanie.odm.fields import ( 6 | ExpressionField, 7 | ) 8 | 9 | # from pydantic.fields import ModelField 10 | # from pydantic.typing import get_origin 11 | 12 | if TYPE_CHECKING: 13 | from beanie import Document 14 | 15 | 16 | def convert_ids( 17 | query: MappingType[str, Any], doc: "Document", fetch_links: bool 18 | ) -> Dict[str, Any]: 19 | # TODO add all the cases 20 | new_query = {} 21 | for k, v in query.items(): 22 | k_splitted = k.split(".") 23 | if ( 24 | isinstance(k, ExpressionField) 25 | and doc.get_link_fields() is not None 26 | and len(k_splitted) == 2 27 | and k_splitted[0] in doc.get_link_fields().keys() # type: ignore 28 | and k_splitted[1] == "id" 29 | ): 30 | if fetch_links: 31 | new_k = f"{k_splitted[0]}._id" 32 | else: 33 | new_k = f"{k_splitted[0]}.$id" 34 | else: 35 | new_k = k 36 | new_v: Any 37 | if isinstance(v, Mapping): 38 | new_v = convert_ids(v, doc, fetch_links) 39 | elif isinstance(v, list): 40 | new_v = [ 41 | convert_ids(ele, doc, fetch_links) 42 | if isinstance(ele, Mapping) 43 | else ele 44 | for ele in v 45 | ] 46 | else: 47 | new_v = v 48 | 49 | new_query[new_k] = new_v 50 | return new_query 51 | -------------------------------------------------------------------------------- /tests/odm/test_typing_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | from typing_extensions import Annotated 6 | 7 | from beanie import Document, Link 8 | from beanie.odm.fields import Indexed 9 | from beanie.odm.utils.pydantic import get_model_fields 10 | from beanie.odm.utils.typing import extract_id_class, get_index_attributes 11 | 12 | 13 | class Lock(Document): 14 | k: int 15 | 16 | 17 | class TestTyping: 18 | def test_extract_id_class(self): 19 | # Union 20 | assert extract_id_class(Union[str, int]) is str 21 | assert extract_id_class(Union[str, None]) is str 22 | assert extract_id_class(Union[str, None, int]) is str 23 | # Optional 24 | assert extract_id_class(Optional[str]) is str 25 | # Link 26 | assert extract_id_class(Link[Lock]) == Lock 27 | 28 | @pytest.mark.parametrize( 29 | "type,result", 30 | ( 31 | (str, None), 32 | (Indexed(str), (1, {})), 33 | (Indexed(str, "text", unique=True), ("text", {"unique": True})), 34 | (Annotated[str, Indexed()], (1, {})), 35 | ( 36 | Annotated[str, "other metadata", Indexed(unique=True)], 37 | (1, {"unique": True}), 38 | ), 39 | (Annotated[str, "other metadata"], None), 40 | ), 41 | ) 42 | def test_get_index_attributes(self, type, result): 43 | class Foo(BaseModel): 44 | bar: type 45 | 46 | field = get_model_fields(Foo)["bar"] 47 | assert get_index_attributes(field) == result 48 | -------------------------------------------------------------------------------- /tests/odm/documents/test_sync.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.exceptions import ApplyChangesException 4 | from beanie.odm.documents import MergeStrategy 5 | from tests.odm.models import DocumentToTestSync 6 | 7 | 8 | class TestSync: 9 | async def test_merge_remote(self): 10 | doc = DocumentToTestSync() 11 | await doc.insert() 12 | 13 | doc2 = await DocumentToTestSync.get(doc.id) 14 | doc2.s = "foo" 15 | 16 | doc.i = 100 17 | await doc.save() 18 | 19 | await doc2.sync() 20 | 21 | assert doc2.s == "TEST" 22 | assert doc2.i == 100 23 | 24 | async def test_merge_local(self): 25 | doc = DocumentToTestSync(d={"option_1": {"s": "foo"}}) 26 | await doc.insert() 27 | 28 | doc2 = await DocumentToTestSync.get(doc.id) 29 | doc2.s = "foo" 30 | doc2.n.option_1.s = "bar" 31 | doc2.d["option_1"]["s"] = "bar" 32 | 33 | doc.i = 100 34 | await doc.save() 35 | 36 | await doc2.sync(merge_strategy=MergeStrategy.local) 37 | 38 | assert doc2.s == "foo" 39 | assert doc2.n.option_1.s == "bar" 40 | assert doc2.d["option_1"]["s"] == "bar" 41 | 42 | assert doc2.i == 100 43 | 44 | async def test_merge_local_impossible_apply_changes(self): 45 | doc = DocumentToTestSync(d={"option_1": {"s": "foo"}}) 46 | await doc.insert() 47 | 48 | doc2 = await DocumentToTestSync.get(doc.id) 49 | doc2.d["option_1"]["s"] = {"foo": "bar"} 50 | 51 | doc.d = {"option_1": "nothing"} 52 | await doc.save() 53 | with pytest.raises(ApplyChangesException): 54 | await doc2.sync(merge_strategy=MergeStrategy.local) 55 | -------------------------------------------------------------------------------- /beanie/odm/operators/find/element.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import List, Union 3 | 4 | from beanie.odm.operators.find import BaseFindOperator 5 | 6 | 7 | class BaseFindElementOperator(BaseFindOperator, ABC): ... 8 | 9 | 10 | class Exists(BaseFindElementOperator): 11 | """ 12 | `$exists` query operator 13 | 14 | Example: 15 | 16 | ```python 17 | class Product(Document): 18 | price: float 19 | 20 | Exists(Product.price, True) 21 | ``` 22 | 23 | Will return query object like 24 | 25 | ```python 26 | {"price": {"$exists": True}} 27 | ``` 28 | 29 | MongoDB doc: 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | field, 36 | value: bool = True, 37 | ): 38 | self.field = field 39 | self.value = value 40 | 41 | @property 42 | def query(self): 43 | return {self.field: {"$exists": self.value}} 44 | 45 | 46 | class Type(BaseFindElementOperator): 47 | """ 48 | `$type` query operator 49 | 50 | Example: 51 | 52 | ```python 53 | class Product(Document): 54 | price: float 55 | 56 | Type(Product.price, "decimal") 57 | ``` 58 | 59 | Will return query object like 60 | 61 | ```python 62 | {"price": {"$type": "decimal"}} 63 | ``` 64 | 65 | MongoDB doc: 66 | 67 | """ 68 | 69 | def __init__(self, field, types: Union[List[str], str]): 70 | self.field = field 71 | self.types = types 72 | 73 | @property 74 | def query(self): 75 | return {self.field: {"$type": self.types}} 76 | -------------------------------------------------------------------------------- /tests/migrations/iterative/test_change_value_subfield.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.migrations.models import RunningDirections 7 | from beanie.odm.documents import Document 8 | from beanie.odm.models import InspectionStatuses 9 | 10 | 11 | class Tag(BaseModel): 12 | color: str 13 | name: str 14 | 15 | 16 | class Note(Document): 17 | title: str 18 | tag: Tag 19 | 20 | class Settings: 21 | name = "notes" 22 | 23 | 24 | @pytest.fixture() 25 | async def notes(db): 26 | await init_beanie(database=db, document_models=[Note]) 27 | await Note.delete_all() 28 | for i in range(10): 29 | note = Note(title=str(i), tag=Tag(name="test", color="red")) 30 | await note.insert() 31 | yield 32 | await Note.delete_all() 33 | 34 | 35 | async def test_migration_change_subfield_value(settings, notes, db): 36 | migration_settings = MigrationSettings( 37 | connection_uri=settings.mongodb_dsn, 38 | database_name=settings.mongodb_db_name, 39 | path="tests/migrations/migrations_for_test/change_subfield_value", 40 | ) 41 | await run_migrate(migration_settings) 42 | 43 | await init_beanie(database=db, document_models=[Note]) 44 | inspection = await Note.inspect_collection() 45 | assert inspection.status == InspectionStatuses.OK 46 | note = await Note.find_one({}) 47 | assert note.tag.color == "blue" 48 | 49 | migration_settings.direction = RunningDirections.BACKWARD 50 | await run_migrate(migration_settings) 51 | inspection = await Note.inspect_collection() 52 | assert inspection.status == InspectionStatuses.OK 53 | note = await Note.find_one({}) 54 | assert note.tag.color == "red" 55 | -------------------------------------------------------------------------------- /tests/odm/test_id.py: -------------------------------------------------------------------------------- 1 | from uuid import UUID 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from beanie import PydanticObjectId 7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 8 | from tests.odm.models import DocumentWithCustomIdInt, DocumentWithCustomIdUUID 9 | 10 | 11 | class A(BaseModel): 12 | id: PydanticObjectId 13 | 14 | 15 | async def test_uuid_id(): 16 | doc = DocumentWithCustomIdUUID(name="TEST") 17 | await doc.insert() 18 | new_doc = await DocumentWithCustomIdUUID.get(doc.id) 19 | assert isinstance(new_doc.id, UUID) 20 | 21 | 22 | async def test_integer_id(): 23 | doc = DocumentWithCustomIdInt(name="TEST", id=1) 24 | await doc.insert() 25 | new_doc = await DocumentWithCustomIdInt.get(doc.id) 26 | assert isinstance(new_doc.id, int) 27 | 28 | 29 | @pytest.mark.skipif( 30 | not IS_PYDANTIC_V2, 31 | reason="supports only pydantic v2", 32 | ) 33 | async def test_pydantic_object_id_validation_json(): 34 | deserialized = A.model_validate_json('{"id": "5eb7cf5a86d9755df3a6c593"}') 35 | assert isinstance(deserialized.id, PydanticObjectId) 36 | assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593" 37 | assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593") 38 | 39 | 40 | @pytest.mark.skipif( 41 | not IS_PYDANTIC_V2, 42 | reason="supports only pydantic v2", 43 | ) 44 | @pytest.mark.parametrize( 45 | "data", 46 | [ 47 | "5eb7cf5a86d9755df3a6c593", 48 | PydanticObjectId("5eb7cf5a86d9755df3a6c593"), 49 | ], 50 | ) 51 | async def test_pydantic_object_id_serialization(data): 52 | deserialized = A(**{"id": data}) 53 | assert isinstance(deserialized.id, PydanticObjectId) 54 | assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593" 55 | assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593") 56 | -------------------------------------------------------------------------------- /docs/tutorial/insert.md: -------------------------------------------------------------------------------- 1 | # Insert the documents 2 | 3 | Beanie documents behave just like pydantic models (because they subclass `pydantic.BaseModel`). 4 | Hence, a document can be created in a similar fashion to pydantic: 5 | 6 | ```python 7 | from typing import Optional 8 | 9 | from pydantic import BaseModel 10 | 11 | from beanie import Document, Indexed 12 | 13 | 14 | class Category(BaseModel): 15 | name: str 16 | description: str 17 | 18 | 19 | class Product(Document): # This is the model 20 | name: str 21 | description: Optional[str] = None 22 | price: Indexed(float) 23 | category: Category 24 | 25 | class Settings: 26 | name = "products" 27 | 28 | 29 | chocolate = Category(name="Chocolate", description="A preparation of roasted and ground cacao seeds.") 30 | tonybar = Product(name="Tony's", price=5.95, category=chocolate) 31 | marsbar = Product(name="Mars", price=1, category=chocolate) 32 | ``` 33 | 34 | This however does not save the documents to the database yet. 35 | 36 | ## Insert a single document 37 | 38 | To insert a document into the database, you can call either `insert()` or `create()` on it (they are synonyms): 39 | 40 | ```python 41 | await tonybar.insert() 42 | await marsbar.create() # does exactly the same as insert() 43 | ``` 44 | You can also call `save()`, which behaves in the same manner for new documents, but will also update existing documents. 45 | See the [section on updating](updating-&-deleting.md) of this tutorial for more details. 46 | 47 | If you prefer, you can also call the `insert_one` class method: 48 | 49 | ```python 50 | await Product.insert_one(tonybar) 51 | ``` 52 | 53 | ## Inserting many documents 54 | 55 | To reduce the number of database queries, 56 | similarly typed documents should be inserted together by calling the class method `insert_many`: 57 | 58 | ```python 59 | await Product.insert_many([tonybar,marsbar]) 60 | ``` 61 | -------------------------------------------------------------------------------- /tests/odm/documents/test_delete.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_delete_one(documents): 5 | await documents(4, "uno") 6 | await documents(2, "dos") 7 | await documents(1, "cuatro") 8 | await DocumentTestModel.find_one({"test_str": "uno"}).delete() 9 | documents = await DocumentTestModel.find_all().to_list() 10 | assert len(documents) == 6 11 | 12 | 13 | async def test_delete_one_not_found(documents): 14 | await documents(4, "uno") 15 | await documents(2, "dos") 16 | await documents(1, "cuatro") 17 | await DocumentTestModel.find_one({"test_str": "wrong"}).delete() 18 | documents = await DocumentTestModel.find_all().to_list() 19 | assert len(documents) == 7 20 | 21 | 22 | async def test_delete_many(documents): 23 | await documents(4, "uno") 24 | await documents(2, "dos") 25 | await documents(1, "cuatro") 26 | await DocumentTestModel.find_many({"test_str": "uno"}).delete() 27 | documents = await DocumentTestModel.find_all().to_list() 28 | assert len(documents) == 3 29 | 30 | 31 | async def test_delete_many_not_found(documents): 32 | await documents(4, "uno") 33 | await documents(2, "dos") 34 | await documents(1, "cuatro") 35 | await DocumentTestModel.find_many({"test_str": "wrong"}).delete() 36 | documents = await DocumentTestModel.find_all().to_list() 37 | assert len(documents) == 7 38 | 39 | 40 | async def test_delete_all(documents): 41 | await documents(4, "uno") 42 | await documents(2, "dos") 43 | await documents(1, "cuatro") 44 | await DocumentTestModel.delete_all() 45 | documents = await DocumentTestModel.find_all().to_list() 46 | assert len(documents) == 0 47 | 48 | 49 | async def test_delete(document): 50 | doc_id = document.id 51 | await document.delete() 52 | new_document = await DocumentTestModel.get(doc_id) 53 | assert new_document is None 54 | -------------------------------------------------------------------------------- /tests/odm/test_consistency.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.general import Set 2 | from tests.odm.models import ( 3 | DocumentTestModel, 4 | DocumentWithTimeStampToTestConsistency, 5 | ) 6 | 7 | 8 | class TestResponseOfTheChangingOperations: 9 | async def test_insert(self, document_not_inserted): 10 | result = await document_not_inserted.insert() 11 | assert isinstance(result, DocumentTestModel) 12 | 13 | async def test_update(self, document): 14 | result = await document.update(Set({"test_int": 43})) 15 | assert isinstance(result, DocumentTestModel) 16 | 17 | async def test_save(self, document, document_not_inserted): 18 | document.test_int = 43 19 | result = await document.save() 20 | assert isinstance(result, DocumentTestModel) 21 | 22 | document_not_inserted.test_int = 43 23 | result = await document_not_inserted.save() 24 | assert isinstance(result, DocumentTestModel) 25 | 26 | async def test_save_changes(self, document): 27 | document.test_int = 43 28 | result = await document.save_changes() 29 | assert isinstance(result, DocumentTestModel) 30 | 31 | async def test_replace(self, document): 32 | result = await document.replace() 33 | assert isinstance(result, DocumentTestModel) 34 | 35 | async def test_set(self, document): 36 | result = await document.set({"test_int": 43}) 37 | assert isinstance(result, DocumentTestModel) 38 | 39 | async def test_inc(self, document): 40 | result = await document.inc({"test_int": 1}) 41 | assert isinstance(result, DocumentTestModel) 42 | 43 | async def test_current_date(self): 44 | document = DocumentWithTimeStampToTestConsistency() 45 | await document.insert() 46 | result = await document.current_date({"ts": True}) 47 | assert isinstance(result, DocumentWithTimeStampToTestConsistency) 48 | -------------------------------------------------------------------------------- /tests/migrations/iterative/test_rename_field.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.migrations.models import RunningDirections 7 | from beanie.odm.documents import Document 8 | from beanie.odm.models import InspectionStatuses 9 | 10 | 11 | class Tag(BaseModel): 12 | color: str 13 | name: str 14 | 15 | 16 | class OldNote(Document): 17 | name: str 18 | tag: Tag 19 | 20 | class Settings: 21 | name = "notes" 22 | 23 | 24 | class Note(Document): 25 | title: str 26 | tag: Tag 27 | 28 | class Settings: 29 | name = "notes" 30 | 31 | 32 | @pytest.fixture() 33 | async def notes(db): 34 | await init_beanie(database=db, document_models=[OldNote]) 35 | await OldNote.delete_all() 36 | for i in range(10): 37 | note = OldNote(name=str(i), tag=Tag(name="test", color="red")) 38 | await note.insert() 39 | yield 40 | # await OldNote.delete_all() 41 | 42 | 43 | async def test_migration_rename_field(settings, notes, db): 44 | migration_settings = MigrationSettings( 45 | connection_uri=settings.mongodb_dsn, 46 | database_name=settings.mongodb_db_name, 47 | path="tests/migrations/migrations_for_test/rename_field", 48 | ) 49 | await run_migrate(migration_settings) 50 | 51 | await init_beanie(database=db, document_models=[Note]) 52 | inspection = await Note.inspect_collection() 53 | assert inspection.status == InspectionStatuses.OK 54 | note = await Note.find_one({}) 55 | assert note.title == "0" 56 | 57 | migration_settings.direction = RunningDirections.BACKWARD 58 | await run_migrate(migration_settings) 59 | inspection = await OldNote.inspect_collection() 60 | assert inspection.status == InspectionStatuses.OK 61 | note = await OldNote.find_one({}) 62 | assert note.name == "0" 63 | -------------------------------------------------------------------------------- /beanie/odm/utils/pydantic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type 2 | 3 | import pydantic 4 | from pydantic import BaseModel 5 | 6 | IS_PYDANTIC_V2 = int(pydantic.VERSION.split(".")[0]) >= 2 7 | IS_PYDANTIC_V2_10 = ( 8 | IS_PYDANTIC_V2 and int(pydantic.VERSION.split(".")[1]) >= 10 9 | ) 10 | 11 | if IS_PYDANTIC_V2: 12 | from pydantic import TypeAdapter 13 | else: 14 | from pydantic import parse_obj_as 15 | 16 | 17 | def parse_object_as(object_type: Type, data: Any): 18 | if IS_PYDANTIC_V2: 19 | return TypeAdapter(object_type).validate_python(data) 20 | else: 21 | return parse_obj_as(object_type, data) 22 | 23 | 24 | def get_field_type(field): 25 | if IS_PYDANTIC_V2: 26 | return field.annotation 27 | else: 28 | return field.outer_type_ 29 | 30 | 31 | def get_model_fields(model): 32 | if IS_PYDANTIC_V2: 33 | if not isinstance(model, type): 34 | model = model.__class__ 35 | return model.model_fields 36 | else: 37 | return model.__fields__ 38 | 39 | 40 | def parse_model(model_type: Type[BaseModel], data: Any): 41 | if IS_PYDANTIC_V2: 42 | return model_type.model_validate(data) 43 | else: 44 | return model_type.parse_obj(data) 45 | 46 | 47 | def get_extra_field_info(field, parameter: str): 48 | if IS_PYDANTIC_V2: 49 | if isinstance(field.json_schema_extra, dict): 50 | return field.json_schema_extra.get(parameter) 51 | return None 52 | else: 53 | return field.field_info.extra.get(parameter) 54 | 55 | 56 | def get_config_value(model, parameter: str): 57 | if IS_PYDANTIC_V2: 58 | return model.model_config.get(parameter) 59 | else: 60 | return getattr(model.Config, parameter, None) 61 | 62 | 63 | def get_model_dump(model, *args: Any, **kwargs: Any): 64 | if IS_PYDANTIC_V2: 65 | return model.model_dump(*args, **kwargs) 66 | else: 67 | return model.dict(*args, **kwargs) 68 | -------------------------------------------------------------------------------- /beanie/odm/settings/timeseries.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Annotated, Any, Dict, Optional 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class Granularity(str, Enum): 8 | """ 9 | Time Series Granuality 10 | """ 11 | 12 | seconds = "seconds" 13 | minutes = "minutes" 14 | hours = "hours" 15 | 16 | 17 | class TimeSeriesConfig(BaseModel): 18 | """ 19 | Time Series Collection config 20 | """ 21 | 22 | time_field: str 23 | meta_field: Optional[str] = None 24 | granularity: Optional[Granularity] = None 25 | bucket_max_span_seconds: Optional[int] = None 26 | bucket_rounding_second: Annotated[ 27 | Optional[int], 28 | Field( 29 | deprecated="This field is deprecated in favor of " 30 | "'bucket_rounding_seconds'.", 31 | ), 32 | ] = None 33 | bucket_rounding_seconds: Optional[int] = None 34 | expire_after_seconds: Optional[int] = None 35 | 36 | def build_query(self, collection_name: str) -> Dict[str, Any]: 37 | res: Dict[str, Any] = {"name": collection_name} 38 | timeseries: Dict[str, Any] = {"timeField": self.time_field} 39 | if self.meta_field is not None: 40 | timeseries["metaField"] = self.meta_field 41 | if self.granularity is not None: 42 | timeseries["granularity"] = self.granularity 43 | if self.bucket_max_span_seconds is not None: 44 | timeseries["bucketMaxSpanSeconds"] = self.bucket_max_span_seconds 45 | if self.bucket_rounding_second is not None: # Deprecated field 46 | timeseries["bucketRoundingSeconds"] = self.bucket_rounding_second 47 | if self.bucket_rounding_seconds is not None: 48 | timeseries["bucketRoundingSeconds"] = self.bucket_rounding_seconds 49 | res["timeseries"] = timeseries 50 | if self.expire_after_seconds is not None: 51 | res["expireAfterSeconds"] = self.expire_after_seconds 52 | return res 53 | -------------------------------------------------------------------------------- /tests/migrations/iterative/test_change_value.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.migrations.models import RunningDirections 7 | from beanie.odm.documents import Document 8 | from beanie.odm.models import InspectionStatuses 9 | 10 | 11 | class Tag(BaseModel): 12 | color: str 13 | name: str 14 | 15 | 16 | class Note(Document): 17 | title: str 18 | tag: Tag 19 | 20 | class Settings: 21 | name = "notes" 22 | 23 | 24 | @pytest.fixture() 25 | async def notes(db): 26 | await init_beanie(database=db, document_models=[Note]) 27 | await Note.delete_all() 28 | for i in range(10): 29 | note = Note(title=str(i), tag=Tag(name="test", color="red")) 30 | await note.insert() 31 | yield 32 | await Note.delete_all() 33 | 34 | 35 | async def test_migration_change_value(settings, notes, db): 36 | migration_settings = MigrationSettings( 37 | connection_uri=settings.mongodb_dsn, 38 | database_name=settings.mongodb_db_name, 39 | path="tests/migrations/migrations_for_test/change_value", 40 | ) 41 | await run_migrate(migration_settings) 42 | 43 | await init_beanie(database=db, document_models=[Note]) 44 | inspection = await Note.inspect_collection() 45 | assert inspection.status == InspectionStatuses.OK 46 | note = await Note.find_one({"title": "five"}) 47 | assert note is not None 48 | 49 | note = await Note.find_one({"title": "5"}) 50 | assert note is None 51 | 52 | migration_settings.direction = RunningDirections.BACKWARD 53 | await run_migrate(migration_settings) 54 | inspection = await Note.inspect_collection() 55 | assert inspection.status == InspectionStatuses.OK 56 | note = await Note.find_one({"title": "5"}) 57 | assert note is not None 58 | 59 | note = await Note.find_one({"title": "five"}) 60 | assert note is None 61 | -------------------------------------------------------------------------------- /tests/migrations/test_break.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import Indexed, init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.odm.documents import Document 7 | from beanie.odm.models import InspectionStatuses 8 | 9 | 10 | class Tag(BaseModel): 11 | color: str 12 | name: str 13 | 14 | 15 | class OldNote(Document): 16 | name: Indexed(str, unique=True) 17 | tag: Tag 18 | 19 | class Settings: 20 | name = "notes" 21 | 22 | 23 | class Note(Document): 24 | name: Indexed(str, unique=True) 25 | title: str 26 | tag: Tag 27 | 28 | class Settings: 29 | name = "notes" 30 | 31 | 32 | @pytest.fixture() 33 | async def notes(db): 34 | await init_beanie(database=db, document_models=[OldNote]) 35 | await OldNote.delete_all() 36 | for i in range(10): 37 | note = OldNote(name=str(i), tag=Tag(name="test", color="red")) 38 | await note.insert() 39 | yield 40 | await OldNote.delete_all() 41 | await OldNote.get_pymongo_collection().drop() 42 | await OldNote.get_pymongo_collection().drop_indexes() 43 | 44 | 45 | @pytest.mark.skip("TODO: Fix this test") 46 | async def test_migration_break(settings, notes, db): 47 | migration_settings = MigrationSettings( 48 | connection_uri=settings.mongodb_dsn, 49 | database_name=settings.mongodb_db_name, 50 | path="tests/migrations/migrations_for_test/break", 51 | ) 52 | with pytest.raises(Exception): 53 | await run_migrate(migration_settings) 54 | 55 | await init_beanie(database=db, document_models=[OldNote]) 56 | inspection = await OldNote.inspect_collection() 57 | assert inspection.status == InspectionStatuses.OK 58 | notes = await OldNote.get_pymongo_collection().find().to_list(length=100) 59 | names = set(n["name"] for n in notes) 60 | assert names == {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} 61 | for note in notes: 62 | assert "title" not in note 63 | -------------------------------------------------------------------------------- /tests/migrations/iterative/test_change_subfield.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.migrations.models import RunningDirections 7 | from beanie.odm.documents import Document 8 | from beanie.odm.models import InspectionStatuses 9 | 10 | 11 | class OldTag(BaseModel): 12 | color: str 13 | name: str 14 | 15 | 16 | class Tag(BaseModel): 17 | color: str 18 | title: str 19 | 20 | 21 | class OldNote(Document): 22 | title: str 23 | tag: OldTag 24 | 25 | class Settings: 26 | name = "notes" 27 | 28 | 29 | class Note(Document): 30 | title: str 31 | tag: Tag 32 | 33 | class Settings: 34 | name = "notes" 35 | 36 | 37 | @pytest.fixture() 38 | async def notes(db): 39 | await init_beanie(database=db, document_models=[OldNote]) 40 | await OldNote.delete_all() 41 | for i in range(10): 42 | note = OldNote(title=str(i), tag=OldTag(name="test", color="red")) 43 | await note.insert() 44 | yield 45 | await OldNote.delete_all() 46 | 47 | 48 | async def test_migration_change_subfield_value(settings, notes, db): 49 | migration_settings = MigrationSettings( 50 | connection_uri=settings.mongodb_dsn, 51 | database_name=settings.mongodb_db_name, 52 | path="tests/migrations/migrations_for_test/change_subfield", 53 | ) 54 | await run_migrate(migration_settings) 55 | 56 | await init_beanie(database=db, document_models=[Note]) 57 | inspection = await Note.inspect_collection() 58 | assert inspection.status == InspectionStatuses.OK 59 | note = await Note.find_one({}) 60 | assert note.tag.title == "test" 61 | 62 | migration_settings.direction = RunningDirections.BACKWARD 63 | await run_migrate(migration_settings) 64 | inspection = await OldNote.inspect_collection() 65 | assert inspection.status == InspectionStatuses.OK 66 | note = await OldNote.find_one({}) 67 | assert note.tag.name == "test" 68 | -------------------------------------------------------------------------------- /tests/migrations/iterative/test_pack_unpack.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.migrations.models import RunningDirections 7 | from beanie.odm.documents import Document 8 | from beanie.odm.models import InspectionStatuses 9 | 10 | 11 | class OldTag(BaseModel): 12 | color: str 13 | name: str 14 | 15 | 16 | class Tag(BaseModel): 17 | color: str 18 | name: str 19 | 20 | 21 | class OldNote(Document): 22 | title: str 23 | tag_name: str 24 | tag_color: str 25 | 26 | class Settings: 27 | name = "notes" 28 | 29 | 30 | class Note(Document): 31 | title: str 32 | tag: Tag 33 | 34 | class Settings: 35 | name = "notes" 36 | 37 | 38 | @pytest.fixture() 39 | async def notes(db): 40 | await init_beanie(database=db, document_models=[OldNote]) 41 | await OldNote.delete_all() 42 | for i in range(10): 43 | note = OldNote(title=str(i), tag_name="test", tag_color="red") 44 | await note.insert() 45 | yield 46 | await OldNote.delete_all() 47 | 48 | 49 | async def test_migration_pack_unpack(settings, notes, db): 50 | migration_settings = MigrationSettings( 51 | connection_uri=settings.mongodb_dsn, 52 | database_name=settings.mongodb_db_name, 53 | path="tests/migrations/migrations_for_test/pack_unpack", 54 | ) 55 | await run_migrate(migration_settings) 56 | 57 | await init_beanie(database=db, document_models=[Note]) 58 | inspection = await Note.inspect_collection() 59 | assert inspection.status == InspectionStatuses.OK 60 | note = await Note.find_one({}) 61 | assert note.tag.name == "test" 62 | 63 | migration_settings.direction = RunningDirections.BACKWARD 64 | await run_migrate(migration_settings) 65 | inspection = await OldNote.inspect_collection() 66 | assert inspection.status == InspectionStatuses.OK 67 | note = await OldNote.find_one({}) 68 | assert note.tag_name == "test" 69 | -------------------------------------------------------------------------------- /docs/tutorial/multi-model.md: -------------------------------------------------------------------------------- 1 | # Multi-model pattern 2 | 3 | Documents with different schemas could be stored in a single collection and managed correctly. 4 | `UnionDoc` class is used for this. 5 | 6 | It supports `find` and `aggregate` methods. 7 | For `find`, it will fetch all the found documents into the respective `Document` classes. 8 | 9 | Documents with `union_doc` in their settings can still be used in `find` and other queries. 10 | Queries of one such class will not see the data of others. 11 | 12 | ## Example 13 | 14 | Create documents: 15 | 16 | ```python 17 | from beanie import Document, UnionDoc 18 | 19 | 20 | class Parent(UnionDoc): # Union 21 | class Settings: 22 | name = "union_doc_collection" # Collection name 23 | class_id = "_class_id" # _class_id is the default beanie internal field used to filter children Documents 24 | 25 | 26 | class One(Document): 27 | int_field: int = 0 28 | shared: int = 0 29 | 30 | class Settings: 31 | name = "One" # Name used to filer union document 'One', default to class name 32 | union_doc = Parent 33 | 34 | 35 | class Two(Document): 36 | str_field: str = "test" 37 | shared: int = 0 38 | 39 | class Settings: 40 | union_doc = Parent 41 | ``` 42 | 43 | The schemas could be incompatible. 44 | 45 | Insert a document 46 | 47 | ```python 48 | await One().insert() 49 | await One().insert() 50 | await One().insert() 51 | 52 | await Two().insert() 53 | ``` 54 | 55 | Find all the documents of the first type: 56 | 57 | ```python 58 | docs = await One.all().to_list() 59 | print(len(docs)) 60 | 61 | >> 3 # It found only documents of class One 62 | ``` 63 | 64 | Of the second type: 65 | 66 | ```python 67 | docs = await Two.all().to_list() 68 | print(len(docs)) 69 | 70 | >> 1 # It found only documents of class Two 71 | ``` 72 | 73 | Of both: 74 | 75 | ```python 76 | docs = await Parent.all().to_list() 77 | print(len(docs)) 78 | 79 | >> 4 # instances of both classes will be in the output here 80 | ``` 81 | 82 | Aggregations will work separately for these two document classes too. 83 | -------------------------------------------------------------------------------- /tests/odm/test_lazy_parsing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.odm.utils.dump import get_dict 4 | from beanie.odm.utils.parsing import parse_obj 5 | from tests.odm.models import SampleLazyParsing 6 | 7 | 8 | @pytest.fixture 9 | async def docs(): 10 | for i in range(10): 11 | await SampleLazyParsing(i=i, s=str(i)).insert() 12 | 13 | 14 | class TestLazyParsing: 15 | async def test_find_all(self, docs): 16 | found_docs = await SampleLazyParsing.all(lazy_parse=True).to_list() 17 | saved_state = found_docs[0].get_saved_state() 18 | assert "_id" in saved_state 19 | del saved_state["_id"] 20 | assert found_docs[0].get_saved_state() == {} 21 | assert found_docs[0].i == 0 22 | assert found_docs[0].s == "0" 23 | assert found_docs[1]._store["i"] == 1 24 | assert found_docs[1]._store["s"] == "1" 25 | assert get_dict(found_docs[2])["i"] == 2 26 | assert get_dict(found_docs[2])["s"] == "2" 27 | 28 | async def test_find_many(self, docs): 29 | found_docs = await SampleLazyParsing.find( 30 | SampleLazyParsing.i <= 5, lazy_parse=True 31 | ).to_list() 32 | saved_state = found_docs[0].get_saved_state() 33 | assert "_id" in saved_state 34 | del saved_state["_id"] 35 | assert found_docs[0].get_saved_state() == {} 36 | assert found_docs[0].i == 0 37 | assert found_docs[0].s == "0" 38 | assert found_docs[1]._store["i"] == 1 39 | assert found_docs[1]._store["s"] == "1" 40 | assert get_dict(found_docs[2])["i"] == 2 41 | assert get_dict(found_docs[2])["s"] == "2" 42 | 43 | async def test_save_changes(self, docs): 44 | found_docs = await SampleLazyParsing.all(lazy_parse=True).to_list() 45 | doc = found_docs[0] 46 | doc.i = 1000 47 | await doc.save_changes() 48 | new_doc = await SampleLazyParsing.find_one(SampleLazyParsing.s == "0") 49 | assert new_doc.i == 1000 50 | 51 | async def test_default_list(self): 52 | res = parse_obj(SampleLazyParsing, {"i": "1", "s": "1"}) 53 | assert res.lst == [] 54 | -------------------------------------------------------------------------------- /beanie/__init__.py: -------------------------------------------------------------------------------- 1 | from beanie.migrations.controllers.free_fall import free_fall_migration 2 | from beanie.migrations.controllers.iterative import iterative_migration 3 | from beanie.odm.actions import ( 4 | After, 5 | Before, 6 | Delete, 7 | Insert, 8 | Replace, 9 | Save, 10 | SaveChanges, 11 | Update, 12 | ValidateOnSave, 13 | after_event, 14 | before_event, 15 | ) 16 | from beanie.odm.bulk import BulkWriter 17 | from beanie.odm.custom_types import DecimalAnnotation 18 | from beanie.odm.custom_types.bson.binary import BsonBinary 19 | from beanie.odm.documents import ( 20 | Document, 21 | DocumentWithSoftDelete, 22 | MergeStrategy, 23 | ) 24 | from beanie.odm.enums import SortDirection 25 | from beanie.odm.fields import ( 26 | BackLink, 27 | BeanieObjectId, 28 | DeleteRules, 29 | Indexed, 30 | Link, 31 | PydanticObjectId, 32 | WriteRules, 33 | ) 34 | from beanie.odm.queries.update import UpdateResponse 35 | from beanie.odm.settings.timeseries import Granularity, TimeSeriesConfig 36 | from beanie.odm.union_doc import UnionDoc 37 | from beanie.odm.utils.init import init_beanie 38 | from beanie.odm.views import View 39 | 40 | __version__ = "2.0.1" 41 | __all__ = [ 42 | # ODM 43 | "Document", 44 | "DocumentWithSoftDelete", 45 | "View", 46 | "UnionDoc", 47 | "init_beanie", 48 | "PydanticObjectId", 49 | "BeanieObjectId", 50 | "Indexed", 51 | "TimeSeriesConfig", 52 | "Granularity", 53 | "SortDirection", 54 | "MergeStrategy", 55 | # Actions 56 | "before_event", 57 | "after_event", 58 | "Insert", 59 | "Replace", 60 | "Save", 61 | "SaveChanges", 62 | "ValidateOnSave", 63 | "Delete", 64 | "Before", 65 | "After", 66 | "Update", 67 | # Bulk Write 68 | "BulkWriter", 69 | # Migrations 70 | "iterative_migration", 71 | "free_fall_migration", 72 | # Relations 73 | "Link", 74 | "BackLink", 75 | "WriteRules", 76 | "DeleteRules", 77 | # Custom Types 78 | "DecimalAnnotation", 79 | "BsonBinary", 80 | # UpdateResponse 81 | "UpdateResponse", 82 | ] 83 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_evaluation.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.evaluation import ( 2 | Expr, 3 | JsonSchema, 4 | Mod, 5 | RegEx, 6 | Text, 7 | Where, 8 | ) 9 | from tests.odm.models import Sample 10 | 11 | 12 | async def test_expr(): 13 | q = Expr({"a": "B"}) 14 | assert q == {"$expr": {"a": "B"}} 15 | 16 | 17 | async def test_json_schema(): 18 | q = JsonSchema({"a": "B"}) 19 | assert q == {"$jsonSchema": {"a": "B"}} 20 | 21 | 22 | async def test_mod(): 23 | q = Mod(Sample.integer, 3, 2) 24 | assert q == {"integer": {"$mod": [3, 2]}} 25 | 26 | 27 | async def test_regex(): 28 | q = RegEx(Sample.integer, "smth") 29 | assert q == {"integer": {"$regex": "smth"}} 30 | 31 | q = RegEx(Sample.integer, "smth", "options") 32 | assert q == {"integer": {"$regex": "smth", "$options": "options"}} 33 | 34 | 35 | async def test_text(): 36 | q = Text("something") 37 | assert q == { 38 | "$text": { 39 | "$search": "something", 40 | "$caseSensitive": False, 41 | "$diacriticSensitive": False, 42 | } 43 | } 44 | q = Text("something", case_sensitive=True) 45 | assert q == { 46 | "$text": { 47 | "$search": "something", 48 | "$caseSensitive": True, 49 | "$diacriticSensitive": False, 50 | } 51 | } 52 | q = Text("something", diacritic_sensitive=True) 53 | assert q == { 54 | "$text": { 55 | "$search": "something", 56 | "$caseSensitive": False, 57 | "$diacriticSensitive": True, 58 | } 59 | } 60 | q = Text("something", diacritic_sensitive=None) 61 | assert q == { 62 | "$text": { 63 | "$search": "something", 64 | "$caseSensitive": False, 65 | } 66 | } 67 | q = Text("something", language="test") 68 | assert q == { 69 | "$text": { 70 | "$search": "something", 71 | "$caseSensitive": False, 72 | "$diacriticSensitive": False, 73 | "$language": "test", 74 | } 75 | } 76 | 77 | 78 | async def test_where(): 79 | q = Where("test") 80 | assert q == {"$where": "test"} 81 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_comparison.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.comparison import ( 2 | GT, 3 | GTE, 4 | LT, 5 | LTE, 6 | NE, 7 | Eq, 8 | In, 9 | NotIn, 10 | ) 11 | from tests.odm.models import Sample 12 | 13 | 14 | async def test_eq(): 15 | q = Sample.integer == 1 16 | assert q == {"integer": 1} 17 | 18 | q = Eq(Sample.integer, 1) 19 | assert q == {"integer": 1} 20 | 21 | q = Eq("integer", 1) 22 | assert q == {"integer": 1} 23 | 24 | 25 | async def test_gt(): 26 | q = Sample.integer > 1 27 | assert q == {"integer": {"$gt": 1}} 28 | 29 | q = GT(Sample.integer, 1) 30 | assert q == {"integer": {"$gt": 1}} 31 | 32 | q = GT("integer", 1) 33 | assert q == {"integer": {"$gt": 1}} 34 | 35 | 36 | async def test_gte(): 37 | q = Sample.integer >= 1 38 | assert q == {"integer": {"$gte": 1}} 39 | 40 | q = GTE(Sample.integer, 1) 41 | assert q == {"integer": {"$gte": 1}} 42 | 43 | q = GTE("integer", 1) 44 | assert q == {"integer": {"$gte": 1}} 45 | 46 | 47 | async def test_in(): 48 | q = In(Sample.integer, [1]) 49 | assert q == {"integer": {"$in": [1]}} 50 | 51 | q = In(Sample.integer, [1]) 52 | assert q == {"integer": {"$in": [1]}} 53 | 54 | 55 | async def test_lt(): 56 | q = Sample.integer < 1 57 | assert q == {"integer": {"$lt": 1}} 58 | 59 | q = LT(Sample.integer, 1) 60 | assert q == {"integer": {"$lt": 1}} 61 | 62 | q = LT("integer", 1) 63 | assert q == {"integer": {"$lt": 1}} 64 | 65 | 66 | async def test_lte(): 67 | q = Sample.integer <= 1 68 | assert q == {"integer": {"$lte": 1}} 69 | 70 | q = LTE(Sample.integer, 1) 71 | assert q == {"integer": {"$lte": 1}} 72 | 73 | q = LTE("integer", 1) 74 | assert q == {"integer": {"$lte": 1}} 75 | 76 | 77 | async def test_ne(): 78 | q = Sample.integer != 1 79 | assert q == {"integer": {"$ne": 1}} 80 | 81 | q = NE(Sample.integer, 1) 82 | assert q == {"integer": {"$ne": 1}} 83 | 84 | q = NE("integer", 1) 85 | assert q == {"integer": {"$ne": 1}} 86 | 87 | 88 | async def test_nin(): 89 | q = NotIn(Sample.integer, [1]) 90 | assert q == {"integer": {"$nin": [1]}} 91 | 92 | q = NotIn(Sample.integer, [1]) 93 | assert q == {"integer": {"$nin": [1]}} 94 | -------------------------------------------------------------------------------- /beanie/odm/views.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, ClassVar, Dict, Optional, Union 3 | 4 | from pydantic import BaseModel 5 | 6 | from beanie.exceptions import ViewWasNotInitialized 7 | from beanie.odm.fields import Link, LinkInfo 8 | from beanie.odm.interfaces.aggregate import AggregateInterface 9 | from beanie.odm.interfaces.detector import DetectionInterface, ModelType 10 | from beanie.odm.interfaces.find import FindInterface 11 | from beanie.odm.interfaces.getters import OtherGettersInterface 12 | from beanie.odm.settings.view import ViewSettings 13 | 14 | 15 | class View( 16 | BaseModel, 17 | FindInterface, 18 | AggregateInterface, 19 | OtherGettersInterface, 20 | DetectionInterface, 21 | ): 22 | """ 23 | What is needed: 24 | 25 | Source collection or view 26 | pipeline 27 | 28 | """ 29 | 30 | # Relations 31 | _link_fields: ClassVar[Optional[Dict[str, LinkInfo]]] = None 32 | 33 | # Settings 34 | _settings: ClassVar[ViewSettings] 35 | 36 | @classmethod 37 | def get_settings(cls) -> ViewSettings: 38 | """ 39 | Get view settings, which was created on 40 | the initialization step 41 | 42 | :return: ViewSettings class 43 | """ 44 | if cls._settings is None: 45 | raise ViewWasNotInitialized 46 | return cls._settings 47 | 48 | async def fetch_link(self, field: Union[str, Any]): 49 | ref_obj = getattr(self, field, None) 50 | if isinstance(ref_obj, Link): 51 | value = await ref_obj.fetch(fetch_links=True) 52 | setattr(self, field, value) 53 | if isinstance(ref_obj, list) and ref_obj: 54 | values = await Link.fetch_list(ref_obj, fetch_links=True) 55 | setattr(self, field, values) 56 | 57 | async def fetch_all_links(self): 58 | coros = [] 59 | link_fields = self.get_link_fields() 60 | if link_fields is not None: 61 | for ref in link_fields.values(): 62 | coros.append(self.fetch_link(ref.field_name)) # TODO lists 63 | await asyncio.gather(*coros) 64 | 65 | @classmethod 66 | def get_link_fields(cls) -> Optional[Dict[str, LinkInfo]]: 67 | return cls._link_fields 68 | 69 | @classmethod 70 | def get_model_type(cls) -> ModelType: 71 | return ModelType.View 72 | -------------------------------------------------------------------------------- /tests/fastapi/routes.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fastapi import APIRouter, Body, status 4 | from pydantic import BaseModel 5 | 6 | from beanie import PydanticObjectId, WriteRules 7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 8 | from tests.fastapi.models import House, HouseAPI, Person, WindowAPI 9 | 10 | house_router = APIRouter() 11 | if not IS_PYDANTIC_V2: 12 | from fastapi.encoders import ENCODERS_BY_TYPE 13 | from pydantic.json import ENCODERS_BY_TYPE as PYDANTIC_ENCODERS_BY_TYPE 14 | 15 | ENCODERS_BY_TYPE.update(PYDANTIC_ENCODERS_BY_TYPE) 16 | 17 | 18 | class WindowInput(BaseModel): 19 | id: PydanticObjectId 20 | 21 | 22 | @house_router.post("/windows/", response_model=WindowAPI) 23 | async def create_window(window: WindowAPI): 24 | await window.create() 25 | return window 26 | 27 | 28 | @house_router.post("/windows_2/") 29 | async def create_window_2(window: WindowAPI): 30 | return await window.save() 31 | 32 | 33 | @house_router.get("/windows/{id}", response_model=Optional[WindowAPI]) 34 | async def get_window(id: PydanticObjectId): 35 | return await WindowAPI.get(id) 36 | 37 | 38 | @house_router.post("/houses/", response_model=HouseAPI) 39 | async def create_house(window: WindowAPI): 40 | house = HouseAPI(name="test_name", windows=[window]) 41 | await house.insert(link_rule=WriteRules.WRITE) 42 | return house 43 | 44 | 45 | @house_router.post("/houses_with_window_link/", response_model=HouseAPI) 46 | async def create_houses_with_window_link(window: WindowInput): 47 | validator = ( 48 | HouseAPI.model_validate if IS_PYDANTIC_V2 else HouseAPI.parse_obj 49 | ) 50 | house = validator( 51 | dict(name="test_name", windows=[WindowAPI.link_from_id(window.id)]) 52 | ) 53 | await house.insert(link_rule=WriteRules.WRITE) 54 | return house 55 | 56 | 57 | @house_router.post("/houses_2/", response_model=HouseAPI) 58 | async def create_houses_2(house: HouseAPI): 59 | await house.insert(link_rule=WriteRules.WRITE) 60 | return house 61 | 62 | 63 | @house_router.post( 64 | "/house", 65 | response_model=House, 66 | status_code=status.HTTP_201_CREATED, 67 | ) 68 | async def create_house_new(house: House = Body(...)): 69 | person = Person(name="Bob") 70 | house.owner = person 71 | await house.save(link_rule=WriteRules.WRITE) 72 | await house.sync() 73 | return house 74 | -------------------------------------------------------------------------------- /beanie/odm/queries/cursor.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import ( 3 | Any, 4 | Dict, 5 | Generic, 6 | List, 7 | Optional, 8 | Type, 9 | TypeVar, 10 | cast, 11 | ) 12 | 13 | from pydantic.main import BaseModel 14 | 15 | from beanie.odm.utils.parsing import parse_obj 16 | 17 | CursorResultType = TypeVar("CursorResultType") 18 | 19 | 20 | class BaseCursorQuery(Generic[CursorResultType]): 21 | """ 22 | BaseCursorQuery class. Wrapper over AsyncCursor, 23 | which parse result with model 24 | """ 25 | 26 | cursor = None 27 | lazy_parse = False 28 | 29 | @abstractmethod 30 | def get_projection_model(self) -> Optional[Type[BaseModel]]: ... 31 | 32 | @abstractmethod 33 | async def get_cursor(self): ... 34 | 35 | def _cursor_params(self): ... 36 | 37 | def __aiter__(self): 38 | return self 39 | 40 | async def __anext__(self) -> CursorResultType: 41 | if self.cursor is None: 42 | self.cursor = await self.get_cursor() 43 | next_item = await self.cursor.__anext__() 44 | projection = self.get_projection_model() 45 | if projection is None: 46 | return next_item 47 | return parse_obj(projection, next_item, lazy_parse=self.lazy_parse) # type: ignore 48 | 49 | @abstractmethod 50 | def _get_cache(self) -> List[Dict[str, Any]]: ... 51 | 52 | @abstractmethod 53 | def _set_cache(self, data): ... 54 | 55 | async def to_list( 56 | self, length: Optional[int] = None 57 | ) -> List[CursorResultType]: # noqa 58 | """ 59 | Get list of documents 60 | 61 | :param length: Optional[int] - length of the list 62 | :return: Union[List[BaseModel], List[Dict[str, Any]]] 63 | """ 64 | cursor = await self.get_cursor() 65 | pymongo_list: List[Dict[str, Any]] = self._get_cache() 66 | 67 | if pymongo_list is None: 68 | pymongo_list = await cursor.to_list(length) 69 | self._set_cache(pymongo_list) 70 | projection = self.get_projection_model() 71 | if projection is not None: 72 | return cast( 73 | List[CursorResultType], 74 | [ 75 | parse_obj(projection, i, lazy_parse=self.lazy_parse) 76 | for i in pymongo_list 77 | ], 78 | ) 79 | return cast(List[CursorResultType], pymongo_list) 80 | -------------------------------------------------------------------------------- /tests/odm/query/test_update_methods.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.general import Max 2 | from beanie.odm.queries.update import UpdateMany, UpdateQuery 3 | from tests.odm.models import Sample 4 | 5 | 6 | async def test_set(session): 7 | q = Sample.find_many(Sample.integer == 1).set( 8 | {Sample.integer: 100}, session=session 9 | ) 10 | 11 | assert isinstance(q, UpdateQuery) 12 | assert isinstance(q, UpdateMany) 13 | assert q.session == session 14 | 15 | assert q.update_query == {"$set": {"integer": 100}} 16 | 17 | q = ( 18 | Sample.find_many(Sample.integer == 1) 19 | .update(Max({Sample.integer: 10})) 20 | .set({Sample.integer: 100}) 21 | ) 22 | 23 | assert isinstance(q, UpdateQuery) 24 | assert isinstance(q, UpdateMany) 25 | 26 | assert q.update_query == { 27 | "$max": {"integer": 10}, 28 | "$set": {"integer": 100}, 29 | } 30 | 31 | 32 | async def test_current_date(session): 33 | q = Sample.find_many(Sample.integer == 1).current_date( 34 | {Sample.timestamp: "timestamp"}, session=session 35 | ) 36 | 37 | assert isinstance(q, UpdateQuery) 38 | assert isinstance(q, UpdateMany) 39 | assert q.session == session 40 | 41 | assert q.update_query == {"$currentDate": {"timestamp": "timestamp"}} 42 | 43 | q = ( 44 | Sample.find_many(Sample.integer == 1) 45 | .update(Max({Sample.integer: 10})) 46 | .current_date({Sample.timestamp: "timestamp"}) 47 | ) 48 | 49 | assert isinstance(q, UpdateQuery) 50 | assert isinstance(q, UpdateMany) 51 | 52 | assert q.update_query == { 53 | "$max": {"integer": 10}, 54 | "$currentDate": {"timestamp": "timestamp"}, 55 | } 56 | 57 | 58 | async def test_inc(session): 59 | q = Sample.find_many(Sample.integer == 1).inc( 60 | {Sample.integer: 100}, session=session 61 | ) 62 | 63 | assert isinstance(q, UpdateQuery) 64 | assert isinstance(q, UpdateMany) 65 | assert q.session == session 66 | 67 | assert q.update_query == {"$inc": {"integer": 100}} 68 | 69 | q = ( 70 | Sample.find_many(Sample.integer == 1) 71 | .update(Max({Sample.integer: 10})) 72 | .inc({Sample.integer: 100}) 73 | ) 74 | 75 | assert isinstance(q, UpdateQuery) 76 | assert isinstance(q, UpdateMany) 77 | 78 | assert q.update_query == { 79 | "$max": {"integer": 10}, 80 | "$inc": {"integer": 100}, 81 | } 82 | -------------------------------------------------------------------------------- /tests/fastapi/test_api.py: -------------------------------------------------------------------------------- 1 | from tests.fastapi.models import WindowAPI 2 | 3 | 4 | async def test_create_window(api_client): 5 | payload = {"x": 10, "y": 20} 6 | resp = await api_client.post("/v1/windows/", json=payload) 7 | resp_json = resp.json() 8 | assert resp_json["x"] == 10 9 | assert resp_json["y"] == 20 10 | 11 | 12 | async def test_get_window(api_client): 13 | payload = {"x": 10, "y": 20} 14 | resp = await api_client.post("/v1/windows/", json=payload) 15 | data1 = resp.json() 16 | window_id = data1["_id"] 17 | resp2 = await api_client.get(f"/v1/windows/{window_id}") 18 | data2 = resp2.json() 19 | assert data2 == data1 20 | 21 | 22 | async def test_create_house(api_client): 23 | payload = {"x": 10, "y": 20} 24 | resp = await api_client.post("/v1/houses/", json=payload) 25 | resp_json = resp.json() 26 | assert len(resp_json["windows"]) == 1 27 | 28 | 29 | async def test_create_house_with_window_link(api_client): 30 | payload = {"x": 10, "y": 20} 31 | resp = await api_client.post("/v1/windows/", json=payload) 32 | 33 | window_id = resp.json()["_id"] 34 | 35 | payload = {"id": window_id} 36 | resp = await api_client.post("/v1/houses_with_window_link/", json=payload) 37 | resp_json = resp.json() 38 | assert resp_json["windows"][0]["collection"] == "WindowAPI" 39 | 40 | 41 | async def test_create_house_2(api_client): 42 | window = WindowAPI(x=10, y=10) 43 | await window.insert() 44 | payload = {"name": "TEST", "windows": [str(window.id)]} 45 | resp = await api_client.post("/v1/houses_2/", json=payload) 46 | resp_json = resp.json() 47 | assert len(resp_json["windows"]) == 1 48 | 49 | 50 | async def test_revision_id(api_client): 51 | payload = {"x": 10, "y": 20} 52 | resp = await api_client.post("/v1/windows_2/", json=payload) 53 | resp_json = resp.json() 54 | assert "revision_id" not in resp_json 55 | assert resp_json == {"x": 10, "y": 20, "_id": resp_json["_id"]} 56 | 57 | 58 | async def test_create_house_new(api_client): 59 | payload = { 60 | "name": "FreshHouse", 61 | "owner": {"name": "will_be_overridden_to_Bob"}, 62 | } 63 | resp = await api_client.post("/v1/house", json=payload) 64 | resp_json = resp.json() 65 | 66 | assert resp_json["name"] == payload["name"] 67 | assert resp_json["owner"]["name"] == payload["owner"]["name"][-3:] 68 | assert resp_json["owner"]["house"]["collection"] == "House" 69 | -------------------------------------------------------------------------------- /beanie/operators.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.array import All, ElemMatch, Size 2 | from beanie.odm.operators.find.bitwise import ( 3 | BitsAllClear, 4 | BitsAllSet, 5 | BitsAnyClear, 6 | BitsAnySet, 7 | ) 8 | from beanie.odm.operators.find.comparison import ( 9 | GT, 10 | GTE, 11 | LT, 12 | LTE, 13 | NE, 14 | Eq, 15 | In, 16 | NotIn, 17 | ) 18 | from beanie.odm.operators.find.element import Exists, Type 19 | from beanie.odm.operators.find.evaluation import ( 20 | Expr, 21 | JsonSchema, 22 | Mod, 23 | RegEx, 24 | Text, 25 | Where, 26 | ) 27 | from beanie.odm.operators.find.geospatial import ( 28 | Box, 29 | GeoIntersects, 30 | GeoWithin, 31 | GeoWithinTypes, 32 | Near, 33 | NearSphere, 34 | ) 35 | from beanie.odm.operators.find.logical import And, Nor, Not, Or 36 | from beanie.odm.operators.update.array import ( 37 | AddToSet, 38 | Pop, 39 | Pull, 40 | PullAll, 41 | Push, 42 | ) 43 | from beanie.odm.operators.update.bitwise import Bit 44 | from beanie.odm.operators.update.general import ( 45 | CurrentDate, 46 | Inc, 47 | Max, 48 | Min, 49 | Mul, 50 | Rename, 51 | Set, 52 | SetOnInsert, 53 | Unset, 54 | ) 55 | 56 | __all__ = [ 57 | # Find 58 | # Array 59 | "All", 60 | "ElemMatch", 61 | "Size", 62 | # Bitwise 63 | "BitsAllClear", 64 | "BitsAllSet", 65 | "BitsAnyClear", 66 | "BitsAnySet", 67 | # Comparison 68 | "Eq", 69 | "GT", 70 | "GTE", 71 | "In", 72 | "NotIn", 73 | "LT", 74 | "LTE", 75 | "NE", 76 | # Element 77 | "Exists", 78 | "Type", 79 | "Type", 80 | # Evaluation 81 | "Expr", 82 | "JsonSchema", 83 | "Mod", 84 | "RegEx", 85 | "Text", 86 | "Where", 87 | # Geospatial 88 | "GeoIntersects", 89 | "GeoWithinTypes", 90 | "GeoWithin", 91 | "Box", 92 | "Near", 93 | "NearSphere", 94 | # Logical 95 | "Or", 96 | "And", 97 | "Nor", 98 | "Not", 99 | # Update 100 | # Array 101 | "AddToSet", 102 | "Pop", 103 | "Pull", 104 | "Push", 105 | "PullAll", 106 | # Bitwise 107 | "Bit", 108 | # General 109 | "Set", 110 | "CurrentDate", 111 | "Inc", 112 | "Min", 113 | "Max", 114 | "Mul", 115 | "Rename", 116 | "SetOnInsert", 117 | "Unset", 118 | ] 119 | -------------------------------------------------------------------------------- /tests/odm/documents/test_aggregate.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | from pydantic.main import BaseModel 3 | 4 | from tests.odm.models import DocumentTestModel 5 | 6 | 7 | async def test_aggregate(documents): 8 | await documents(4, "uno") 9 | await documents(2, "dos") 10 | await documents(1, "cuatro") 11 | result = await DocumentTestModel.aggregate( 12 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 13 | ).to_list() 14 | assert len(result) == 3 15 | assert {"_id": "cuatro", "total": 0} in result 16 | assert {"_id": "dos", "total": 1} in result 17 | assert {"_id": "uno", "total": 6} in result 18 | 19 | 20 | async def test_aggregate_with_filter(documents): 21 | await documents(4, "uno") 22 | await documents(2, "dos") 23 | await documents(1, "cuatro") 24 | result = ( 25 | await DocumentTestModel.find(DocumentTestModel.test_int >= 1) 26 | .aggregate( 27 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 28 | ) 29 | .to_list() 30 | ) 31 | assert len(result) == 2 32 | assert {"_id": "dos", "total": 1} in result 33 | assert {"_id": "uno", "total": 6} in result 34 | 35 | 36 | async def test_aggregate_with_item_model(documents): 37 | class OutputItem(BaseModel): 38 | id: str = Field(None, alias="_id") 39 | total: int 40 | 41 | await documents(4, "uno") 42 | await documents(2, "dos") 43 | await documents(1, "cuatro") 44 | ids = [] 45 | async for i in DocumentTestModel.aggregate( 46 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}], 47 | projection_model=OutputItem, 48 | ): 49 | if i.id == "cuatro": 50 | assert i.total == 0 51 | elif i.id == "dos": 52 | assert i.total == 1 53 | elif i.id == "uno": 54 | assert i.total == 6 55 | else: 56 | raise KeyError 57 | ids.append(i.id) 58 | assert set(ids) == {"cuatro", "dos", "uno"} 59 | 60 | 61 | async def test_aggregate_with_session(documents, session): 62 | await documents(4, "uno") 63 | await documents(2, "dos") 64 | await documents(1, "cuatro") 65 | result = await DocumentTestModel.aggregate( 66 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}], 67 | session=session, 68 | ).to_list() 69 | assert len(result) == 3 70 | assert {"_id": "cuatro", "total": 0} in result 71 | assert {"_id": "dos", "total": 1} in result 72 | assert {"_id": "uno", "total": 6} in result 73 | -------------------------------------------------------------------------------- /tests/odm/test_expression_fields.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.enums import SortDirection 2 | from beanie.odm.operators.find.comparison import In, NotIn 3 | from tests.odm.models import Sample 4 | 5 | 6 | def test_nesting(): 7 | assert Sample.id == "_id" 8 | 9 | q = Sample.find_many(Sample.integer == 1) 10 | assert q.get_filter_query() == {"integer": 1} 11 | assert Sample.integer == "integer" 12 | 13 | q = Sample.find_many(Sample.nested.integer == 1) 14 | assert q.get_filter_query() == {"nested.integer": 1} 15 | assert Sample.nested.integer == "nested.integer" 16 | 17 | q = Sample.find_many(Sample.union.s == "test") 18 | assert q.get_filter_query() == {"union.s": "test"} 19 | assert Sample.union.s == "union.s" 20 | 21 | q = Sample.find_many(Sample.nested.optional == None) # noqa 22 | assert q.get_filter_query() == {"nested.optional": None} 23 | assert Sample.nested.optional == "nested.optional" 24 | 25 | q = Sample.find_many(Sample.nested.integer == 1).find_many( 26 | Sample.nested.union.s == "test" 27 | ) 28 | assert q.get_filter_query() == { 29 | "$and": [{"nested.integer": 1}, {"nested.union.s": "test"}] 30 | } 31 | 32 | 33 | def test_eq(): 34 | q = Sample.find_many(Sample.integer == 1) 35 | assert q.get_filter_query() == {"integer": 1} 36 | 37 | 38 | def test_gt(): 39 | q = Sample.find_many(Sample.integer > 1) 40 | assert q.get_filter_query() == {"integer": {"$gt": 1}} 41 | 42 | 43 | def test_gte(): 44 | q = Sample.find_many(Sample.integer >= 1) 45 | assert q.get_filter_query() == {"integer": {"$gte": 1}} 46 | 47 | 48 | def test_in(): 49 | q = Sample.find_many(In(Sample.integer, [1, 2, 3, 4])) 50 | assert dict(q.get_filter_query()) == {"integer": {"$in": [1, 2, 3, 4]}} 51 | 52 | 53 | def test_lt(): 54 | q = Sample.find_many(Sample.integer < 1) 55 | assert q.get_filter_query() == {"integer": {"$lt": 1}} 56 | 57 | 58 | def test_lte(): 59 | q = Sample.find_many(Sample.integer <= 1) 60 | assert q.get_filter_query() == {"integer": {"$lte": 1}} 61 | 62 | 63 | def test_ne(): 64 | q = Sample.find_many(Sample.integer != 1) 65 | assert q.get_filter_query() == {"integer": {"$ne": 1}} 66 | 67 | 68 | def test_nin(): 69 | q = Sample.find_many(NotIn(Sample.integer, [1, 2, 3, 4])) 70 | assert dict(q.get_filter_query()) == {"integer": {"$nin": [1, 2, 3, 4]}} 71 | 72 | 73 | def test_pos(): 74 | q = +Sample.integer 75 | assert q == ("integer", SortDirection.ASCENDING) 76 | 77 | 78 | def test_neg(): 79 | q = -Sample.integer 80 | assert q == ("integer", SortDirection.DESCENDING) 81 | -------------------------------------------------------------------------------- /docs/publishing.md: -------------------------------------------------------------------------------- 1 | # Publishing a New Version of Beanie 2 | 3 | This guide provides step-by-step instructions on how to prepare and publish a new version of Beanie. Before starting, ensure that you have the necessary permissions to update the repository. 4 | 5 | ## 1. Prepare a New Version PR 6 | 7 | To publish a new version of Beanie, you need to create a pull request (PR) with the following updates: 8 | 9 | ### 1.1 Update the Version in `pyproject.toml` 10 | 11 | 1. Open the [`pyproject.toml`](https://github.com/BeanieODM/beanie/blob/main/pyproject.toml) file. 12 | 2. Update the `version` field to the new version number. 13 | 14 | ### 1.2 Update the Version in the `__init__.py` File 15 | 16 | 1. Open the [`__init__.py`](https://github.com/BeanieODM/beanie/blob/main/beanie/__init__.py) file. 17 | 2. Update the `__version__` variable to the new version number. 18 | 19 | ### 1.3 Update the Changelog 20 | 21 | To update the changelog, follow these steps: 22 | 23 | #### 1.3.1 Set the Version in the Changelog Script 24 | 25 | 1. Open the [`scripts/generate_changelog.py`](https://github.com/BeanieODM/beanie/blob/main/scripts/generate_changelog.py) file. 26 | 2. Set the `current_version` to the current version and `new_version` to the new version in the script. 27 | 28 | #### 1.3.2 Run the Changelog Script 29 | 30 | 1. Run the following command to generate the updated changelog: 31 | 32 | ```bash 33 | python scripts/generate_changelog.py 34 | ``` 35 | 36 | 2. The script will generate the changelog for the new version. 37 | 38 | #### 1.3.3 Update the Changelog File 39 | 40 | 1. Open the [`changelog.md`](https://github.com/BeanieODM/beanie/blob/main/docs/changelog.md) file. 41 | 2. Copy the generated changelog and paste it at the top of the `changelog.md` file. 42 | 43 | ### 1.4 Create and Submit the PR 44 | 45 | Once you have made the necessary updates, create a PR with a descriptive title and summary of the changes. Ensure that all checks pass before merging the PR. 46 | 47 | ## 2. Publishing the Version 48 | 49 | After the PR has been merged, respective GH action will publish it to the PyPI. 50 | 51 | ## 3. Create a Git Tag and GitHub Release 52 | 53 | After the version has been published: 54 | 55 | 1. Pull the latest changes from the `master` branch: 56 | ```bash 57 | git pull origin master 58 | ``` 59 | 60 | 2. Create a new Git tag with the version number: 61 | ```bash 62 | git tag -a v1.xx.y -m "Release v1.xx.y" 63 | ``` 64 | 65 | 3. Push the tag to the remote repository: 66 | ```bash 67 | git push origin v1.xx.y 68 | ``` 69 | 70 | 4. Create a new release on GitHub using the GitHub interface. -------------------------------------------------------------------------------- /docs/tutorial/views.md: -------------------------------------------------------------------------------- 1 | # Views 2 | 3 | Virtual views are aggregation pipelines stored in MongoDB that act as collections for reading operations. 4 | You can use the `View` class the same way as `Document` for `find` and `aggregate` operations. 5 | 6 | ## Here are some examples. 7 | 8 | Create a view: 9 | 10 | ```python 11 | from pydantic import Field 12 | 13 | from beanie import Document, View 14 | 15 | 16 | class Bike(Document): 17 | type: str 18 | frame_size: int 19 | is_new: bool 20 | 21 | 22 | class Metrics(View): 23 | type: str = Field(alias="_id") 24 | number: int 25 | new: int 26 | 27 | class Settings: 28 | source = Bike 29 | pipeline = [ 30 | { 31 | "$group": { 32 | "_id": "$type", 33 | "number": {"$sum": 1}, 34 | "new": {"$sum": {"$cond": ["$is_new", 1, 0]}} 35 | } 36 | }, 37 | ] 38 | 39 | ``` 40 | 41 | Initialize Beanie: 42 | 43 | ```python 44 | from pymongo import AsyncMongoClient 45 | 46 | from beanie import init_beanie 47 | 48 | 49 | async def main(): 50 | uri = "mongodb://beanie:beanie@localhost:27017" 51 | client = AsyncMongoClient(uri) 52 | db = client.bikes 53 | 54 | await init_beanie( 55 | database=db, 56 | document_models=[Bike, Metrics], 57 | recreate_views=True, 58 | ) 59 | ``` 60 | 61 | Create bikes: 62 | 63 | ```python 64 | await Bike(type="Mountain", frame_size=54, is_new=True).insert() 65 | await Bike(type="Mountain", frame_size=60, is_new=False).insert() 66 | await Bike(type="Road", frame_size=52, is_new=True).insert() 67 | await Bike(type="Road", frame_size=54, is_new=True).insert() 68 | await Bike(type="Road", frame_size=58, is_new=False).insert() 69 | ``` 70 | 71 | Find metrics for `type == "Road"` 72 | 73 | ```python 74 | results = await Metrics.find(Metrics.type == "Road").to_list() 75 | print(results) 76 | 77 | >> [Metrics(type='Road', number=3, new=2)] 78 | ``` 79 | 80 | Aggregate over metrics to get the count of all the new bikes: 81 | 82 | ```python 83 | results = await Metrics.aggregate([{ 84 | "$group": { 85 | "_id": None, 86 | "new_total": {"$sum": "$new"} 87 | } 88 | }]).to_list() 89 | 90 | print(results) 91 | 92 | >> [{'_id': None, 'new_total': 3}] 93 | ``` 94 | 95 | A better result can be achieved by using find query aggregation syntactic sugar: 96 | 97 | ```python 98 | results = await Metrics.all().sum(Metrics.new) 99 | 100 | print(results) 101 | 102 | >> 3 103 | ``` 104 | -------------------------------------------------------------------------------- /beanie/odm/utils/typing.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import sys 3 | from typing import Any, Dict, Optional, Tuple, Type 4 | 5 | from beanie.odm.fields import IndexedAnnotation 6 | 7 | from .pydantic import IS_PYDANTIC_V2, get_field_type 8 | 9 | if sys.version_info >= (3, 8): 10 | from typing import get_args, get_origin 11 | else: 12 | from typing_extensions import get_args, get_origin 13 | 14 | 15 | def extract_id_class(annotation) -> Type[Any]: 16 | if get_origin(annotation) is not None: 17 | try: 18 | annotation = next( 19 | arg for arg in get_args(annotation) if arg is not type(None) 20 | ) 21 | except StopIteration: 22 | annotation = None 23 | if inspect.isclass(annotation): 24 | return annotation 25 | raise ValueError("Unknown annotation: {}".format(annotation)) 26 | 27 | 28 | def get_index_attributes(field) -> Optional[Tuple[int, Dict[str, Any]]]: 29 | """Gets the index attributes from the field, if it is indexed. 30 | 31 | :param field: The field to get the index attributes from. 32 | 33 | :return: The index attributes, if the field is indexed. Otherwise, None. 34 | """ 35 | # For fields that are directly typed with `Indexed()`, the type will have 36 | # an `_indexed` attribute. 37 | field_type = get_field_type(field) 38 | if hasattr(field_type, "_indexed"): 39 | return getattr(field_type, "_indexed", None) 40 | 41 | # For fields that are use `Indexed` within `Annotated`, the field will have 42 | # metadata that might contain an `IndexedAnnotation` instance. 43 | if IS_PYDANTIC_V2: 44 | # In Pydantic 2, the field has a `metadata` attribute with 45 | # the annotations. 46 | metadata = getattr(field, "metadata", None) 47 | elif hasattr(field, "annotation") and hasattr( 48 | field.annotation, "__metadata__" 49 | ): 50 | # In Pydantic 1, the field has an `annotation` attribute with the 51 | # type assigned to the field. If the type is annotated, it will 52 | # have a `__metadata__` attribute with the annotations. 53 | metadata = field.annotation.__metadata__ 54 | else: 55 | return None 56 | 57 | if metadata is None: 58 | return None 59 | 60 | try: 61 | iter(metadata) 62 | except TypeError: 63 | return None 64 | 65 | indexed_annotation = next( 66 | ( 67 | annotation 68 | for annotation in metadata 69 | if isinstance(annotation, IndexedAnnotation) 70 | ), 71 | None, 72 | ) 73 | 74 | return getattr(indexed_annotation, "_indexed", None) 75 | -------------------------------------------------------------------------------- /tests/odm/documents/test_multi_model.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import ( 2 | DocumentMultiModelOne, 3 | DocumentMultiModelTwo, 4 | DocumentUnion, 5 | ) 6 | 7 | 8 | class TestMultiModel: 9 | async def test_multi_model(self): 10 | doc_1 = await DocumentMultiModelOne().insert() 11 | doc_2 = await DocumentMultiModelTwo().insert() 12 | 13 | new_doc_1 = await DocumentMultiModelOne.get(doc_1.id) 14 | new_doc_2 = await DocumentMultiModelTwo.get(doc_2.id) 15 | 16 | assert new_doc_1 is not None 17 | assert new_doc_2 is not None 18 | 19 | new_doc_1 = await DocumentMultiModelTwo.get(doc_1.id) 20 | new_doc_2 = await DocumentMultiModelOne.get(doc_2.id) 21 | 22 | assert new_doc_1 is None 23 | assert new_doc_2 is None 24 | 25 | new_docs_1 = await DocumentMultiModelOne.find({}).to_list() 26 | new_docs_2 = await DocumentMultiModelTwo.find({}).to_list() 27 | 28 | assert len(new_docs_1) == 1 29 | assert len(new_docs_2) == 1 30 | 31 | await DocumentMultiModelOne.update_all({"$set": {"shared": 100}}) 32 | 33 | new_doc_1 = await DocumentMultiModelOne.get(doc_1.id) 34 | new_doc_2 = await DocumentMultiModelTwo.get(doc_2.id) 35 | 36 | assert new_doc_1.shared == 100 37 | assert new_doc_2.shared == 0 38 | 39 | async def test_union_doc(self): 40 | await DocumentMultiModelOne().insert() 41 | await DocumentMultiModelTwo().insert() 42 | await DocumentMultiModelOne().insert() 43 | await DocumentMultiModelTwo().insert() 44 | 45 | docs = await DocumentUnion.all().to_list() 46 | assert isinstance(docs[0], DocumentMultiModelOne) 47 | assert isinstance(docs[1], DocumentMultiModelTwo) 48 | assert isinstance(docs[2], DocumentMultiModelOne) 49 | assert isinstance(docs[3], DocumentMultiModelTwo) 50 | 51 | async def test_union_doc_aggregation(self): 52 | await DocumentMultiModelOne().insert() 53 | await DocumentMultiModelTwo().insert() 54 | await DocumentMultiModelOne().insert() 55 | await DocumentMultiModelTwo().insert() 56 | 57 | docs = await DocumentUnion.aggregate( 58 | [{"$match": {"$expr": {"$eq": ["$int_filed", 0]}}}] 59 | ).to_list() 60 | assert len(docs) == 2 61 | 62 | async def test_union_doc_link(self): 63 | doc_1 = await DocumentMultiModelOne().insert() 64 | await DocumentMultiModelTwo(linked_doc=doc_1).insert() 65 | 66 | docs = await DocumentMultiModelTwo.find({}, fetch_links=True).to_list() 67 | assert isinstance(docs[0].linked_doc, DocumentMultiModelOne) 68 | -------------------------------------------------------------------------------- /tests/typing/decorators.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Coroutine 2 | 3 | from typing_extensions import Protocol, TypeAlias, assert_type 4 | 5 | from beanie import Document 6 | from beanie.odm.actions import EventTypes, wrap_with_actions 7 | from beanie.odm.utils.self_validation import validate_self_before 8 | from beanie.odm.utils.state import ( 9 | previous_saved_state_needed, 10 | save_state_after, 11 | saved_state_needed, 12 | ) 13 | 14 | 15 | def sync_func(doc_self: Document, arg1: str, arg2: int, /) -> Document: 16 | """ 17 | Models `Document` sync method that expects self 18 | """ 19 | raise NotImplementedError 20 | 21 | 22 | SyncFunc: TypeAlias = Callable[[Document, str, int], Document] 23 | 24 | 25 | async def async_func(doc_self: Document, arg1: str, arg2: int, /) -> Document: 26 | """ 27 | Models `Document` async method that expects self 28 | """ 29 | raise NotImplementedError 30 | 31 | 32 | AsyncFunc: TypeAlias = Callable[ 33 | [Document, str, int], Coroutine[Any, Any, Document] 34 | ] 35 | 36 | 37 | def test_wrap_with_actions_preserves_signature() -> None: 38 | assert_type(async_func, AsyncFunc) 39 | assert_type(wrap_with_actions(EventTypes.SAVE)(async_func), AsyncFunc) 40 | 41 | 42 | def test_save_state_after_preserves_signature() -> None: 43 | assert_type(async_func, AsyncFunc) 44 | assert_type(save_state_after(async_func), AsyncFunc) 45 | 46 | 47 | def test_validate_self_before_preserves_signature() -> None: 48 | assert_type(async_func, AsyncFunc) 49 | assert_type(validate_self_before(async_func), AsyncFunc) 50 | 51 | 52 | def test_saved_state_needed_preserves_signature() -> None: 53 | assert_type(async_func, AsyncFunc) 54 | assert_type(saved_state_needed(async_func), AsyncFunc) 55 | 56 | assert_type(sync_func, SyncFunc) 57 | assert_type(saved_state_needed(sync_func), SyncFunc) 58 | 59 | 60 | def test_previous_saved_state_needed_preserves_signature() -> None: 61 | assert_type(async_func, AsyncFunc) 62 | assert_type(previous_saved_state_needed(async_func), AsyncFunc) 63 | 64 | assert_type(sync_func, SyncFunc) 65 | assert_type(previous_saved_state_needed(sync_func), SyncFunc) 66 | 67 | 68 | class ExpectsDocumentSelf(Protocol): 69 | def __call__(self, doc_self: Document, /) -> Any: ... 70 | 71 | 72 | def test_document_insert_expects_self() -> None: 73 | test_insert: ExpectsDocumentSelf = Document.insert # noqa: F841 74 | 75 | 76 | def test_document_save_expects_self() -> None: 77 | test_insert: ExpectsDocumentSelf = Document.save # noqa: F841 78 | 79 | 80 | def test_document_replace_expects_self() -> None: 81 | test_insert: ExpectsDocumentSelf = Document.replace # noqa: F841 82 | -------------------------------------------------------------------------------- /tests/odm/documents/test_validation_on_save.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | from bson import ObjectId 5 | from pydantic import BaseModel, ValidationError 6 | 7 | from beanie import PydanticObjectId 8 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 9 | from tests.odm.models import ( 10 | DocumentWithValidationOnSave, 11 | Lock, 12 | WindowWithValidationOnSave, 13 | ) 14 | 15 | 16 | async def test_validate_on_insert(): 17 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 18 | doc.num_1 = "wrong_value" 19 | with pytest.raises(ValidationError): 20 | await doc.insert() 21 | 22 | 23 | async def test_validate_on_replace(): 24 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 25 | await doc.insert() 26 | doc.num_1 = "wrong_value" 27 | with pytest.raises(ValidationError): 28 | await doc.replace() 29 | 30 | 31 | async def test_validate_on_save_changes(): 32 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 33 | await doc.insert() 34 | doc.num_1 = "wrong_value" 35 | with pytest.raises(ValidationError): 36 | await doc.save_changes() 37 | 38 | 39 | async def test_validate_on_save_keep_the_id_type(): 40 | class UpdateModel(BaseModel): 41 | num_1: Optional[int] = None 42 | related: Optional[PydanticObjectId] = None 43 | 44 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 45 | await doc.insert() 46 | update = UpdateModel(related=PydanticObjectId()) 47 | if IS_PYDANTIC_V2: 48 | doc = doc.model_copy(update=update.model_dump(exclude_unset=True)) 49 | else: 50 | doc = doc.copy(update=update.dict(exclude_unset=True)) 51 | doc.num_2 = 1000 52 | await doc.save() 53 | in_db = ( 54 | await DocumentWithValidationOnSave.get_pymongo_collection().find_one( 55 | {"_id": doc.id} 56 | ) 57 | ) 58 | assert isinstance(in_db["related"], ObjectId) 59 | new_doc = await DocumentWithValidationOnSave.get(doc.id) 60 | assert isinstance(new_doc.related, PydanticObjectId) 61 | 62 | 63 | async def test_validate_on_save_action(): 64 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 65 | await doc.insert() 66 | assert doc.num_2 == 3 67 | 68 | 69 | async def test_validate_on_save_skip_action(): 70 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 71 | await doc.insert(skip_actions=["num_2_plus_1"]) 72 | assert doc.num_2 == 2 73 | 74 | 75 | async def test_validate_on_save_dbref(): 76 | lock = Lock(k=1) 77 | await lock.insert() 78 | window = WindowWithValidationOnSave( 79 | x=1, 80 | y=1, 81 | lock=lock.to_ref(), # this is what exactly we want to test 82 | ) 83 | await window.insert() 84 | -------------------------------------------------------------------------------- /beanie/odm/operators/find/array.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Optional 3 | 4 | from beanie.odm.operators.find import BaseFindOperator 5 | 6 | 7 | class BaseFindArrayOperator(BaseFindOperator, ABC): ... 8 | 9 | 10 | class All(BaseFindArrayOperator): 11 | """ 12 | `$all` array query operator 13 | 14 | Example: 15 | 16 | ```python 17 | class Sample(Document): 18 | results: List[int] 19 | 20 | All(Sample.results, [80, 85]) 21 | ``` 22 | 23 | Will return query object like 24 | 25 | ```python 26 | {"results": {"$all": [80, 85]}} 27 | ``` 28 | 29 | MongoDB doc: 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | field, 36 | values: list, 37 | ): 38 | self.field = field 39 | self.values_list = values 40 | 41 | @property 42 | def query(self): 43 | return {self.field: {"$all": self.values_list}} 44 | 45 | 46 | class ElemMatch(BaseFindArrayOperator): 47 | """ 48 | `$elemMatch` array query operator 49 | 50 | Example: 51 | 52 | ```python 53 | class Sample(Document): 54 | results: List[int] 55 | 56 | ElemMatch(Sample.results, {"$in": [80, 85]}) 57 | ``` 58 | 59 | Will return query object like 60 | 61 | ```python 62 | {"results": {"$elemMatch": {"$in": [80, 85]}}} 63 | ``` 64 | 65 | MongoDB doc: 66 | 67 | """ 68 | 69 | def __init__( 70 | self, 71 | field, 72 | expression: Optional[dict] = None, 73 | **kwargs: Any, 74 | ): 75 | self.field = field 76 | 77 | if expression is None: 78 | self.expression = kwargs 79 | else: 80 | self.expression = expression 81 | 82 | @property 83 | def query(self): 84 | return {self.field: {"$elemMatch": self.expression}} 85 | 86 | 87 | class Size(BaseFindArrayOperator): 88 | """ 89 | `$size` array query operator 90 | 91 | Example: 92 | 93 | ```python 94 | class Sample(Document): 95 | results: List[int] 96 | 97 | Size(Sample.results, 2) 98 | ``` 99 | 100 | Will return query object like 101 | 102 | ```python 103 | {"results": {"$size": 2}} 104 | ``` 105 | 106 | MongoDB doc: 107 | 108 | """ 109 | 110 | def __init__( 111 | self, 112 | field, 113 | num: int, 114 | ): 115 | self.field = field 116 | self.num = num 117 | 118 | @property 119 | def query(self): 120 | return {self.field: {"$size": self.num}} 121 | -------------------------------------------------------------------------------- /docs/tutorial/indexes.md: -------------------------------------------------------------------------------- 1 | ## Indexes setup 2 | 3 | There are more than one way to set up indexes using Beanie 4 | 5 | ### Indexed function 6 | 7 | To set up an index over a single field, the `Indexed` function can be used to wrap the type 8 | and does not require a `Settings` class: 9 | 10 | ```python 11 | from beanie import Document, Indexed 12 | 13 | 14 | class Sample(Document): 15 | num: Annotated[int, Indexed()] 16 | description: str 17 | ``` 18 | 19 | The `Indexed` function takes an optional `index_type` argument, which may be set to a pymongo index type: 20 | 21 | ```python 22 | import pymongo 23 | 24 | from beanie import Document, Indexed 25 | 26 | 27 | class Sample(Document): 28 | description: Annotated[str, Indexed(index_type=pymongo.TEXT)] 29 | ``` 30 | 31 | The `Indexed` function also supports PyMongo's `IndexModel` kwargs arguments (see the [PyMongo Documentation](https://pymongo.readthedocs.io/en/stable/api/pymongo/operations.html#pymongo.operations.IndexModel) for details). 32 | 33 | For example, to create a `unique` index: 34 | 35 | ```python 36 | from beanie import Document, Indexed 37 | 38 | 39 | class Sample(Document): 40 | name: Annotated[str, Indexed(unique=True)] 41 | ``` 42 | 43 | The `Indexed` function can also be used directly in the type annotation, by giving it the wrapped type as the first argument. Note that this might not work with some Pydantic V2 types, such as `UUID4` or `EmailStr`. 44 | 45 | ```python 46 | from beanie import Document, Indexed 47 | 48 | 49 | class Sample(Document): 50 | name: Indexed(str, unique=True) 51 | ``` 52 | 53 | ### Multi-field indexes 54 | 55 | The `indexes` field of the inner `Settings` class is responsible for more complex indexes. 56 | It is a list where items can be: 57 | 58 | - Single key. Name of the document's field (this is equivalent to using the Indexed function described above without any additional arguments) 59 | - List of (key, direction) pairs. Key - string, name of the document's field. Direction - pymongo direction ( 60 | example: `pymongo.ASCENDING`) 61 | - `pymongo.IndexModel` instance - the most flexible 62 | option. [PyMongo Documentation](https://pymongo.readthedocs.io/en/stable/api/pymongo/operations.html#pymongo.operations.IndexModel) 63 | 64 | ```python 65 | import pymongo 66 | from pymongo import IndexModel 67 | 68 | from beanie import Document 69 | 70 | 71 | class Sample(Document): 72 | test_int: int 73 | test_str: str 74 | 75 | class Settings: 76 | indexes = [ 77 | "test_int", 78 | [ 79 | ("test_int", pymongo.ASCENDING), 80 | ("test_str", pymongo.DESCENDING), 81 | ], 82 | IndexModel( 83 | [("test_str", pymongo.DESCENDING)], 84 | name="test_string_index_DESCENDING", 85 | ), 86 | ] 87 | ``` 88 | -------------------------------------------------------------------------------- /beanie/odm/queries/delete.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Generator, Mapping, Optional, Type 2 | 3 | from pymongo import DeleteMany as DeleteManyPyMongo 4 | from pymongo import DeleteOne as DeleteOnePyMongo 5 | from pymongo.asynchronous.client_session import AsyncClientSession 6 | from pymongo.results import DeleteResult 7 | 8 | from beanie.odm.bulk import BulkWriter 9 | from beanie.odm.interfaces.clone import CloneInterface 10 | from beanie.odm.interfaces.session import SessionMethods 11 | 12 | if TYPE_CHECKING: 13 | from beanie.odm.documents import DocType 14 | 15 | 16 | class DeleteQuery(SessionMethods, CloneInterface): 17 | """ 18 | Deletion Query 19 | """ 20 | 21 | def __init__( 22 | self, 23 | document_model: Type["DocType"], 24 | find_query: Mapping[str, Any], 25 | bulk_writer: Optional[BulkWriter] = None, 26 | **pymongo_kwargs: Any, 27 | ): 28 | self.document_model = document_model 29 | self.find_query = find_query 30 | self.session: Optional[AsyncClientSession] = None 31 | self.bulk_writer = bulk_writer 32 | self.pymongo_kwargs: Dict[str, Any] = pymongo_kwargs 33 | 34 | 35 | class DeleteMany(DeleteQuery): 36 | def __await__( 37 | self, 38 | ) -> Generator[DeleteResult, None, Optional[DeleteResult]]: 39 | """ 40 | Run the query 41 | :return: 42 | """ 43 | if self.bulk_writer is None: 44 | return ( 45 | yield from self.document_model.get_pymongo_collection() 46 | .delete_many( 47 | self.find_query, 48 | session=self.session, 49 | **self.pymongo_kwargs, 50 | ) 51 | .__await__() 52 | ) 53 | else: 54 | self.bulk_writer.add_operation( 55 | self.document_model, 56 | DeleteManyPyMongo(self.find_query, **self.pymongo_kwargs), 57 | ) 58 | return None 59 | 60 | 61 | class DeleteOne(DeleteQuery): 62 | def __await__( 63 | self, 64 | ) -> Generator[DeleteResult, None, Optional[DeleteResult]]: 65 | """ 66 | Run the query 67 | :return: 68 | """ 69 | if self.bulk_writer is None: 70 | return ( 71 | yield from self.document_model.get_pymongo_collection() 72 | .delete_one( 73 | self.find_query, 74 | session=self.session, 75 | **self.pymongo_kwargs, 76 | ) 77 | .__await__() 78 | ) 79 | else: 80 | self.bulk_writer.add_operation( 81 | self.document_model, 82 | DeleteOnePyMongo(self.find_query), 83 | **self.pymongo_kwargs, 84 | ) 85 | return None 86 | -------------------------------------------------------------------------------- /assets/logo/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /scripts/generate_changelog.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from dataclasses import dataclass 3 | from datetime import datetime 4 | from typing import List, Set 5 | 6 | import requests # type: ignore 7 | 8 | 9 | @dataclass 10 | class PullRequest: 11 | number: int 12 | title: str 13 | user: str 14 | user_url: str 15 | url: str 16 | 17 | 18 | class ChangelogGenerator: 19 | def __init__( 20 | self, 21 | username: str, 22 | repository: str, 23 | current_version: str, 24 | new_version: str, 25 | ): 26 | self.username = username 27 | self.repository = repository 28 | self.base_url = f"https://api.github.com/repos/{username}/{repository}" 29 | self.current_version = current_version 30 | self.new_version = new_version 31 | self.commits = self.get_commits_after_tag(current_version) 32 | self.prs = self.get_prs_for_commits(self.commits) 33 | 34 | def get_commits_after_tag(self, tag: str) -> List[str]: 35 | result = subprocess.run( 36 | ["git", "log", f"{tag}..HEAD", "--pretty=format:%H"], 37 | stdout=subprocess.PIPE, 38 | text=True, 39 | ) 40 | return result.stdout.split() 41 | 42 | def get_pr_for_commit(self, commit_sha: str) -> PullRequest: 43 | url = f"{self.base_url}/commits/{commit_sha}/pulls" 44 | response = requests.get(url) 45 | response.raise_for_status() 46 | pr_data = response.json()[0] 47 | return PullRequest( 48 | number=pr_data["number"], 49 | title=pr_data["title"], 50 | user=pr_data["user"]["login"], 51 | user_url=pr_data["user"]["html_url"], 52 | url=pr_data["html_url"], 53 | ) 54 | 55 | def get_prs_for_commits(self, commit_shas: List[str]) -> List[PullRequest]: 56 | prs: List[PullRequest] = [] 57 | unique_prs: Set[int] = set() 58 | for commit_sha in commit_shas: 59 | pr = self.get_pr_for_commit(commit_sha) 60 | pr_id = pr.number 61 | if pr_id not in unique_prs: 62 | unique_prs.add(pr_id) 63 | prs.append(pr) 64 | prs.sort(key=lambda pr: pr.number, reverse=True) 65 | return prs 66 | 67 | def generate_changelog(self) -> str: 68 | markdown = f"\n## [{self.new_version}] - {datetime.now().strftime('%Y-%m-%d')}\n" 69 | for pr in self.prs: 70 | markdown += ( 71 | f"### {pr.title.capitalize()}\n" 72 | f"- Author - [{pr.user}]({pr.user_url})\n" 73 | f"- PR <{pr.url}>\n" 74 | ) 75 | markdown += f"\n[{self.new_version}]: https://pypi.org/project/{self.repository}/{self.new_version}\n" 76 | return markdown 77 | 78 | 79 | if __name__ == "__main__": 80 | generator = ChangelogGenerator( 81 | username="BeanieODM", 82 | repository="beanie", 83 | current_version="2.0.0", 84 | new_version="2.0.1", 85 | ) 86 | 87 | changelog = generator.generate_changelog() 88 | print(changelog) 89 | -------------------------------------------------------------------------------- /tests/migrations/test_free_fall.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.migrations.models import RunningDirections 7 | from beanie.odm.documents import Document 8 | from beanie.odm.models import InspectionStatuses 9 | 10 | 11 | class Tag(BaseModel): 12 | color: str 13 | name: str 14 | 15 | 16 | class OldNote(Document): 17 | name: str 18 | tag: Tag 19 | 20 | class Settings: 21 | name = "notes" 22 | 23 | 24 | class Note(Document): 25 | title: str 26 | tag: Tag 27 | 28 | class Settings: 29 | name = "notes" 30 | 31 | 32 | @pytest.fixture() 33 | async def notes(db): 34 | await init_beanie(database=db, document_models=[OldNote]) 35 | await OldNote.delete_all() 36 | for i in range(10): 37 | note = OldNote(name=str(i), tag=Tag(name="test", color="red")) 38 | await note.insert() 39 | yield 40 | await OldNote.delete_all() 41 | 42 | 43 | async def test_migration_free_fall(settings, notes, db): 44 | if not db.client.is_mongos and not len(db.client.nodes) > 1: 45 | return pytest.skip( 46 | "MongoDB server does not support transactions as it is neighter a mongos instance not a replica set." 47 | ) 48 | 49 | migration_settings = MigrationSettings( 50 | connection_uri=settings.mongodb_dsn, 51 | database_name=settings.mongodb_db_name, 52 | path="tests/migrations/migrations_for_test/free_fall", 53 | ) 54 | await run_migrate(migration_settings) 55 | 56 | await init_beanie(database=db, document_models=[Note]) 57 | inspection = await Note.inspect_collection() 58 | assert inspection.status == InspectionStatuses.OK 59 | note = await Note.find_one({}) 60 | assert note.title == "0" 61 | 62 | migration_settings.direction = RunningDirections.BACKWARD 63 | await run_migrate(migration_settings) 64 | inspection = await OldNote.inspect_collection() 65 | assert inspection.status == InspectionStatuses.OK 66 | note = await OldNote.find_one({}) 67 | assert note.name == "0" 68 | 69 | 70 | async def test_migration_free_fall_no_use_transactions(settings, notes, db): 71 | migration_settings = MigrationSettings( 72 | connection_uri=settings.mongodb_dsn, 73 | database_name=settings.mongodb_db_name, 74 | path="tests/migrations/migrations_for_test/free_fall", 75 | use_transaction=False, 76 | ) 77 | await run_migrate(migration_settings) 78 | 79 | await init_beanie(database=db, document_models=[Note]) 80 | inspection = await Note.inspect_collection() 81 | assert inspection.status == InspectionStatuses.OK 82 | note = await Note.find_one({}) 83 | assert note.title == "0" 84 | 85 | migration_settings.direction = RunningDirections.BACKWARD 86 | await run_migrate(migration_settings) 87 | inspection = await OldNote.inspect_collection() 88 | assert inspection.status == InspectionStatuses.OK 89 | note = await OldNote.find_one({}) 90 | assert note.name == "0" 91 | -------------------------------------------------------------------------------- /tests/migrations/test_remove_indexes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic.main import BaseModel 3 | 4 | from beanie import init_beanie 5 | from beanie.executors.migrate import MigrationSettings, run_migrate 6 | from beanie.odm.documents import Document 7 | 8 | 9 | class Tag(BaseModel): 10 | color: str 11 | name: str 12 | 13 | 14 | class OldNote(Document): 15 | title: str 16 | tag: Tag 17 | 18 | class Settings: 19 | name = "notes" 20 | indexes = ["title"] 21 | 22 | 23 | class Note(Document): 24 | title: str 25 | tag: Tag 26 | 27 | class Settings: 28 | name = "notes" 29 | 30 | 31 | @pytest.fixture() 32 | async def notes(db): 33 | await init_beanie(database=db, document_models=[OldNote]) 34 | await OldNote.delete_all() 35 | for i in range(10): 36 | note = OldNote(title=str(i), tag=Tag(name="test", color="red")) 37 | await note.insert() 38 | yield 39 | await OldNote.delete_all() 40 | 41 | 42 | async def test_remove_index_allowed(settings, notes, db): 43 | migration_settings = MigrationSettings( 44 | connection_uri=settings.mongodb_dsn, 45 | database_name=settings.mongodb_db_name, 46 | path="tests/migrations/migrations_for_test/remove_index", 47 | allow_index_dropping=True, 48 | ) 49 | await run_migrate(migration_settings) 50 | 51 | await init_beanie( 52 | database=db, document_models=[Note], allow_index_dropping=False 53 | ) 54 | collection = Note.get_pymongo_collection() 55 | index_info = await collection.index_information() 56 | assert index_info == { 57 | "_id_": {"key": [("_id", 1)], "v": 2}, 58 | } 59 | 60 | 61 | async def test_remove_index_default(settings, notes, db): 62 | migration_settings = MigrationSettings( 63 | connection_uri=settings.mongodb_dsn, 64 | database_name=settings.mongodb_db_name, 65 | path="tests/migrations/migrations_for_test/remove_index", 66 | ) 67 | await run_migrate(migration_settings) 68 | 69 | await init_beanie( 70 | database=db, document_models=[Note], allow_index_dropping=False 71 | ) 72 | collection = Note.get_pymongo_collection() 73 | index_info = await collection.index_information() 74 | assert index_info == { 75 | "_id_": {"key": [("_id", 1)], "v": 2}, 76 | "title_1": {"key": [("title", 1)], "v": 2}, 77 | } 78 | 79 | 80 | async def test_remove_index_not_allowed(settings, notes, db): 81 | migration_settings = MigrationSettings( 82 | connection_uri=settings.mongodb_dsn, 83 | database_name=settings.mongodb_db_name, 84 | path="tests/migrations/migrations_for_test/remove_index", 85 | allow_index_dropping=False, 86 | ) 87 | await run_migrate(migration_settings) 88 | 89 | await init_beanie( 90 | database=db, document_models=[Note], allow_index_dropping=False 91 | ) 92 | collection = Note.get_pymongo_collection() 93 | index_info = await collection.index_information() 94 | assert index_info == { 95 | "_id_": {"key": [("_id", 1)], "v": 2}, 96 | "title_1": {"key": [("title", 1)], "v": 2}, 97 | } 98 | -------------------------------------------------------------------------------- /docs/tutorial/actions.md: -------------------------------------------------------------------------------- 1 | # Event-based actions 2 | 3 | You can register methods as pre- or post- actions for document events. 4 | 5 | Currently supported events: 6 | 7 | - Save 8 | - Insert 9 | - Replace 10 | - Update 11 | - SaveChanges 12 | - Delete 13 | - ValidateOnSave 14 | 15 | Currently supported directions: 16 | 17 | - `Before` 18 | - `After` 19 | 20 | Current operations creating events: 21 | 22 | - `insert()` for Insert 23 | - `replace()` for Replace 24 | - `save()` for Save and triggers Insert if it is creating a new document, or triggers Replace if it replaces an existing document 25 | - `save_changes()` for SaveChanges 26 | - `insert()`, `replace()`, `save_changes()`, and `save()` for ValidateOnSave 27 | - `set()`, `update()` for Update 28 | - `delete()` for Delete 29 | 30 | To register an action, you can use `@before_event` and `@after_event` decorators respectively: 31 | 32 | ```python 33 | from beanie import Insert, Replace, before_event, after_event 34 | 35 | 36 | class Sample(Document): 37 | num: int 38 | name: str 39 | 40 | @before_event(Insert) 41 | def capitalize_name(self): 42 | self.name = self.name.capitalize() 43 | 44 | @after_event(Replace) 45 | def num_change(self): 46 | self.num -= 1 47 | ``` 48 | 49 | It is possible to register action for several events: 50 | 51 | ```python 52 | from beanie import Insert, Replace, before_event 53 | 54 | 55 | class Sample(Document): 56 | num: int 57 | name: str 58 | 59 | @before_event(Insert, Replace) 60 | def capitalize_name(self): 61 | self.name = self.name.capitalize() 62 | ``` 63 | 64 | This will capitalize the `name` field value before each document's Insert and Replace. 65 | 66 | And sync and async methods could work as actions. 67 | 68 | ```python 69 | from beanie import Insert, Replace, after_event 70 | 71 | 72 | class Sample(Document): 73 | num: int 74 | name: str 75 | 76 | @after_event(Insert, Replace) 77 | async def send_callback(self): 78 | await client.send(self.id) 79 | ``` 80 | 81 | Actions can be selectively skipped by passing the `skip_actions` argument when calling 82 | the operations that trigger events. `skip_actions` accepts a list of directions and action names. 83 | 84 | ```python 85 | from beanie import After, Before, Insert, Replace, before_event, after_event 86 | 87 | 88 | class Sample(Document): 89 | num: int 90 | name: str 91 | 92 | @before_event(Insert) 93 | def capitalize_name(self): 94 | self.name = self.name.capitalize() 95 | 96 | @before_event(Replace) 97 | def redact_name(self): 98 | self.name = "[REDACTED]" 99 | 100 | @after_event(Replace) 101 | def num_change(self): 102 | self.num -= 1 103 | 104 | 105 | sample = Sample() 106 | 107 | # capitalize_name will not be executed 108 | await sample.insert(skip_actions=['capitalize_name']) 109 | 110 | # num_change will not be executed 111 | await sample.replace(skip_actions=[After]) 112 | 113 | # redact_name and num_change will not be executed 114 | await sample.replace(skip_actions[Before, 'num_change']) 115 | ``` 116 | -------------------------------------------------------------------------------- /docs/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | 3 | ## Installing beanie 4 | 5 | You can simply install Beanie from the [PyPI](https://pypi.org/project/beanie/): 6 | 7 | ### PIP 8 | 9 | ```shell 10 | pip install beanie 11 | ``` 12 | 13 | ### Poetry 14 | 15 | ```shell 16 | poetry add beanie 17 | ``` 18 | 19 | ### Optional dependencies 20 | 21 | Beanie supports some optional dependencies from PyMongo (`pip` or `poetry` can be used). 22 | 23 | GSSAPI authentication requires `gssapi` extra dependency: 24 | 25 | ```bash 26 | pip install "beanie[gssapi]" 27 | ``` 28 | 29 | MONGODB-AWS authentication requires `aws` extra dependency: 30 | 31 | ```bash 32 | pip install "beanie[aws]" 33 | ``` 34 | 35 | Support for mongodb+srv:// URIs requires `srv` extra dependency: 36 | 37 | ```bash 38 | pip install "beanie[srv]" 39 | ``` 40 | 41 | OCSP requires `ocsp` extra dependency: 42 | 43 | ```bash 44 | pip install "beanie[ocsp]" 45 | ``` 46 | 47 | Wire protocol compression with snappy requires `snappy` extra 48 | dependency: 49 | 50 | ```bash 51 | pip install "beanie[snappy]" 52 | ``` 53 | 54 | Wire protocol compression with zstandard requires `zstd` extra 55 | dependency: 56 | 57 | ```bash 58 | pip install "beanie[zstd]" 59 | ``` 60 | 61 | Client-Side Field Level Encryption requires `encryption` extra 62 | dependency: 63 | 64 | ```bash 65 | pip install "beanie[encryption]" 66 | ``` 67 | 68 | You can install all dependencies automatically with the following 69 | command: 70 | 71 | ```bash 72 | pip install "beanie[gssapi,aws,ocsp,snappy,srv,zstd,encryption]" 73 | ``` 74 | 75 | ## Initialization 76 | 77 | Getting Beanie setup in your code is really easy: 78 | 79 | 1. Write your database model as a Pydantic class but use `beanie.Document` instead of `pydantic.BaseModel`. 80 | 2. Initialize Async PyMongo, as Beanie uses this as an async database engine under the hood. 81 | 3. Call `beanie.init_beanie` with the PyMongo client and list of Beanie models 82 | 83 | The code below should get you started and shows some of the field types that you can use with beanie. 84 | 85 | ```python 86 | from typing import Optional 87 | 88 | from pymongo import AsyncMongoClient 89 | from pydantic import BaseModel 90 | 91 | from beanie import Document, Indexed, init_beanie 92 | 93 | 94 | class Category(BaseModel): 95 | name: str 96 | description: str 97 | 98 | 99 | # This is the model that will be saved to the database 100 | class Product(Document): 101 | name: str # You can use normal types just like in pydantic 102 | description: Optional[str] = None 103 | price: Indexed(float) # You can also specify that a field should correspond to an index 104 | category: Category # You can include pydantic models as well 105 | 106 | 107 | # Call this from within your event loop to get beanie setup. 108 | async def init(): 109 | # Create Async PyMongo client 110 | client = AsyncMongoClient("mongodb://user:pass@host:27017") 111 | 112 | # Init beanie with the Product document class 113 | await init_beanie(database=client.db_name, document_models=[Product]) 114 | ``` 115 | -------------------------------------------------------------------------------- /.github/scripts/handlers/gh.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from dataclasses import dataclass 3 | from datetime import datetime 4 | from typing import List 5 | 6 | import requests # type: ignore 7 | 8 | 9 | @dataclass 10 | class PullRequest: 11 | number: int 12 | title: str 13 | user: str 14 | user_url: str 15 | url: str 16 | 17 | 18 | class GitHubHandler: 19 | def __init__( 20 | self, 21 | username: str, 22 | repository: str, 23 | current_version: str, 24 | new_version: str, 25 | ): 26 | self.username = username 27 | self.repository = repository 28 | self.base_url = f"https://api.github.com/repos/{username}/{repository}" 29 | self.current_version = current_version 30 | self.new_version = new_version 31 | self.commits = self.get_commits_after_tag(current_version) 32 | self.prs = [self.get_pr_for_commit(commit) for commit in self.commits] 33 | 34 | def get_commits_after_tag(self, tag: str) -> List[str]: 35 | result = subprocess.run( 36 | ["git", "log", f"{tag}..HEAD", "--pretty=format:%H"], 37 | stdout=subprocess.PIPE, 38 | text=True, 39 | ) 40 | return result.stdout.split() 41 | 42 | def get_pr_for_commit(self, commit_sha: str) -> PullRequest: 43 | url = f"{self.base_url}/commits/{commit_sha}/pulls" 44 | response = requests.get(url) 45 | response.raise_for_status() 46 | pr_data = response.json()[0] 47 | return PullRequest( 48 | number=pr_data["number"], 49 | title=pr_data["title"], 50 | user=pr_data["user"]["login"], 51 | user_url=pr_data["user"]["html_url"], 52 | url=pr_data["html_url"], 53 | ) 54 | 55 | def build_markdown_for_many_prs(self) -> str: 56 | markdown = f"\n## [{self.new_version}] - {datetime.now().strftime('%Y-%m-%d')}\n" 57 | for pr in self.prs: 58 | markdown += ( 59 | f"### {pr.title.capitalize()}\n" 60 | f"- Author - [{pr.user}]({pr.user_url})\n" 61 | f"- PR <{pr.url}>\n" 62 | ) 63 | markdown += f"\n[{self.new_version}]: https://pypi.org/project/{self.repository}/{self.new_version}\n" 64 | return markdown 65 | 66 | def commit_changes(self): 67 | self.run_git_command( 68 | ["git", "config", "--global", "user.name", "github-actions[bot]"] 69 | ) 70 | self.run_git_command( 71 | [ 72 | "git", 73 | "config", 74 | "--global", 75 | "user.email", 76 | "github-actions[bot]@users.noreply.github.com", 77 | ] 78 | ) 79 | self.run_git_command(["git", "add", "."]) 80 | self.run_git_command( 81 | ["git", "commit", "-m", f"Bump version to {self.new_version}"] 82 | ) 83 | self.run_git_command(["git", "tag", self.new_version]) 84 | self.git_push() 85 | 86 | def git_push(self): 87 | self.run_git_command(["git", "push", "origin", "main", "--tags"]) 88 | 89 | @staticmethod 90 | def run_git_command(command: List[str]): 91 | subprocess.run(command, check=True) 92 | -------------------------------------------------------------------------------- /beanie/odm/operators/update/array.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update import BaseUpdateOperator 2 | 3 | 4 | class BaseUpdateArrayOperator(BaseUpdateOperator): 5 | operator = "" 6 | 7 | def __init__(self, expression): 8 | self.expression = expression 9 | 10 | @property 11 | def query(self): 12 | return {self.operator: self.expression} 13 | 14 | 15 | class AddToSet(BaseUpdateArrayOperator): 16 | """ 17 | `$addToSet` update array query operator 18 | 19 | Example: 20 | 21 | ```python 22 | class Sample(Document): 23 | results: List[int] 24 | 25 | AddToSet({Sample.results: 2}) 26 | ``` 27 | 28 | Will return query object like 29 | 30 | ```python 31 | {"$addToSet": {"results": 2}} 32 | ``` 33 | 34 | MongoDB docs: 35 | 36 | """ 37 | 38 | operator = "$addToSet" 39 | 40 | 41 | class Pop(BaseUpdateArrayOperator): 42 | """ 43 | `$pop` update array query operator 44 | 45 | Example: 46 | 47 | ```python 48 | class Sample(Document): 49 | results: List[int] 50 | 51 | Pop({Sample.results: 2}) 52 | ``` 53 | 54 | Will return query object like 55 | 56 | ```python 57 | {"$pop": {"results": -1}} 58 | ``` 59 | 60 | MongoDB docs: 61 | 62 | """ 63 | 64 | operator = "$pop" 65 | 66 | 67 | class Pull(BaseUpdateArrayOperator): 68 | """ 69 | `$pull` update array query operator 70 | 71 | Example: 72 | 73 | ```python 74 | class Sample(Document): 75 | results: List[int] 76 | 77 | Pull(In(Sample.result: [1,2,3,4,5]) 78 | ``` 79 | 80 | Will return query object like 81 | 82 | ```python 83 | {"$pull": { "results": { $in: [1,2,3,4,5] }}} 84 | ``` 85 | 86 | MongoDB docs: 87 | 88 | """ 89 | 90 | operator = "$pull" 91 | 92 | 93 | class Push(BaseUpdateArrayOperator): 94 | """ 95 | `$push` update array query operator 96 | 97 | Example: 98 | 99 | ```python 100 | class Sample(Document): 101 | results: List[int] 102 | 103 | Push({Sample.results: 1}) 104 | ``` 105 | 106 | Will return query object like 107 | 108 | ```python 109 | {"$push": { "results": 1}} 110 | ``` 111 | 112 | MongoDB docs: 113 | 114 | """ 115 | 116 | operator = "$push" 117 | 118 | 119 | class PullAll(BaseUpdateArrayOperator): 120 | """ 121 | `$pullAll` update array query operator 122 | 123 | Example: 124 | 125 | ```python 126 | class Sample(Document): 127 | results: List[int] 128 | 129 | PullAll({ Sample.results: [ 0, 5 ] }) 130 | ``` 131 | 132 | Will return query object like 133 | 134 | ```python 135 | {"$pullAll": { "results": [ 0, 5 ] }} 136 | ``` 137 | 138 | MongoDB docs: 139 | 140 | """ 141 | 142 | operator = "$pullAll" 143 | -------------------------------------------------------------------------------- /beanie/odm/union_doc.py: -------------------------------------------------------------------------------- 1 | from typing import Any, ClassVar, Dict, Optional, Type, TypeVar 2 | 3 | from pymongo.asynchronous.client_session import AsyncClientSession 4 | 5 | from beanie.exceptions import UnionDocNotInited 6 | from beanie.odm.bulk import BulkWriter 7 | from beanie.odm.interfaces.aggregate import AggregateInterface 8 | from beanie.odm.interfaces.detector import DetectionInterface, ModelType 9 | from beanie.odm.interfaces.find import FindInterface 10 | from beanie.odm.interfaces.getters import OtherGettersInterface 11 | from beanie.odm.settings.union_doc import UnionDocSettings 12 | 13 | UnionDocType = TypeVar("UnionDocType", bound="UnionDoc") 14 | 15 | 16 | class UnionDoc( 17 | FindInterface, 18 | AggregateInterface, 19 | OtherGettersInterface, 20 | DetectionInterface, 21 | ): 22 | _document_models: ClassVar[Optional[Dict[str, Type]]] = None 23 | _is_inited: ClassVar[bool] = False 24 | _settings: ClassVar[UnionDocSettings] 25 | 26 | @classmethod 27 | def get_settings(cls) -> UnionDocSettings: 28 | return cls._settings 29 | 30 | @classmethod 31 | def register_doc(cls, name: str, doc_model: Type): 32 | if cls._document_models is None: 33 | cls._document_models = {} 34 | 35 | if cls._is_inited is False: 36 | raise UnionDocNotInited 37 | 38 | cls._document_models[name] = doc_model 39 | return cls.get_settings().name 40 | 41 | @classmethod 42 | def get_model_type(cls) -> ModelType: 43 | return ModelType.UnionDoc 44 | 45 | @classmethod 46 | def bulk_writer( 47 | cls, 48 | session: Optional[AsyncClientSession] = None, 49 | ordered: bool = True, 50 | bypass_document_validation: bool = False, 51 | comment: Optional[Any] = None, 52 | ) -> BulkWriter: 53 | """ 54 | Returns a BulkWriter instance for handling bulk write operations. 55 | 56 | :param session: Optional[AsyncClientSession] - pymongo session. 57 | The session instance used for transactional operations. 58 | :param ordered: bool 59 | If ``True`` (the default), requests will be performed on the server serially, in the order provided. If an error 60 | occurs, all remaining operations are aborted. If ``False``, requests will be performed on the server in 61 | arbitrary order, possibly in parallel, and all operations will be attempted. 62 | :param bypass_document_validation: bool, optional 63 | If ``True``, allows the write to opt-out of document-level validation. Default is ``False``. 64 | :param comment: str, optional 65 | A user-provided comment to attach to the BulkWriter. 66 | 67 | :returns: BulkWriter 68 | An instance of BulkWriter configured with the provided settings. 69 | 70 | Example Usage: 71 | -------------- 72 | This method is typically used within an asynchronous context manager. 73 | 74 | .. code-block:: python 75 | 76 | async with Document.bulk_writer(ordered=True) as bulk: 77 | await Document.insert_one(Document(field="value"), bulk_writer=bulk) 78 | """ 79 | return BulkWriter( 80 | session, ordered, cls, bypass_document_validation, comment 81 | ) 82 | -------------------------------------------------------------------------------- /docs/tutorial/update.md: -------------------------------------------------------------------------------- 1 | # Updating & Deleting 2 | 3 | Now that we know how to find documents, how do we change them or delete them? 4 | 5 | ## Saving changes to existing documents 6 | 7 | The easiest way to change a document in the database is to use either the `replace` or `save` method on an altered document. 8 | These methods both write the document to the database, 9 | but `replace` will raise an exception when the document does not exist yet, while `save` will insert the document. 10 | 11 | Using `save()` method: 12 | 13 | ```python 14 | bar = await Product.find_one(Product.name == "Mars") 15 | bar.price = 10 16 | await bar.save() 17 | ``` 18 | 19 | Otherwise, use the `replace()` method, which throws: 20 | - a `ValueError` if the document does not have an `id` yet, or 21 | - a `beanie.exceptions.DocumentNotFound` if it does, but the `id` is not present in the collection 22 | 23 | ```python 24 | bar.price = 10 25 | try: 26 | await bar.replace() 27 | except (ValueError, beanie.exceptions.DocumentNotFound): 28 | print("Can't replace a non existing document") 29 | ``` 30 | 31 | Note that these methods require multiple queries to the database and replace the entire document with the new version. 32 | A more tailored solution can often be created by applying update queries directly on the database level. 33 | 34 | ## Update queries 35 | 36 | Update queries can be performed on the result of a `find` or `find_one` query, 37 | or on a document that was returned from an earlier query. 38 | Simpler updates can be performed using the `set`, `inc`, and `current_date` methods: 39 | 40 | ```python 41 | bar = await Product.find_one(Product.name == "Mars") 42 | await bar.set({Product.name:"Gold bar"}) 43 | bar = await Product.find(Product.price > .5).inc({Product.price: 1}) 44 | ``` 45 | 46 | More complex update operations can be performed by calling `update()` with an update operator, similar to find queries: 47 | 48 | ```python 49 | await Product.find_one(Product.name == "Tony's").update(Set({Product.price: 3.33})) 50 | ``` 51 | 52 | The whole list of the update query operators can be found [here](../api-documentation/operators/update.md). 53 | 54 | Native MongoDB syntax is also supported: 55 | 56 | ```python 57 | await Product.find_one(Product.name == "Tony's").update({"$set": {Product.price: 3.33}}) 58 | ``` 59 | 60 | ## Upsert 61 | 62 | To insert a document when no documents are matched against the search criteria, the `upsert` method can be used: 63 | 64 | ```python 65 | await Product.find_one(Product.name == "Tony's").upsert( 66 | Set({Product.price: 3.33}), 67 | on_insert=Product(name="Tony's", price=3.33, category=chocolate) 68 | ) 69 | ``` 70 | 71 | ## Deleting documents 72 | 73 | Deleting objects works just like updating them, you simply call `delete()` on the found documents: 74 | 75 | ```python 76 | bar = await Product.find_one(Product.name == "Milka") 77 | await bar.delete() 78 | 79 | await Product.find_one(Product.name == "Milka").delete() 80 | 81 | await Product.find(Product.category.name == "Chocolate").delete() 82 | ``` 83 | 84 | ## Response Type 85 | 86 | For the object methods `update` and `upsert`, you can use the `response_type` parameter to specify the type of response. 87 | 88 | The options are: 89 | - `UpdateResponse.UPDATE_RESULT` - returns the result of the update operation. 90 | - `UpdateResponse.NEW_DOCUMENT` - returns the newly updated document. 91 | - `UpdateResponse.OLD_DOCUMENT` - returns the document before the update. 92 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_geospatial.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.geospatial import ( 2 | Box, 3 | GeoIntersects, 4 | GeoWithin, 5 | Near, 6 | NearSphere, 7 | ) 8 | from tests.odm.models import Sample 9 | 10 | 11 | async def test_geo_intersects(): 12 | q = GeoIntersects( 13 | Sample.geo, geo_type="Polygon", coordinates=[[1, 1], [2, 2], [3, 3]] 14 | ) 15 | assert q == { 16 | "geo": { 17 | "$geoIntersects": { 18 | "$geometry": { 19 | "type": "Polygon", 20 | "coordinates": [[1, 1], [2, 2], [3, 3]], 21 | } 22 | } 23 | } 24 | } 25 | 26 | 27 | async def test_geo_within(): 28 | q = GeoWithin( 29 | Sample.geo, geo_type="Polygon", coordinates=[[1, 1], [2, 2], [3, 3]] 30 | ) 31 | assert q == { 32 | "geo": { 33 | "$geoWithin": { 34 | "$geometry": { 35 | "type": "Polygon", 36 | "coordinates": [[1, 1], [2, 2], [3, 3]], 37 | } 38 | } 39 | } 40 | } 41 | 42 | 43 | async def test_box(): 44 | q = Box(Sample.geo, lower_left=[1, 3], upper_right=[2, 4]) 45 | assert q == {"geo": {"$geoWithin": {"$box": [[1, 3], [2, 4]]}}} 46 | 47 | 48 | async def test_near(): 49 | q = Near(Sample.geo, longitude=1.1, latitude=2.2) 50 | assert q == { 51 | "geo": { 52 | "$near": { 53 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]} 54 | } 55 | } 56 | } 57 | 58 | q = Near(Sample.geo, longitude=1.1, latitude=2.2, max_distance=1) 59 | assert q == { 60 | "geo": { 61 | "$near": { 62 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 63 | "$maxDistance": 1, 64 | } 65 | } 66 | } 67 | 68 | q = Near( 69 | Sample.geo, 70 | longitude=1.1, 71 | latitude=2.2, 72 | max_distance=1, 73 | min_distance=0.5, 74 | ) 75 | assert q == { 76 | "geo": { 77 | "$near": { 78 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 79 | "$maxDistance": 1, 80 | "$minDistance": 0.5, 81 | } 82 | } 83 | } 84 | 85 | 86 | async def test_near_sphere(): 87 | q = NearSphere(Sample.geo, longitude=1.1, latitude=2.2) 88 | assert q == { 89 | "geo": { 90 | "$nearSphere": { 91 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]} 92 | } 93 | } 94 | } 95 | 96 | q = NearSphere(Sample.geo, longitude=1.1, latitude=2.2, max_distance=1) 97 | assert q == { 98 | "geo": { 99 | "$nearSphere": { 100 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 101 | "$maxDistance": 1, 102 | } 103 | } 104 | } 105 | 106 | q = NearSphere( 107 | Sample.geo, 108 | longitude=1.1, 109 | latitude=2.2, 110 | max_distance=1, 111 | min_distance=0.5, 112 | ) 113 | assert q == { 114 | "geo": { 115 | "$nearSphere": { 116 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 117 | "$maxDistance": 1, 118 | "$minDistance": 0.5, 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /tests/odm/test_cache.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from tests.odm.models import DocumentTestModel 4 | 5 | 6 | async def test_find_one(documents): 7 | await documents(5) 8 | doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 9 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 1).set( 10 | {DocumentTestModel.test_str: "NEW_VALUE"} 11 | ) 12 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 13 | assert doc == new_doc 14 | 15 | new_doc = await DocumentTestModel.find_one( 16 | DocumentTestModel.test_int == 1, ignore_cache=True 17 | ) 18 | assert doc != new_doc 19 | 20 | await asyncio.sleep(10) 21 | 22 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 23 | assert doc != new_doc 24 | 25 | 26 | async def test_find_many(documents): 27 | await documents(5) 28 | docs = await DocumentTestModel.find( 29 | DocumentTestModel.test_int > 1 30 | ).to_list() 31 | 32 | await DocumentTestModel.find(DocumentTestModel.test_int > 1).set( 33 | {DocumentTestModel.test_str: "NEW_VALUE"} 34 | ) 35 | 36 | new_docs = await DocumentTestModel.find( 37 | DocumentTestModel.test_int > 1 38 | ).to_list() 39 | assert docs == new_docs 40 | 41 | new_docs = await DocumentTestModel.find( 42 | DocumentTestModel.test_int > 1, ignore_cache=True 43 | ).to_list() 44 | assert docs != new_docs 45 | 46 | await asyncio.sleep(10) 47 | 48 | new_docs = await DocumentTestModel.find( 49 | DocumentTestModel.test_int > 1 50 | ).to_list() 51 | assert docs != new_docs 52 | 53 | 54 | async def test_aggregation(documents): 55 | await documents(5) 56 | docs = await DocumentTestModel.aggregate( 57 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 58 | ).to_list() 59 | 60 | await DocumentTestModel.find(DocumentTestModel.test_int > 1).set( 61 | {DocumentTestModel.test_str: "NEW_VALUE"} 62 | ) 63 | 64 | new_docs = await DocumentTestModel.aggregate( 65 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 66 | ).to_list() 67 | assert docs == new_docs 68 | 69 | new_docs = await DocumentTestModel.aggregate( 70 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}], 71 | ignore_cache=True, 72 | ).to_list() 73 | assert docs != new_docs 74 | 75 | await asyncio.sleep(10) 76 | 77 | new_docs = await DocumentTestModel.aggregate( 78 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 79 | ).to_list() 80 | assert docs != new_docs 81 | 82 | 83 | async def test_capacity(documents): 84 | await documents(10) 85 | docs = [] 86 | for i in range(10): 87 | docs.append( 88 | await DocumentTestModel.find_one(DocumentTestModel.test_int == i) 89 | ) 90 | 91 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 1).set( 92 | {DocumentTestModel.test_str: "NEW_VALUE"} 93 | ) 94 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 9).set( 95 | {DocumentTestModel.test_str: "NEW_VALUE"} 96 | ) 97 | 98 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 99 | assert docs[1] != new_doc 100 | 101 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 9) 102 | assert docs[9] == new_doc 103 | -------------------------------------------------------------------------------- /beanie/odm/utils/state.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from functools import wraps 3 | from typing import TYPE_CHECKING, TypeVar 4 | 5 | from typing_extensions import ParamSpec 6 | 7 | from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved 8 | 9 | if TYPE_CHECKING: 10 | from beanie.odm.documents import AnyDocMethod, AsyncDocMethod, DocType 11 | 12 | P = ParamSpec("P") 13 | R = TypeVar("R") 14 | 15 | 16 | def check_if_state_saved(self: "DocType"): 17 | if not self.use_state_management(): 18 | raise StateManagementIsTurnedOff( 19 | "State management is turned off for this document" 20 | ) 21 | if self._saved_state is None: 22 | raise StateNotSaved("No state was saved") 23 | 24 | 25 | def saved_state_needed( 26 | f: "AnyDocMethod[DocType, P, R]", 27 | ) -> "AnyDocMethod[DocType, P, R]": 28 | @wraps(f) 29 | def sync_wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R: 30 | check_if_state_saved(self) 31 | return f(self, *args, **kwargs) 32 | 33 | @wraps(f) 34 | async def async_wrapper( 35 | self: "DocType", *args: P.args, **kwargs: P.kwargs 36 | ) -> R: 37 | check_if_state_saved(self) 38 | # type ignore because there is no nice/proper way to annotate both sync 39 | # and async case without parametrized TypeVar, which is not supported 40 | return await f(self, *args, **kwargs) # type: ignore[misc] 41 | 42 | if inspect.iscoroutinefunction(f): 43 | # type ignore because there is no nice/proper way to annotate both sync 44 | # and async case without parametrized TypeVar, which is not supported 45 | return async_wrapper # type: ignore[return-value] 46 | return sync_wrapper 47 | 48 | 49 | def check_if_previous_state_saved(self: "DocType"): 50 | if not self.use_state_management(): 51 | raise StateManagementIsTurnedOff( 52 | "State management is turned off for this document" 53 | ) 54 | if not self.state_management_save_previous(): 55 | raise StateManagementIsTurnedOff( 56 | "State management's option to save previous state is turned off for this document" 57 | ) 58 | 59 | 60 | def previous_saved_state_needed( 61 | f: "AnyDocMethod[DocType, P, R]", 62 | ) -> "AnyDocMethod[DocType, P, R]": 63 | @wraps(f) 64 | def sync_wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R: 65 | check_if_previous_state_saved(self) 66 | return f(self, *args, **kwargs) 67 | 68 | @wraps(f) 69 | async def async_wrapper( 70 | self: "DocType", *args: P.args, **kwargs: P.kwargs 71 | ) -> R: 72 | check_if_previous_state_saved(self) 73 | # type ignore because there is no nice/proper way to annotate both sync 74 | # and async case without parametrized TypeVar, which is not supported 75 | return await f(self, *args, **kwargs) # type: ignore[misc] 76 | 77 | if inspect.iscoroutinefunction(f): 78 | # type ignore because there is no nice/proper way to annotate both sync 79 | # and async case without parametrized TypeVar, which is not supported 80 | return async_wrapper # type: ignore[return-value] 81 | return sync_wrapper 82 | 83 | 84 | def save_state_after( 85 | f: "AsyncDocMethod[DocType, P, R]", 86 | ) -> "AsyncDocMethod[DocType, P, R]": 87 | @wraps(f) 88 | async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R: 89 | result = await f(self, *args, **kwargs) 90 | self._save_state() 91 | return result 92 | 93 | return wrapper 94 | --------------------------------------------------------------------------------