├── beanie
├── py.typed
├── odm
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── general.py
│ │ ├── self_validation.py
│ │ ├── projection.py
│ │ ├── dump.py
│ │ ├── relations.py
│ │ ├── pydantic.py
│ │ ├── typing.py
│ │ └── state.py
│ ├── interfaces
│ │ ├── __init__.py
│ │ ├── clone.py
│ │ ├── detector.py
│ │ ├── inheritance.py
│ │ ├── session.py
│ │ ├── getters.py
│ │ └── setters.py
│ ├── queries
│ │ ├── __init__.py
│ │ ├── cursor.py
│ │ └── delete.py
│ ├── settings
│ │ ├── __init__.py
│ │ ├── union_doc.py
│ │ ├── view.py
│ │ ├── base.py
│ │ ├── document.py
│ │ └── timeseries.py
│ ├── custom_types
│ │ ├── bson
│ │ │ ├── __init__.py
│ │ │ └── binary.py
│ │ ├── __init__.py
│ │ ├── decimal.py
│ │ └── re.py
│ ├── operators
│ │ ├── find
│ │ │ ├── __init__.py
│ │ │ ├── bitwise.py
│ │ │ ├── element.py
│ │ │ └── array.py
│ │ ├── update
│ │ │ ├── __init__.py
│ │ │ ├── bitwise.py
│ │ │ └── array.py
│ │ └── __init__.py
│ ├── enums.py
│ ├── models.py
│ ├── registry.py
│ ├── cache.py
│ ├── views.py
│ └── union_doc.py
├── executors
│ └── __init__.py
├── migrations
│ ├── __init__.py
│ ├── controllers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── free_fall.py
│ ├── template.py
│ ├── utils.py
│ ├── database.py
│ └── models.py
├── exceptions.py
├── __init__.py
└── operators.py
├── tests
├── __init__.py
├── odm
│ ├── __init__.py
│ ├── query
│ │ ├── __init__.py
│ │ └── test_update_methods.py
│ ├── documents
│ │ ├── __init__.py
│ │ ├── test_pydantic_config.py
│ │ ├── test_exists.py
│ │ ├── test_count.py
│ │ ├── test_pydantic_extras.py
│ │ ├── test_inspect.py
│ │ ├── test_replace.py
│ │ ├── test_distinct.py
│ │ ├── test_sync.py
│ │ ├── test_delete.py
│ │ ├── test_aggregate.py
│ │ ├── test_multi_model.py
│ │ └── test_validation_on_save.py
│ ├── operators
│ │ ├── __init__.py
│ │ ├── find
│ │ │ ├── __init__.py
│ │ │ ├── test_element.py
│ │ │ ├── test_bitwise.py
│ │ │ ├── test_array.py
│ │ │ ├── test_logical.py
│ │ │ ├── test_evaluation.py
│ │ │ ├── test_comparison.py
│ │ │ └── test_geospatial.py
│ │ └── update
│ │ │ ├── __init__.py
│ │ │ ├── test_bitwise.py
│ │ │ ├── test_array.py
│ │ │ └── test_general.py
│ ├── custom_types
│ │ ├── __init__.py
│ │ ├── test_decimal_annotation.py
│ │ └── test_bson_binary.py
│ ├── test_cursor.py
│ ├── test_deprecated.py
│ ├── test_root_models.py
│ ├── views.py
│ ├── test_views.py
│ ├── test_concurrency.py
│ ├── test_beanie_object_dumping.py
│ ├── test_timesries_collection.py
│ ├── test_typing_utils.py
│ ├── test_id.py
│ ├── test_consistency.py
│ ├── test_lazy_parsing.py
│ ├── test_expression_fields.py
│ └── test_cache.py
├── fastapi
│ ├── __init__.py
│ ├── test_openapi_schema_generation.py
│ ├── models.py
│ ├── app.py
│ ├── conftest.py
│ ├── routes.py
│ └── test_api.py
├── migrations
│ ├── __init__.py
│ ├── models.py
│ ├── iterative
│ │ ├── __init__.py
│ │ ├── test_change_value_subfield.py
│ │ ├── test_rename_field.py
│ │ ├── test_change_value.py
│ │ ├── test_change_subfield.py
│ │ └── test_pack_unpack.py
│ ├── conftest.py
│ ├── migrations_for_test
│ │ ├── many_migrations
│ │ │ ├── 20210413170700_3_skip_backward.py
│ │ │ ├── 20210413170709_3_skip_forward.py
│ │ │ ├── 20210413170640_1.py
│ │ │ ├── 20210413170645_2.py
│ │ │ ├── 20210413170728_4_5.py
│ │ │ └── 20210413170734_6_7.py
│ │ ├── change_subfield_value
│ │ │ └── 20210413143405_change_subfield_value.py
│ │ ├── change_value
│ │ │ └── 20210413115234_change_value.py
│ │ ├── remove_index
│ │ │ └── 20210414135045_remove_index.py
│ │ ├── rename_field
│ │ │ └── 20210407203225_rename_field.py
│ │ ├── change_subfield
│ │ │ └── 20210413152406_change_subfield.py
│ │ ├── pack_unpack
│ │ │ └── 20210413135927_pack_unpack.py
│ │ ├── break
│ │ │ └── 20210413211219_break.py
│ │ └── free_fall
│ │ │ └── 20210413210446_free_fall.py
│ ├── test_break.py
│ ├── test_free_fall.py
│ └── test_remove_indexes.py
├── typing
│ ├── __init__.py
│ ├── models.py
│ ├── aggregation.py
│ ├── find.py
│ └── decorators.py
├── test_beanie.py
└── conftest.py
├── docs
├── CNAME
├── assets
│ ├── favicon.png
│ └── color_scheme.css
├── tutorial
│ ├── delete.md
│ ├── lazy_parse.md
│ ├── on_save_validation.md
│ ├── aggregate.md
│ ├── cache.md
│ ├── time_series.md
│ ├── revision.md
│ ├── init.md
│ ├── insert.md
│ ├── multi-model.md
│ ├── views.md
│ ├── indexes.md
│ ├── actions.md
│ └── update.md
├── publishing.md
└── getting-started.md
├── .github
├── scripts
│ └── handlers
│ │ ├── __init__.py
│ │ └── gh.py
├── workflows
│ ├── github-actions-publish-docs.yml
│ ├── github-actions-publish-project.yml
│ ├── close_inactive_issues.yml
│ └── github-actions-tests.yml
└── ISSUE_TEMPLATE
│ ├── config.yml
│ └── bug_report.md
├── CONTRIBUTING.md
├── CODE_OF_CONDUCT.md
├── scripts
├── publish_docs.sh
└── generate_changelog.py
├── .pypirc
├── .pre-commit-config.yaml
└── assets
└── logo
└── logo.svg
/beanie/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/executors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/utils/general.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | beanie-odm.dev
--------------------------------------------------------------------------------
/tests/fastapi/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/migrations/models.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/query/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/typing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/interfaces/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/queries/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/settings/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/documents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/operators/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/scripts/handlers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/migrations/iterative/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/custom_types/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/odm/operators/update/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/migrations/controllers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/odm/custom_types/bson/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/beanie/migrations/template.py:
--------------------------------------------------------------------------------
1 | class Forward: ...
2 |
3 |
4 | class Backward: ...
5 |
--------------------------------------------------------------------------------
/docs/assets/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeanieODM/beanie/HEAD/docs/assets/favicon.png
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Contributing
2 | ------------
3 |
4 | Please check [this page in the documentation](https://beanie-odm.dev/development/).
5 |
--------------------------------------------------------------------------------
/beanie/odm/settings/union_doc.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.settings.base import ItemSettings
2 |
3 |
4 | class UnionDocSettings(ItemSettings): ...
5 |
--------------------------------------------------------------------------------
/beanie/odm/interfaces/clone.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 |
4 | class CloneInterface:
5 | def clone(self):
6 | return deepcopy(self)
7 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | Code of Conduct
2 | ---------------
3 |
4 | Please check [this page in the documentation](https://roman-right.github.io/beanie/code-of-conduct/).
5 |
--------------------------------------------------------------------------------
/beanie/odm/operators/find/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from beanie.odm.operators import BaseOperator
4 |
5 |
6 | class BaseFindOperator(BaseOperator, ABC): ...
7 |
--------------------------------------------------------------------------------
/beanie/migrations/utils.py:
--------------------------------------------------------------------------------
1 | def update_dict(d, u):
2 | for k, v in u.items():
3 | if isinstance(v, dict):
4 | d[k] = update_dict(d.get(k, {}), v)
5 | else:
6 | d[k] = v
7 | return d
8 |
--------------------------------------------------------------------------------
/tests/odm/operators/update/test_bitwise.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.update.bitwise import Bit
2 | from tests.odm.models import Sample
3 |
4 |
5 | def test_bit():
6 | q = Bit({Sample.integer: 2})
7 | assert q == {"$bit": {"integer": 2}}
8 |
--------------------------------------------------------------------------------
/docs/assets/color_scheme.css:
--------------------------------------------------------------------------------
1 | [data-md-color-scheme="slate"] {
2 |
3 | --md-primary-fg-color: #2e303e; /* Panel */
4 | --md-default-bg-color: #2e303e; /* BG */
5 | --md-typeset-a-color: #8e91aa;
6 | --md-accent-fg-color: #658eda
7 | }
--------------------------------------------------------------------------------
/scripts/publish_docs.sh:
--------------------------------------------------------------------------------
1 | pydoc-markdown
2 | cd docs/build
3 |
4 | remote_repo="https://x-access-token:${GITHUB_TOKEN}@${GITHUB_DOMAIN:-"github.com"}/${GITHUB_REPOSITORY}.git"
5 | git remote rm origin
6 | git remote add origin "${remote_repo}"
7 | mkdocs gh-deploy --force
--------------------------------------------------------------------------------
/tests/typing/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 | from beanie import Document
4 |
5 |
6 | class Test(Document):
7 | foo: str
8 | bar: str
9 | baz: str
10 |
11 |
12 | class ProjectionTest(BaseModel):
13 | foo: str
14 | bar: int
15 |
--------------------------------------------------------------------------------
/beanie/odm/operators/update/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Mapping
3 |
4 | from beanie.odm.operators import BaseOperator
5 |
6 |
7 | class BaseUpdateOperator(BaseOperator):
8 | @property
9 | @abstractmethod
10 | def query(self) -> Mapping[str, Any]: ...
11 |
--------------------------------------------------------------------------------
/beanie/odm/custom_types/__init__.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
2 |
3 | if IS_PYDANTIC_V2:
4 | from beanie.odm.custom_types.decimal import DecimalAnnotation
5 | else:
6 | from decimal import Decimal as DecimalAnnotation
7 |
8 | __all__ = [
9 | "DecimalAnnotation",
10 | ]
11 |
--------------------------------------------------------------------------------
/.pypirc:
--------------------------------------------------------------------------------
1 | [distutils]
2 | index-servers =
3 | pypi
4 | testpypi
5 |
6 | [pypi]
7 | repository = https://upload.pypi.org/legacy/
8 | username = __token__
9 | password = ${PYPI_TOKEN}
10 |
11 | [testpypi]
12 | repository = https://test.pypi.org/legacy/
13 | username = roman-right
14 | password = =$C[wT}^]5EWvX(p#9Po
--------------------------------------------------------------------------------
/beanie/odm/interfaces/detector.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class ModelType(str, Enum):
5 | Document = "Document"
6 | View = "View"
7 | UnionDoc = "UnionDoc"
8 |
9 |
10 | class DetectionInterface:
11 | @classmethod
12 | def get_model_type(cls) -> ModelType:
13 | return ModelType.Document
14 |
--------------------------------------------------------------------------------
/beanie/odm/custom_types/decimal.py:
--------------------------------------------------------------------------------
1 | import decimal
2 |
3 | import bson
4 | import pydantic
5 | from typing_extensions import Annotated
6 |
7 | DecimalAnnotation = Annotated[
8 | decimal.Decimal,
9 | pydantic.BeforeValidator(
10 | lambda v: v.to_decimal() if isinstance(v, bson.Decimal128) else v
11 | ),
12 | ]
13 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_pydantic_config.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic import ValidationError
3 |
4 | from tests.odm.models import DocumentWithPydanticConfig
5 |
6 |
7 | def test_pydantic_config():
8 | doc = DocumentWithPydanticConfig(num_1=2)
9 | with pytest.raises(ValidationError):
10 | doc.num_1 = "wrong"
11 |
--------------------------------------------------------------------------------
/beanie/odm/settings/view.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Type, Union
2 |
3 | from pydantic import Field
4 |
5 | from beanie.odm.settings.base import ItemSettings
6 |
7 |
8 | class ViewSettings(ItemSettings):
9 | source: Union[str, Type]
10 | pipeline: List[Dict[str, Any]]
11 |
12 | max_nesting_depths_per_field: dict = Field(default_factory=dict)
13 | max_nesting_depth: int = 3
14 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/charliermarsh/ruff-pre-commit
3 | rev: v0.11.6
4 | hooks:
5 | - id: ruff
6 | args: [ --fix ]
7 | - id: ruff-format
8 | - repo: https://github.com/pre-commit/mirrors-mypy
9 | rev: v1.15.0
10 | hooks:
11 | - id: mypy
12 | additional_dependencies:
13 | - types-click
14 | exclude: ^tests/
15 |
--------------------------------------------------------------------------------
/beanie/odm/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import pymongo
4 |
5 |
6 | class SortDirection(int, Enum):
7 | """
8 | Sorting directions
9 | """
10 |
11 | ASCENDING = pymongo.ASCENDING
12 | DESCENDING = pymongo.DESCENDING
13 |
14 |
15 | class InspectionStatuses(str, Enum):
16 | """
17 | Statuses of the collection inspection
18 | """
19 |
20 | FAIL = "FAIL"
21 | OK = "OK"
22 |
--------------------------------------------------------------------------------
/tests/odm/test_cursor.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import DocumentTestModel
2 |
3 |
4 | async def test_to_list(documents):
5 | await documents(10)
6 | result = await DocumentTestModel.find_all().to_list()
7 | assert len(result) == 10
8 |
9 |
10 | async def test_async_for(documents):
11 | await documents(10)
12 | async for document in DocumentTestModel.find_all():
13 | assert document.test_int in list(range(10))
14 |
--------------------------------------------------------------------------------
/tests/migrations/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from beanie import init_beanie
4 | from beanie.migrations.models import MigrationLog
5 |
6 |
7 | @pytest.fixture(autouse=True)
8 | async def init(db):
9 | await init_beanie(
10 | database=db,
11 | document_models=[
12 | MigrationLog,
13 | ],
14 | )
15 |
16 |
17 | @pytest.fixture(autouse=True)
18 | async def remove_migrations_log(db, init):
19 | await MigrationLog.delete_all()
20 |
--------------------------------------------------------------------------------
/.github/workflows/github-actions-publish-docs.yml:
--------------------------------------------------------------------------------
1 | name: Publish docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | jobs:
7 | publish_docs:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v3
11 | - uses: actions/setup-python@v2
12 | with:
13 | python-version: 3.10.9
14 | - name: install dependencies
15 | run: pip3 install .[doc]
16 | - name: publish docs
17 | run: bash scripts/publish_docs.sh
--------------------------------------------------------------------------------
/.github/workflows/github-actions-publish-project.yml:
--------------------------------------------------------------------------------
1 | name: Publish project
2 | on:
3 | push:
4 | branches:
5 | - main
6 | jobs:
7 | publish_project:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v3
11 | - name: install flit
12 | run: pip3 install flit
13 | - name: publish project
14 | env:
15 | FLIT_USERNAME: __token__
16 | FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }}
17 | run: flit publish
--------------------------------------------------------------------------------
/beanie/migrations/controllers/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Type
3 |
4 | from beanie.odm.documents import Document
5 |
6 |
7 | class BaseMigrationController(ABC):
8 | def __init__(self, function):
9 | self.function = function
10 |
11 | @abstractmethod
12 | async def run(self, session):
13 | pass
14 |
15 | @property
16 | @abstractmethod
17 | def models(self) -> List[Type[Document]]:
18 | pass
19 |
--------------------------------------------------------------------------------
/beanie/migrations/database.py:
--------------------------------------------------------------------------------
1 | from pymongo import AsyncMongoClient
2 |
3 |
4 | class DBHandler:
5 | @classmethod
6 | def set_db(cls, uri, db_name):
7 | cls.client = AsyncMongoClient(uri)
8 | cls.database = cls.client[db_name]
9 |
10 | @classmethod
11 | def get_cli(cls):
12 | return cls.client if hasattr(cls, "client") else None
13 |
14 | @classmethod
15 | def get_db(cls):
16 | return cls.database if hasattr(cls, "database") else None
17 |
--------------------------------------------------------------------------------
/tests/odm/test_deprecated.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from beanie import init_beanie
4 | from beanie.exceptions import Deprecation
5 | from tests.odm.models import DocWithCollectionInnerClass
6 |
7 |
8 | class TestDeprecations:
9 | async def test_doc_with_inner_collection_class_init(self, db):
10 | with pytest.raises(Deprecation):
11 | await init_beanie(
12 | database=db,
13 | document_models=[DocWithCollectionInnerClass],
14 | )
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 | contact_links:
3 | - name: Question
4 | url: 'https://github.com/roman-right/beanie/discussions/new?category=question'
5 | about: Ask a question about how to use Beanie using github discussions
6 |
7 | - name: Feature Request
8 | url: 'https://github.com/roman-right/beanie/discussions/new?category=feature-request'
9 | about: >
10 | If you think we should add a new feature to Beanie, please start a discussion, once it attracts
11 | wider support, it can be migrated to an issue
--------------------------------------------------------------------------------
/tests/test_beanie.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | if sys.version_info >= (3, 11):
5 | import tomllib
6 | else:
7 | import tomli as tomllib
8 |
9 | from beanie import __version__
10 |
11 |
12 | def parse_version_from_pyproject():
13 | pyproject = Path(__file__).parent.parent / "pyproject.toml"
14 | with pyproject.open("rb") as f:
15 | toml_data = tomllib.load(f)
16 | return toml_data["project"]["version"]
17 |
18 |
19 | def test_version():
20 | assert __version__ == parse_version_from_pyproject()
21 |
--------------------------------------------------------------------------------
/docs/tutorial/delete.md:
--------------------------------------------------------------------------------
1 | # Delete documents
2 |
3 | Beanie supports single and batch deletions:
4 |
5 | ## Single
6 |
7 | ```python
8 |
9 | await Product.find_one(Product.name == "Milka").delete()
10 |
11 | # Or
12 | bar = await Product.find_one(Product.name == "Milka")
13 | await bar.delete()
14 | ```
15 |
16 | ## Many
17 |
18 | ```python
19 |
20 | await Product.find(Product.category.name == "Chocolate").delete()
21 | ```
22 |
23 | ## All
24 |
25 | ```python
26 |
27 | await Product.delete_all()
28 | # Or
29 | await Product.all().delete()
30 |
31 | ```
--------------------------------------------------------------------------------
/beanie/odm/interfaces/inheritance.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | ClassVar,
3 | Dict,
4 | Optional,
5 | Type,
6 | )
7 |
8 |
9 | class InheritanceInterface:
10 | _children: ClassVar[Dict[str, Type]]
11 | _parent: ClassVar[Optional[Type]]
12 | _inheritance_inited: ClassVar[bool]
13 | _class_id: ClassVar[Optional[str]] = None
14 |
15 | @classmethod
16 | def add_child(cls, name: str, clas: Type):
17 | cls._children[name] = clas
18 | if cls._parent is not None:
19 | cls._parent.add_child(name, clas)
20 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[BUG]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | ```python
15 | Please add a code snippet here, that reproduces the problem completely
16 | ```
17 |
18 | **Expected behavior**
19 | A clear and concise description of what you expected to happen.
20 |
21 | **Additional context**
22 | Add any other context about the problem here.
23 |
--------------------------------------------------------------------------------
/beanie/odm/interfaces/session.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pymongo.asynchronous.client_session import AsyncClientSession
4 |
5 |
6 | class SessionMethods:
7 | """
8 | Session methods
9 | """
10 |
11 | def set_session(self, session: Optional[AsyncClientSession] = None):
12 | """
13 | Set session
14 | :param session: Optional[AsyncClientSession] - pymongo session
15 | :return:
16 | """
17 | if session is not None:
18 | self.session: Optional[AsyncClientSession] = session
19 | return self
20 |
--------------------------------------------------------------------------------
/beanie/odm/models.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from pydantic import BaseModel
4 |
5 | from beanie.odm.enums import InspectionStatuses
6 | from beanie.odm.fields import PydanticObjectId
7 |
8 |
9 | class InspectionError(BaseModel):
10 | """
11 | Inspection error details
12 | """
13 |
14 | document_id: PydanticObjectId
15 | error: str
16 |
17 |
18 | class InspectionResult(BaseModel):
19 | """
20 | Collection inspection result
21 | """
22 |
23 | status: InspectionStatuses = InspectionStatuses.OK
24 | errors: List[InspectionError] = []
25 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_exists.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import DocumentTestModel
2 |
3 |
4 | async def test_count_with_filter_query(documents):
5 | await documents(4, "uno", True)
6 | await documents(2, "dos", True)
7 | await documents(1, "cuatro", True)
8 | e = await DocumentTestModel.find_many({"test_str": "dos"}).exists()
9 | assert e is True
10 |
11 | e = await DocumentTestModel.find_one({"test_str": "dos"}).exists()
12 | assert e is True
13 |
14 | e = await DocumentTestModel.find_many({"test_str": "wrong"}).exists()
15 | assert e is False
16 |
--------------------------------------------------------------------------------
/tests/fastapi/test_openapi_schema_generation.py:
--------------------------------------------------------------------------------
1 | from json import dumps
2 |
3 | from fastapi.openapi.utils import get_openapi
4 |
5 | from tests.fastapi.app import app
6 |
7 |
8 | def test_openapi_schema_generation():
9 | openapi_schema_json_str = dumps(
10 | get_openapi(
11 | title=app.title,
12 | version=app.version,
13 | openapi_version=app.openapi_version,
14 | description=app.description,
15 | routes=app.routes,
16 | ),
17 | )
18 |
19 | assert openapi_schema_json_str is not None
20 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_element.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.element import Exists, Type
2 | from tests.odm.models import Sample
3 |
4 |
5 | async def test_exists():
6 | q = Exists(Sample.integer, True)
7 | assert q == {"integer": {"$exists": True}}
8 |
9 | q = Exists(Sample.integer, False)
10 | assert q == {"integer": {"$exists": False}}
11 |
12 | q = Exists(Sample.integer)
13 | assert q == {"integer": {"$exists": True}}
14 |
15 |
16 | async def test_type():
17 | q = Type(Sample.integer, "smth")
18 | assert q == {"integer": {"$type": "smth"}}
19 |
--------------------------------------------------------------------------------
/beanie/odm/operators/update/bitwise.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from beanie.odm.operators.update import BaseUpdateOperator
4 |
5 |
6 | class BaseUpdateBitwiseOperator(BaseUpdateOperator, ABC): ...
7 |
8 |
9 | class Bit(BaseUpdateBitwiseOperator):
10 | """
11 | `$bit` update query operator
12 |
13 | MongoDB doc:
14 |
15 | """
16 |
17 | def __init__(self, expression: dict):
18 | self.expression = expression
19 |
20 | @property
21 | def query(self):
22 | return {"$bit": self.expression}
23 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/many_migrations/20210413170700_3_skip_backward.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_title(self, input_document: Note, output_document: Note):
22 | if input_document.title == "3":
23 | output_document.title = "three"
24 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/many_migrations/20210413170709_3_skip_forward.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Backward:
20 | @iterative_migration()
21 | async def change_title(self, input_document: Note, output_document: Note):
22 | if input_document.title == "three":
23 | output_document.title = "3"
24 |
--------------------------------------------------------------------------------
/tests/odm/custom_types/test_decimal_annotation.py:
--------------------------------------------------------------------------------
1 | from decimal import Decimal
2 |
3 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
4 | from tests.odm.models import DocumentWithDecimalField
5 |
6 |
7 | def test_decimal_deserialize():
8 | m = DocumentWithDecimalField(amt=Decimal("1.4"))
9 | if IS_PYDANTIC_V2:
10 | m_json = m.model_dump_json()
11 | m_from_json = DocumentWithDecimalField.model_validate_json(m_json)
12 | else:
13 | m_json = m.json()
14 | m_from_json = DocumentWithDecimalField.parse_raw(m_json)
15 | assert isinstance(m_from_json.amt, Decimal)
16 | assert m_from_json.amt == Decimal("1.4")
17 |
--------------------------------------------------------------------------------
/beanie/odm/utils/self_validation.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from typing import TYPE_CHECKING, TypeVar
3 |
4 | from typing_extensions import ParamSpec
5 |
6 | if TYPE_CHECKING:
7 | from beanie.odm.documents import AsyncDocMethod, DocType
8 |
9 | P = ParamSpec("P")
10 | R = TypeVar("R")
11 |
12 |
13 | def validate_self_before(
14 | f: "AsyncDocMethod[DocType, P, R]",
15 | ) -> "AsyncDocMethod[DocType, P, R]":
16 | @wraps(f)
17 | async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
18 | await self.validate_self(*args, **kwargs)
19 | return await f(self, *args, **kwargs)
20 |
21 | return wrapper
22 |
--------------------------------------------------------------------------------
/beanie/odm/custom_types/bson/binary.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import bson
4 | import pydantic
5 | from typing_extensions import Annotated
6 |
7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
8 |
9 |
10 | def _to_bson_binary(value: Any) -> bson.Binary:
11 | return value if isinstance(value, bson.Binary) else bson.Binary(value)
12 |
13 |
14 | if IS_PYDANTIC_V2:
15 | BsonBinary = Annotated[
16 | bson.Binary, pydantic.PlainValidator(_to_bson_binary)
17 | ]
18 | else:
19 |
20 | class BsonBinary(bson.Binary): # type: ignore[no-redef]
21 | @classmethod
22 | def __get_validators__(cls):
23 | yield _to_bson_binary
24 |
--------------------------------------------------------------------------------
/tests/typing/aggregation.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | from tests.typing.models import ProjectionTest, Test
4 |
5 |
6 | async def aggregate() -> List[Dict[str, Any]]:
7 | result = await Test.aggregate([]).to_list()
8 | result_2 = await Test.find().aggregate([]).to_list()
9 | return result or result_2
10 |
11 |
12 | async def aggregate_with_projection() -> List[ProjectionTest]:
13 | result = (
14 | await Test.find()
15 | .aggregate([], projection_model=ProjectionTest)
16 | .to_list()
17 | )
18 | result_2 = await Test.aggregate(
19 | [], projection_model=ProjectionTest
20 | ).to_list()
21 | return result or result_2
22 |
--------------------------------------------------------------------------------
/beanie/odm/custom_types/re.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import bson
4 | import pydantic
5 | from typing_extensions import Annotated
6 |
7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
8 |
9 |
10 | def _to_bson_regex(v):
11 | return v.try_compile() if isinstance(v, bson.Regex) else v
12 |
13 |
14 | if IS_PYDANTIC_V2:
15 | Pattern = Annotated[
16 | re.Pattern,
17 | pydantic.BeforeValidator(
18 | lambda v: v.try_compile() if isinstance(v, bson.Regex) else v
19 | ),
20 | ]
21 | else:
22 |
23 | class Pattern(bson.Regex): # type: ignore[no-redef]
24 | @classmethod
25 | def __get_validators__(cls):
26 | yield _to_bson_regex
27 |
--------------------------------------------------------------------------------
/tests/odm/test_root_models.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
2 | from tests.odm.models import DocumentWithRootModelAsAField
3 |
4 | if IS_PYDANTIC_V2:
5 |
6 | class TestRootModels:
7 | async def test_insert(self):
8 | doc = DocumentWithRootModelAsAField(pets=["dog", "cat", "fish"])
9 | await doc.insert()
10 |
11 | new_doc = await DocumentWithRootModelAsAField.get(doc.id)
12 | assert new_doc.pets.root == ["dog", "cat", "fish"]
13 |
14 | collection = DocumentWithRootModelAsAField.get_pymongo_collection()
15 | raw_doc = await collection.find_one({"_id": doc.id})
16 | assert raw_doc["pets"] == ["dog", "cat", "fish"]
17 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pymongo import AsyncMongoClient
3 |
4 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
5 |
6 | if IS_PYDANTIC_V2:
7 | from pydantic_settings import BaseSettings
8 | else:
9 | from pydantic import BaseSettings
10 |
11 |
12 | class Settings(BaseSettings):
13 | mongodb_dsn: str = "mongodb://localhost:27017/beanie_db"
14 | mongodb_db_name: str = "beanie_db"
15 |
16 |
17 | @pytest.fixture
18 | def settings():
19 | return Settings()
20 |
21 |
22 | @pytest.fixture
23 | async def cli(settings):
24 | client = AsyncMongoClient(settings.mongodb_dsn)
25 |
26 | yield client
27 |
28 | await client.close()
29 |
30 |
31 | @pytest.fixture
32 | def db(cli, settings):
33 | return cli[settings.mongodb_db_name]
34 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/change_subfield_value/20210413143405_change_subfield_value.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_color(self, input_document: Note, output_document: Note):
22 | output_document.tag.color = "blue"
23 |
24 |
25 | class Backward:
26 | @iterative_migration()
27 | async def change_title(self, input_document: Note, output_document: Note):
28 | output_document.tag.color = "red"
29 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_count.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import DocumentTestModel
2 |
3 |
4 | async def test_count(documents):
5 | await documents(4, "uno", True)
6 | c = await DocumentTestModel.count()
7 | assert c == 4
8 |
9 |
10 | async def test_count_with_filter_query(documents):
11 | await documents(4, "uno", True)
12 | await documents(2, "dos", True)
13 | await documents(1, "cuatro", True)
14 | c = await DocumentTestModel.find_many({"test_str": "dos"}).count()
15 | assert c == 2
16 |
17 |
18 | async def test_count_with_limit(documents):
19 | await documents(5, "five", True)
20 | c = await DocumentTestModel.find_all().limit(1).count()
21 | assert c == 1
22 | d = await DocumentTestModel.find_all().count()
23 | assert d == 5
24 |
--------------------------------------------------------------------------------
/tests/odm/views.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.fields import Link
2 | from beanie.odm.views import View
3 | from tests.odm.models import DocumentTestModel, DocumentTestModelWithLink
4 |
5 |
6 | class ViewForTest(View):
7 | number: int
8 | string: str
9 |
10 | class Settings:
11 | view_name = "test_view"
12 | source = DocumentTestModel
13 | pipeline = [
14 | {"$match": {"$expr": {"$gt": ["$test_int", 8]}}},
15 | {"$project": {"number": "$test_int", "string": "$test_str"}},
16 | ]
17 |
18 |
19 | class ViewForTestWithLink(View):
20 | link: Link[DocumentTestModel]
21 |
22 | class Settings:
23 | view_name = "test_view_with_link"
24 | source = DocumentTestModelWithLink
25 | pipeline = [{"$set": {"link": "$test_link"}}]
26 |
--------------------------------------------------------------------------------
/beanie/odm/interfaces/getters.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | from pymongo.asynchronous.collection import AsyncCollection
4 |
5 | from beanie.odm.settings.base import ItemSettings
6 |
7 |
8 | class OtherGettersInterface:
9 | @classmethod
10 | @abstractmethod
11 | def get_settings(cls) -> ItemSettings:
12 | pass
13 |
14 | @classmethod
15 | def get_pymongo_collection(cls) -> AsyncCollection:
16 | return cls.get_settings().pymongo_collection
17 |
18 | @classmethod
19 | def get_collection_name(cls) -> str:
20 | return cls.get_settings().name # type: ignore
21 |
22 | @classmethod
23 | def get_bson_encoders(cls):
24 | return cls.get_settings().bson_encoders
25 |
26 | @classmethod
27 | def get_link_fields(cls):
28 | return None
29 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_bitwise.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.bitwise import (
2 | BitsAllClear,
3 | BitsAllSet,
4 | BitsAnyClear,
5 | BitsAnySet,
6 | )
7 | from tests.odm.models import Sample
8 |
9 |
10 | async def test_bits_all_clear():
11 | q = BitsAllClear(Sample.integer, "smth")
12 | assert q == {"integer": {"$bitsAllClear": "smth"}}
13 |
14 |
15 | async def test_bits_all_set():
16 | q = BitsAllSet(Sample.integer, "smth")
17 | assert q == {"integer": {"$bitsAllSet": "smth"}}
18 |
19 |
20 | async def test_any_clear():
21 | q = BitsAnyClear(Sample.integer, "smth")
22 | assert q == {"integer": {"$bitsAnyClear": "smth"}}
23 |
24 |
25 | async def test_any_set():
26 | q = BitsAnySet(Sample.integer, "smth")
27 | assert q == {"integer": {"$bitsAnySet": "smth"}}
28 |
--------------------------------------------------------------------------------
/tests/odm/operators/update/test_array.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.update.array import (
2 | AddToSet,
3 | Pop,
4 | Pull,
5 | PullAll,
6 | Push,
7 | )
8 | from tests.odm.models import Sample
9 |
10 |
11 | def test_add_to_set():
12 | q = AddToSet({Sample.integer: 2})
13 | assert q == {"$addToSet": {"integer": 2}}
14 |
15 |
16 | def test_pop():
17 | q = Pop({Sample.integer: 2})
18 | assert q == {"$pop": {"integer": 2}}
19 |
20 |
21 | def test_pull():
22 | q = Pull({Sample.integer: 2})
23 | assert q == {"$pull": {"integer": 2}}
24 |
25 |
26 | def test_push():
27 | q = Push({Sample.integer: 2})
28 | assert q == {"$push": {"integer": 2}}
29 |
30 |
31 | def test_pull_all():
32 | q = PullAll({Sample.integer: 2})
33 | assert q == {"$pullAll": {"integer": 2}}
34 |
--------------------------------------------------------------------------------
/beanie/migrations/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from enum import Enum
3 | from typing import List, Optional
4 |
5 | from pydantic import Field
6 | from pydantic.main import BaseModel
7 |
8 | from beanie.odm.documents import Document
9 |
10 |
11 | class MigrationLog(Document):
12 | ts: datetime = Field(default_factory=datetime.now)
13 | name: str
14 | is_current: bool
15 |
16 | class Settings:
17 | name = "migrations_log"
18 |
19 |
20 | class RunningDirections(str, Enum):
21 | FORWARD = "FORWARD"
22 | BACKWARD = "BACKWARD"
23 |
24 |
25 | class RunningMode(BaseModel):
26 | direction: RunningDirections
27 | distance: int = 0
28 |
29 |
30 | class ParsedMigrations(BaseModel):
31 | path: str
32 | names: List[str]
33 | current: Optional[MigrationLog] = None
34 |
--------------------------------------------------------------------------------
/beanie/odm/interfaces/setters.py:
--------------------------------------------------------------------------------
1 | from typing import ClassVar, Optional
2 |
3 | from beanie.odm.settings.document import DocumentSettings
4 |
5 |
6 | class SettersInterface:
7 | _document_settings: ClassVar[Optional[DocumentSettings]]
8 |
9 | @classmethod
10 | def set_collection(cls, collection):
11 | """
12 | Collection setter
13 | """
14 | cls._document_settings.pymongo_collection = collection
15 |
16 | @classmethod
17 | def set_database(cls, database):
18 | """
19 | Database setter
20 | """
21 | cls._document_settings.pymongo_db = database
22 |
23 | @classmethod
24 | def set_collection_name(cls, name: str):
25 | """
26 | Collection name setter
27 | """
28 | cls._document_settings.name = name # type: ignore
29 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/many_migrations/20210413170640_1.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_title(self, input_document: Note, output_document: Note):
22 | if input_document.title == "1":
23 | output_document.title = "one"
24 |
25 |
26 | class Backward:
27 | @iterative_migration()
28 | async def change_title(self, input_document: Note, output_document: Note):
29 | if input_document.title == "one":
30 | output_document.title = "1"
31 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/many_migrations/20210413170645_2.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_title(self, input_document: Note, output_document: Note):
22 | if input_document.title == "2":
23 | output_document.title = "two"
24 |
25 |
26 | class Backward:
27 | @iterative_migration()
28 | async def change_title(self, input_document: Note, output_document: Note):
29 | if input_document.title == "two":
30 | output_document.title = "2"
31 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/change_value/20210413115234_change_value.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_title(self, input_document: Note, output_document: Note):
22 | if input_document.title == "5":
23 | output_document.title = "five"
24 |
25 |
26 | class Backward:
27 | @iterative_migration()
28 | async def change_title(self, input_document: Note, output_document: Note):
29 | if input_document.title == "five":
30 | output_document.title = "5"
31 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_array.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.array import All, ElemMatch, Size
2 | from tests.odm.models import PackageElemMatch, Sample
3 |
4 |
5 | async def test_all():
6 | q = All(Sample.integer, [1, 2, 3])
7 | assert q == {"integer": {"$all": [1, 2, 3]}}
8 |
9 |
10 | async def test_elem_match():
11 | q = ElemMatch(Sample.integer, {"a": "b"})
12 | assert q == {"integer": {"$elemMatch": {"a": "b"}}}
13 |
14 |
15 | async def test_size():
16 | q = Size(Sample.integer, 4)
17 | assert q == {"integer": {"$size": 4}}
18 |
19 |
20 | async def test_elem_match_nested():
21 | q = ElemMatch(
22 | PackageElemMatch.releases, major_ver=7, minor_ver=1, build_ver=0
23 | )
24 | assert q == {
25 | "releases": {
26 | "$elemMatch": {"major_ver": 7, "minor_ver": 1, "build_ver": 0}
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/tests/fastapi/models.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from pydantic import Field
4 |
5 | from beanie import Document, Indexed, Link
6 | from beanie.odm.fields import BackLink
7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
8 |
9 |
10 | class WindowAPI(Document):
11 | x: int
12 | y: int
13 |
14 |
15 | class DoorAPI(Document):
16 | t: int = 10
17 |
18 |
19 | class RoofAPI(Document):
20 | r: int = 100
21 |
22 |
23 | class HouseAPI(Document):
24 | windows: List[Link[WindowAPI]]
25 | name: Indexed(str)
26 | height: Indexed(int) = 2
27 |
28 |
29 | class House(Document):
30 | name: str
31 | owner: Link["Person"]
32 |
33 |
34 | class Person(Document):
35 | name: str
36 | house: BackLink[House] = (
37 | Field(json_schema_extra={"original_field": "owner"})
38 | if IS_PYDANTIC_V2
39 | else Field(original_field="owner")
40 | )
41 |
--------------------------------------------------------------------------------
/docs/tutorial/lazy_parse.md:
--------------------------------------------------------------------------------
1 | ## Using Lazy Parsing in Queries
2 | Lazy parsing allows you to skip the parsing and validation process for documents and instead call it on demand for each field separately. This can be useful for optimizing performance in certain scenarios.
3 |
4 | To use lazy parsing in your queries, you can pass the `lazy_parse=True` parameter to your find method.
5 |
6 | Here's an example of how to use lazy parsing in a find query:
7 |
8 | ```python
9 | await Sample.find(Sample.number == 10, lazy_parse=True).to_list()
10 | ```
11 |
12 | By setting lazy_parse=True, the parsing and validation process will be skipped and be called on demand when the respective fields will be used. This can potentially improve the performance of your query by reducing the amount of processing required upfront. However, keep in mind that using lazy parsing may also introduce some additional overhead when accessing the fields later on.
--------------------------------------------------------------------------------
/tests/fastapi/app.py:
--------------------------------------------------------------------------------
1 | from contextlib import asynccontextmanager
2 |
3 | from fastapi import FastAPI
4 | from pymongo import AsyncMongoClient
5 |
6 | from beanie import init_beanie
7 | from tests.conftest import Settings
8 | from tests.fastapi.models import (
9 | DoorAPI,
10 | House,
11 | HouseAPI,
12 | Person,
13 | RoofAPI,
14 | WindowAPI,
15 | )
16 | from tests.fastapi.routes import house_router
17 |
18 |
19 | @asynccontextmanager
20 | async def live_span(_: FastAPI):
21 | # CREATE ASYNC PYMONGO CLIENT
22 | client = AsyncMongoClient(Settings().mongodb_dsn)
23 |
24 | # INIT BEANIE
25 | await init_beanie(
26 | client.beanie_db,
27 | document_models=[House, Person, HouseAPI, WindowAPI, DoorAPI, RoofAPI],
28 | )
29 |
30 | yield
31 |
32 |
33 | app = FastAPI(lifespan=live_span)
34 |
35 | # ADD ROUTES
36 | app.include_router(house_router, prefix="/v1", tags=["house"])
37 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/remove_index/20210414135045_remove_index.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class OldNote(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 | indexes = ["title"]
18 |
19 |
20 | class Note(Document):
21 | title: str
22 | tag: Tag
23 |
24 | class Settings:
25 | name = "notes"
26 | indexes = [
27 | "_id",
28 | ]
29 |
30 |
31 | class Forward:
32 | @iterative_migration()
33 | async def name_to_title(
34 | self, input_document: OldNote, output_document: Note
35 | ): ...
36 |
37 |
38 | class Backward:
39 | @iterative_migration()
40 | async def title_to_name(
41 | self, input_document: Note, output_document: OldNote
42 | ): ...
43 |
--------------------------------------------------------------------------------
/docs/tutorial/on_save_validation.md:
--------------------------------------------------------------------------------
1 | # On save validation
2 |
3 | Pydantic has a very useful config to validate values on assignment - `validate_assignment = True`.
4 | But, unfortunately, this is an expensive operation and doesn't fit some use cases.
5 | You can validate all the values before saving the document (`insert`, `replace`, `save`, `save_changes`)
6 | with beanie config `validate_on_save` instead.
7 |
8 | This feature must be turned on in the `Settings` inner class explicitly:
9 |
10 | ```python
11 | class Sample(Document):
12 | num: int
13 | name: str
14 |
15 | class Settings:
16 | validate_on_save = True
17 | ```
18 |
19 | If any field has a wrong value,
20 | it will raise an error on write operations (`insert`, `replace`, `save`, `save_changes`).
21 |
22 | ```python
23 | sample = Sample.find_one(Sample.name == "Test")
24 | sample.num = "wrong value type"
25 |
26 | # Next call will raise an error
27 | await sample.replace()
28 | ```
29 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/rename_field/20210407203225_rename_field.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class OldNote(Document):
12 | name: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Note(Document):
20 | title: str
21 | tag: Tag
22 |
23 | class Settings:
24 | name = "notes"
25 |
26 |
27 | class Forward:
28 | @iterative_migration()
29 | async def name_to_title(
30 | self, input_document: OldNote, output_document: Note
31 | ):
32 | output_document.title = input_document.name
33 |
34 |
35 | class Backward:
36 | @iterative_migration()
37 | async def title_to_name(
38 | self, input_document: Note, output_document: OldNote
39 | ):
40 | output_document.name = input_document.title
41 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_pydantic_extras.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.odm.models import (
4 | DocumentWithExtras,
5 | DocumentWithExtrasKw,
6 | DocumentWithPydanticConfig,
7 | )
8 |
9 |
10 | async def test_pydantic_extras():
11 | doc = DocumentWithExtras(num_1=2)
12 | doc.extra_value = "foo"
13 | await doc.save()
14 |
15 | loaded_doc = await DocumentWithExtras.get(doc.id)
16 |
17 | assert loaded_doc.extra_value == "foo"
18 |
19 |
20 | @pytest.mark.skip(reason="setting extra to allow via class kwargs not working")
21 | async def test_pydantic_extras_kw():
22 | doc = DocumentWithExtrasKw(num_1=2)
23 | doc.extra_value = "foo"
24 | await doc.save()
25 |
26 | loaded_doc = await DocumentWithExtras.get(doc.id)
27 |
28 | assert loaded_doc.extra_value == "foo"
29 |
30 |
31 | async def test_fail_with_no_extras():
32 | doc = DocumentWithPydanticConfig(num_1=2)
33 | with pytest.raises(ValueError):
34 | doc.extra_value = "foo"
35 |
--------------------------------------------------------------------------------
/tests/odm/custom_types/test_bson_binary.py:
--------------------------------------------------------------------------------
1 | import bson
2 | import pytest
3 |
4 | from beanie import BsonBinary
5 | from beanie.odm.utils.pydantic import get_model_dump, parse_model
6 | from tests.odm.models import DocumentWithBsonBinaryField
7 |
8 |
9 | @pytest.mark.parametrize("binary_field", [bson.Binary(b"test"), b"test"])
10 | async def test_bson_binary(binary_field):
11 | doc = DocumentWithBsonBinaryField(binary_field=binary_field)
12 | await doc.insert()
13 | assert doc.binary_field == BsonBinary(b"test")
14 |
15 | new_doc = await DocumentWithBsonBinaryField.get(doc.id)
16 | assert new_doc.binary_field == BsonBinary(b"test")
17 |
18 |
19 | @pytest.mark.parametrize("binary_field", [bson.Binary(b"test"), b"test"])
20 | def test_bson_binary_roundtrip(binary_field):
21 | doc = DocumentWithBsonBinaryField(binary_field=binary_field)
22 | doc_dict = get_model_dump(doc)
23 | new_doc = parse_model(DocumentWithBsonBinaryField, doc_dict)
24 | assert new_doc == doc
25 |
--------------------------------------------------------------------------------
/beanie/odm/operators/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from collections.abc import Mapping
3 | from copy import copy, deepcopy
4 | from typing import Any, Dict
5 | from typing import Mapping as MappingType
6 |
7 |
8 | class BaseOperator(Mapping):
9 | """
10 | Base operator.
11 | """
12 |
13 | @property
14 | @abstractmethod
15 | def query(self) -> MappingType[str, Any]: ...
16 |
17 | def __getitem__(self, item: str):
18 | return self.query[item]
19 |
20 | def __iter__(self):
21 | return iter(self.query)
22 |
23 | def __len__(self):
24 | return len(self.query)
25 |
26 | def __repr__(self):
27 | return repr(self.query)
28 |
29 | def __str__(self):
30 | return str(self.query)
31 |
32 | def __copy__(self):
33 | return copy(self.query)
34 |
35 | def __deepcopy__(self, memodict: Dict[str, Any] = {}):
36 | return deepcopy(self.query)
37 |
38 | def copy(self):
39 | return copy(self)
40 |
--------------------------------------------------------------------------------
/beanie/odm/registry.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, ForwardRef, Type, Union
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class DocsRegistry:
7 | _registry: Dict[str, Type[BaseModel]] = {}
8 |
9 | @classmethod
10 | def register(cls, name: str, doc_type: Type[BaseModel]):
11 | cls._registry[name] = doc_type
12 |
13 | @classmethod
14 | def get(cls, name: str) -> Type[BaseModel]:
15 | return cls._registry[name]
16 |
17 | @classmethod
18 | def evaluate_fr(cls, forward_ref: Union[ForwardRef, Type]):
19 | """
20 | Evaluate forward ref
21 |
22 | :param forward_ref: ForwardRef - forward ref to evaluate
23 | :return: Type[BaseModel] - class of the forward ref
24 | """
25 | if (
26 | isinstance(forward_ref, ForwardRef)
27 | and forward_ref.__forward_arg__ in cls._registry
28 | ):
29 | return cls._registry[forward_ref.__forward_arg__]
30 | else:
31 | return forward_ref
32 |
--------------------------------------------------------------------------------
/.github/workflows/close_inactive_issues.yml:
--------------------------------------------------------------------------------
1 | name: Close inactive issues
2 | on:
3 | schedule:
4 | - cron: "30 1 * * *"
5 |
6 | jobs:
7 | close-issues:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | issues: write
11 | pull-requests: write
12 | steps:
13 | - uses: actions/stale@v5
14 | with:
15 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity.'
16 | stale-pr-message: 'This PR is stale because it has been open 45 days with no activity.'
17 | close-issue-message: 'This issue was closed because it has been stalled for 14 days with no activity.'
18 | close-pr-message: 'This PR was closed because it has been stalled for 14 days with no activity.'
19 | exempt-issue-labels: 'bug,feature-request,typing bug,feature request,doc,documentation'
20 | days-before-issue-stale: 30
21 | days-before-pr-stale: 45
22 | days-before-issue-close: 14
23 | days-before-pr-close: 14
--------------------------------------------------------------------------------
/tests/fastapi/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from asgi_lifespan import LifespanManager
3 | from httpx import ASGITransport, AsyncClient
4 |
5 | from tests.fastapi.app import app
6 | from tests.fastapi.models import (
7 | DoorAPI,
8 | House,
9 | HouseAPI,
10 | Person,
11 | RoofAPI,
12 | WindowAPI,
13 | )
14 |
15 |
16 | @pytest.fixture(autouse=True)
17 | async def api_client(clean_db):
18 | """api client fixture."""
19 | async with LifespanManager(app, startup_timeout=100, shutdown_timeout=100):
20 | server_name = "http://localhost"
21 | async with AsyncClient(
22 | transport=ASGITransport(app=app), base_url=server_name
23 | ) as ac:
24 | yield ac
25 |
26 |
27 | @pytest.fixture(autouse=True)
28 | async def clean_db(db):
29 | models = [House, Person, HouseAPI, WindowAPI, DoorAPI, RoofAPI]
30 | yield None
31 |
32 | for model in models:
33 | await model.get_pymongo_collection().drop()
34 | await model.get_pymongo_collection().drop_indexes()
35 |
--------------------------------------------------------------------------------
/tests/typing/find.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from tests.typing.models import ProjectionTest, Test
4 |
5 |
6 | async def find_many() -> List[Test]:
7 | return await Test.find().to_list()
8 |
9 |
10 | async def find_many_with_projection() -> List[ProjectionTest]:
11 | return await Test.find().project(projection_model=ProjectionTest).to_list()
12 |
13 |
14 | async def find_many_generator() -> List[Test]:
15 | docs: List[Test] = []
16 | async for doc in Test.find():
17 | docs.append(doc)
18 | return docs
19 |
20 |
21 | async def find_many_generator_with_projection() -> List[ProjectionTest]:
22 | docs: List[ProjectionTest] = []
23 | async for doc in Test.find().project(projection_model=ProjectionTest):
24 | docs.append(doc)
25 | return docs
26 |
27 |
28 | async def find_one() -> Optional[Test]:
29 | return await Test.find_one()
30 |
31 |
32 | async def find_one_with_projection() -> Optional[ProjectionTest]:
33 | return await Test.find_one().project(projection_model=ProjectionTest)
34 |
--------------------------------------------------------------------------------
/tests/odm/test_views.py:
--------------------------------------------------------------------------------
1 | from tests.odm.views import ViewForTest, ViewForTestWithLink
2 |
3 |
4 | class TestViews:
5 | async def test_simple(self, documents):
6 | await documents(number=15)
7 | results = await ViewForTest.all().to_list()
8 | assert len(results) == 6
9 |
10 | async def test_aggregate(self, documents):
11 | await documents(number=15)
12 | results = await ViewForTest.aggregate(
13 | [
14 | {"$set": {"test_field": 1}},
15 | {"$match": {"$expr": {"$lt": ["$number", 12]}}},
16 | ]
17 | ).to_list()
18 | assert len(results) == 3
19 | assert results[0]["test_field"] == 1
20 |
21 | async def test_link(self, documents_with_links):
22 | await documents_with_links()
23 | results = await ViewForTestWithLink.all().to_list()
24 | for document in results:
25 | await document.fetch_all_links()
26 |
27 | for i, document in enumerate(results):
28 | assert document.link.test_int == i
29 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/change_subfield/20210413152406_change_subfield.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class OldTag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Tag(BaseModel):
12 | color: str
13 | title: str
14 |
15 |
16 | class OldNote(Document):
17 | title: str
18 | tag: OldTag
19 |
20 | class Settings:
21 | name = "notes"
22 |
23 |
24 | class Note(Document):
25 | title: str
26 | tag: Tag
27 |
28 | class Settings:
29 | name = "notes"
30 |
31 |
32 | class Forward:
33 | @iterative_migration()
34 | async def change_color(
35 | self, input_document: OldNote, output_document: Note
36 | ):
37 | output_document.tag.title = input_document.tag.name
38 |
39 |
40 | class Backward:
41 | @iterative_migration()
42 | async def change_title(
43 | self, input_document: Note, output_document: OldNote
44 | ):
45 | output_document.tag.name = input_document.tag.title
46 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_inspect.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.models import InspectionStatuses
2 | from tests.odm.models import DocumentTestModel, DocumentTestModelFailInspection
3 |
4 |
5 | async def test_inspect_ok(documents):
6 | await documents(10, "smth")
7 | result = await DocumentTestModel.inspect_collection()
8 | assert result.status == InspectionStatuses.OK
9 | assert result.errors == []
10 |
11 |
12 | async def test_inspect_fail(documents):
13 | await documents(10, "smth")
14 | result = await DocumentTestModelFailInspection.inspect_collection()
15 | assert result.status == InspectionStatuses.FAIL
16 | assert len(result.errors) == 10
17 | assert (
18 | "1 validation error for DocumentTestModelFailInspection"
19 | in result.errors[0].error
20 | )
21 |
22 |
23 | async def test_inspect_ok_with_session(documents, session):
24 | await documents(10, "smth")
25 | result = await DocumentTestModel.inspect_collection(session=session)
26 | assert result.status == InspectionStatuses.OK
27 | assert result.errors == []
28 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/pack_unpack/20210413135927_pack_unpack.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class OldNote(Document):
12 | title: str
13 | tag_name: str
14 | tag_color: str
15 |
16 | class Settings:
17 | name = "notes"
18 |
19 |
20 | class Note(Document):
21 | title: str
22 | tag: Tag
23 |
24 | class Settings:
25 | name = "notes"
26 |
27 |
28 | class Forward:
29 | @iterative_migration()
30 | async def pack(self, input_document: OldNote, output_document: Note):
31 | output_document.tag = Tag(
32 | name=input_document.tag_name, color=input_document.tag_color
33 | )
34 |
35 |
36 | class Backward:
37 | @iterative_migration()
38 | async def unpack(self, input_document: Note, output_document: OldNote):
39 | output_document.tag_name = input_document.tag.name
40 | output_document.tag_color = input_document.tag.color
41 |
--------------------------------------------------------------------------------
/beanie/migrations/controllers/free_fall.py:
--------------------------------------------------------------------------------
1 | from inspect import signature
2 | from typing import Any, List, Type
3 |
4 | from beanie.migrations.controllers.base import BaseMigrationController
5 | from beanie.odm.documents import Document
6 |
7 |
8 | def free_fall_migration(document_models: List[Type[Document]]):
9 | class FreeFallMigrationController(BaseMigrationController):
10 | def __init__(self, function):
11 | self.function = function
12 | self.function_signature = signature(function)
13 | self.document_models = document_models
14 |
15 | def __call__(self, *args: Any, **kwargs: Any):
16 | pass
17 |
18 | @property
19 | def models(self) -> List[Type[Document]]:
20 | return self.document_models
21 |
22 | async def run(self, session):
23 | function_kwargs = {"session": session}
24 | if "self" in self.function_signature.parameters:
25 | function_kwargs["self"] = None
26 | await self.function(**function_kwargs)
27 |
28 | return FreeFallMigrationController
29 |
--------------------------------------------------------------------------------
/docs/tutorial/aggregate.md:
--------------------------------------------------------------------------------
1 | # Aggregations
2 |
3 | You can perform aggregation queries through beanie as well. For example, to calculate the average:
4 |
5 | ```python
6 | # With a search:
7 | avg_price = await Product.find(
8 | Product.category.name == "Chocolate"
9 | ).avg(Product.price)
10 |
11 | # Over the whole collection:
12 | avg_price = await Product.avg(Product.price)
13 | ```
14 |
15 | A full list of available methods can be found [here](../api-documentation/interfaces.md/#aggregatemethods).
16 |
17 | You can also use the native PyMongo syntax by calling the `aggregate` method.
18 | However, as Beanie will not know what output to expect, you will have to supply a projection model yourself.
19 | If you do not supply a projection model, then a dictionary will be returned.
20 |
21 | ```python
22 | class OutputItem(BaseModel):
23 | id: str = Field(None, alias="_id")
24 | total: float
25 |
26 |
27 | result = await Product.find(
28 | Product.category.name == "Chocolate").aggregate(
29 | [{"$group": {"_id": "$category.name", "total": {"$avg": "$price"}}}],
30 | projection_model=OutputItem
31 | ).to_list()
32 |
33 | ```
34 |
--------------------------------------------------------------------------------
/.github/workflows/github-actions-tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 | on:
3 | pull_request:
4 |
5 | jobs:
6 | run-tests:
7 | strategy:
8 | fail-fast: false
9 | matrix:
10 | python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
11 | mongodb-version: [ "4.4", "5.0", "6.0", "7.0", "8.0" ]
12 | pydantic-version: [ "1.10.18", "2.9.2" , "2.10.4", "2.11"]
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | - uses: actions/setup-python@v5
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | cache: pip
20 | cache-dependency-path: pyproject.toml
21 | - name: Start MongoDB
22 | uses: supercharge/mongodb-github-action@1.11.0
23 | with:
24 | mongodb-version: ${{ matrix.mongodb-version }}
25 | mongodb-replica-set: test-rs
26 | - name: install dependencies
27 | run: pip install .[test,ci]
28 | - name: install pydantic
29 | run: pip install pydantic==${{ matrix.pydantic-version }}
30 | - name: run tests
31 | env:
32 | PYTHON_JIT: 1
33 | run: pytest -v
34 |
--------------------------------------------------------------------------------
/tests/odm/test_concurrency.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from pymongo import AsyncMongoClient
4 |
5 | from beanie import Document, init_beanie
6 |
7 |
8 | class SampleModel(Document):
9 | s: str = "TEST"
10 | i: int = 10
11 |
12 |
13 | class SampleModel2(SampleModel): ...
14 |
15 |
16 | class SampleModel3(SampleModel2): ...
17 |
18 |
19 | class TestConcurrency:
20 | async def test_without_init(self, settings):
21 | clients = []
22 | for _ in range(10):
23 | client = AsyncMongoClient(settings.mongodb_dsn)
24 | clients.append(client)
25 | db = client[settings.mongodb_db_name]
26 | await init_beanie(
27 | db, document_models=[SampleModel3, SampleModel, SampleModel2]
28 | )
29 |
30 | async def insert_find():
31 | await SampleModel2().insert()
32 | docs = await SampleModel2.find(SampleModel2.i == 10).to_list()
33 | return docs
34 |
35 | await asyncio.gather(*[insert_find() for _ in range(10)])
36 |
37 | await SampleModel2.delete_all()
38 |
39 | [await client.close() for client in clients]
40 |
--------------------------------------------------------------------------------
/docs/tutorial/cache.md:
--------------------------------------------------------------------------------
1 | # Cache
2 | All query results could be locally cached.
3 |
4 | This feature must be explicitly turned on in the `Settings` inner class.
5 |
6 | ```python
7 | class Sample(Document):
8 | num: int
9 | name: str
10 |
11 | class Settings:
12 | use_cache = True
13 | ```
14 |
15 | Beanie uses LRU cache with expiration time.
16 | You can set `capacity` (the maximum number of the cached queries) and expiration time in the `Settings` inner class.
17 |
18 | ```python
19 | class Sample(Document):
20 | num: int
21 | name: str
22 |
23 | class Settings:
24 | use_cache = True
25 | cache_expiration_time = datetime.timedelta(seconds=10)
26 | cache_capacity = 5
27 | ```
28 |
29 | Any query will be cached for this document class.
30 |
31 | ```python
32 | # on the first call it will go to the database
33 | samples = await Sample.find(num>10).to_list()
34 |
35 | # on the second - it will use cache instead
36 | samples = await Sample.find(num>10).to_list()
37 |
38 | await asyncio.sleep(15)
39 |
40 | # if the expiration time was reached it will go to the database again
41 | samples = await Sample.find(num>10).to_list()
42 | ```
--------------------------------------------------------------------------------
/docs/tutorial/time_series.md:
--------------------------------------------------------------------------------
1 | # Time series
2 |
3 | You can set up a timeseries collection using the inner `Settings` class.
4 |
5 | **Be aware, timeseries collections a supported by MongoDB 5.0 and higher only. The fields `bucket_max_span_seconds` and `bucket_rounding_seconds` however require MongoDB 6.3 or higher**
6 |
7 | ```python
8 | from datetime import datetime
9 |
10 | from beanie import Document, TimeSeriesConfig, Granularity
11 | from pydantic import Field
12 |
13 |
14 | class Sample(Document):
15 | ts: datetime = Field(default_factory=datetime.now)
16 | meta: str
17 |
18 | class Settings:
19 | timeseries = TimeSeriesConfig(
20 | time_field="ts", # Required
21 | meta_field="meta", # Optional
22 | granularity=Granularity.hours, # Optional
23 | bucket_max_span_seconds=3600, # Optional
24 | bucket_rounding_seconds=3600, # Optional
25 | expire_after_seconds=2 # Optional
26 | )
27 | ```
28 |
29 | TimeSeriesConfig fields reflect the respective parameters of the MongoDB timeseries creation function.
30 |
31 | MongoDB documentation: https://docs.mongodb.com/manual/core/timeseries-collections/
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/break/20210413211219_break.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, Indexed, PydanticObjectId, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class OldNote(Document):
12 | name: Indexed(str, unique=True)
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Note(Document):
20 | name: Indexed(str, unique=True)
21 | title: str
22 | tag: Tag
23 |
24 | class Settings:
25 | name = "notes"
26 |
27 |
28 | fixed_id = PydanticObjectId("6076f1f3e4b7f6b7a0f6e5a0")
29 |
30 |
31 | class Forward:
32 | @iterative_migration(batch_size=2)
33 | async def name_to_title(
34 | self, input_document: OldNote, output_document: Note
35 | ):
36 | output_document.title = input_document.name
37 | if output_document.title > "5":
38 | output_document.name = "5"
39 |
40 |
41 | class Backward:
42 | @iterative_migration()
43 | async def title_to_name(
44 | self, input_document: Note, output_document: OldNote
45 | ):
46 | output_document.name = input_document.title
47 |
--------------------------------------------------------------------------------
/beanie/odm/settings/base.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from typing import Any, Dict, Optional, Type
3 |
4 | from pydantic import BaseModel, Field
5 | from pymongo.asynchronous.collection import AsyncCollection
6 | from pymongo.asynchronous.database import AsyncDatabase
7 |
8 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
9 |
10 | if IS_PYDANTIC_V2:
11 | from pydantic import ConfigDict
12 |
13 |
14 | class ItemSettings(BaseModel):
15 | name: Optional[str] = None
16 |
17 | use_cache: bool = False
18 | cache_capacity: int = 32
19 | cache_expiration_time: timedelta = timedelta(minutes=10)
20 | bson_encoders: Dict[Any, Any] = Field(default_factory=dict)
21 | projection: Optional[Dict[str, Any]] = None
22 |
23 | pymongo_db: Optional[AsyncDatabase] = None
24 | pymongo_collection: Optional[AsyncCollection] = None
25 |
26 | union_doc: Optional[Type] = None
27 | union_doc_alias: Optional[str] = None
28 | class_id: str = "_class_id"
29 |
30 | is_root: bool = False
31 |
32 | if IS_PYDANTIC_V2:
33 | model_config = ConfigDict(
34 | arbitrary_types_allowed=True,
35 | )
36 | else:
37 |
38 | class Config:
39 | arbitrary_types_allowed = True
40 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/free_fall/20210413210446_free_fall.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, free_fall_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class OldNote(Document):
12 | name: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Note(Document):
20 | title: str
21 | tag: Tag
22 |
23 | class Settings:
24 | name = "notes"
25 |
26 |
27 | class Forward:
28 | @free_fall_migration(document_models=[OldNote, Note])
29 | async def name_to_title(self, session):
30 | async for old_note in OldNote.find_all():
31 | new_note = Note(
32 | id=old_note.id, title=old_note.name, tag=old_note.tag
33 | )
34 | await new_note.replace(session=session)
35 |
36 |
37 | class Backward:
38 | @free_fall_migration(document_models=[OldNote, Note])
39 | async def title_to_name(self, session):
40 | async for old_note in Note.find_all():
41 | new_note = OldNote(
42 | id=old_note.id, name=old_note.title, tag=old_note.tag
43 | )
44 | await new_note.replace(session=session)
45 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_replace.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import Sample
2 |
3 |
4 | async def test_replace_one(preset_documents):
5 | count_1_before = await Sample.find_many(Sample.integer == 1).count()
6 | count_2_before = await Sample.find_many(Sample.integer == 2).count()
7 |
8 | a_2 = await Sample.find_one(Sample.integer == 2)
9 | await Sample.find_one(Sample.integer == 1).replace_one(a_2)
10 |
11 | count_1_after = await Sample.find_many(Sample.integer == 1).count()
12 | count_2_after = await Sample.find_many(Sample.integer == 2).count()
13 |
14 | assert count_1_after == count_1_before - 1
15 | assert count_2_after == count_2_before + 1
16 |
17 |
18 | async def test_replace_self(preset_documents):
19 | count_1_before = await Sample.find_many(Sample.integer == 1).count()
20 | count_2_before = await Sample.find_many(Sample.integer == 2).count()
21 |
22 | a_1 = await Sample.find_one(Sample.integer == 1)
23 | a_1.integer = 2
24 | await a_1.replace()
25 |
26 | count_1_after = await Sample.find_many(Sample.integer == 1).count()
27 | count_2_after = await Sample.find_many(Sample.integer == 2).count()
28 |
29 | assert count_1_after == count_1_before - 1
30 | assert count_2_after == count_2_before + 1
31 |
--------------------------------------------------------------------------------
/beanie/exceptions.py:
--------------------------------------------------------------------------------
1 | class WrongDocumentUpdateStrategy(Exception):
2 | pass
3 |
4 |
5 | class DocumentNotFound(Exception):
6 | pass
7 |
8 |
9 | class DocumentAlreadyCreated(Exception):
10 | pass
11 |
12 |
13 | class DocumentWasNotSaved(Exception):
14 | pass
15 |
16 |
17 | class CollectionWasNotInitialized(Exception):
18 | pass
19 |
20 |
21 | class MigrationException(Exception):
22 | pass
23 |
24 |
25 | class ReplaceError(Exception):
26 | pass
27 |
28 |
29 | class StateManagementIsTurnedOff(Exception):
30 | pass
31 |
32 |
33 | class StateNotSaved(Exception):
34 | pass
35 |
36 |
37 | class RevisionIdWasChanged(Exception):
38 | pass
39 |
40 |
41 | class NotSupported(Exception):
42 | pass
43 |
44 |
45 | class MongoDBVersionError(Exception):
46 | pass
47 |
48 |
49 | class ViewWasNotInitialized(Exception):
50 | pass
51 |
52 |
53 | class ViewHasNoSettings(Exception):
54 | pass
55 |
56 |
57 | class UnionHasNoRegisteredDocs(Exception):
58 | pass
59 |
60 |
61 | class UnionDocNotInited(Exception):
62 | pass
63 |
64 |
65 | class DocWasNotRegisteredInUnionClass(Exception):
66 | pass
67 |
68 |
69 | class Deprecation(Exception):
70 | pass
71 |
72 |
73 | class ApplyChangesException(Exception):
74 | pass
75 |
--------------------------------------------------------------------------------
/beanie/odm/settings/document.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from pydantic import Field
4 |
5 | from beanie.odm.fields import IndexModelField
6 | from beanie.odm.settings.base import ItemSettings
7 | from beanie.odm.settings.timeseries import TimeSeriesConfig
8 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
9 |
10 | if IS_PYDANTIC_V2:
11 | from pydantic import ConfigDict
12 |
13 |
14 | class DocumentSettings(ItemSettings):
15 | use_state_management: bool = False
16 | state_management_replace_objects: bool = False
17 | state_management_save_previous: bool = False
18 | validate_on_save: bool = False
19 | use_revision: bool = False
20 | single_root_inheritance: bool = False
21 |
22 | indexes: List[IndexModelField] = Field(default_factory=list)
23 | merge_indexes: bool = False
24 | timeseries: Optional[TimeSeriesConfig] = None
25 |
26 | lazy_parsing: bool = False
27 |
28 | keep_nulls: bool = True
29 |
30 | max_nesting_depths_per_field: dict = Field(default_factory=dict)
31 | max_nesting_depth: int = 3
32 |
33 | if IS_PYDANTIC_V2:
34 | model_config = ConfigDict(
35 | arbitrary_types_allowed=True,
36 | )
37 | else:
38 |
39 | class Config:
40 | arbitrary_types_allowed = True
41 |
--------------------------------------------------------------------------------
/beanie/odm/utils/projection.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Type, TypeVar
2 |
3 | from pydantic import BaseModel
4 |
5 | from beanie.odm.interfaces.detector import ModelType
6 | from beanie.odm.utils.pydantic import get_config_value, get_model_fields
7 |
8 | ProjectionModelType = TypeVar("ProjectionModelType", bound=BaseModel)
9 |
10 |
11 | def get_projection(
12 | model: Type[ProjectionModelType],
13 | ) -> Optional[Dict[str, int]]:
14 | if hasattr(model, "get_model_type") and (
15 | model.get_model_type() == ModelType.UnionDoc # type: ignore
16 | or ( # type: ignore
17 | model.get_model_type() == ModelType.Document # type: ignore
18 | and model._inheritance_inited # type: ignore
19 | )
20 | ): # type: ignore
21 | return None
22 |
23 | if hasattr(model, "Settings"): # MyPy checks
24 | settings = getattr(model, "Settings")
25 |
26 | if hasattr(settings, "projection"):
27 | return getattr(settings, "projection")
28 |
29 | if get_config_value(model, "extra") == "allow":
30 | return None
31 |
32 | document_projection: Dict[str, int] = {}
33 |
34 | for name, field in get_model_fields(model).items():
35 | document_projection[field.alias or name] = 1
36 | return document_projection
37 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_logical.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from beanie.odm.operators.find.logical import And, Nor, Not, Or
4 | from tests.odm.models import Sample
5 |
6 |
7 | async def test_and():
8 | q = And(Sample.integer == 1)
9 | assert q == {"integer": 1}
10 |
11 | q = And(Sample.integer == 1, Sample.nested.integer > 3)
12 | assert q == {"$and": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]}
13 |
14 |
15 | async def test_not(preset_documents):
16 | q = Not(Sample.integer == 1)
17 | assert q == {"integer": {"$not": {"$eq": 1}}}
18 |
19 | docs = await Sample.find(q).to_list()
20 | assert len(docs) == 7
21 |
22 | with pytest.raises(AttributeError):
23 | q = Not(And(Sample.integer == 1, Sample.nested.integer > 3))
24 | await Sample.find(q).to_list()
25 |
26 |
27 | async def test_nor():
28 | q = Nor(Sample.integer == 1)
29 | assert q == {"$nor": [{"integer": 1}]}
30 |
31 | q = Nor(Sample.integer == 1, Sample.nested.integer > 3)
32 | assert q == {"$nor": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]}
33 |
34 |
35 | async def test_or():
36 | q = Or(Sample.integer == 1)
37 | assert q == {"integer": 1}
38 |
39 | q = Or(Sample.integer == 1, Sample.nested.integer > 3)
40 | assert q == {"$or": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]}
41 |
--------------------------------------------------------------------------------
/docs/tutorial/revision.md:
--------------------------------------------------------------------------------
1 | # Revision
2 |
3 | This feature helps with concurrent operations.
4 | It stores `revision_id` together with the document and changes it on each document update.
5 | If the application with an older local copy of the document tries to change it, an exception will be raised.
6 | Only when the local copy is synced with the database, the application will be allowed to change the data.
7 | This helps to avoid data losses.
8 |
9 | ### Be aware
10 | revision id feature may work incorrectly with BulkWriter.
11 |
12 | ### Usage
13 |
14 | This feature must be explicitly turned on in the `Settings` inner class:
15 |
16 | ```python
17 | class Sample(Document):
18 | num: int
19 | name: str
20 |
21 | class Settings:
22 | use_revision = True
23 | ```
24 |
25 | Any changing operation will check if the local copy of the document has the up-to-date `revision_id` value:
26 |
27 | ```python
28 | s = await Sample.find_one(Sample.name="TestName")
29 | s.num = 10
30 |
31 | # If a concurrent process already changed the doc, the next operation will raise an error
32 | await s.replace()
33 | ```
34 |
35 | If you want to ignore revision and apply all the changes even if the local copy is outdated,
36 | you can use the `ignore_revision` parameter:
37 |
38 | ```python
39 | await s.replace(ignore_revision=True)
40 | ```
41 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/many_migrations/20210413170728_4_5.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_title_4(
22 | self, input_document: Note, output_document: Note
23 | ):
24 | if input_document.title == "4":
25 | output_document.title = "four"
26 |
27 | @iterative_migration()
28 | async def change_title_5(
29 | self, input_document: Note, output_document: Note
30 | ):
31 | if input_document.title == "5":
32 | output_document.title = "five"
33 |
34 |
35 | class Backward:
36 | @iterative_migration()
37 | async def change_title_5(
38 | self, input_document: Note, output_document: Note
39 | ):
40 | if input_document.title == "five":
41 | output_document.title = "5"
42 |
43 | @iterative_migration()
44 | async def change_title_4(
45 | self, input_document: Note, output_document: Note
46 | ):
47 | if input_document.title == "four":
48 | output_document.title = "4"
49 |
--------------------------------------------------------------------------------
/tests/migrations/migrations_for_test/many_migrations/20210413170734_6_7.py:
--------------------------------------------------------------------------------
1 | from pydantic.main import BaseModel
2 |
3 | from beanie import Document, iterative_migration
4 |
5 |
6 | class Tag(BaseModel):
7 | color: str
8 | name: str
9 |
10 |
11 | class Note(Document):
12 | title: str
13 | tag: Tag
14 |
15 | class Settings:
16 | name = "notes"
17 |
18 |
19 | class Forward:
20 | @iterative_migration()
21 | async def change_title_6(
22 | self, input_document: Note, output_document: Note
23 | ):
24 | if input_document.title == "6":
25 | output_document.title = "six"
26 |
27 | @iterative_migration()
28 | async def change_title_7(
29 | self, input_document: Note, output_document: Note
30 | ):
31 | if input_document.title == "7":
32 | output_document.title = "seven"
33 |
34 |
35 | class Backward:
36 | @iterative_migration()
37 | async def change_title_7(
38 | self, input_document: Note, output_document: Note
39 | ):
40 | if input_document.title == "seven":
41 | output_document.title = "7"
42 |
43 | @iterative_migration()
44 | async def change_title_6(
45 | self, input_document: Note, output_document: Note
46 | ):
47 | if input_document.title == "six":
48 | output_document.title = "6"
49 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_distinct.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import DocumentTestModel
2 |
3 |
4 | async def test_distinct_unique(documents, document_not_inserted):
5 | await documents(1, "uno")
6 | await documents(2, "dos")
7 | await documents(3, "cuatro")
8 | expected_result = ["cuatro", "dos", "uno"]
9 | unique_test_strs = await DocumentTestModel.distinct("test_str", {})
10 | assert unique_test_strs == expected_result
11 | document_not_inserted.test_str = "uno"
12 | await document_not_inserted.insert()
13 | another_unique_test_strs = await DocumentTestModel.distinct("test_str", {})
14 | assert another_unique_test_strs == expected_result
15 |
16 |
17 | async def test_distinct_different_value(documents, document_not_inserted):
18 | await documents(1, "uno")
19 | await documents(2, "dos")
20 | await documents(3, "cuatro")
21 | expected_result = ["cuatro", "dos", "uno"]
22 | unique_test_strs = await DocumentTestModel.distinct("test_str", {})
23 | assert unique_test_strs == expected_result
24 | document_not_inserted.test_str = "diff_val"
25 | await document_not_inserted.insert()
26 | another_unique_test_strs = await DocumentTestModel.distinct("test_str", {})
27 | assert not another_unique_test_strs == expected_result
28 | assert another_unique_test_strs == ["cuatro", "diff_val", "dos", "uno"]
29 |
--------------------------------------------------------------------------------
/tests/odm/operators/update/test_general.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.update.general import (
2 | CurrentDate,
3 | Inc,
4 | Max,
5 | Min,
6 | Mul,
7 | Rename,
8 | Set,
9 | SetOnInsert,
10 | Unset,
11 | )
12 | from tests.odm.models import Sample
13 |
14 |
15 | def test_set():
16 | q = Set({Sample.integer: 2})
17 | assert q == {"$set": {"integer": 2}}
18 |
19 |
20 | def test_current_date():
21 | q = CurrentDate({Sample.integer: 2})
22 | assert q == {"$currentDate": {"integer": 2}}
23 |
24 |
25 | def test_inc():
26 | q = Inc({Sample.integer: 2})
27 | assert q == {"$inc": {"integer": 2}}
28 |
29 |
30 | def test_min():
31 | q = Min({Sample.integer: 2})
32 | assert q == {"$min": {"integer": 2}}
33 |
34 |
35 | def test_max():
36 | q = Max({Sample.integer: 2})
37 | assert q == {"$max": {"integer": 2}}
38 |
39 |
40 | def test_mul():
41 | q = Mul({Sample.integer: 2})
42 | assert q == {"$mul": {"integer": 2}}
43 |
44 |
45 | def test_rename():
46 | q = Rename({Sample.integer: 2})
47 | assert q == {"$rename": {"integer": 2}}
48 |
49 |
50 | def test_set_on_insert():
51 | q = SetOnInsert({Sample.integer: 2})
52 | assert q == {"$setOnInsert": {"integer": 2}}
53 |
54 |
55 | def test_unset():
56 | q = Unset({Sample.integer: 2})
57 | assert q == {"$unset": {"integer": 2}}
58 |
--------------------------------------------------------------------------------
/tests/odm/test_beanie_object_dumping.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic import BaseModel, Field
3 |
4 | from beanie import Link, PydanticObjectId
5 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
6 | from tests.odm.models import DocumentTestModelWithSoftDelete
7 |
8 |
9 | class TestModel(BaseModel):
10 | my_id: PydanticObjectId = Field(default_factory=PydanticObjectId)
11 | fake_doc: Link[DocumentTestModelWithSoftDelete]
12 |
13 |
14 | def data_maker():
15 | return TestModel(
16 | my_id="5f4e3f3b7c0c9d001f7d4c8e",
17 | fake_doc=DocumentTestModelWithSoftDelete(
18 | test_int=1, test_str="test", id="5f4e3f3b7c0c9d001f7d4c8f"
19 | ),
20 | )
21 |
22 |
23 | @pytest.mark.skipif(
24 | not IS_PYDANTIC_V2,
25 | reason="model dumping support is more complete with pydantic v2",
26 | )
27 | def test_id_types_preserved_when_dumping_to_python():
28 | dumped = data_maker().model_dump(mode="python")
29 | assert isinstance(dumped["my_id"], PydanticObjectId)
30 | assert isinstance(dumped["fake_doc"]["id"], PydanticObjectId)
31 |
32 |
33 | @pytest.mark.skipif(
34 | not IS_PYDANTIC_V2,
35 | reason="model dumping support is more complete with pydantic v2",
36 | )
37 | def test_id_types_serialized_when_dumping_to_json():
38 | dumped = data_maker().model_dump(mode="json")
39 | assert isinstance(dumped["my_id"], str)
40 | assert isinstance(dumped["fake_doc"]["id"], str)
41 |
--------------------------------------------------------------------------------
/tests/odm/test_timesries_collection.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from beanie import init_beanie
4 | from beanie.exceptions import MongoDBVersionError
5 | from tests.odm.models import DocumentWithTimeseries
6 |
7 |
8 | async def test_timeseries_collection(db):
9 | build_info = await db.command({"buildInfo": 1})
10 | mongo_version = build_info["version"]
11 | major_version = int(mongo_version.split(".")[0])
12 | if major_version < 5:
13 | with pytest.raises(MongoDBVersionError):
14 | await init_beanie(
15 | database=db, document_models=[DocumentWithTimeseries]
16 | )
17 |
18 | if major_version >= 5:
19 | await init_beanie(
20 | database=db, document_models=[DocumentWithTimeseries]
21 | )
22 | info = await db.command(
23 | {
24 | "listCollections": 1,
25 | "filter": {"name": "DocumentWithTimeseries"},
26 | }
27 | )
28 |
29 | assert info["cursor"]["firstBatch"][0] == {
30 | "name": "DocumentWithTimeseries",
31 | "type": "timeseries",
32 | "options": {
33 | "expireAfterSeconds": 2,
34 | "timeseries": {
35 | "timeField": "ts",
36 | "granularity": "seconds",
37 | "bucketMaxSpanSeconds": 3600,
38 | },
39 | },
40 | "info": {"readOnly": False},
41 | }
42 |
--------------------------------------------------------------------------------
/beanie/odm/cache.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import datetime
3 | from datetime import timedelta, timezone
4 | from typing import Any, Optional
5 |
6 | from pydantic import BaseModel, Field
7 |
8 |
9 | class CachedItem(BaseModel):
10 | timestamp: datetime.datetime = Field(
11 | default_factory=lambda: datetime.datetime.now(tz=timezone.utc)
12 | )
13 | value: Any
14 |
15 |
16 | class LRUCache:
17 | def __init__(self, capacity: int, expiration_time: timedelta):
18 | self.capacity: int = capacity
19 | self.expiration_time: timedelta = expiration_time
20 | self.cache: collections.OrderedDict = collections.OrderedDict()
21 |
22 | def get(self, key) -> Optional[CachedItem]:
23 | try:
24 | item: CachedItem = self.cache.pop(key)
25 | if (
26 | datetime.datetime.now(tz=timezone.utc) - item.timestamp
27 | > self.expiration_time
28 | ):
29 | return None
30 | self.cache[key] = item
31 | return item.value
32 | except KeyError:
33 | return None
34 |
35 | def set(self, key, value) -> None:
36 | try:
37 | self.cache.pop(key)
38 | except KeyError:
39 | if len(self.cache) >= self.capacity:
40 | self.cache.popitem(last=False)
41 | self.cache[key] = CachedItem(value=value)
42 |
43 | @staticmethod
44 | def create_key(*args):
45 | return str(args) # TODO think about this
46 |
--------------------------------------------------------------------------------
/docs/tutorial/init.md:
--------------------------------------------------------------------------------
1 | Beanie uses Async PyMongo as an async database engine.
2 | To initialize previously created documents, you should provide an Async PyMongo database instance
3 | and a list of your document models to the `init_beanie(...)` function, as it is shown in the example:
4 |
5 | ```python
6 | from beanie import init_beanie, Document
7 | from pymongo import AsyncMongoClient
8 |
9 | class Sample(Document):
10 | name: str
11 |
12 | async def init():
13 | # Create Async PyMongo client
14 | client = AsyncMongoClient(
15 | "mongodb://user:pass@host:27017"
16 | )
17 |
18 | # Initialize beanie with the Sample document class and a database
19 | await init_beanie(database=client.db_name, document_models=[Sample])
20 | ```
21 |
22 | This creates the collection (if necessary) and sets up any indexes that are defined.
23 |
24 |
25 | `init_beanie` supports not only a list of classes as the document_models argument,
26 | but also strings with dot-separated paths:
27 |
28 | ```python
29 | await init_beanie(
30 | database=client.db_name,
31 | document_models=[
32 | "app.models.DemoDocument",
33 | ],
34 | )
35 | ```
36 |
37 | ### Warning
38 |
39 | `init_beanie` supports the parameter named `allow_index_dropping` that will drop indexes from your collections.
40 | `allow_index_dropping` is by default set to `False`. If you set this to `True`,
41 | ensure that you are not managing your indexes in another manner.
42 | If you are, these will be deleted when setting `allow_index_dropping=True`.
--------------------------------------------------------------------------------
/beanie/odm/utils/dump.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Optional, Set
2 |
3 | from beanie.odm.utils.encoder import Encoder
4 |
5 | if TYPE_CHECKING:
6 | from beanie.odm.documents import Document
7 |
8 |
9 | def get_dict(
10 | document: "Document",
11 | to_db: bool = False,
12 | exclude: Optional[Set[str]] = None,
13 | keep_nulls: bool = True,
14 | ):
15 | if exclude is None:
16 | exclude = set()
17 | if document.id is None:
18 | exclude.add("_id")
19 | include = set()
20 | if document.get_settings().use_revision:
21 | include.add("revision_id")
22 | encoder = Encoder(
23 | exclude=exclude, include=include, to_db=to_db, keep_nulls=keep_nulls
24 | )
25 | return encoder.encode(document)
26 |
27 |
28 | def get_nulls(
29 | document: "Document",
30 | exclude: Optional[Set[str]] = None,
31 | ):
32 | dictionary = get_dict(document, exclude=exclude, keep_nulls=True)
33 | return filter_none(dictionary)
34 |
35 |
36 | def get_top_level_nones(
37 | document: "Document",
38 | exclude: Optional[Set[str]] = None,
39 | ):
40 | dictionary = get_dict(document, exclude=exclude, keep_nulls=True)
41 | return {k: v for k, v in dictionary.items() if v is None}
42 |
43 |
44 | def filter_none(d):
45 | result = {}
46 | for k, v in d.items():
47 | if isinstance(v, dict):
48 | filtered = filter_none(v)
49 | if filtered:
50 | result[k] = filtered
51 | elif v is None:
52 | result[k] = v
53 | return result
54 |
--------------------------------------------------------------------------------
/beanie/odm/operators/find/bitwise.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from beanie.odm.fields import ExpressionField
4 | from beanie.odm.operators.find import BaseFindOperator
5 |
6 |
7 | class BaseFindBitwiseOperator(BaseFindOperator):
8 | operator = ""
9 |
10 | def __init__(self, field: Union[str, ExpressionField], bitmask):
11 | self.field = field
12 | self.bitmask = bitmask
13 |
14 | @property
15 | def query(self):
16 | return {self.field: {self.operator: self.bitmask}}
17 |
18 |
19 | class BitsAllClear(BaseFindBitwiseOperator):
20 | """
21 | `$bitsAllClear` query operator
22 |
23 | MongoDB doc:
24 |
25 | """
26 |
27 | operator = "$bitsAllClear"
28 |
29 |
30 | class BitsAllSet(BaseFindBitwiseOperator):
31 | """
32 | `$bitsAllSet` query operator
33 |
34 | MongoDB doc:
35 | https://docs.mongodb.com/manual/reference/operator/query/bitsAllSet/
36 | """
37 |
38 | operator = "$bitsAllSet"
39 |
40 |
41 | class BitsAnyClear(BaseFindBitwiseOperator):
42 | """
43 | `$bitsAnyClear` query operator
44 |
45 | MongoDB doc:
46 | https://docs.mongodb.com/manual/reference/operator/query/bitsAnyClear/
47 | """
48 |
49 | operator = "$bitsAnyClear"
50 |
51 |
52 | class BitsAnySet(BaseFindBitwiseOperator):
53 | """
54 | `$bitsAnySet` query operator
55 |
56 | MongoDB doc:
57 | https://docs.mongodb.com/manual/reference/operator/query/bitsAnySet/
58 | """
59 |
60 | operator = "$bitsAnySet"
61 |
--------------------------------------------------------------------------------
/beanie/odm/utils/relations.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 | from typing import TYPE_CHECKING, Any, Dict
3 | from typing import Mapping as MappingType
4 |
5 | from beanie.odm.fields import (
6 | ExpressionField,
7 | )
8 |
9 | # from pydantic.fields import ModelField
10 | # from pydantic.typing import get_origin
11 |
12 | if TYPE_CHECKING:
13 | from beanie import Document
14 |
15 |
16 | def convert_ids(
17 | query: MappingType[str, Any], doc: "Document", fetch_links: bool
18 | ) -> Dict[str, Any]:
19 | # TODO add all the cases
20 | new_query = {}
21 | for k, v in query.items():
22 | k_splitted = k.split(".")
23 | if (
24 | isinstance(k, ExpressionField)
25 | and doc.get_link_fields() is not None
26 | and len(k_splitted) == 2
27 | and k_splitted[0] in doc.get_link_fields().keys() # type: ignore
28 | and k_splitted[1] == "id"
29 | ):
30 | if fetch_links:
31 | new_k = f"{k_splitted[0]}._id"
32 | else:
33 | new_k = f"{k_splitted[0]}.$id"
34 | else:
35 | new_k = k
36 | new_v: Any
37 | if isinstance(v, Mapping):
38 | new_v = convert_ids(v, doc, fetch_links)
39 | elif isinstance(v, list):
40 | new_v = [
41 | convert_ids(ele, doc, fetch_links)
42 | if isinstance(ele, Mapping)
43 | else ele
44 | for ele in v
45 | ]
46 | else:
47 | new_v = v
48 |
49 | new_query[new_k] = new_v
50 | return new_query
51 |
--------------------------------------------------------------------------------
/tests/odm/test_typing_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import pytest
4 | from pydantic import BaseModel
5 | from typing_extensions import Annotated
6 |
7 | from beanie import Document, Link
8 | from beanie.odm.fields import Indexed
9 | from beanie.odm.utils.pydantic import get_model_fields
10 | from beanie.odm.utils.typing import extract_id_class, get_index_attributes
11 |
12 |
13 | class Lock(Document):
14 | k: int
15 |
16 |
17 | class TestTyping:
18 | def test_extract_id_class(self):
19 | # Union
20 | assert extract_id_class(Union[str, int]) is str
21 | assert extract_id_class(Union[str, None]) is str
22 | assert extract_id_class(Union[str, None, int]) is str
23 | # Optional
24 | assert extract_id_class(Optional[str]) is str
25 | # Link
26 | assert extract_id_class(Link[Lock]) == Lock
27 |
28 | @pytest.mark.parametrize(
29 | "type,result",
30 | (
31 | (str, None),
32 | (Indexed(str), (1, {})),
33 | (Indexed(str, "text", unique=True), ("text", {"unique": True})),
34 | (Annotated[str, Indexed()], (1, {})),
35 | (
36 | Annotated[str, "other metadata", Indexed(unique=True)],
37 | (1, {"unique": True}),
38 | ),
39 | (Annotated[str, "other metadata"], None),
40 | ),
41 | )
42 | def test_get_index_attributes(self, type, result):
43 | class Foo(BaseModel):
44 | bar: type
45 |
46 | field = get_model_fields(Foo)["bar"]
47 | assert get_index_attributes(field) == result
48 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_sync.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from beanie.exceptions import ApplyChangesException
4 | from beanie.odm.documents import MergeStrategy
5 | from tests.odm.models import DocumentToTestSync
6 |
7 |
8 | class TestSync:
9 | async def test_merge_remote(self):
10 | doc = DocumentToTestSync()
11 | await doc.insert()
12 |
13 | doc2 = await DocumentToTestSync.get(doc.id)
14 | doc2.s = "foo"
15 |
16 | doc.i = 100
17 | await doc.save()
18 |
19 | await doc2.sync()
20 |
21 | assert doc2.s == "TEST"
22 | assert doc2.i == 100
23 |
24 | async def test_merge_local(self):
25 | doc = DocumentToTestSync(d={"option_1": {"s": "foo"}})
26 | await doc.insert()
27 |
28 | doc2 = await DocumentToTestSync.get(doc.id)
29 | doc2.s = "foo"
30 | doc2.n.option_1.s = "bar"
31 | doc2.d["option_1"]["s"] = "bar"
32 |
33 | doc.i = 100
34 | await doc.save()
35 |
36 | await doc2.sync(merge_strategy=MergeStrategy.local)
37 |
38 | assert doc2.s == "foo"
39 | assert doc2.n.option_1.s == "bar"
40 | assert doc2.d["option_1"]["s"] == "bar"
41 |
42 | assert doc2.i == 100
43 |
44 | async def test_merge_local_impossible_apply_changes(self):
45 | doc = DocumentToTestSync(d={"option_1": {"s": "foo"}})
46 | await doc.insert()
47 |
48 | doc2 = await DocumentToTestSync.get(doc.id)
49 | doc2.d["option_1"]["s"] = {"foo": "bar"}
50 |
51 | doc.d = {"option_1": "nothing"}
52 | await doc.save()
53 | with pytest.raises(ApplyChangesException):
54 | await doc2.sync(merge_strategy=MergeStrategy.local)
55 |
--------------------------------------------------------------------------------
/beanie/odm/operators/find/element.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import List, Union
3 |
4 | from beanie.odm.operators.find import BaseFindOperator
5 |
6 |
7 | class BaseFindElementOperator(BaseFindOperator, ABC): ...
8 |
9 |
10 | class Exists(BaseFindElementOperator):
11 | """
12 | `$exists` query operator
13 |
14 | Example:
15 |
16 | ```python
17 | class Product(Document):
18 | price: float
19 |
20 | Exists(Product.price, True)
21 | ```
22 |
23 | Will return query object like
24 |
25 | ```python
26 | {"price": {"$exists": True}}
27 | ```
28 |
29 | MongoDB doc:
30 |
31 | """
32 |
33 | def __init__(
34 | self,
35 | field,
36 | value: bool = True,
37 | ):
38 | self.field = field
39 | self.value = value
40 |
41 | @property
42 | def query(self):
43 | return {self.field: {"$exists": self.value}}
44 |
45 |
46 | class Type(BaseFindElementOperator):
47 | """
48 | `$type` query operator
49 |
50 | Example:
51 |
52 | ```python
53 | class Product(Document):
54 | price: float
55 |
56 | Type(Product.price, "decimal")
57 | ```
58 |
59 | Will return query object like
60 |
61 | ```python
62 | {"price": {"$type": "decimal"}}
63 | ```
64 |
65 | MongoDB doc:
66 |
67 | """
68 |
69 | def __init__(self, field, types: Union[List[str], str]):
70 | self.field = field
71 | self.types = types
72 |
73 | @property
74 | def query(self):
75 | return {self.field: {"$type": self.types}}
76 |
--------------------------------------------------------------------------------
/tests/migrations/iterative/test_change_value_subfield.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.migrations.models import RunningDirections
7 | from beanie.odm.documents import Document
8 | from beanie.odm.models import InspectionStatuses
9 |
10 |
11 | class Tag(BaseModel):
12 | color: str
13 | name: str
14 |
15 |
16 | class Note(Document):
17 | title: str
18 | tag: Tag
19 |
20 | class Settings:
21 | name = "notes"
22 |
23 |
24 | @pytest.fixture()
25 | async def notes(db):
26 | await init_beanie(database=db, document_models=[Note])
27 | await Note.delete_all()
28 | for i in range(10):
29 | note = Note(title=str(i), tag=Tag(name="test", color="red"))
30 | await note.insert()
31 | yield
32 | await Note.delete_all()
33 |
34 |
35 | async def test_migration_change_subfield_value(settings, notes, db):
36 | migration_settings = MigrationSettings(
37 | connection_uri=settings.mongodb_dsn,
38 | database_name=settings.mongodb_db_name,
39 | path="tests/migrations/migrations_for_test/change_subfield_value",
40 | )
41 | await run_migrate(migration_settings)
42 |
43 | await init_beanie(database=db, document_models=[Note])
44 | inspection = await Note.inspect_collection()
45 | assert inspection.status == InspectionStatuses.OK
46 | note = await Note.find_one({})
47 | assert note.tag.color == "blue"
48 |
49 | migration_settings.direction = RunningDirections.BACKWARD
50 | await run_migrate(migration_settings)
51 | inspection = await Note.inspect_collection()
52 | assert inspection.status == InspectionStatuses.OK
53 | note = await Note.find_one({})
54 | assert note.tag.color == "red"
55 |
--------------------------------------------------------------------------------
/tests/odm/test_id.py:
--------------------------------------------------------------------------------
1 | from uuid import UUID
2 |
3 | import pytest
4 | from pydantic import BaseModel
5 |
6 | from beanie import PydanticObjectId
7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
8 | from tests.odm.models import DocumentWithCustomIdInt, DocumentWithCustomIdUUID
9 |
10 |
11 | class A(BaseModel):
12 | id: PydanticObjectId
13 |
14 |
15 | async def test_uuid_id():
16 | doc = DocumentWithCustomIdUUID(name="TEST")
17 | await doc.insert()
18 | new_doc = await DocumentWithCustomIdUUID.get(doc.id)
19 | assert isinstance(new_doc.id, UUID)
20 |
21 |
22 | async def test_integer_id():
23 | doc = DocumentWithCustomIdInt(name="TEST", id=1)
24 | await doc.insert()
25 | new_doc = await DocumentWithCustomIdInt.get(doc.id)
26 | assert isinstance(new_doc.id, int)
27 |
28 |
29 | @pytest.mark.skipif(
30 | not IS_PYDANTIC_V2,
31 | reason="supports only pydantic v2",
32 | )
33 | async def test_pydantic_object_id_validation_json():
34 | deserialized = A.model_validate_json('{"id": "5eb7cf5a86d9755df3a6c593"}')
35 | assert isinstance(deserialized.id, PydanticObjectId)
36 | assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593"
37 | assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")
38 |
39 |
40 | @pytest.mark.skipif(
41 | not IS_PYDANTIC_V2,
42 | reason="supports only pydantic v2",
43 | )
44 | @pytest.mark.parametrize(
45 | "data",
46 | [
47 | "5eb7cf5a86d9755df3a6c593",
48 | PydanticObjectId("5eb7cf5a86d9755df3a6c593"),
49 | ],
50 | )
51 | async def test_pydantic_object_id_serialization(data):
52 | deserialized = A(**{"id": data})
53 | assert isinstance(deserialized.id, PydanticObjectId)
54 | assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593"
55 | assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")
56 |
--------------------------------------------------------------------------------
/docs/tutorial/insert.md:
--------------------------------------------------------------------------------
1 | # Insert the documents
2 |
3 | Beanie documents behave just like pydantic models (because they subclass `pydantic.BaseModel`).
4 | Hence, a document can be created in a similar fashion to pydantic:
5 |
6 | ```python
7 | from typing import Optional
8 |
9 | from pydantic import BaseModel
10 |
11 | from beanie import Document, Indexed
12 |
13 |
14 | class Category(BaseModel):
15 | name: str
16 | description: str
17 |
18 |
19 | class Product(Document): # This is the model
20 | name: str
21 | description: Optional[str] = None
22 | price: Indexed(float)
23 | category: Category
24 |
25 | class Settings:
26 | name = "products"
27 |
28 |
29 | chocolate = Category(name="Chocolate", description="A preparation of roasted and ground cacao seeds.")
30 | tonybar = Product(name="Tony's", price=5.95, category=chocolate)
31 | marsbar = Product(name="Mars", price=1, category=chocolate)
32 | ```
33 |
34 | This however does not save the documents to the database yet.
35 |
36 | ## Insert a single document
37 |
38 | To insert a document into the database, you can call either `insert()` or `create()` on it (they are synonyms):
39 |
40 | ```python
41 | await tonybar.insert()
42 | await marsbar.create() # does exactly the same as insert()
43 | ```
44 | You can also call `save()`, which behaves in the same manner for new documents, but will also update existing documents.
45 | See the [section on updating](updating-&-deleting.md) of this tutorial for more details.
46 |
47 | If you prefer, you can also call the `insert_one` class method:
48 |
49 | ```python
50 | await Product.insert_one(tonybar)
51 | ```
52 |
53 | ## Inserting many documents
54 |
55 | To reduce the number of database queries,
56 | similarly typed documents should be inserted together by calling the class method `insert_many`:
57 |
58 | ```python
59 | await Product.insert_many([tonybar,marsbar])
60 | ```
61 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_delete.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import DocumentTestModel
2 |
3 |
4 | async def test_delete_one(documents):
5 | await documents(4, "uno")
6 | await documents(2, "dos")
7 | await documents(1, "cuatro")
8 | await DocumentTestModel.find_one({"test_str": "uno"}).delete()
9 | documents = await DocumentTestModel.find_all().to_list()
10 | assert len(documents) == 6
11 |
12 |
13 | async def test_delete_one_not_found(documents):
14 | await documents(4, "uno")
15 | await documents(2, "dos")
16 | await documents(1, "cuatro")
17 | await DocumentTestModel.find_one({"test_str": "wrong"}).delete()
18 | documents = await DocumentTestModel.find_all().to_list()
19 | assert len(documents) == 7
20 |
21 |
22 | async def test_delete_many(documents):
23 | await documents(4, "uno")
24 | await documents(2, "dos")
25 | await documents(1, "cuatro")
26 | await DocumentTestModel.find_many({"test_str": "uno"}).delete()
27 | documents = await DocumentTestModel.find_all().to_list()
28 | assert len(documents) == 3
29 |
30 |
31 | async def test_delete_many_not_found(documents):
32 | await documents(4, "uno")
33 | await documents(2, "dos")
34 | await documents(1, "cuatro")
35 | await DocumentTestModel.find_many({"test_str": "wrong"}).delete()
36 | documents = await DocumentTestModel.find_all().to_list()
37 | assert len(documents) == 7
38 |
39 |
40 | async def test_delete_all(documents):
41 | await documents(4, "uno")
42 | await documents(2, "dos")
43 | await documents(1, "cuatro")
44 | await DocumentTestModel.delete_all()
45 | documents = await DocumentTestModel.find_all().to_list()
46 | assert len(documents) == 0
47 |
48 |
49 | async def test_delete(document):
50 | doc_id = document.id
51 | await document.delete()
52 | new_document = await DocumentTestModel.get(doc_id)
53 | assert new_document is None
54 |
--------------------------------------------------------------------------------
/tests/odm/test_consistency.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.update.general import Set
2 | from tests.odm.models import (
3 | DocumentTestModel,
4 | DocumentWithTimeStampToTestConsistency,
5 | )
6 |
7 |
8 | class TestResponseOfTheChangingOperations:
9 | async def test_insert(self, document_not_inserted):
10 | result = await document_not_inserted.insert()
11 | assert isinstance(result, DocumentTestModel)
12 |
13 | async def test_update(self, document):
14 | result = await document.update(Set({"test_int": 43}))
15 | assert isinstance(result, DocumentTestModel)
16 |
17 | async def test_save(self, document, document_not_inserted):
18 | document.test_int = 43
19 | result = await document.save()
20 | assert isinstance(result, DocumentTestModel)
21 |
22 | document_not_inserted.test_int = 43
23 | result = await document_not_inserted.save()
24 | assert isinstance(result, DocumentTestModel)
25 |
26 | async def test_save_changes(self, document):
27 | document.test_int = 43
28 | result = await document.save_changes()
29 | assert isinstance(result, DocumentTestModel)
30 |
31 | async def test_replace(self, document):
32 | result = await document.replace()
33 | assert isinstance(result, DocumentTestModel)
34 |
35 | async def test_set(self, document):
36 | result = await document.set({"test_int": 43})
37 | assert isinstance(result, DocumentTestModel)
38 |
39 | async def test_inc(self, document):
40 | result = await document.inc({"test_int": 1})
41 | assert isinstance(result, DocumentTestModel)
42 |
43 | async def test_current_date(self):
44 | document = DocumentWithTimeStampToTestConsistency()
45 | await document.insert()
46 | result = await document.current_date({"ts": True})
47 | assert isinstance(result, DocumentWithTimeStampToTestConsistency)
48 |
--------------------------------------------------------------------------------
/tests/migrations/iterative/test_rename_field.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.migrations.models import RunningDirections
7 | from beanie.odm.documents import Document
8 | from beanie.odm.models import InspectionStatuses
9 |
10 |
11 | class Tag(BaseModel):
12 | color: str
13 | name: str
14 |
15 |
16 | class OldNote(Document):
17 | name: str
18 | tag: Tag
19 |
20 | class Settings:
21 | name = "notes"
22 |
23 |
24 | class Note(Document):
25 | title: str
26 | tag: Tag
27 |
28 | class Settings:
29 | name = "notes"
30 |
31 |
32 | @pytest.fixture()
33 | async def notes(db):
34 | await init_beanie(database=db, document_models=[OldNote])
35 | await OldNote.delete_all()
36 | for i in range(10):
37 | note = OldNote(name=str(i), tag=Tag(name="test", color="red"))
38 | await note.insert()
39 | yield
40 | # await OldNote.delete_all()
41 |
42 |
43 | async def test_migration_rename_field(settings, notes, db):
44 | migration_settings = MigrationSettings(
45 | connection_uri=settings.mongodb_dsn,
46 | database_name=settings.mongodb_db_name,
47 | path="tests/migrations/migrations_for_test/rename_field",
48 | )
49 | await run_migrate(migration_settings)
50 |
51 | await init_beanie(database=db, document_models=[Note])
52 | inspection = await Note.inspect_collection()
53 | assert inspection.status == InspectionStatuses.OK
54 | note = await Note.find_one({})
55 | assert note.title == "0"
56 |
57 | migration_settings.direction = RunningDirections.BACKWARD
58 | await run_migrate(migration_settings)
59 | inspection = await OldNote.inspect_collection()
60 | assert inspection.status == InspectionStatuses.OK
61 | note = await OldNote.find_one({})
62 | assert note.name == "0"
63 |
--------------------------------------------------------------------------------
/beanie/odm/utils/pydantic.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Type
2 |
3 | import pydantic
4 | from pydantic import BaseModel
5 |
6 | IS_PYDANTIC_V2 = int(pydantic.VERSION.split(".")[0]) >= 2
7 | IS_PYDANTIC_V2_10 = (
8 | IS_PYDANTIC_V2 and int(pydantic.VERSION.split(".")[1]) >= 10
9 | )
10 |
11 | if IS_PYDANTIC_V2:
12 | from pydantic import TypeAdapter
13 | else:
14 | from pydantic import parse_obj_as
15 |
16 |
17 | def parse_object_as(object_type: Type, data: Any):
18 | if IS_PYDANTIC_V2:
19 | return TypeAdapter(object_type).validate_python(data)
20 | else:
21 | return parse_obj_as(object_type, data)
22 |
23 |
24 | def get_field_type(field):
25 | if IS_PYDANTIC_V2:
26 | return field.annotation
27 | else:
28 | return field.outer_type_
29 |
30 |
31 | def get_model_fields(model):
32 | if IS_PYDANTIC_V2:
33 | if not isinstance(model, type):
34 | model = model.__class__
35 | return model.model_fields
36 | else:
37 | return model.__fields__
38 |
39 |
40 | def parse_model(model_type: Type[BaseModel], data: Any):
41 | if IS_PYDANTIC_V2:
42 | return model_type.model_validate(data)
43 | else:
44 | return model_type.parse_obj(data)
45 |
46 |
47 | def get_extra_field_info(field, parameter: str):
48 | if IS_PYDANTIC_V2:
49 | if isinstance(field.json_schema_extra, dict):
50 | return field.json_schema_extra.get(parameter)
51 | return None
52 | else:
53 | return field.field_info.extra.get(parameter)
54 |
55 |
56 | def get_config_value(model, parameter: str):
57 | if IS_PYDANTIC_V2:
58 | return model.model_config.get(parameter)
59 | else:
60 | return getattr(model.Config, parameter, None)
61 |
62 |
63 | def get_model_dump(model, *args: Any, **kwargs: Any):
64 | if IS_PYDANTIC_V2:
65 | return model.model_dump(*args, **kwargs)
66 | else:
67 | return model.dict(*args, **kwargs)
68 |
--------------------------------------------------------------------------------
/beanie/odm/settings/timeseries.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Annotated, Any, Dict, Optional
3 |
4 | from pydantic import BaseModel, Field
5 |
6 |
7 | class Granularity(str, Enum):
8 | """
9 | Time Series Granuality
10 | """
11 |
12 | seconds = "seconds"
13 | minutes = "minutes"
14 | hours = "hours"
15 |
16 |
17 | class TimeSeriesConfig(BaseModel):
18 | """
19 | Time Series Collection config
20 | """
21 |
22 | time_field: str
23 | meta_field: Optional[str] = None
24 | granularity: Optional[Granularity] = None
25 | bucket_max_span_seconds: Optional[int] = None
26 | bucket_rounding_second: Annotated[
27 | Optional[int],
28 | Field(
29 | deprecated="This field is deprecated in favor of "
30 | "'bucket_rounding_seconds'.",
31 | ),
32 | ] = None
33 | bucket_rounding_seconds: Optional[int] = None
34 | expire_after_seconds: Optional[int] = None
35 |
36 | def build_query(self, collection_name: str) -> Dict[str, Any]:
37 | res: Dict[str, Any] = {"name": collection_name}
38 | timeseries: Dict[str, Any] = {"timeField": self.time_field}
39 | if self.meta_field is not None:
40 | timeseries["metaField"] = self.meta_field
41 | if self.granularity is not None:
42 | timeseries["granularity"] = self.granularity
43 | if self.bucket_max_span_seconds is not None:
44 | timeseries["bucketMaxSpanSeconds"] = self.bucket_max_span_seconds
45 | if self.bucket_rounding_second is not None: # Deprecated field
46 | timeseries["bucketRoundingSeconds"] = self.bucket_rounding_second
47 | if self.bucket_rounding_seconds is not None:
48 | timeseries["bucketRoundingSeconds"] = self.bucket_rounding_seconds
49 | res["timeseries"] = timeseries
50 | if self.expire_after_seconds is not None:
51 | res["expireAfterSeconds"] = self.expire_after_seconds
52 | return res
53 |
--------------------------------------------------------------------------------
/tests/migrations/iterative/test_change_value.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.migrations.models import RunningDirections
7 | from beanie.odm.documents import Document
8 | from beanie.odm.models import InspectionStatuses
9 |
10 |
11 | class Tag(BaseModel):
12 | color: str
13 | name: str
14 |
15 |
16 | class Note(Document):
17 | title: str
18 | tag: Tag
19 |
20 | class Settings:
21 | name = "notes"
22 |
23 |
24 | @pytest.fixture()
25 | async def notes(db):
26 | await init_beanie(database=db, document_models=[Note])
27 | await Note.delete_all()
28 | for i in range(10):
29 | note = Note(title=str(i), tag=Tag(name="test", color="red"))
30 | await note.insert()
31 | yield
32 | await Note.delete_all()
33 |
34 |
35 | async def test_migration_change_value(settings, notes, db):
36 | migration_settings = MigrationSettings(
37 | connection_uri=settings.mongodb_dsn,
38 | database_name=settings.mongodb_db_name,
39 | path="tests/migrations/migrations_for_test/change_value",
40 | )
41 | await run_migrate(migration_settings)
42 |
43 | await init_beanie(database=db, document_models=[Note])
44 | inspection = await Note.inspect_collection()
45 | assert inspection.status == InspectionStatuses.OK
46 | note = await Note.find_one({"title": "five"})
47 | assert note is not None
48 |
49 | note = await Note.find_one({"title": "5"})
50 | assert note is None
51 |
52 | migration_settings.direction = RunningDirections.BACKWARD
53 | await run_migrate(migration_settings)
54 | inspection = await Note.inspect_collection()
55 | assert inspection.status == InspectionStatuses.OK
56 | note = await Note.find_one({"title": "5"})
57 | assert note is not None
58 |
59 | note = await Note.find_one({"title": "five"})
60 | assert note is None
61 |
--------------------------------------------------------------------------------
/tests/migrations/test_break.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import Indexed, init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.odm.documents import Document
7 | from beanie.odm.models import InspectionStatuses
8 |
9 |
10 | class Tag(BaseModel):
11 | color: str
12 | name: str
13 |
14 |
15 | class OldNote(Document):
16 | name: Indexed(str, unique=True)
17 | tag: Tag
18 |
19 | class Settings:
20 | name = "notes"
21 |
22 |
23 | class Note(Document):
24 | name: Indexed(str, unique=True)
25 | title: str
26 | tag: Tag
27 |
28 | class Settings:
29 | name = "notes"
30 |
31 |
32 | @pytest.fixture()
33 | async def notes(db):
34 | await init_beanie(database=db, document_models=[OldNote])
35 | await OldNote.delete_all()
36 | for i in range(10):
37 | note = OldNote(name=str(i), tag=Tag(name="test", color="red"))
38 | await note.insert()
39 | yield
40 | await OldNote.delete_all()
41 | await OldNote.get_pymongo_collection().drop()
42 | await OldNote.get_pymongo_collection().drop_indexes()
43 |
44 |
45 | @pytest.mark.skip("TODO: Fix this test")
46 | async def test_migration_break(settings, notes, db):
47 | migration_settings = MigrationSettings(
48 | connection_uri=settings.mongodb_dsn,
49 | database_name=settings.mongodb_db_name,
50 | path="tests/migrations/migrations_for_test/break",
51 | )
52 | with pytest.raises(Exception):
53 | await run_migrate(migration_settings)
54 |
55 | await init_beanie(database=db, document_models=[OldNote])
56 | inspection = await OldNote.inspect_collection()
57 | assert inspection.status == InspectionStatuses.OK
58 | notes = await OldNote.get_pymongo_collection().find().to_list(length=100)
59 | names = set(n["name"] for n in notes)
60 | assert names == {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
61 | for note in notes:
62 | assert "title" not in note
63 |
--------------------------------------------------------------------------------
/tests/migrations/iterative/test_change_subfield.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.migrations.models import RunningDirections
7 | from beanie.odm.documents import Document
8 | from beanie.odm.models import InspectionStatuses
9 |
10 |
11 | class OldTag(BaseModel):
12 | color: str
13 | name: str
14 |
15 |
16 | class Tag(BaseModel):
17 | color: str
18 | title: str
19 |
20 |
21 | class OldNote(Document):
22 | title: str
23 | tag: OldTag
24 |
25 | class Settings:
26 | name = "notes"
27 |
28 |
29 | class Note(Document):
30 | title: str
31 | tag: Tag
32 |
33 | class Settings:
34 | name = "notes"
35 |
36 |
37 | @pytest.fixture()
38 | async def notes(db):
39 | await init_beanie(database=db, document_models=[OldNote])
40 | await OldNote.delete_all()
41 | for i in range(10):
42 | note = OldNote(title=str(i), tag=OldTag(name="test", color="red"))
43 | await note.insert()
44 | yield
45 | await OldNote.delete_all()
46 |
47 |
48 | async def test_migration_change_subfield_value(settings, notes, db):
49 | migration_settings = MigrationSettings(
50 | connection_uri=settings.mongodb_dsn,
51 | database_name=settings.mongodb_db_name,
52 | path="tests/migrations/migrations_for_test/change_subfield",
53 | )
54 | await run_migrate(migration_settings)
55 |
56 | await init_beanie(database=db, document_models=[Note])
57 | inspection = await Note.inspect_collection()
58 | assert inspection.status == InspectionStatuses.OK
59 | note = await Note.find_one({})
60 | assert note.tag.title == "test"
61 |
62 | migration_settings.direction = RunningDirections.BACKWARD
63 | await run_migrate(migration_settings)
64 | inspection = await OldNote.inspect_collection()
65 | assert inspection.status == InspectionStatuses.OK
66 | note = await OldNote.find_one({})
67 | assert note.tag.name == "test"
68 |
--------------------------------------------------------------------------------
/tests/migrations/iterative/test_pack_unpack.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.migrations.models import RunningDirections
7 | from beanie.odm.documents import Document
8 | from beanie.odm.models import InspectionStatuses
9 |
10 |
11 | class OldTag(BaseModel):
12 | color: str
13 | name: str
14 |
15 |
16 | class Tag(BaseModel):
17 | color: str
18 | name: str
19 |
20 |
21 | class OldNote(Document):
22 | title: str
23 | tag_name: str
24 | tag_color: str
25 |
26 | class Settings:
27 | name = "notes"
28 |
29 |
30 | class Note(Document):
31 | title: str
32 | tag: Tag
33 |
34 | class Settings:
35 | name = "notes"
36 |
37 |
38 | @pytest.fixture()
39 | async def notes(db):
40 | await init_beanie(database=db, document_models=[OldNote])
41 | await OldNote.delete_all()
42 | for i in range(10):
43 | note = OldNote(title=str(i), tag_name="test", tag_color="red")
44 | await note.insert()
45 | yield
46 | await OldNote.delete_all()
47 |
48 |
49 | async def test_migration_pack_unpack(settings, notes, db):
50 | migration_settings = MigrationSettings(
51 | connection_uri=settings.mongodb_dsn,
52 | database_name=settings.mongodb_db_name,
53 | path="tests/migrations/migrations_for_test/pack_unpack",
54 | )
55 | await run_migrate(migration_settings)
56 |
57 | await init_beanie(database=db, document_models=[Note])
58 | inspection = await Note.inspect_collection()
59 | assert inspection.status == InspectionStatuses.OK
60 | note = await Note.find_one({})
61 | assert note.tag.name == "test"
62 |
63 | migration_settings.direction = RunningDirections.BACKWARD
64 | await run_migrate(migration_settings)
65 | inspection = await OldNote.inspect_collection()
66 | assert inspection.status == InspectionStatuses.OK
67 | note = await OldNote.find_one({})
68 | assert note.tag_name == "test"
69 |
--------------------------------------------------------------------------------
/docs/tutorial/multi-model.md:
--------------------------------------------------------------------------------
1 | # Multi-model pattern
2 |
3 | Documents with different schemas could be stored in a single collection and managed correctly.
4 | `UnionDoc` class is used for this.
5 |
6 | It supports `find` and `aggregate` methods.
7 | For `find`, it will fetch all the found documents into the respective `Document` classes.
8 |
9 | Documents with `union_doc` in their settings can still be used in `find` and other queries.
10 | Queries of one such class will not see the data of others.
11 |
12 | ## Example
13 |
14 | Create documents:
15 |
16 | ```python
17 | from beanie import Document, UnionDoc
18 |
19 |
20 | class Parent(UnionDoc): # Union
21 | class Settings:
22 | name = "union_doc_collection" # Collection name
23 | class_id = "_class_id" # _class_id is the default beanie internal field used to filter children Documents
24 |
25 |
26 | class One(Document):
27 | int_field: int = 0
28 | shared: int = 0
29 |
30 | class Settings:
31 | name = "One" # Name used to filer union document 'One', default to class name
32 | union_doc = Parent
33 |
34 |
35 | class Two(Document):
36 | str_field: str = "test"
37 | shared: int = 0
38 |
39 | class Settings:
40 | union_doc = Parent
41 | ```
42 |
43 | The schemas could be incompatible.
44 |
45 | Insert a document
46 |
47 | ```python
48 | await One().insert()
49 | await One().insert()
50 | await One().insert()
51 |
52 | await Two().insert()
53 | ```
54 |
55 | Find all the documents of the first type:
56 |
57 | ```python
58 | docs = await One.all().to_list()
59 | print(len(docs))
60 |
61 | >> 3 # It found only documents of class One
62 | ```
63 |
64 | Of the second type:
65 |
66 | ```python
67 | docs = await Two.all().to_list()
68 | print(len(docs))
69 |
70 | >> 1 # It found only documents of class Two
71 | ```
72 |
73 | Of both:
74 |
75 | ```python
76 | docs = await Parent.all().to_list()
77 | print(len(docs))
78 |
79 | >> 4 # instances of both classes will be in the output here
80 | ```
81 |
82 | Aggregations will work separately for these two document classes too.
83 |
--------------------------------------------------------------------------------
/tests/odm/test_lazy_parsing.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from beanie.odm.utils.dump import get_dict
4 | from beanie.odm.utils.parsing import parse_obj
5 | from tests.odm.models import SampleLazyParsing
6 |
7 |
8 | @pytest.fixture
9 | async def docs():
10 | for i in range(10):
11 | await SampleLazyParsing(i=i, s=str(i)).insert()
12 |
13 |
14 | class TestLazyParsing:
15 | async def test_find_all(self, docs):
16 | found_docs = await SampleLazyParsing.all(lazy_parse=True).to_list()
17 | saved_state = found_docs[0].get_saved_state()
18 | assert "_id" in saved_state
19 | del saved_state["_id"]
20 | assert found_docs[0].get_saved_state() == {}
21 | assert found_docs[0].i == 0
22 | assert found_docs[0].s == "0"
23 | assert found_docs[1]._store["i"] == 1
24 | assert found_docs[1]._store["s"] == "1"
25 | assert get_dict(found_docs[2])["i"] == 2
26 | assert get_dict(found_docs[2])["s"] == "2"
27 |
28 | async def test_find_many(self, docs):
29 | found_docs = await SampleLazyParsing.find(
30 | SampleLazyParsing.i <= 5, lazy_parse=True
31 | ).to_list()
32 | saved_state = found_docs[0].get_saved_state()
33 | assert "_id" in saved_state
34 | del saved_state["_id"]
35 | assert found_docs[0].get_saved_state() == {}
36 | assert found_docs[0].i == 0
37 | assert found_docs[0].s == "0"
38 | assert found_docs[1]._store["i"] == 1
39 | assert found_docs[1]._store["s"] == "1"
40 | assert get_dict(found_docs[2])["i"] == 2
41 | assert get_dict(found_docs[2])["s"] == "2"
42 |
43 | async def test_save_changes(self, docs):
44 | found_docs = await SampleLazyParsing.all(lazy_parse=True).to_list()
45 | doc = found_docs[0]
46 | doc.i = 1000
47 | await doc.save_changes()
48 | new_doc = await SampleLazyParsing.find_one(SampleLazyParsing.s == "0")
49 | assert new_doc.i == 1000
50 |
51 | async def test_default_list(self):
52 | res = parse_obj(SampleLazyParsing, {"i": "1", "s": "1"})
53 | assert res.lst == []
54 |
--------------------------------------------------------------------------------
/beanie/__init__.py:
--------------------------------------------------------------------------------
1 | from beanie.migrations.controllers.free_fall import free_fall_migration
2 | from beanie.migrations.controllers.iterative import iterative_migration
3 | from beanie.odm.actions import (
4 | After,
5 | Before,
6 | Delete,
7 | Insert,
8 | Replace,
9 | Save,
10 | SaveChanges,
11 | Update,
12 | ValidateOnSave,
13 | after_event,
14 | before_event,
15 | )
16 | from beanie.odm.bulk import BulkWriter
17 | from beanie.odm.custom_types import DecimalAnnotation
18 | from beanie.odm.custom_types.bson.binary import BsonBinary
19 | from beanie.odm.documents import (
20 | Document,
21 | DocumentWithSoftDelete,
22 | MergeStrategy,
23 | )
24 | from beanie.odm.enums import SortDirection
25 | from beanie.odm.fields import (
26 | BackLink,
27 | BeanieObjectId,
28 | DeleteRules,
29 | Indexed,
30 | Link,
31 | PydanticObjectId,
32 | WriteRules,
33 | )
34 | from beanie.odm.queries.update import UpdateResponse
35 | from beanie.odm.settings.timeseries import Granularity, TimeSeriesConfig
36 | from beanie.odm.union_doc import UnionDoc
37 | from beanie.odm.utils.init import init_beanie
38 | from beanie.odm.views import View
39 |
40 | __version__ = "2.0.1"
41 | __all__ = [
42 | # ODM
43 | "Document",
44 | "DocumentWithSoftDelete",
45 | "View",
46 | "UnionDoc",
47 | "init_beanie",
48 | "PydanticObjectId",
49 | "BeanieObjectId",
50 | "Indexed",
51 | "TimeSeriesConfig",
52 | "Granularity",
53 | "SortDirection",
54 | "MergeStrategy",
55 | # Actions
56 | "before_event",
57 | "after_event",
58 | "Insert",
59 | "Replace",
60 | "Save",
61 | "SaveChanges",
62 | "ValidateOnSave",
63 | "Delete",
64 | "Before",
65 | "After",
66 | "Update",
67 | # Bulk Write
68 | "BulkWriter",
69 | # Migrations
70 | "iterative_migration",
71 | "free_fall_migration",
72 | # Relations
73 | "Link",
74 | "BackLink",
75 | "WriteRules",
76 | "DeleteRules",
77 | # Custom Types
78 | "DecimalAnnotation",
79 | "BsonBinary",
80 | # UpdateResponse
81 | "UpdateResponse",
82 | ]
83 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_evaluation.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.evaluation import (
2 | Expr,
3 | JsonSchema,
4 | Mod,
5 | RegEx,
6 | Text,
7 | Where,
8 | )
9 | from tests.odm.models import Sample
10 |
11 |
12 | async def test_expr():
13 | q = Expr({"a": "B"})
14 | assert q == {"$expr": {"a": "B"}}
15 |
16 |
17 | async def test_json_schema():
18 | q = JsonSchema({"a": "B"})
19 | assert q == {"$jsonSchema": {"a": "B"}}
20 |
21 |
22 | async def test_mod():
23 | q = Mod(Sample.integer, 3, 2)
24 | assert q == {"integer": {"$mod": [3, 2]}}
25 |
26 |
27 | async def test_regex():
28 | q = RegEx(Sample.integer, "smth")
29 | assert q == {"integer": {"$regex": "smth"}}
30 |
31 | q = RegEx(Sample.integer, "smth", "options")
32 | assert q == {"integer": {"$regex": "smth", "$options": "options"}}
33 |
34 |
35 | async def test_text():
36 | q = Text("something")
37 | assert q == {
38 | "$text": {
39 | "$search": "something",
40 | "$caseSensitive": False,
41 | "$diacriticSensitive": False,
42 | }
43 | }
44 | q = Text("something", case_sensitive=True)
45 | assert q == {
46 | "$text": {
47 | "$search": "something",
48 | "$caseSensitive": True,
49 | "$diacriticSensitive": False,
50 | }
51 | }
52 | q = Text("something", diacritic_sensitive=True)
53 | assert q == {
54 | "$text": {
55 | "$search": "something",
56 | "$caseSensitive": False,
57 | "$diacriticSensitive": True,
58 | }
59 | }
60 | q = Text("something", diacritic_sensitive=None)
61 | assert q == {
62 | "$text": {
63 | "$search": "something",
64 | "$caseSensitive": False,
65 | }
66 | }
67 | q = Text("something", language="test")
68 | assert q == {
69 | "$text": {
70 | "$search": "something",
71 | "$caseSensitive": False,
72 | "$diacriticSensitive": False,
73 | "$language": "test",
74 | }
75 | }
76 |
77 |
78 | async def test_where():
79 | q = Where("test")
80 | assert q == {"$where": "test"}
81 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_comparison.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.comparison import (
2 | GT,
3 | GTE,
4 | LT,
5 | LTE,
6 | NE,
7 | Eq,
8 | In,
9 | NotIn,
10 | )
11 | from tests.odm.models import Sample
12 |
13 |
14 | async def test_eq():
15 | q = Sample.integer == 1
16 | assert q == {"integer": 1}
17 |
18 | q = Eq(Sample.integer, 1)
19 | assert q == {"integer": 1}
20 |
21 | q = Eq("integer", 1)
22 | assert q == {"integer": 1}
23 |
24 |
25 | async def test_gt():
26 | q = Sample.integer > 1
27 | assert q == {"integer": {"$gt": 1}}
28 |
29 | q = GT(Sample.integer, 1)
30 | assert q == {"integer": {"$gt": 1}}
31 |
32 | q = GT("integer", 1)
33 | assert q == {"integer": {"$gt": 1}}
34 |
35 |
36 | async def test_gte():
37 | q = Sample.integer >= 1
38 | assert q == {"integer": {"$gte": 1}}
39 |
40 | q = GTE(Sample.integer, 1)
41 | assert q == {"integer": {"$gte": 1}}
42 |
43 | q = GTE("integer", 1)
44 | assert q == {"integer": {"$gte": 1}}
45 |
46 |
47 | async def test_in():
48 | q = In(Sample.integer, [1])
49 | assert q == {"integer": {"$in": [1]}}
50 |
51 | q = In(Sample.integer, [1])
52 | assert q == {"integer": {"$in": [1]}}
53 |
54 |
55 | async def test_lt():
56 | q = Sample.integer < 1
57 | assert q == {"integer": {"$lt": 1}}
58 |
59 | q = LT(Sample.integer, 1)
60 | assert q == {"integer": {"$lt": 1}}
61 |
62 | q = LT("integer", 1)
63 | assert q == {"integer": {"$lt": 1}}
64 |
65 |
66 | async def test_lte():
67 | q = Sample.integer <= 1
68 | assert q == {"integer": {"$lte": 1}}
69 |
70 | q = LTE(Sample.integer, 1)
71 | assert q == {"integer": {"$lte": 1}}
72 |
73 | q = LTE("integer", 1)
74 | assert q == {"integer": {"$lte": 1}}
75 |
76 |
77 | async def test_ne():
78 | q = Sample.integer != 1
79 | assert q == {"integer": {"$ne": 1}}
80 |
81 | q = NE(Sample.integer, 1)
82 | assert q == {"integer": {"$ne": 1}}
83 |
84 | q = NE("integer", 1)
85 | assert q == {"integer": {"$ne": 1}}
86 |
87 |
88 | async def test_nin():
89 | q = NotIn(Sample.integer, [1])
90 | assert q == {"integer": {"$nin": [1]}}
91 |
92 | q = NotIn(Sample.integer, [1])
93 | assert q == {"integer": {"$nin": [1]}}
94 |
--------------------------------------------------------------------------------
/beanie/odm/views.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, ClassVar, Dict, Optional, Union
3 |
4 | from pydantic import BaseModel
5 |
6 | from beanie.exceptions import ViewWasNotInitialized
7 | from beanie.odm.fields import Link, LinkInfo
8 | from beanie.odm.interfaces.aggregate import AggregateInterface
9 | from beanie.odm.interfaces.detector import DetectionInterface, ModelType
10 | from beanie.odm.interfaces.find import FindInterface
11 | from beanie.odm.interfaces.getters import OtherGettersInterface
12 | from beanie.odm.settings.view import ViewSettings
13 |
14 |
15 | class View(
16 | BaseModel,
17 | FindInterface,
18 | AggregateInterface,
19 | OtherGettersInterface,
20 | DetectionInterface,
21 | ):
22 | """
23 | What is needed:
24 |
25 | Source collection or view
26 | pipeline
27 |
28 | """
29 |
30 | # Relations
31 | _link_fields: ClassVar[Optional[Dict[str, LinkInfo]]] = None
32 |
33 | # Settings
34 | _settings: ClassVar[ViewSettings]
35 |
36 | @classmethod
37 | def get_settings(cls) -> ViewSettings:
38 | """
39 | Get view settings, which was created on
40 | the initialization step
41 |
42 | :return: ViewSettings class
43 | """
44 | if cls._settings is None:
45 | raise ViewWasNotInitialized
46 | return cls._settings
47 |
48 | async def fetch_link(self, field: Union[str, Any]):
49 | ref_obj = getattr(self, field, None)
50 | if isinstance(ref_obj, Link):
51 | value = await ref_obj.fetch(fetch_links=True)
52 | setattr(self, field, value)
53 | if isinstance(ref_obj, list) and ref_obj:
54 | values = await Link.fetch_list(ref_obj, fetch_links=True)
55 | setattr(self, field, values)
56 |
57 | async def fetch_all_links(self):
58 | coros = []
59 | link_fields = self.get_link_fields()
60 | if link_fields is not None:
61 | for ref in link_fields.values():
62 | coros.append(self.fetch_link(ref.field_name)) # TODO lists
63 | await asyncio.gather(*coros)
64 |
65 | @classmethod
66 | def get_link_fields(cls) -> Optional[Dict[str, LinkInfo]]:
67 | return cls._link_fields
68 |
69 | @classmethod
70 | def get_model_type(cls) -> ModelType:
71 | return ModelType.View
72 |
--------------------------------------------------------------------------------
/tests/fastapi/routes.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from fastapi import APIRouter, Body, status
4 | from pydantic import BaseModel
5 |
6 | from beanie import PydanticObjectId, WriteRules
7 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
8 | from tests.fastapi.models import House, HouseAPI, Person, WindowAPI
9 |
10 | house_router = APIRouter()
11 | if not IS_PYDANTIC_V2:
12 | from fastapi.encoders import ENCODERS_BY_TYPE
13 | from pydantic.json import ENCODERS_BY_TYPE as PYDANTIC_ENCODERS_BY_TYPE
14 |
15 | ENCODERS_BY_TYPE.update(PYDANTIC_ENCODERS_BY_TYPE)
16 |
17 |
18 | class WindowInput(BaseModel):
19 | id: PydanticObjectId
20 |
21 |
22 | @house_router.post("/windows/", response_model=WindowAPI)
23 | async def create_window(window: WindowAPI):
24 | await window.create()
25 | return window
26 |
27 |
28 | @house_router.post("/windows_2/")
29 | async def create_window_2(window: WindowAPI):
30 | return await window.save()
31 |
32 |
33 | @house_router.get("/windows/{id}", response_model=Optional[WindowAPI])
34 | async def get_window(id: PydanticObjectId):
35 | return await WindowAPI.get(id)
36 |
37 |
38 | @house_router.post("/houses/", response_model=HouseAPI)
39 | async def create_house(window: WindowAPI):
40 | house = HouseAPI(name="test_name", windows=[window])
41 | await house.insert(link_rule=WriteRules.WRITE)
42 | return house
43 |
44 |
45 | @house_router.post("/houses_with_window_link/", response_model=HouseAPI)
46 | async def create_houses_with_window_link(window: WindowInput):
47 | validator = (
48 | HouseAPI.model_validate if IS_PYDANTIC_V2 else HouseAPI.parse_obj
49 | )
50 | house = validator(
51 | dict(name="test_name", windows=[WindowAPI.link_from_id(window.id)])
52 | )
53 | await house.insert(link_rule=WriteRules.WRITE)
54 | return house
55 |
56 |
57 | @house_router.post("/houses_2/", response_model=HouseAPI)
58 | async def create_houses_2(house: HouseAPI):
59 | await house.insert(link_rule=WriteRules.WRITE)
60 | return house
61 |
62 |
63 | @house_router.post(
64 | "/house",
65 | response_model=House,
66 | status_code=status.HTTP_201_CREATED,
67 | )
68 | async def create_house_new(house: House = Body(...)):
69 | person = Person(name="Bob")
70 | house.owner = person
71 | await house.save(link_rule=WriteRules.WRITE)
72 | await house.sync()
73 | return house
74 |
--------------------------------------------------------------------------------
/beanie/odm/queries/cursor.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import (
3 | Any,
4 | Dict,
5 | Generic,
6 | List,
7 | Optional,
8 | Type,
9 | TypeVar,
10 | cast,
11 | )
12 |
13 | from pydantic.main import BaseModel
14 |
15 | from beanie.odm.utils.parsing import parse_obj
16 |
17 | CursorResultType = TypeVar("CursorResultType")
18 |
19 |
20 | class BaseCursorQuery(Generic[CursorResultType]):
21 | """
22 | BaseCursorQuery class. Wrapper over AsyncCursor,
23 | which parse result with model
24 | """
25 |
26 | cursor = None
27 | lazy_parse = False
28 |
29 | @abstractmethod
30 | def get_projection_model(self) -> Optional[Type[BaseModel]]: ...
31 |
32 | @abstractmethod
33 | async def get_cursor(self): ...
34 |
35 | def _cursor_params(self): ...
36 |
37 | def __aiter__(self):
38 | return self
39 |
40 | async def __anext__(self) -> CursorResultType:
41 | if self.cursor is None:
42 | self.cursor = await self.get_cursor()
43 | next_item = await self.cursor.__anext__()
44 | projection = self.get_projection_model()
45 | if projection is None:
46 | return next_item
47 | return parse_obj(projection, next_item, lazy_parse=self.lazy_parse) # type: ignore
48 |
49 | @abstractmethod
50 | def _get_cache(self) -> List[Dict[str, Any]]: ...
51 |
52 | @abstractmethod
53 | def _set_cache(self, data): ...
54 |
55 | async def to_list(
56 | self, length: Optional[int] = None
57 | ) -> List[CursorResultType]: # noqa
58 | """
59 | Get list of documents
60 |
61 | :param length: Optional[int] - length of the list
62 | :return: Union[List[BaseModel], List[Dict[str, Any]]]
63 | """
64 | cursor = await self.get_cursor()
65 | pymongo_list: List[Dict[str, Any]] = self._get_cache()
66 |
67 | if pymongo_list is None:
68 | pymongo_list = await cursor.to_list(length)
69 | self._set_cache(pymongo_list)
70 | projection = self.get_projection_model()
71 | if projection is not None:
72 | return cast(
73 | List[CursorResultType],
74 | [
75 | parse_obj(projection, i, lazy_parse=self.lazy_parse)
76 | for i in pymongo_list
77 | ],
78 | )
79 | return cast(List[CursorResultType], pymongo_list)
80 |
--------------------------------------------------------------------------------
/tests/odm/query/test_update_methods.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.update.general import Max
2 | from beanie.odm.queries.update import UpdateMany, UpdateQuery
3 | from tests.odm.models import Sample
4 |
5 |
6 | async def test_set(session):
7 | q = Sample.find_many(Sample.integer == 1).set(
8 | {Sample.integer: 100}, session=session
9 | )
10 |
11 | assert isinstance(q, UpdateQuery)
12 | assert isinstance(q, UpdateMany)
13 | assert q.session == session
14 |
15 | assert q.update_query == {"$set": {"integer": 100}}
16 |
17 | q = (
18 | Sample.find_many(Sample.integer == 1)
19 | .update(Max({Sample.integer: 10}))
20 | .set({Sample.integer: 100})
21 | )
22 |
23 | assert isinstance(q, UpdateQuery)
24 | assert isinstance(q, UpdateMany)
25 |
26 | assert q.update_query == {
27 | "$max": {"integer": 10},
28 | "$set": {"integer": 100},
29 | }
30 |
31 |
32 | async def test_current_date(session):
33 | q = Sample.find_many(Sample.integer == 1).current_date(
34 | {Sample.timestamp: "timestamp"}, session=session
35 | )
36 |
37 | assert isinstance(q, UpdateQuery)
38 | assert isinstance(q, UpdateMany)
39 | assert q.session == session
40 |
41 | assert q.update_query == {"$currentDate": {"timestamp": "timestamp"}}
42 |
43 | q = (
44 | Sample.find_many(Sample.integer == 1)
45 | .update(Max({Sample.integer: 10}))
46 | .current_date({Sample.timestamp: "timestamp"})
47 | )
48 |
49 | assert isinstance(q, UpdateQuery)
50 | assert isinstance(q, UpdateMany)
51 |
52 | assert q.update_query == {
53 | "$max": {"integer": 10},
54 | "$currentDate": {"timestamp": "timestamp"},
55 | }
56 |
57 |
58 | async def test_inc(session):
59 | q = Sample.find_many(Sample.integer == 1).inc(
60 | {Sample.integer: 100}, session=session
61 | )
62 |
63 | assert isinstance(q, UpdateQuery)
64 | assert isinstance(q, UpdateMany)
65 | assert q.session == session
66 |
67 | assert q.update_query == {"$inc": {"integer": 100}}
68 |
69 | q = (
70 | Sample.find_many(Sample.integer == 1)
71 | .update(Max({Sample.integer: 10}))
72 | .inc({Sample.integer: 100})
73 | )
74 |
75 | assert isinstance(q, UpdateQuery)
76 | assert isinstance(q, UpdateMany)
77 |
78 | assert q.update_query == {
79 | "$max": {"integer": 10},
80 | "$inc": {"integer": 100},
81 | }
82 |
--------------------------------------------------------------------------------
/tests/fastapi/test_api.py:
--------------------------------------------------------------------------------
1 | from tests.fastapi.models import WindowAPI
2 |
3 |
4 | async def test_create_window(api_client):
5 | payload = {"x": 10, "y": 20}
6 | resp = await api_client.post("/v1/windows/", json=payload)
7 | resp_json = resp.json()
8 | assert resp_json["x"] == 10
9 | assert resp_json["y"] == 20
10 |
11 |
12 | async def test_get_window(api_client):
13 | payload = {"x": 10, "y": 20}
14 | resp = await api_client.post("/v1/windows/", json=payload)
15 | data1 = resp.json()
16 | window_id = data1["_id"]
17 | resp2 = await api_client.get(f"/v1/windows/{window_id}")
18 | data2 = resp2.json()
19 | assert data2 == data1
20 |
21 |
22 | async def test_create_house(api_client):
23 | payload = {"x": 10, "y": 20}
24 | resp = await api_client.post("/v1/houses/", json=payload)
25 | resp_json = resp.json()
26 | assert len(resp_json["windows"]) == 1
27 |
28 |
29 | async def test_create_house_with_window_link(api_client):
30 | payload = {"x": 10, "y": 20}
31 | resp = await api_client.post("/v1/windows/", json=payload)
32 |
33 | window_id = resp.json()["_id"]
34 |
35 | payload = {"id": window_id}
36 | resp = await api_client.post("/v1/houses_with_window_link/", json=payload)
37 | resp_json = resp.json()
38 | assert resp_json["windows"][0]["collection"] == "WindowAPI"
39 |
40 |
41 | async def test_create_house_2(api_client):
42 | window = WindowAPI(x=10, y=10)
43 | await window.insert()
44 | payload = {"name": "TEST", "windows": [str(window.id)]}
45 | resp = await api_client.post("/v1/houses_2/", json=payload)
46 | resp_json = resp.json()
47 | assert len(resp_json["windows"]) == 1
48 |
49 |
50 | async def test_revision_id(api_client):
51 | payload = {"x": 10, "y": 20}
52 | resp = await api_client.post("/v1/windows_2/", json=payload)
53 | resp_json = resp.json()
54 | assert "revision_id" not in resp_json
55 | assert resp_json == {"x": 10, "y": 20, "_id": resp_json["_id"]}
56 |
57 |
58 | async def test_create_house_new(api_client):
59 | payload = {
60 | "name": "FreshHouse",
61 | "owner": {"name": "will_be_overridden_to_Bob"},
62 | }
63 | resp = await api_client.post("/v1/house", json=payload)
64 | resp_json = resp.json()
65 |
66 | assert resp_json["name"] == payload["name"]
67 | assert resp_json["owner"]["name"] == payload["owner"]["name"][-3:]
68 | assert resp_json["owner"]["house"]["collection"] == "House"
69 |
--------------------------------------------------------------------------------
/beanie/operators.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.array import All, ElemMatch, Size
2 | from beanie.odm.operators.find.bitwise import (
3 | BitsAllClear,
4 | BitsAllSet,
5 | BitsAnyClear,
6 | BitsAnySet,
7 | )
8 | from beanie.odm.operators.find.comparison import (
9 | GT,
10 | GTE,
11 | LT,
12 | LTE,
13 | NE,
14 | Eq,
15 | In,
16 | NotIn,
17 | )
18 | from beanie.odm.operators.find.element import Exists, Type
19 | from beanie.odm.operators.find.evaluation import (
20 | Expr,
21 | JsonSchema,
22 | Mod,
23 | RegEx,
24 | Text,
25 | Where,
26 | )
27 | from beanie.odm.operators.find.geospatial import (
28 | Box,
29 | GeoIntersects,
30 | GeoWithin,
31 | GeoWithinTypes,
32 | Near,
33 | NearSphere,
34 | )
35 | from beanie.odm.operators.find.logical import And, Nor, Not, Or
36 | from beanie.odm.operators.update.array import (
37 | AddToSet,
38 | Pop,
39 | Pull,
40 | PullAll,
41 | Push,
42 | )
43 | from beanie.odm.operators.update.bitwise import Bit
44 | from beanie.odm.operators.update.general import (
45 | CurrentDate,
46 | Inc,
47 | Max,
48 | Min,
49 | Mul,
50 | Rename,
51 | Set,
52 | SetOnInsert,
53 | Unset,
54 | )
55 |
56 | __all__ = [
57 | # Find
58 | # Array
59 | "All",
60 | "ElemMatch",
61 | "Size",
62 | # Bitwise
63 | "BitsAllClear",
64 | "BitsAllSet",
65 | "BitsAnyClear",
66 | "BitsAnySet",
67 | # Comparison
68 | "Eq",
69 | "GT",
70 | "GTE",
71 | "In",
72 | "NotIn",
73 | "LT",
74 | "LTE",
75 | "NE",
76 | # Element
77 | "Exists",
78 | "Type",
79 | "Type",
80 | # Evaluation
81 | "Expr",
82 | "JsonSchema",
83 | "Mod",
84 | "RegEx",
85 | "Text",
86 | "Where",
87 | # Geospatial
88 | "GeoIntersects",
89 | "GeoWithinTypes",
90 | "GeoWithin",
91 | "Box",
92 | "Near",
93 | "NearSphere",
94 | # Logical
95 | "Or",
96 | "And",
97 | "Nor",
98 | "Not",
99 | # Update
100 | # Array
101 | "AddToSet",
102 | "Pop",
103 | "Pull",
104 | "Push",
105 | "PullAll",
106 | # Bitwise
107 | "Bit",
108 | # General
109 | "Set",
110 | "CurrentDate",
111 | "Inc",
112 | "Min",
113 | "Max",
114 | "Mul",
115 | "Rename",
116 | "SetOnInsert",
117 | "Unset",
118 | ]
119 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_aggregate.py:
--------------------------------------------------------------------------------
1 | from pydantic import Field
2 | from pydantic.main import BaseModel
3 |
4 | from tests.odm.models import DocumentTestModel
5 |
6 |
7 | async def test_aggregate(documents):
8 | await documents(4, "uno")
9 | await documents(2, "dos")
10 | await documents(1, "cuatro")
11 | result = await DocumentTestModel.aggregate(
12 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}]
13 | ).to_list()
14 | assert len(result) == 3
15 | assert {"_id": "cuatro", "total": 0} in result
16 | assert {"_id": "dos", "total": 1} in result
17 | assert {"_id": "uno", "total": 6} in result
18 |
19 |
20 | async def test_aggregate_with_filter(documents):
21 | await documents(4, "uno")
22 | await documents(2, "dos")
23 | await documents(1, "cuatro")
24 | result = (
25 | await DocumentTestModel.find(DocumentTestModel.test_int >= 1)
26 | .aggregate(
27 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}]
28 | )
29 | .to_list()
30 | )
31 | assert len(result) == 2
32 | assert {"_id": "dos", "total": 1} in result
33 | assert {"_id": "uno", "total": 6} in result
34 |
35 |
36 | async def test_aggregate_with_item_model(documents):
37 | class OutputItem(BaseModel):
38 | id: str = Field(None, alias="_id")
39 | total: int
40 |
41 | await documents(4, "uno")
42 | await documents(2, "dos")
43 | await documents(1, "cuatro")
44 | ids = []
45 | async for i in DocumentTestModel.aggregate(
46 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}],
47 | projection_model=OutputItem,
48 | ):
49 | if i.id == "cuatro":
50 | assert i.total == 0
51 | elif i.id == "dos":
52 | assert i.total == 1
53 | elif i.id == "uno":
54 | assert i.total == 6
55 | else:
56 | raise KeyError
57 | ids.append(i.id)
58 | assert set(ids) == {"cuatro", "dos", "uno"}
59 |
60 |
61 | async def test_aggregate_with_session(documents, session):
62 | await documents(4, "uno")
63 | await documents(2, "dos")
64 | await documents(1, "cuatro")
65 | result = await DocumentTestModel.aggregate(
66 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}],
67 | session=session,
68 | ).to_list()
69 | assert len(result) == 3
70 | assert {"_id": "cuatro", "total": 0} in result
71 | assert {"_id": "dos", "total": 1} in result
72 | assert {"_id": "uno", "total": 6} in result
73 |
--------------------------------------------------------------------------------
/tests/odm/test_expression_fields.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.enums import SortDirection
2 | from beanie.odm.operators.find.comparison import In, NotIn
3 | from tests.odm.models import Sample
4 |
5 |
6 | def test_nesting():
7 | assert Sample.id == "_id"
8 |
9 | q = Sample.find_many(Sample.integer == 1)
10 | assert q.get_filter_query() == {"integer": 1}
11 | assert Sample.integer == "integer"
12 |
13 | q = Sample.find_many(Sample.nested.integer == 1)
14 | assert q.get_filter_query() == {"nested.integer": 1}
15 | assert Sample.nested.integer == "nested.integer"
16 |
17 | q = Sample.find_many(Sample.union.s == "test")
18 | assert q.get_filter_query() == {"union.s": "test"}
19 | assert Sample.union.s == "union.s"
20 |
21 | q = Sample.find_many(Sample.nested.optional == None) # noqa
22 | assert q.get_filter_query() == {"nested.optional": None}
23 | assert Sample.nested.optional == "nested.optional"
24 |
25 | q = Sample.find_many(Sample.nested.integer == 1).find_many(
26 | Sample.nested.union.s == "test"
27 | )
28 | assert q.get_filter_query() == {
29 | "$and": [{"nested.integer": 1}, {"nested.union.s": "test"}]
30 | }
31 |
32 |
33 | def test_eq():
34 | q = Sample.find_many(Sample.integer == 1)
35 | assert q.get_filter_query() == {"integer": 1}
36 |
37 |
38 | def test_gt():
39 | q = Sample.find_many(Sample.integer > 1)
40 | assert q.get_filter_query() == {"integer": {"$gt": 1}}
41 |
42 |
43 | def test_gte():
44 | q = Sample.find_many(Sample.integer >= 1)
45 | assert q.get_filter_query() == {"integer": {"$gte": 1}}
46 |
47 |
48 | def test_in():
49 | q = Sample.find_many(In(Sample.integer, [1, 2, 3, 4]))
50 | assert dict(q.get_filter_query()) == {"integer": {"$in": [1, 2, 3, 4]}}
51 |
52 |
53 | def test_lt():
54 | q = Sample.find_many(Sample.integer < 1)
55 | assert q.get_filter_query() == {"integer": {"$lt": 1}}
56 |
57 |
58 | def test_lte():
59 | q = Sample.find_many(Sample.integer <= 1)
60 | assert q.get_filter_query() == {"integer": {"$lte": 1}}
61 |
62 |
63 | def test_ne():
64 | q = Sample.find_many(Sample.integer != 1)
65 | assert q.get_filter_query() == {"integer": {"$ne": 1}}
66 |
67 |
68 | def test_nin():
69 | q = Sample.find_many(NotIn(Sample.integer, [1, 2, 3, 4]))
70 | assert dict(q.get_filter_query()) == {"integer": {"$nin": [1, 2, 3, 4]}}
71 |
72 |
73 | def test_pos():
74 | q = +Sample.integer
75 | assert q == ("integer", SortDirection.ASCENDING)
76 |
77 |
78 | def test_neg():
79 | q = -Sample.integer
80 | assert q == ("integer", SortDirection.DESCENDING)
81 |
--------------------------------------------------------------------------------
/docs/publishing.md:
--------------------------------------------------------------------------------
1 | # Publishing a New Version of Beanie
2 |
3 | This guide provides step-by-step instructions on how to prepare and publish a new version of Beanie. Before starting, ensure that you have the necessary permissions to update the repository.
4 |
5 | ## 1. Prepare a New Version PR
6 |
7 | To publish a new version of Beanie, you need to create a pull request (PR) with the following updates:
8 |
9 | ### 1.1 Update the Version in `pyproject.toml`
10 |
11 | 1. Open the [`pyproject.toml`](https://github.com/BeanieODM/beanie/blob/main/pyproject.toml) file.
12 | 2. Update the `version` field to the new version number.
13 |
14 | ### 1.2 Update the Version in the `__init__.py` File
15 |
16 | 1. Open the [`__init__.py`](https://github.com/BeanieODM/beanie/blob/main/beanie/__init__.py) file.
17 | 2. Update the `__version__` variable to the new version number.
18 |
19 | ### 1.3 Update the Changelog
20 |
21 | To update the changelog, follow these steps:
22 |
23 | #### 1.3.1 Set the Version in the Changelog Script
24 |
25 | 1. Open the [`scripts/generate_changelog.py`](https://github.com/BeanieODM/beanie/blob/main/scripts/generate_changelog.py) file.
26 | 2. Set the `current_version` to the current version and `new_version` to the new version in the script.
27 |
28 | #### 1.3.2 Run the Changelog Script
29 |
30 | 1. Run the following command to generate the updated changelog:
31 |
32 | ```bash
33 | python scripts/generate_changelog.py
34 | ```
35 |
36 | 2. The script will generate the changelog for the new version.
37 |
38 | #### 1.3.3 Update the Changelog File
39 |
40 | 1. Open the [`changelog.md`](https://github.com/BeanieODM/beanie/blob/main/docs/changelog.md) file.
41 | 2. Copy the generated changelog and paste it at the top of the `changelog.md` file.
42 |
43 | ### 1.4 Create and Submit the PR
44 |
45 | Once you have made the necessary updates, create a PR with a descriptive title and summary of the changes. Ensure that all checks pass before merging the PR.
46 |
47 | ## 2. Publishing the Version
48 |
49 | After the PR has been merged, respective GH action will publish it to the PyPI.
50 |
51 | ## 3. Create a Git Tag and GitHub Release
52 |
53 | After the version has been published:
54 |
55 | 1. Pull the latest changes from the `master` branch:
56 | ```bash
57 | git pull origin master
58 | ```
59 |
60 | 2. Create a new Git tag with the version number:
61 | ```bash
62 | git tag -a v1.xx.y -m "Release v1.xx.y"
63 | ```
64 |
65 | 3. Push the tag to the remote repository:
66 | ```bash
67 | git push origin v1.xx.y
68 | ```
69 |
70 | 4. Create a new release on GitHub using the GitHub interface.
--------------------------------------------------------------------------------
/docs/tutorial/views.md:
--------------------------------------------------------------------------------
1 | # Views
2 |
3 | Virtual views are aggregation pipelines stored in MongoDB that act as collections for reading operations.
4 | You can use the `View` class the same way as `Document` for `find` and `aggregate` operations.
5 |
6 | ## Here are some examples.
7 |
8 | Create a view:
9 |
10 | ```python
11 | from pydantic import Field
12 |
13 | from beanie import Document, View
14 |
15 |
16 | class Bike(Document):
17 | type: str
18 | frame_size: int
19 | is_new: bool
20 |
21 |
22 | class Metrics(View):
23 | type: str = Field(alias="_id")
24 | number: int
25 | new: int
26 |
27 | class Settings:
28 | source = Bike
29 | pipeline = [
30 | {
31 | "$group": {
32 | "_id": "$type",
33 | "number": {"$sum": 1},
34 | "new": {"$sum": {"$cond": ["$is_new", 1, 0]}}
35 | }
36 | },
37 | ]
38 |
39 | ```
40 |
41 | Initialize Beanie:
42 |
43 | ```python
44 | from pymongo import AsyncMongoClient
45 |
46 | from beanie import init_beanie
47 |
48 |
49 | async def main():
50 | uri = "mongodb://beanie:beanie@localhost:27017"
51 | client = AsyncMongoClient(uri)
52 | db = client.bikes
53 |
54 | await init_beanie(
55 | database=db,
56 | document_models=[Bike, Metrics],
57 | recreate_views=True,
58 | )
59 | ```
60 |
61 | Create bikes:
62 |
63 | ```python
64 | await Bike(type="Mountain", frame_size=54, is_new=True).insert()
65 | await Bike(type="Mountain", frame_size=60, is_new=False).insert()
66 | await Bike(type="Road", frame_size=52, is_new=True).insert()
67 | await Bike(type="Road", frame_size=54, is_new=True).insert()
68 | await Bike(type="Road", frame_size=58, is_new=False).insert()
69 | ```
70 |
71 | Find metrics for `type == "Road"`
72 |
73 | ```python
74 | results = await Metrics.find(Metrics.type == "Road").to_list()
75 | print(results)
76 |
77 | >> [Metrics(type='Road', number=3, new=2)]
78 | ```
79 |
80 | Aggregate over metrics to get the count of all the new bikes:
81 |
82 | ```python
83 | results = await Metrics.aggregate([{
84 | "$group": {
85 | "_id": None,
86 | "new_total": {"$sum": "$new"}
87 | }
88 | }]).to_list()
89 |
90 | print(results)
91 |
92 | >> [{'_id': None, 'new_total': 3}]
93 | ```
94 |
95 | A better result can be achieved by using find query aggregation syntactic sugar:
96 |
97 | ```python
98 | results = await Metrics.all().sum(Metrics.new)
99 |
100 | print(results)
101 |
102 | >> 3
103 | ```
104 |
--------------------------------------------------------------------------------
/beanie/odm/utils/typing.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import sys
3 | from typing import Any, Dict, Optional, Tuple, Type
4 |
5 | from beanie.odm.fields import IndexedAnnotation
6 |
7 | from .pydantic import IS_PYDANTIC_V2, get_field_type
8 |
9 | if sys.version_info >= (3, 8):
10 | from typing import get_args, get_origin
11 | else:
12 | from typing_extensions import get_args, get_origin
13 |
14 |
15 | def extract_id_class(annotation) -> Type[Any]:
16 | if get_origin(annotation) is not None:
17 | try:
18 | annotation = next(
19 | arg for arg in get_args(annotation) if arg is not type(None)
20 | )
21 | except StopIteration:
22 | annotation = None
23 | if inspect.isclass(annotation):
24 | return annotation
25 | raise ValueError("Unknown annotation: {}".format(annotation))
26 |
27 |
28 | def get_index_attributes(field) -> Optional[Tuple[int, Dict[str, Any]]]:
29 | """Gets the index attributes from the field, if it is indexed.
30 |
31 | :param field: The field to get the index attributes from.
32 |
33 | :return: The index attributes, if the field is indexed. Otherwise, None.
34 | """
35 | # For fields that are directly typed with `Indexed()`, the type will have
36 | # an `_indexed` attribute.
37 | field_type = get_field_type(field)
38 | if hasattr(field_type, "_indexed"):
39 | return getattr(field_type, "_indexed", None)
40 |
41 | # For fields that are use `Indexed` within `Annotated`, the field will have
42 | # metadata that might contain an `IndexedAnnotation` instance.
43 | if IS_PYDANTIC_V2:
44 | # In Pydantic 2, the field has a `metadata` attribute with
45 | # the annotations.
46 | metadata = getattr(field, "metadata", None)
47 | elif hasattr(field, "annotation") and hasattr(
48 | field.annotation, "__metadata__"
49 | ):
50 | # In Pydantic 1, the field has an `annotation` attribute with the
51 | # type assigned to the field. If the type is annotated, it will
52 | # have a `__metadata__` attribute with the annotations.
53 | metadata = field.annotation.__metadata__
54 | else:
55 | return None
56 |
57 | if metadata is None:
58 | return None
59 |
60 | try:
61 | iter(metadata)
62 | except TypeError:
63 | return None
64 |
65 | indexed_annotation = next(
66 | (
67 | annotation
68 | for annotation in metadata
69 | if isinstance(annotation, IndexedAnnotation)
70 | ),
71 | None,
72 | )
73 |
74 | return getattr(indexed_annotation, "_indexed", None)
75 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_multi_model.py:
--------------------------------------------------------------------------------
1 | from tests.odm.models import (
2 | DocumentMultiModelOne,
3 | DocumentMultiModelTwo,
4 | DocumentUnion,
5 | )
6 |
7 |
8 | class TestMultiModel:
9 | async def test_multi_model(self):
10 | doc_1 = await DocumentMultiModelOne().insert()
11 | doc_2 = await DocumentMultiModelTwo().insert()
12 |
13 | new_doc_1 = await DocumentMultiModelOne.get(doc_1.id)
14 | new_doc_2 = await DocumentMultiModelTwo.get(doc_2.id)
15 |
16 | assert new_doc_1 is not None
17 | assert new_doc_2 is not None
18 |
19 | new_doc_1 = await DocumentMultiModelTwo.get(doc_1.id)
20 | new_doc_2 = await DocumentMultiModelOne.get(doc_2.id)
21 |
22 | assert new_doc_1 is None
23 | assert new_doc_2 is None
24 |
25 | new_docs_1 = await DocumentMultiModelOne.find({}).to_list()
26 | new_docs_2 = await DocumentMultiModelTwo.find({}).to_list()
27 |
28 | assert len(new_docs_1) == 1
29 | assert len(new_docs_2) == 1
30 |
31 | await DocumentMultiModelOne.update_all({"$set": {"shared": 100}})
32 |
33 | new_doc_1 = await DocumentMultiModelOne.get(doc_1.id)
34 | new_doc_2 = await DocumentMultiModelTwo.get(doc_2.id)
35 |
36 | assert new_doc_1.shared == 100
37 | assert new_doc_2.shared == 0
38 |
39 | async def test_union_doc(self):
40 | await DocumentMultiModelOne().insert()
41 | await DocumentMultiModelTwo().insert()
42 | await DocumentMultiModelOne().insert()
43 | await DocumentMultiModelTwo().insert()
44 |
45 | docs = await DocumentUnion.all().to_list()
46 | assert isinstance(docs[0], DocumentMultiModelOne)
47 | assert isinstance(docs[1], DocumentMultiModelTwo)
48 | assert isinstance(docs[2], DocumentMultiModelOne)
49 | assert isinstance(docs[3], DocumentMultiModelTwo)
50 |
51 | async def test_union_doc_aggregation(self):
52 | await DocumentMultiModelOne().insert()
53 | await DocumentMultiModelTwo().insert()
54 | await DocumentMultiModelOne().insert()
55 | await DocumentMultiModelTwo().insert()
56 |
57 | docs = await DocumentUnion.aggregate(
58 | [{"$match": {"$expr": {"$eq": ["$int_filed", 0]}}}]
59 | ).to_list()
60 | assert len(docs) == 2
61 |
62 | async def test_union_doc_link(self):
63 | doc_1 = await DocumentMultiModelOne().insert()
64 | await DocumentMultiModelTwo(linked_doc=doc_1).insert()
65 |
66 | docs = await DocumentMultiModelTwo.find({}, fetch_links=True).to_list()
67 | assert isinstance(docs[0].linked_doc, DocumentMultiModelOne)
68 |
--------------------------------------------------------------------------------
/tests/typing/decorators.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Coroutine
2 |
3 | from typing_extensions import Protocol, TypeAlias, assert_type
4 |
5 | from beanie import Document
6 | from beanie.odm.actions import EventTypes, wrap_with_actions
7 | from beanie.odm.utils.self_validation import validate_self_before
8 | from beanie.odm.utils.state import (
9 | previous_saved_state_needed,
10 | save_state_after,
11 | saved_state_needed,
12 | )
13 |
14 |
15 | def sync_func(doc_self: Document, arg1: str, arg2: int, /) -> Document:
16 | """
17 | Models `Document` sync method that expects self
18 | """
19 | raise NotImplementedError
20 |
21 |
22 | SyncFunc: TypeAlias = Callable[[Document, str, int], Document]
23 |
24 |
25 | async def async_func(doc_self: Document, arg1: str, arg2: int, /) -> Document:
26 | """
27 | Models `Document` async method that expects self
28 | """
29 | raise NotImplementedError
30 |
31 |
32 | AsyncFunc: TypeAlias = Callable[
33 | [Document, str, int], Coroutine[Any, Any, Document]
34 | ]
35 |
36 |
37 | def test_wrap_with_actions_preserves_signature() -> None:
38 | assert_type(async_func, AsyncFunc)
39 | assert_type(wrap_with_actions(EventTypes.SAVE)(async_func), AsyncFunc)
40 |
41 |
42 | def test_save_state_after_preserves_signature() -> None:
43 | assert_type(async_func, AsyncFunc)
44 | assert_type(save_state_after(async_func), AsyncFunc)
45 |
46 |
47 | def test_validate_self_before_preserves_signature() -> None:
48 | assert_type(async_func, AsyncFunc)
49 | assert_type(validate_self_before(async_func), AsyncFunc)
50 |
51 |
52 | def test_saved_state_needed_preserves_signature() -> None:
53 | assert_type(async_func, AsyncFunc)
54 | assert_type(saved_state_needed(async_func), AsyncFunc)
55 |
56 | assert_type(sync_func, SyncFunc)
57 | assert_type(saved_state_needed(sync_func), SyncFunc)
58 |
59 |
60 | def test_previous_saved_state_needed_preserves_signature() -> None:
61 | assert_type(async_func, AsyncFunc)
62 | assert_type(previous_saved_state_needed(async_func), AsyncFunc)
63 |
64 | assert_type(sync_func, SyncFunc)
65 | assert_type(previous_saved_state_needed(sync_func), SyncFunc)
66 |
67 |
68 | class ExpectsDocumentSelf(Protocol):
69 | def __call__(self, doc_self: Document, /) -> Any: ...
70 |
71 |
72 | def test_document_insert_expects_self() -> None:
73 | test_insert: ExpectsDocumentSelf = Document.insert # noqa: F841
74 |
75 |
76 | def test_document_save_expects_self() -> None:
77 | test_insert: ExpectsDocumentSelf = Document.save # noqa: F841
78 |
79 |
80 | def test_document_replace_expects_self() -> None:
81 | test_insert: ExpectsDocumentSelf = Document.replace # noqa: F841
82 |
--------------------------------------------------------------------------------
/tests/odm/documents/test_validation_on_save.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 | from bson import ObjectId
5 | from pydantic import BaseModel, ValidationError
6 |
7 | from beanie import PydanticObjectId
8 | from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
9 | from tests.odm.models import (
10 | DocumentWithValidationOnSave,
11 | Lock,
12 | WindowWithValidationOnSave,
13 | )
14 |
15 |
16 | async def test_validate_on_insert():
17 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
18 | doc.num_1 = "wrong_value"
19 | with pytest.raises(ValidationError):
20 | await doc.insert()
21 |
22 |
23 | async def test_validate_on_replace():
24 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
25 | await doc.insert()
26 | doc.num_1 = "wrong_value"
27 | with pytest.raises(ValidationError):
28 | await doc.replace()
29 |
30 |
31 | async def test_validate_on_save_changes():
32 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
33 | await doc.insert()
34 | doc.num_1 = "wrong_value"
35 | with pytest.raises(ValidationError):
36 | await doc.save_changes()
37 |
38 |
39 | async def test_validate_on_save_keep_the_id_type():
40 | class UpdateModel(BaseModel):
41 | num_1: Optional[int] = None
42 | related: Optional[PydanticObjectId] = None
43 |
44 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
45 | await doc.insert()
46 | update = UpdateModel(related=PydanticObjectId())
47 | if IS_PYDANTIC_V2:
48 | doc = doc.model_copy(update=update.model_dump(exclude_unset=True))
49 | else:
50 | doc = doc.copy(update=update.dict(exclude_unset=True))
51 | doc.num_2 = 1000
52 | await doc.save()
53 | in_db = (
54 | await DocumentWithValidationOnSave.get_pymongo_collection().find_one(
55 | {"_id": doc.id}
56 | )
57 | )
58 | assert isinstance(in_db["related"], ObjectId)
59 | new_doc = await DocumentWithValidationOnSave.get(doc.id)
60 | assert isinstance(new_doc.related, PydanticObjectId)
61 |
62 |
63 | async def test_validate_on_save_action():
64 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
65 | await doc.insert()
66 | assert doc.num_2 == 3
67 |
68 |
69 | async def test_validate_on_save_skip_action():
70 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
71 | await doc.insert(skip_actions=["num_2_plus_1"])
72 | assert doc.num_2 == 2
73 |
74 |
75 | async def test_validate_on_save_dbref():
76 | lock = Lock(k=1)
77 | await lock.insert()
78 | window = WindowWithValidationOnSave(
79 | x=1,
80 | y=1,
81 | lock=lock.to_ref(), # this is what exactly we want to test
82 | )
83 | await window.insert()
84 |
--------------------------------------------------------------------------------
/beanie/odm/operators/find/array.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Any, Optional
3 |
4 | from beanie.odm.operators.find import BaseFindOperator
5 |
6 |
7 | class BaseFindArrayOperator(BaseFindOperator, ABC): ...
8 |
9 |
10 | class All(BaseFindArrayOperator):
11 | """
12 | `$all` array query operator
13 |
14 | Example:
15 |
16 | ```python
17 | class Sample(Document):
18 | results: List[int]
19 |
20 | All(Sample.results, [80, 85])
21 | ```
22 |
23 | Will return query object like
24 |
25 | ```python
26 | {"results": {"$all": [80, 85]}}
27 | ```
28 |
29 | MongoDB doc:
30 |
31 | """
32 |
33 | def __init__(
34 | self,
35 | field,
36 | values: list,
37 | ):
38 | self.field = field
39 | self.values_list = values
40 |
41 | @property
42 | def query(self):
43 | return {self.field: {"$all": self.values_list}}
44 |
45 |
46 | class ElemMatch(BaseFindArrayOperator):
47 | """
48 | `$elemMatch` array query operator
49 |
50 | Example:
51 |
52 | ```python
53 | class Sample(Document):
54 | results: List[int]
55 |
56 | ElemMatch(Sample.results, {"$in": [80, 85]})
57 | ```
58 |
59 | Will return query object like
60 |
61 | ```python
62 | {"results": {"$elemMatch": {"$in": [80, 85]}}}
63 | ```
64 |
65 | MongoDB doc:
66 |
67 | """
68 |
69 | def __init__(
70 | self,
71 | field,
72 | expression: Optional[dict] = None,
73 | **kwargs: Any,
74 | ):
75 | self.field = field
76 |
77 | if expression is None:
78 | self.expression = kwargs
79 | else:
80 | self.expression = expression
81 |
82 | @property
83 | def query(self):
84 | return {self.field: {"$elemMatch": self.expression}}
85 |
86 |
87 | class Size(BaseFindArrayOperator):
88 | """
89 | `$size` array query operator
90 |
91 | Example:
92 |
93 | ```python
94 | class Sample(Document):
95 | results: List[int]
96 |
97 | Size(Sample.results, 2)
98 | ```
99 |
100 | Will return query object like
101 |
102 | ```python
103 | {"results": {"$size": 2}}
104 | ```
105 |
106 | MongoDB doc:
107 |
108 | """
109 |
110 | def __init__(
111 | self,
112 | field,
113 | num: int,
114 | ):
115 | self.field = field
116 | self.num = num
117 |
118 | @property
119 | def query(self):
120 | return {self.field: {"$size": self.num}}
121 |
--------------------------------------------------------------------------------
/docs/tutorial/indexes.md:
--------------------------------------------------------------------------------
1 | ## Indexes setup
2 |
3 | There are more than one way to set up indexes using Beanie
4 |
5 | ### Indexed function
6 |
7 | To set up an index over a single field, the `Indexed` function can be used to wrap the type
8 | and does not require a `Settings` class:
9 |
10 | ```python
11 | from beanie import Document, Indexed
12 |
13 |
14 | class Sample(Document):
15 | num: Annotated[int, Indexed()]
16 | description: str
17 | ```
18 |
19 | The `Indexed` function takes an optional `index_type` argument, which may be set to a pymongo index type:
20 |
21 | ```python
22 | import pymongo
23 |
24 | from beanie import Document, Indexed
25 |
26 |
27 | class Sample(Document):
28 | description: Annotated[str, Indexed(index_type=pymongo.TEXT)]
29 | ```
30 |
31 | The `Indexed` function also supports PyMongo's `IndexModel` kwargs arguments (see the [PyMongo Documentation](https://pymongo.readthedocs.io/en/stable/api/pymongo/operations.html#pymongo.operations.IndexModel) for details).
32 |
33 | For example, to create a `unique` index:
34 |
35 | ```python
36 | from beanie import Document, Indexed
37 |
38 |
39 | class Sample(Document):
40 | name: Annotated[str, Indexed(unique=True)]
41 | ```
42 |
43 | The `Indexed` function can also be used directly in the type annotation, by giving it the wrapped type as the first argument. Note that this might not work with some Pydantic V2 types, such as `UUID4` or `EmailStr`.
44 |
45 | ```python
46 | from beanie import Document, Indexed
47 |
48 |
49 | class Sample(Document):
50 | name: Indexed(str, unique=True)
51 | ```
52 |
53 | ### Multi-field indexes
54 |
55 | The `indexes` field of the inner `Settings` class is responsible for more complex indexes.
56 | It is a list where items can be:
57 |
58 | - Single key. Name of the document's field (this is equivalent to using the Indexed function described above without any additional arguments)
59 | - List of (key, direction) pairs. Key - string, name of the document's field. Direction - pymongo direction (
60 | example: `pymongo.ASCENDING`)
61 | - `pymongo.IndexModel` instance - the most flexible
62 | option. [PyMongo Documentation](https://pymongo.readthedocs.io/en/stable/api/pymongo/operations.html#pymongo.operations.IndexModel)
63 |
64 | ```python
65 | import pymongo
66 | from pymongo import IndexModel
67 |
68 | from beanie import Document
69 |
70 |
71 | class Sample(Document):
72 | test_int: int
73 | test_str: str
74 |
75 | class Settings:
76 | indexes = [
77 | "test_int",
78 | [
79 | ("test_int", pymongo.ASCENDING),
80 | ("test_str", pymongo.DESCENDING),
81 | ],
82 | IndexModel(
83 | [("test_str", pymongo.DESCENDING)],
84 | name="test_string_index_DESCENDING",
85 | ),
86 | ]
87 | ```
88 |
--------------------------------------------------------------------------------
/beanie/odm/queries/delete.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Any, Dict, Generator, Mapping, Optional, Type
2 |
3 | from pymongo import DeleteMany as DeleteManyPyMongo
4 | from pymongo import DeleteOne as DeleteOnePyMongo
5 | from pymongo.asynchronous.client_session import AsyncClientSession
6 | from pymongo.results import DeleteResult
7 |
8 | from beanie.odm.bulk import BulkWriter
9 | from beanie.odm.interfaces.clone import CloneInterface
10 | from beanie.odm.interfaces.session import SessionMethods
11 |
12 | if TYPE_CHECKING:
13 | from beanie.odm.documents import DocType
14 |
15 |
16 | class DeleteQuery(SessionMethods, CloneInterface):
17 | """
18 | Deletion Query
19 | """
20 |
21 | def __init__(
22 | self,
23 | document_model: Type["DocType"],
24 | find_query: Mapping[str, Any],
25 | bulk_writer: Optional[BulkWriter] = None,
26 | **pymongo_kwargs: Any,
27 | ):
28 | self.document_model = document_model
29 | self.find_query = find_query
30 | self.session: Optional[AsyncClientSession] = None
31 | self.bulk_writer = bulk_writer
32 | self.pymongo_kwargs: Dict[str, Any] = pymongo_kwargs
33 |
34 |
35 | class DeleteMany(DeleteQuery):
36 | def __await__(
37 | self,
38 | ) -> Generator[DeleteResult, None, Optional[DeleteResult]]:
39 | """
40 | Run the query
41 | :return:
42 | """
43 | if self.bulk_writer is None:
44 | return (
45 | yield from self.document_model.get_pymongo_collection()
46 | .delete_many(
47 | self.find_query,
48 | session=self.session,
49 | **self.pymongo_kwargs,
50 | )
51 | .__await__()
52 | )
53 | else:
54 | self.bulk_writer.add_operation(
55 | self.document_model,
56 | DeleteManyPyMongo(self.find_query, **self.pymongo_kwargs),
57 | )
58 | return None
59 |
60 |
61 | class DeleteOne(DeleteQuery):
62 | def __await__(
63 | self,
64 | ) -> Generator[DeleteResult, None, Optional[DeleteResult]]:
65 | """
66 | Run the query
67 | :return:
68 | """
69 | if self.bulk_writer is None:
70 | return (
71 | yield from self.document_model.get_pymongo_collection()
72 | .delete_one(
73 | self.find_query,
74 | session=self.session,
75 | **self.pymongo_kwargs,
76 | )
77 | .__await__()
78 | )
79 | else:
80 | self.bulk_writer.add_operation(
81 | self.document_model,
82 | DeleteOnePyMongo(self.find_query),
83 | **self.pymongo_kwargs,
84 | )
85 | return None
86 |
--------------------------------------------------------------------------------
/assets/logo/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/scripts/generate_changelog.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from dataclasses import dataclass
3 | from datetime import datetime
4 | from typing import List, Set
5 |
6 | import requests # type: ignore
7 |
8 |
9 | @dataclass
10 | class PullRequest:
11 | number: int
12 | title: str
13 | user: str
14 | user_url: str
15 | url: str
16 |
17 |
18 | class ChangelogGenerator:
19 | def __init__(
20 | self,
21 | username: str,
22 | repository: str,
23 | current_version: str,
24 | new_version: str,
25 | ):
26 | self.username = username
27 | self.repository = repository
28 | self.base_url = f"https://api.github.com/repos/{username}/{repository}"
29 | self.current_version = current_version
30 | self.new_version = new_version
31 | self.commits = self.get_commits_after_tag(current_version)
32 | self.prs = self.get_prs_for_commits(self.commits)
33 |
34 | def get_commits_after_tag(self, tag: str) -> List[str]:
35 | result = subprocess.run(
36 | ["git", "log", f"{tag}..HEAD", "--pretty=format:%H"],
37 | stdout=subprocess.PIPE,
38 | text=True,
39 | )
40 | return result.stdout.split()
41 |
42 | def get_pr_for_commit(self, commit_sha: str) -> PullRequest:
43 | url = f"{self.base_url}/commits/{commit_sha}/pulls"
44 | response = requests.get(url)
45 | response.raise_for_status()
46 | pr_data = response.json()[0]
47 | return PullRequest(
48 | number=pr_data["number"],
49 | title=pr_data["title"],
50 | user=pr_data["user"]["login"],
51 | user_url=pr_data["user"]["html_url"],
52 | url=pr_data["html_url"],
53 | )
54 |
55 | def get_prs_for_commits(self, commit_shas: List[str]) -> List[PullRequest]:
56 | prs: List[PullRequest] = []
57 | unique_prs: Set[int] = set()
58 | for commit_sha in commit_shas:
59 | pr = self.get_pr_for_commit(commit_sha)
60 | pr_id = pr.number
61 | if pr_id not in unique_prs:
62 | unique_prs.add(pr_id)
63 | prs.append(pr)
64 | prs.sort(key=lambda pr: pr.number, reverse=True)
65 | return prs
66 |
67 | def generate_changelog(self) -> str:
68 | markdown = f"\n## [{self.new_version}] - {datetime.now().strftime('%Y-%m-%d')}\n"
69 | for pr in self.prs:
70 | markdown += (
71 | f"### {pr.title.capitalize()}\n"
72 | f"- Author - [{pr.user}]({pr.user_url})\n"
73 | f"- PR <{pr.url}>\n"
74 | )
75 | markdown += f"\n[{self.new_version}]: https://pypi.org/project/{self.repository}/{self.new_version}\n"
76 | return markdown
77 |
78 |
79 | if __name__ == "__main__":
80 | generator = ChangelogGenerator(
81 | username="BeanieODM",
82 | repository="beanie",
83 | current_version="2.0.0",
84 | new_version="2.0.1",
85 | )
86 |
87 | changelog = generator.generate_changelog()
88 | print(changelog)
89 |
--------------------------------------------------------------------------------
/tests/migrations/test_free_fall.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.migrations.models import RunningDirections
7 | from beanie.odm.documents import Document
8 | from beanie.odm.models import InspectionStatuses
9 |
10 |
11 | class Tag(BaseModel):
12 | color: str
13 | name: str
14 |
15 |
16 | class OldNote(Document):
17 | name: str
18 | tag: Tag
19 |
20 | class Settings:
21 | name = "notes"
22 |
23 |
24 | class Note(Document):
25 | title: str
26 | tag: Tag
27 |
28 | class Settings:
29 | name = "notes"
30 |
31 |
32 | @pytest.fixture()
33 | async def notes(db):
34 | await init_beanie(database=db, document_models=[OldNote])
35 | await OldNote.delete_all()
36 | for i in range(10):
37 | note = OldNote(name=str(i), tag=Tag(name="test", color="red"))
38 | await note.insert()
39 | yield
40 | await OldNote.delete_all()
41 |
42 |
43 | async def test_migration_free_fall(settings, notes, db):
44 | if not db.client.is_mongos and not len(db.client.nodes) > 1:
45 | return pytest.skip(
46 | "MongoDB server does not support transactions as it is neighter a mongos instance not a replica set."
47 | )
48 |
49 | migration_settings = MigrationSettings(
50 | connection_uri=settings.mongodb_dsn,
51 | database_name=settings.mongodb_db_name,
52 | path="tests/migrations/migrations_for_test/free_fall",
53 | )
54 | await run_migrate(migration_settings)
55 |
56 | await init_beanie(database=db, document_models=[Note])
57 | inspection = await Note.inspect_collection()
58 | assert inspection.status == InspectionStatuses.OK
59 | note = await Note.find_one({})
60 | assert note.title == "0"
61 |
62 | migration_settings.direction = RunningDirections.BACKWARD
63 | await run_migrate(migration_settings)
64 | inspection = await OldNote.inspect_collection()
65 | assert inspection.status == InspectionStatuses.OK
66 | note = await OldNote.find_one({})
67 | assert note.name == "0"
68 |
69 |
70 | async def test_migration_free_fall_no_use_transactions(settings, notes, db):
71 | migration_settings = MigrationSettings(
72 | connection_uri=settings.mongodb_dsn,
73 | database_name=settings.mongodb_db_name,
74 | path="tests/migrations/migrations_for_test/free_fall",
75 | use_transaction=False,
76 | )
77 | await run_migrate(migration_settings)
78 |
79 | await init_beanie(database=db, document_models=[Note])
80 | inspection = await Note.inspect_collection()
81 | assert inspection.status == InspectionStatuses.OK
82 | note = await Note.find_one({})
83 | assert note.title == "0"
84 |
85 | migration_settings.direction = RunningDirections.BACKWARD
86 | await run_migrate(migration_settings)
87 | inspection = await OldNote.inspect_collection()
88 | assert inspection.status == InspectionStatuses.OK
89 | note = await OldNote.find_one({})
90 | assert note.name == "0"
91 |
--------------------------------------------------------------------------------
/tests/migrations/test_remove_indexes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic.main import BaseModel
3 |
4 | from beanie import init_beanie
5 | from beanie.executors.migrate import MigrationSettings, run_migrate
6 | from beanie.odm.documents import Document
7 |
8 |
9 | class Tag(BaseModel):
10 | color: str
11 | name: str
12 |
13 |
14 | class OldNote(Document):
15 | title: str
16 | tag: Tag
17 |
18 | class Settings:
19 | name = "notes"
20 | indexes = ["title"]
21 |
22 |
23 | class Note(Document):
24 | title: str
25 | tag: Tag
26 |
27 | class Settings:
28 | name = "notes"
29 |
30 |
31 | @pytest.fixture()
32 | async def notes(db):
33 | await init_beanie(database=db, document_models=[OldNote])
34 | await OldNote.delete_all()
35 | for i in range(10):
36 | note = OldNote(title=str(i), tag=Tag(name="test", color="red"))
37 | await note.insert()
38 | yield
39 | await OldNote.delete_all()
40 |
41 |
42 | async def test_remove_index_allowed(settings, notes, db):
43 | migration_settings = MigrationSettings(
44 | connection_uri=settings.mongodb_dsn,
45 | database_name=settings.mongodb_db_name,
46 | path="tests/migrations/migrations_for_test/remove_index",
47 | allow_index_dropping=True,
48 | )
49 | await run_migrate(migration_settings)
50 |
51 | await init_beanie(
52 | database=db, document_models=[Note], allow_index_dropping=False
53 | )
54 | collection = Note.get_pymongo_collection()
55 | index_info = await collection.index_information()
56 | assert index_info == {
57 | "_id_": {"key": [("_id", 1)], "v": 2},
58 | }
59 |
60 |
61 | async def test_remove_index_default(settings, notes, db):
62 | migration_settings = MigrationSettings(
63 | connection_uri=settings.mongodb_dsn,
64 | database_name=settings.mongodb_db_name,
65 | path="tests/migrations/migrations_for_test/remove_index",
66 | )
67 | await run_migrate(migration_settings)
68 |
69 | await init_beanie(
70 | database=db, document_models=[Note], allow_index_dropping=False
71 | )
72 | collection = Note.get_pymongo_collection()
73 | index_info = await collection.index_information()
74 | assert index_info == {
75 | "_id_": {"key": [("_id", 1)], "v": 2},
76 | "title_1": {"key": [("title", 1)], "v": 2},
77 | }
78 |
79 |
80 | async def test_remove_index_not_allowed(settings, notes, db):
81 | migration_settings = MigrationSettings(
82 | connection_uri=settings.mongodb_dsn,
83 | database_name=settings.mongodb_db_name,
84 | path="tests/migrations/migrations_for_test/remove_index",
85 | allow_index_dropping=False,
86 | )
87 | await run_migrate(migration_settings)
88 |
89 | await init_beanie(
90 | database=db, document_models=[Note], allow_index_dropping=False
91 | )
92 | collection = Note.get_pymongo_collection()
93 | index_info = await collection.index_information()
94 | assert index_info == {
95 | "_id_": {"key": [("_id", 1)], "v": 2},
96 | "title_1": {"key": [("title", 1)], "v": 2},
97 | }
98 |
--------------------------------------------------------------------------------
/docs/tutorial/actions.md:
--------------------------------------------------------------------------------
1 | # Event-based actions
2 |
3 | You can register methods as pre- or post- actions for document events.
4 |
5 | Currently supported events:
6 |
7 | - Save
8 | - Insert
9 | - Replace
10 | - Update
11 | - SaveChanges
12 | - Delete
13 | - ValidateOnSave
14 |
15 | Currently supported directions:
16 |
17 | - `Before`
18 | - `After`
19 |
20 | Current operations creating events:
21 |
22 | - `insert()` for Insert
23 | - `replace()` for Replace
24 | - `save()` for Save and triggers Insert if it is creating a new document, or triggers Replace if it replaces an existing document
25 | - `save_changes()` for SaveChanges
26 | - `insert()`, `replace()`, `save_changes()`, and `save()` for ValidateOnSave
27 | - `set()`, `update()` for Update
28 | - `delete()` for Delete
29 |
30 | To register an action, you can use `@before_event` and `@after_event` decorators respectively:
31 |
32 | ```python
33 | from beanie import Insert, Replace, before_event, after_event
34 |
35 |
36 | class Sample(Document):
37 | num: int
38 | name: str
39 |
40 | @before_event(Insert)
41 | def capitalize_name(self):
42 | self.name = self.name.capitalize()
43 |
44 | @after_event(Replace)
45 | def num_change(self):
46 | self.num -= 1
47 | ```
48 |
49 | It is possible to register action for several events:
50 |
51 | ```python
52 | from beanie import Insert, Replace, before_event
53 |
54 |
55 | class Sample(Document):
56 | num: int
57 | name: str
58 |
59 | @before_event(Insert, Replace)
60 | def capitalize_name(self):
61 | self.name = self.name.capitalize()
62 | ```
63 |
64 | This will capitalize the `name` field value before each document's Insert and Replace.
65 |
66 | And sync and async methods could work as actions.
67 |
68 | ```python
69 | from beanie import Insert, Replace, after_event
70 |
71 |
72 | class Sample(Document):
73 | num: int
74 | name: str
75 |
76 | @after_event(Insert, Replace)
77 | async def send_callback(self):
78 | await client.send(self.id)
79 | ```
80 |
81 | Actions can be selectively skipped by passing the `skip_actions` argument when calling
82 | the operations that trigger events. `skip_actions` accepts a list of directions and action names.
83 |
84 | ```python
85 | from beanie import After, Before, Insert, Replace, before_event, after_event
86 |
87 |
88 | class Sample(Document):
89 | num: int
90 | name: str
91 |
92 | @before_event(Insert)
93 | def capitalize_name(self):
94 | self.name = self.name.capitalize()
95 |
96 | @before_event(Replace)
97 | def redact_name(self):
98 | self.name = "[REDACTED]"
99 |
100 | @after_event(Replace)
101 | def num_change(self):
102 | self.num -= 1
103 |
104 |
105 | sample = Sample()
106 |
107 | # capitalize_name will not be executed
108 | await sample.insert(skip_actions=['capitalize_name'])
109 |
110 | # num_change will not be executed
111 | await sample.replace(skip_actions=[After])
112 |
113 | # redact_name and num_change will not be executed
114 | await sample.replace(skip_actions[Before, 'num_change'])
115 | ```
116 |
--------------------------------------------------------------------------------
/docs/getting-started.md:
--------------------------------------------------------------------------------
1 | # Getting started
2 |
3 | ## Installing beanie
4 |
5 | You can simply install Beanie from the [PyPI](https://pypi.org/project/beanie/):
6 |
7 | ### PIP
8 |
9 | ```shell
10 | pip install beanie
11 | ```
12 |
13 | ### Poetry
14 |
15 | ```shell
16 | poetry add beanie
17 | ```
18 |
19 | ### Optional dependencies
20 |
21 | Beanie supports some optional dependencies from PyMongo (`pip` or `poetry` can be used).
22 |
23 | GSSAPI authentication requires `gssapi` extra dependency:
24 |
25 | ```bash
26 | pip install "beanie[gssapi]"
27 | ```
28 |
29 | MONGODB-AWS authentication requires `aws` extra dependency:
30 |
31 | ```bash
32 | pip install "beanie[aws]"
33 | ```
34 |
35 | Support for mongodb+srv:// URIs requires `srv` extra dependency:
36 |
37 | ```bash
38 | pip install "beanie[srv]"
39 | ```
40 |
41 | OCSP requires `ocsp` extra dependency:
42 |
43 | ```bash
44 | pip install "beanie[ocsp]"
45 | ```
46 |
47 | Wire protocol compression with snappy requires `snappy` extra
48 | dependency:
49 |
50 | ```bash
51 | pip install "beanie[snappy]"
52 | ```
53 |
54 | Wire protocol compression with zstandard requires `zstd` extra
55 | dependency:
56 |
57 | ```bash
58 | pip install "beanie[zstd]"
59 | ```
60 |
61 | Client-Side Field Level Encryption requires `encryption` extra
62 | dependency:
63 |
64 | ```bash
65 | pip install "beanie[encryption]"
66 | ```
67 |
68 | You can install all dependencies automatically with the following
69 | command:
70 |
71 | ```bash
72 | pip install "beanie[gssapi,aws,ocsp,snappy,srv,zstd,encryption]"
73 | ```
74 |
75 | ## Initialization
76 |
77 | Getting Beanie setup in your code is really easy:
78 |
79 | 1. Write your database model as a Pydantic class but use `beanie.Document` instead of `pydantic.BaseModel`.
80 | 2. Initialize Async PyMongo, as Beanie uses this as an async database engine under the hood.
81 | 3. Call `beanie.init_beanie` with the PyMongo client and list of Beanie models
82 |
83 | The code below should get you started and shows some of the field types that you can use with beanie.
84 |
85 | ```python
86 | from typing import Optional
87 |
88 | from pymongo import AsyncMongoClient
89 | from pydantic import BaseModel
90 |
91 | from beanie import Document, Indexed, init_beanie
92 |
93 |
94 | class Category(BaseModel):
95 | name: str
96 | description: str
97 |
98 |
99 | # This is the model that will be saved to the database
100 | class Product(Document):
101 | name: str # You can use normal types just like in pydantic
102 | description: Optional[str] = None
103 | price: Indexed(float) # You can also specify that a field should correspond to an index
104 | category: Category # You can include pydantic models as well
105 |
106 |
107 | # Call this from within your event loop to get beanie setup.
108 | async def init():
109 | # Create Async PyMongo client
110 | client = AsyncMongoClient("mongodb://user:pass@host:27017")
111 |
112 | # Init beanie with the Product document class
113 | await init_beanie(database=client.db_name, document_models=[Product])
114 | ```
115 |
--------------------------------------------------------------------------------
/.github/scripts/handlers/gh.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from dataclasses import dataclass
3 | from datetime import datetime
4 | from typing import List
5 |
6 | import requests # type: ignore
7 |
8 |
9 | @dataclass
10 | class PullRequest:
11 | number: int
12 | title: str
13 | user: str
14 | user_url: str
15 | url: str
16 |
17 |
18 | class GitHubHandler:
19 | def __init__(
20 | self,
21 | username: str,
22 | repository: str,
23 | current_version: str,
24 | new_version: str,
25 | ):
26 | self.username = username
27 | self.repository = repository
28 | self.base_url = f"https://api.github.com/repos/{username}/{repository}"
29 | self.current_version = current_version
30 | self.new_version = new_version
31 | self.commits = self.get_commits_after_tag(current_version)
32 | self.prs = [self.get_pr_for_commit(commit) for commit in self.commits]
33 |
34 | def get_commits_after_tag(self, tag: str) -> List[str]:
35 | result = subprocess.run(
36 | ["git", "log", f"{tag}..HEAD", "--pretty=format:%H"],
37 | stdout=subprocess.PIPE,
38 | text=True,
39 | )
40 | return result.stdout.split()
41 |
42 | def get_pr_for_commit(self, commit_sha: str) -> PullRequest:
43 | url = f"{self.base_url}/commits/{commit_sha}/pulls"
44 | response = requests.get(url)
45 | response.raise_for_status()
46 | pr_data = response.json()[0]
47 | return PullRequest(
48 | number=pr_data["number"],
49 | title=pr_data["title"],
50 | user=pr_data["user"]["login"],
51 | user_url=pr_data["user"]["html_url"],
52 | url=pr_data["html_url"],
53 | )
54 |
55 | def build_markdown_for_many_prs(self) -> str:
56 | markdown = f"\n## [{self.new_version}] - {datetime.now().strftime('%Y-%m-%d')}\n"
57 | for pr in self.prs:
58 | markdown += (
59 | f"### {pr.title.capitalize()}\n"
60 | f"- Author - [{pr.user}]({pr.user_url})\n"
61 | f"- PR <{pr.url}>\n"
62 | )
63 | markdown += f"\n[{self.new_version}]: https://pypi.org/project/{self.repository}/{self.new_version}\n"
64 | return markdown
65 |
66 | def commit_changes(self):
67 | self.run_git_command(
68 | ["git", "config", "--global", "user.name", "github-actions[bot]"]
69 | )
70 | self.run_git_command(
71 | [
72 | "git",
73 | "config",
74 | "--global",
75 | "user.email",
76 | "github-actions[bot]@users.noreply.github.com",
77 | ]
78 | )
79 | self.run_git_command(["git", "add", "."])
80 | self.run_git_command(
81 | ["git", "commit", "-m", f"Bump version to {self.new_version}"]
82 | )
83 | self.run_git_command(["git", "tag", self.new_version])
84 | self.git_push()
85 |
86 | def git_push(self):
87 | self.run_git_command(["git", "push", "origin", "main", "--tags"])
88 |
89 | @staticmethod
90 | def run_git_command(command: List[str]):
91 | subprocess.run(command, check=True)
92 |
--------------------------------------------------------------------------------
/beanie/odm/operators/update/array.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.update import BaseUpdateOperator
2 |
3 |
4 | class BaseUpdateArrayOperator(BaseUpdateOperator):
5 | operator = ""
6 |
7 | def __init__(self, expression):
8 | self.expression = expression
9 |
10 | @property
11 | def query(self):
12 | return {self.operator: self.expression}
13 |
14 |
15 | class AddToSet(BaseUpdateArrayOperator):
16 | """
17 | `$addToSet` update array query operator
18 |
19 | Example:
20 |
21 | ```python
22 | class Sample(Document):
23 | results: List[int]
24 |
25 | AddToSet({Sample.results: 2})
26 | ```
27 |
28 | Will return query object like
29 |
30 | ```python
31 | {"$addToSet": {"results": 2}}
32 | ```
33 |
34 | MongoDB docs:
35 |
36 | """
37 |
38 | operator = "$addToSet"
39 |
40 |
41 | class Pop(BaseUpdateArrayOperator):
42 | """
43 | `$pop` update array query operator
44 |
45 | Example:
46 |
47 | ```python
48 | class Sample(Document):
49 | results: List[int]
50 |
51 | Pop({Sample.results: 2})
52 | ```
53 |
54 | Will return query object like
55 |
56 | ```python
57 | {"$pop": {"results": -1}}
58 | ```
59 |
60 | MongoDB docs:
61 |
62 | """
63 |
64 | operator = "$pop"
65 |
66 |
67 | class Pull(BaseUpdateArrayOperator):
68 | """
69 | `$pull` update array query operator
70 |
71 | Example:
72 |
73 | ```python
74 | class Sample(Document):
75 | results: List[int]
76 |
77 | Pull(In(Sample.result: [1,2,3,4,5])
78 | ```
79 |
80 | Will return query object like
81 |
82 | ```python
83 | {"$pull": { "results": { $in: [1,2,3,4,5] }}}
84 | ```
85 |
86 | MongoDB docs:
87 |
88 | """
89 |
90 | operator = "$pull"
91 |
92 |
93 | class Push(BaseUpdateArrayOperator):
94 | """
95 | `$push` update array query operator
96 |
97 | Example:
98 |
99 | ```python
100 | class Sample(Document):
101 | results: List[int]
102 |
103 | Push({Sample.results: 1})
104 | ```
105 |
106 | Will return query object like
107 |
108 | ```python
109 | {"$push": { "results": 1}}
110 | ```
111 |
112 | MongoDB docs:
113 |
114 | """
115 |
116 | operator = "$push"
117 |
118 |
119 | class PullAll(BaseUpdateArrayOperator):
120 | """
121 | `$pullAll` update array query operator
122 |
123 | Example:
124 |
125 | ```python
126 | class Sample(Document):
127 | results: List[int]
128 |
129 | PullAll({ Sample.results: [ 0, 5 ] })
130 | ```
131 |
132 | Will return query object like
133 |
134 | ```python
135 | {"$pullAll": { "results": [ 0, 5 ] }}
136 | ```
137 |
138 | MongoDB docs:
139 |
140 | """
141 |
142 | operator = "$pullAll"
143 |
--------------------------------------------------------------------------------
/beanie/odm/union_doc.py:
--------------------------------------------------------------------------------
1 | from typing import Any, ClassVar, Dict, Optional, Type, TypeVar
2 |
3 | from pymongo.asynchronous.client_session import AsyncClientSession
4 |
5 | from beanie.exceptions import UnionDocNotInited
6 | from beanie.odm.bulk import BulkWriter
7 | from beanie.odm.interfaces.aggregate import AggregateInterface
8 | from beanie.odm.interfaces.detector import DetectionInterface, ModelType
9 | from beanie.odm.interfaces.find import FindInterface
10 | from beanie.odm.interfaces.getters import OtherGettersInterface
11 | from beanie.odm.settings.union_doc import UnionDocSettings
12 |
13 | UnionDocType = TypeVar("UnionDocType", bound="UnionDoc")
14 |
15 |
16 | class UnionDoc(
17 | FindInterface,
18 | AggregateInterface,
19 | OtherGettersInterface,
20 | DetectionInterface,
21 | ):
22 | _document_models: ClassVar[Optional[Dict[str, Type]]] = None
23 | _is_inited: ClassVar[bool] = False
24 | _settings: ClassVar[UnionDocSettings]
25 |
26 | @classmethod
27 | def get_settings(cls) -> UnionDocSettings:
28 | return cls._settings
29 |
30 | @classmethod
31 | def register_doc(cls, name: str, doc_model: Type):
32 | if cls._document_models is None:
33 | cls._document_models = {}
34 |
35 | if cls._is_inited is False:
36 | raise UnionDocNotInited
37 |
38 | cls._document_models[name] = doc_model
39 | return cls.get_settings().name
40 |
41 | @classmethod
42 | def get_model_type(cls) -> ModelType:
43 | return ModelType.UnionDoc
44 |
45 | @classmethod
46 | def bulk_writer(
47 | cls,
48 | session: Optional[AsyncClientSession] = None,
49 | ordered: bool = True,
50 | bypass_document_validation: bool = False,
51 | comment: Optional[Any] = None,
52 | ) -> BulkWriter:
53 | """
54 | Returns a BulkWriter instance for handling bulk write operations.
55 |
56 | :param session: Optional[AsyncClientSession] - pymongo session.
57 | The session instance used for transactional operations.
58 | :param ordered: bool
59 | If ``True`` (the default), requests will be performed on the server serially, in the order provided. If an error
60 | occurs, all remaining operations are aborted. If ``False``, requests will be performed on the server in
61 | arbitrary order, possibly in parallel, and all operations will be attempted.
62 | :param bypass_document_validation: bool, optional
63 | If ``True``, allows the write to opt-out of document-level validation. Default is ``False``.
64 | :param comment: str, optional
65 | A user-provided comment to attach to the BulkWriter.
66 |
67 | :returns: BulkWriter
68 | An instance of BulkWriter configured with the provided settings.
69 |
70 | Example Usage:
71 | --------------
72 | This method is typically used within an asynchronous context manager.
73 |
74 | .. code-block:: python
75 |
76 | async with Document.bulk_writer(ordered=True) as bulk:
77 | await Document.insert_one(Document(field="value"), bulk_writer=bulk)
78 | """
79 | return BulkWriter(
80 | session, ordered, cls, bypass_document_validation, comment
81 | )
82 |
--------------------------------------------------------------------------------
/docs/tutorial/update.md:
--------------------------------------------------------------------------------
1 | # Updating & Deleting
2 |
3 | Now that we know how to find documents, how do we change them or delete them?
4 |
5 | ## Saving changes to existing documents
6 |
7 | The easiest way to change a document in the database is to use either the `replace` or `save` method on an altered document.
8 | These methods both write the document to the database,
9 | but `replace` will raise an exception when the document does not exist yet, while `save` will insert the document.
10 |
11 | Using `save()` method:
12 |
13 | ```python
14 | bar = await Product.find_one(Product.name == "Mars")
15 | bar.price = 10
16 | await bar.save()
17 | ```
18 |
19 | Otherwise, use the `replace()` method, which throws:
20 | - a `ValueError` if the document does not have an `id` yet, or
21 | - a `beanie.exceptions.DocumentNotFound` if it does, but the `id` is not present in the collection
22 |
23 | ```python
24 | bar.price = 10
25 | try:
26 | await bar.replace()
27 | except (ValueError, beanie.exceptions.DocumentNotFound):
28 | print("Can't replace a non existing document")
29 | ```
30 |
31 | Note that these methods require multiple queries to the database and replace the entire document with the new version.
32 | A more tailored solution can often be created by applying update queries directly on the database level.
33 |
34 | ## Update queries
35 |
36 | Update queries can be performed on the result of a `find` or `find_one` query,
37 | or on a document that was returned from an earlier query.
38 | Simpler updates can be performed using the `set`, `inc`, and `current_date` methods:
39 |
40 | ```python
41 | bar = await Product.find_one(Product.name == "Mars")
42 | await bar.set({Product.name:"Gold bar"})
43 | bar = await Product.find(Product.price > .5).inc({Product.price: 1})
44 | ```
45 |
46 | More complex update operations can be performed by calling `update()` with an update operator, similar to find queries:
47 |
48 | ```python
49 | await Product.find_one(Product.name == "Tony's").update(Set({Product.price: 3.33}))
50 | ```
51 |
52 | The whole list of the update query operators can be found [here](../api-documentation/operators/update.md).
53 |
54 | Native MongoDB syntax is also supported:
55 |
56 | ```python
57 | await Product.find_one(Product.name == "Tony's").update({"$set": {Product.price: 3.33}})
58 | ```
59 |
60 | ## Upsert
61 |
62 | To insert a document when no documents are matched against the search criteria, the `upsert` method can be used:
63 |
64 | ```python
65 | await Product.find_one(Product.name == "Tony's").upsert(
66 | Set({Product.price: 3.33}),
67 | on_insert=Product(name="Tony's", price=3.33, category=chocolate)
68 | )
69 | ```
70 |
71 | ## Deleting documents
72 |
73 | Deleting objects works just like updating them, you simply call `delete()` on the found documents:
74 |
75 | ```python
76 | bar = await Product.find_one(Product.name == "Milka")
77 | await bar.delete()
78 |
79 | await Product.find_one(Product.name == "Milka").delete()
80 |
81 | await Product.find(Product.category.name == "Chocolate").delete()
82 | ```
83 |
84 | ## Response Type
85 |
86 | For the object methods `update` and `upsert`, you can use the `response_type` parameter to specify the type of response.
87 |
88 | The options are:
89 | - `UpdateResponse.UPDATE_RESULT` - returns the result of the update operation.
90 | - `UpdateResponse.NEW_DOCUMENT` - returns the newly updated document.
91 | - `UpdateResponse.OLD_DOCUMENT` - returns the document before the update.
92 |
--------------------------------------------------------------------------------
/tests/odm/operators/find/test_geospatial.py:
--------------------------------------------------------------------------------
1 | from beanie.odm.operators.find.geospatial import (
2 | Box,
3 | GeoIntersects,
4 | GeoWithin,
5 | Near,
6 | NearSphere,
7 | )
8 | from tests.odm.models import Sample
9 |
10 |
11 | async def test_geo_intersects():
12 | q = GeoIntersects(
13 | Sample.geo, geo_type="Polygon", coordinates=[[1, 1], [2, 2], [3, 3]]
14 | )
15 | assert q == {
16 | "geo": {
17 | "$geoIntersects": {
18 | "$geometry": {
19 | "type": "Polygon",
20 | "coordinates": [[1, 1], [2, 2], [3, 3]],
21 | }
22 | }
23 | }
24 | }
25 |
26 |
27 | async def test_geo_within():
28 | q = GeoWithin(
29 | Sample.geo, geo_type="Polygon", coordinates=[[1, 1], [2, 2], [3, 3]]
30 | )
31 | assert q == {
32 | "geo": {
33 | "$geoWithin": {
34 | "$geometry": {
35 | "type": "Polygon",
36 | "coordinates": [[1, 1], [2, 2], [3, 3]],
37 | }
38 | }
39 | }
40 | }
41 |
42 |
43 | async def test_box():
44 | q = Box(Sample.geo, lower_left=[1, 3], upper_right=[2, 4])
45 | assert q == {"geo": {"$geoWithin": {"$box": [[1, 3], [2, 4]]}}}
46 |
47 |
48 | async def test_near():
49 | q = Near(Sample.geo, longitude=1.1, latitude=2.2)
50 | assert q == {
51 | "geo": {
52 | "$near": {
53 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}
54 | }
55 | }
56 | }
57 |
58 | q = Near(Sample.geo, longitude=1.1, latitude=2.2, max_distance=1)
59 | assert q == {
60 | "geo": {
61 | "$near": {
62 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]},
63 | "$maxDistance": 1,
64 | }
65 | }
66 | }
67 |
68 | q = Near(
69 | Sample.geo,
70 | longitude=1.1,
71 | latitude=2.2,
72 | max_distance=1,
73 | min_distance=0.5,
74 | )
75 | assert q == {
76 | "geo": {
77 | "$near": {
78 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]},
79 | "$maxDistance": 1,
80 | "$minDistance": 0.5,
81 | }
82 | }
83 | }
84 |
85 |
86 | async def test_near_sphere():
87 | q = NearSphere(Sample.geo, longitude=1.1, latitude=2.2)
88 | assert q == {
89 | "geo": {
90 | "$nearSphere": {
91 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}
92 | }
93 | }
94 | }
95 |
96 | q = NearSphere(Sample.geo, longitude=1.1, latitude=2.2, max_distance=1)
97 | assert q == {
98 | "geo": {
99 | "$nearSphere": {
100 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]},
101 | "$maxDistance": 1,
102 | }
103 | }
104 | }
105 |
106 | q = NearSphere(
107 | Sample.geo,
108 | longitude=1.1,
109 | latitude=2.2,
110 | max_distance=1,
111 | min_distance=0.5,
112 | )
113 | assert q == {
114 | "geo": {
115 | "$nearSphere": {
116 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]},
117 | "$maxDistance": 1,
118 | "$minDistance": 0.5,
119 | }
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/tests/odm/test_cache.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from tests.odm.models import DocumentTestModel
4 |
5 |
6 | async def test_find_one(documents):
7 | await documents(5)
8 | doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1)
9 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 1).set(
10 | {DocumentTestModel.test_str: "NEW_VALUE"}
11 | )
12 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1)
13 | assert doc == new_doc
14 |
15 | new_doc = await DocumentTestModel.find_one(
16 | DocumentTestModel.test_int == 1, ignore_cache=True
17 | )
18 | assert doc != new_doc
19 |
20 | await asyncio.sleep(10)
21 |
22 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1)
23 | assert doc != new_doc
24 |
25 |
26 | async def test_find_many(documents):
27 | await documents(5)
28 | docs = await DocumentTestModel.find(
29 | DocumentTestModel.test_int > 1
30 | ).to_list()
31 |
32 | await DocumentTestModel.find(DocumentTestModel.test_int > 1).set(
33 | {DocumentTestModel.test_str: "NEW_VALUE"}
34 | )
35 |
36 | new_docs = await DocumentTestModel.find(
37 | DocumentTestModel.test_int > 1
38 | ).to_list()
39 | assert docs == new_docs
40 |
41 | new_docs = await DocumentTestModel.find(
42 | DocumentTestModel.test_int > 1, ignore_cache=True
43 | ).to_list()
44 | assert docs != new_docs
45 |
46 | await asyncio.sleep(10)
47 |
48 | new_docs = await DocumentTestModel.find(
49 | DocumentTestModel.test_int > 1
50 | ).to_list()
51 | assert docs != new_docs
52 |
53 |
54 | async def test_aggregation(documents):
55 | await documents(5)
56 | docs = await DocumentTestModel.aggregate(
57 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}]
58 | ).to_list()
59 |
60 | await DocumentTestModel.find(DocumentTestModel.test_int > 1).set(
61 | {DocumentTestModel.test_str: "NEW_VALUE"}
62 | )
63 |
64 | new_docs = await DocumentTestModel.aggregate(
65 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}]
66 | ).to_list()
67 | assert docs == new_docs
68 |
69 | new_docs = await DocumentTestModel.aggregate(
70 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}],
71 | ignore_cache=True,
72 | ).to_list()
73 | assert docs != new_docs
74 |
75 | await asyncio.sleep(10)
76 |
77 | new_docs = await DocumentTestModel.aggregate(
78 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}]
79 | ).to_list()
80 | assert docs != new_docs
81 |
82 |
83 | async def test_capacity(documents):
84 | await documents(10)
85 | docs = []
86 | for i in range(10):
87 | docs.append(
88 | await DocumentTestModel.find_one(DocumentTestModel.test_int == i)
89 | )
90 |
91 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 1).set(
92 | {DocumentTestModel.test_str: "NEW_VALUE"}
93 | )
94 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 9).set(
95 | {DocumentTestModel.test_str: "NEW_VALUE"}
96 | )
97 |
98 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1)
99 | assert docs[1] != new_doc
100 |
101 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 9)
102 | assert docs[9] == new_doc
103 |
--------------------------------------------------------------------------------
/beanie/odm/utils/state.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from functools import wraps
3 | from typing import TYPE_CHECKING, TypeVar
4 |
5 | from typing_extensions import ParamSpec
6 |
7 | from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved
8 |
9 | if TYPE_CHECKING:
10 | from beanie.odm.documents import AnyDocMethod, AsyncDocMethod, DocType
11 |
12 | P = ParamSpec("P")
13 | R = TypeVar("R")
14 |
15 |
16 | def check_if_state_saved(self: "DocType"):
17 | if not self.use_state_management():
18 | raise StateManagementIsTurnedOff(
19 | "State management is turned off for this document"
20 | )
21 | if self._saved_state is None:
22 | raise StateNotSaved("No state was saved")
23 |
24 |
25 | def saved_state_needed(
26 | f: "AnyDocMethod[DocType, P, R]",
27 | ) -> "AnyDocMethod[DocType, P, R]":
28 | @wraps(f)
29 | def sync_wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
30 | check_if_state_saved(self)
31 | return f(self, *args, **kwargs)
32 |
33 | @wraps(f)
34 | async def async_wrapper(
35 | self: "DocType", *args: P.args, **kwargs: P.kwargs
36 | ) -> R:
37 | check_if_state_saved(self)
38 | # type ignore because there is no nice/proper way to annotate both sync
39 | # and async case without parametrized TypeVar, which is not supported
40 | return await f(self, *args, **kwargs) # type: ignore[misc]
41 |
42 | if inspect.iscoroutinefunction(f):
43 | # type ignore because there is no nice/proper way to annotate both sync
44 | # and async case without parametrized TypeVar, which is not supported
45 | return async_wrapper # type: ignore[return-value]
46 | return sync_wrapper
47 |
48 |
49 | def check_if_previous_state_saved(self: "DocType"):
50 | if not self.use_state_management():
51 | raise StateManagementIsTurnedOff(
52 | "State management is turned off for this document"
53 | )
54 | if not self.state_management_save_previous():
55 | raise StateManagementIsTurnedOff(
56 | "State management's option to save previous state is turned off for this document"
57 | )
58 |
59 |
60 | def previous_saved_state_needed(
61 | f: "AnyDocMethod[DocType, P, R]",
62 | ) -> "AnyDocMethod[DocType, P, R]":
63 | @wraps(f)
64 | def sync_wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
65 | check_if_previous_state_saved(self)
66 | return f(self, *args, **kwargs)
67 |
68 | @wraps(f)
69 | async def async_wrapper(
70 | self: "DocType", *args: P.args, **kwargs: P.kwargs
71 | ) -> R:
72 | check_if_previous_state_saved(self)
73 | # type ignore because there is no nice/proper way to annotate both sync
74 | # and async case without parametrized TypeVar, which is not supported
75 | return await f(self, *args, **kwargs) # type: ignore[misc]
76 |
77 | if inspect.iscoroutinefunction(f):
78 | # type ignore because there is no nice/proper way to annotate both sync
79 | # and async case without parametrized TypeVar, which is not supported
80 | return async_wrapper # type: ignore[return-value]
81 | return sync_wrapper
82 |
83 |
84 | def save_state_after(
85 | f: "AsyncDocMethod[DocType, P, R]",
86 | ) -> "AsyncDocMethod[DocType, P, R]":
87 | @wraps(f)
88 | async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
89 | result = await f(self, *args, **kwargs)
90 | self._save_state()
91 | return result
92 |
93 | return wrapper
94 |
--------------------------------------------------------------------------------