├── tests ├── __init__.py ├── odm │ ├── __init__.py │ ├── query │ │ ├── __init__.py │ │ ├── test_update_methods.py │ │ ├── test_aggregate_methods.py │ │ ├── test_delete.py │ │ ├── test_aggregate.py │ │ ├── test_update.py │ │ └── test_find.py │ ├── test_deprecated.py │ ├── documents │ │ ├── __init__.py │ │ ├── test_pydantic_config.py │ │ ├── test_count.py │ │ ├── test_exists.py │ │ ├── test_pydantic_extras.py │ │ ├── test_inspect.py │ │ ├── test_replace.py │ │ ├── test_validation_on_save.py │ │ ├── test_distinct.py │ │ ├── test_revision.py │ │ ├── test_delete.py │ │ ├── test_create.py │ │ ├── test_multi_model.py │ │ ├── test_aggregate.py │ │ ├── test_find.py │ │ ├── test_bulk_write.py │ │ ├── test_init.py │ │ └── test_update.py │ ├── operators │ │ ├── __init__.py │ │ ├── find │ │ │ ├── __init__.py │ │ │ ├── test_array.py │ │ │ ├── test_element.py │ │ │ ├── test_bitwise.py │ │ │ ├── test_logical.py │ │ │ ├── test_evaluation.py │ │ │ ├── test_comparison.py │ │ │ └── test_geospatial.py │ │ └── update │ │ │ ├── __init__.py │ │ │ ├── test_bitwise.py │ │ │ ├── test_array.py │ │ │ └── test_general.py │ ├── test_cursor.py │ ├── views.py │ ├── test_id.py │ ├── test_views.py │ ├── test_timesries_collection.py │ ├── test_encoder.py │ ├── test_expression_fields.py │ ├── test_actions.py │ ├── test_cache.py │ ├── test_fields.py │ ├── test_state_management.py │ ├── conftest.py │ ├── test_relations.py │ └── models.py ├── test_beanie.py ├── test_beanita.py └── conftest.py ├── beanita ├── __init__.py ├── client.py ├── cursor.py ├── results.py ├── db.py └── collection.py ├── .pre-commit-config.yaml ├── pyproject.toml ├── README.md ├── LICENSE └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/query/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/test_deprecated.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/documents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/operators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/operators/find/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/odm/operators/update/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beanita/__init__.py: -------------------------------------------------------------------------------- 1 | from beanita.client import Client 2 | 3 | __version__ = "0.1.0" 4 | __all__ = ["Client"] 5 | -------------------------------------------------------------------------------- /tests/test_beanie.py: -------------------------------------------------------------------------------- 1 | from beanie import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ == "1.11.4" 6 | -------------------------------------------------------------------------------- /tests/test_beanita.py: -------------------------------------------------------------------------------- 1 | from beanita import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ == "0.1.0" 6 | -------------------------------------------------------------------------------- /tests/odm/operators/update/test_bitwise.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.bitwise import Bit 2 | from tests.odm.models import Sample 3 | 4 | 5 | def test_bit(): 6 | q = Bit({Sample.integer: 2}) 7 | assert q == {"$bit": {"integer": 2}} 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | language_version: python3.9 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.9.2 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /tests/odm/documents/test_pydantic_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from tests.odm.models import DocumentWithPydanticConfig 5 | 6 | 7 | def test_pydantic_config(): 8 | doc = DocumentWithPydanticConfig(num_1=2) 9 | with pytest.raises(ValidationError): 10 | doc.num_1 = "wrong" 11 | -------------------------------------------------------------------------------- /tests/odm/test_cursor.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_to_list(documents): 5 | await documents(10) 6 | result = await DocumentTestModel.find_all().to_list() 7 | assert len(result) == 10 8 | 9 | 10 | async def test_async_for(documents): 11 | await documents(10) 12 | async for document in DocumentTestModel.find_all(): 13 | assert document.test_int in list(range(10)) 14 | -------------------------------------------------------------------------------- /tests/odm/views.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.views import View 2 | from tests.odm.models import DocumentTestModel 3 | 4 | 5 | class TestView(View): 6 | number: int 7 | string: str 8 | 9 | class Settings: 10 | view_name = "test_view" 11 | source = DocumentTestModel 12 | pipeline = [ 13 | {"$match": {"$expr": {"$gt": ["$test_int", 8]}}}, 14 | {"$project": {"number": "$test_int", "string": "$test_str"}}, 15 | ] 16 | -------------------------------------------------------------------------------- /tests/odm/documents/test_count.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_count(documents): 5 | await documents(4, "uno", True) 6 | c = await DocumentTestModel.count() 7 | assert c == 4 8 | 9 | 10 | async def test_count_with_filter_query(documents): 11 | await documents(4, "uno", True) 12 | await documents(2, "dos", True) 13 | await documents(1, "cuatro", True) 14 | c = await DocumentTestModel.find_many({"test_str": "dos"}).count() 15 | assert c == 2 16 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseSettings 3 | 4 | from beanita.client import Client 5 | 6 | 7 | class Settings(BaseSettings): 8 | mongodb_path: str = "beanita_path" 9 | mongodb_db_name: str = "beanie_db" 10 | 11 | 12 | @pytest.fixture 13 | def settings(): 14 | return Settings() 15 | 16 | 17 | @pytest.fixture() 18 | def cli(settings, loop): 19 | return Client(settings.mongodb_path) 20 | 21 | 22 | @pytest.fixture() 23 | def db(cli, settings, loop): 24 | return cli[settings.mongodb_db_name] 25 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_array.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.array import All, ElemMatch, Size 2 | from tests.odm.models import Sample 3 | 4 | 5 | async def test_all(): 6 | q = All(Sample.integer, [1, 2, 3]) 7 | assert q == {"integer": {"$all": [1, 2, 3]}} 8 | 9 | 10 | async def test_elem_match(): 11 | q = ElemMatch(Sample.integer, {"a": "b"}) 12 | assert q == {"integer": {"$elemMatch": {"a": "b"}}} 13 | 14 | 15 | async def test_size(): 16 | q = Size(Sample.integer, 4) 17 | assert q == {"integer": {"$size": 4}} 18 | -------------------------------------------------------------------------------- /tests/odm/documents/test_exists.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_count_with_filter_query(documents): 5 | await documents(4, "uno", True) 6 | await documents(2, "dos", True) 7 | await documents(1, "cuatro", True) 8 | e = await DocumentTestModel.find_many({"test_str": "dos"}).exists() 9 | assert e is True 10 | 11 | e = await DocumentTestModel.find_one({"test_str": "dos"}).exists() 12 | assert e is True 13 | 14 | e = await DocumentTestModel.find_many({"test_str": "wrong"}).exists() 15 | assert e is False 16 | -------------------------------------------------------------------------------- /beanita/client.py: -------------------------------------------------------------------------------- 1 | from mongita import MongitaClientDisk 2 | from mongita.common import ok_name 3 | from mongita.errors import InvalidName 4 | 5 | from beanita.db import Database 6 | 7 | 8 | class Client(MongitaClientDisk): 9 | def __getitem__(self, db_name): 10 | try: 11 | return self._cache[db_name] 12 | except KeyError: 13 | if not ok_name(db_name): 14 | raise InvalidName("Database cannot be named %r." % db_name) 15 | db = Database(db_name, self) 16 | self._cache[db_name] = db 17 | return db 18 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_element.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.element import Exists, Type 2 | from tests.odm.models import Sample 3 | 4 | 5 | async def test_exists(): 6 | q = Exists(Sample.integer, True) 7 | assert q == {"integer": {"$exists": True}} 8 | 9 | q = Exists(Sample.integer, False) 10 | assert q == {"integer": {"$exists": False}} 11 | 12 | q = Exists(Sample.integer) 13 | assert q == {"integer": {"$exists": True}} 14 | 15 | 16 | async def test_type(): 17 | q = Type(Sample.integer, "smth") 18 | assert q == {"integer": {"$type": "smth"}} 19 | -------------------------------------------------------------------------------- /beanita/cursor.py: -------------------------------------------------------------------------------- 1 | class Cursor: 2 | def __init__(self, cur): 3 | self.cur = cur 4 | 5 | def __aiter__(self): 6 | return self 7 | 8 | async def __anext__(self): 9 | try: 10 | return self.cur.next() 11 | except StopIteration: 12 | raise StopAsyncIteration 13 | 14 | async def to_list(self, length=None): 15 | res = [] 16 | ind = 0 17 | for i in self.cur: 18 | if length is not None and length == ind: 19 | break 20 | res.append(i) 21 | ind += 1 22 | return res 23 | -------------------------------------------------------------------------------- /tests/odm/test_id.py: -------------------------------------------------------------------------------- 1 | from uuid import UUID 2 | 3 | import pytest 4 | 5 | from tests.odm.models import DocumentWithCustomIdUUID, DocumentWithCustomIdInt 6 | 7 | 8 | @pytest.mark.skip("Not supported") 9 | async def test_uuid_id(): 10 | doc = DocumentWithCustomIdUUID(name="TEST") 11 | await doc.insert() 12 | new_doc = await DocumentWithCustomIdUUID.get(doc.id) 13 | assert type(new_doc.id) == UUID 14 | 15 | 16 | @pytest.mark.skip("Not supported") 17 | async def test_integer_id(): 18 | doc = DocumentWithCustomIdInt(name="TEST", id=1) 19 | await doc.insert() 20 | new_doc = await DocumentWithCustomIdInt.get(doc.id) 21 | assert type(new_doc.id) == int 22 | -------------------------------------------------------------------------------- /tests/odm/test_views.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.odm.views import TestView 4 | 5 | 6 | @pytest.mark.skip("Not supported") 7 | class TestViews: 8 | async def test_simple(self, documents): 9 | await documents(number=15) 10 | results = await TestView.all().to_list() 11 | assert len(results) == 6 12 | 13 | async def test_aggregate(self, documents): 14 | await documents(number=15) 15 | results = await TestView.aggregate( 16 | [ 17 | {"$set": {"test_field": 1}}, 18 | {"$match": {"$expr": {"$lt": ["$number", 12]}}}, 19 | ] 20 | ).to_list() 21 | assert len(results) == 3 22 | assert results[0]["test_field"] == 1 23 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_bitwise.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.bitwise import ( 2 | BitsAllClear, 3 | BitsAllSet, 4 | BitsAnyClear, 5 | BitsAnySet, 6 | ) 7 | from tests.odm.models import Sample 8 | 9 | 10 | async def test_bits_all_clear(): 11 | q = BitsAllClear(Sample.integer, "smth") 12 | assert q == {"integer": {"$bitsAllClear": "smth"}} 13 | 14 | 15 | async def test_bits_all_set(): 16 | q = BitsAllSet(Sample.integer, "smth") 17 | assert q == {"integer": {"$bitsAllSet": "smth"}} 18 | 19 | 20 | async def test_any_clear(): 21 | q = BitsAnyClear(Sample.integer, "smth") 22 | assert q == {"integer": {"$bitsAnyClear": "smth"}} 23 | 24 | 25 | async def test_any_set(): 26 | q = BitsAnySet(Sample.integer, "smth") 27 | assert q == {"integer": {"$bitsAnySet": "smth"}} 28 | -------------------------------------------------------------------------------- /tests/odm/operators/update/test_array.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.array import ( 2 | AddToSet, 3 | Pull, 4 | PullAll, 5 | Pop, 6 | Push, 7 | ) 8 | from tests.odm.models import Sample 9 | 10 | 11 | def test_add_to_set(): 12 | q = AddToSet({Sample.integer: 2}) 13 | assert q == {"$addToSet": {"integer": 2}} 14 | 15 | 16 | def test_pop(): 17 | q = Pop({Sample.integer: 2}) 18 | assert q == {"$pop": {"integer": 2}} 19 | 20 | 21 | def test_pull(): 22 | q = Pull({Sample.integer: 2}) 23 | assert q == {"$pull": {"integer": 2}} 24 | 25 | 26 | def test_push(): 27 | q = Push({Sample.integer: 2}) 28 | assert q == {"$push": {"integer": 2}} 29 | 30 | 31 | def test_pull_all(): 32 | q = PullAll({Sample.integer: 2}) 33 | assert q == {"$pullAll": {"integer": 2}} 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "beanita" 3 | version = "0.1.0" 4 | description = "Local MongoDB-like database, based on Mongita and prepared to work with Beanie ODM" 5 | authors = ["Roman Right "] 6 | homepage = "https://github.com/roman-right/beanita" 7 | repository = "https://github.com/roman-right/beanita" 8 | readme = "README.md" 9 | license = "MIT" 10 | include = [ 11 | "LICENSE", 12 | ] 13 | 14 | [tool.poetry.dependencies] 15 | python = "^3.9" 16 | mongita = "^1.1.1" 17 | 18 | [tool.poetry.dev-dependencies] 19 | pytest = "^5.2" 20 | beanie = "^1.11.4" 21 | pytest-aiohttp = "^0.3.0" 22 | 23 | [build-system] 24 | requires = ["poetry-core>=1.0.0"] 25 | build-backend = "poetry.core.masonry.api" 26 | 27 | [tool.black] 28 | line-length = 79 29 | include = '\.pyi?$' 30 | exclude = ''' 31 | /( 32 | \.git 33 | | \.hg 34 | | \.mypy_cache 35 | | \.tox 36 | | \.venv 37 | | _build 38 | | buck-out 39 | | build 40 | | dist 41 | )/ 42 | ''' -------------------------------------------------------------------------------- /tests/odm/documents/test_pydantic_extras.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.odm.models import ( 4 | DocumentWithExtras, 5 | DocumentWithPydanticConfig, 6 | DocumentWithExtrasKw, 7 | ) 8 | 9 | 10 | async def test_pydantic_extras(): 11 | doc = DocumentWithExtras(num_1=2) 12 | doc.extra_value = "foo" 13 | await doc.save() 14 | 15 | loaded_doc = await DocumentWithExtras.get(doc.id) 16 | 17 | assert loaded_doc.extra_value == "foo" 18 | 19 | 20 | @pytest.mark.skip(reason="setting extra to allow via class kwargs not working") 21 | async def test_pydantic_extras_kw(): 22 | doc = DocumentWithExtrasKw(num_1=2) 23 | doc.extra_value = "foo" 24 | await doc.save() 25 | 26 | loaded_doc = await DocumentWithExtras.get(doc.id) 27 | 28 | assert loaded_doc.extra_value == "foo" 29 | 30 | 31 | async def test_fail_with_no_extras(): 32 | doc = DocumentWithPydanticConfig(num_1=2) 33 | with pytest.raises(ValueError): 34 | doc.extra_value = "foo" 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Beanita 2 | 3 | Local MongoDB-like database, based on [Mongita](https://github.com/scottrogowski/mongita) and prepared to work with [Beanie ODM](https://github.com/roman-right/beanie) 4 | 5 | I highly recommend using it only for experiment purposes. It is safer to use a real MongoDB database and for testing, and for production. 6 | 7 | ### Install 8 | 9 | ```shell 10 | pip install beanita 11 | ``` 12 | or 13 | ```shell 14 | poetry add beanita 15 | ``` 16 | 17 | ### Init 18 | 19 | ```python 20 | from beanie import init_beanie, Document 21 | from beanita import Client 22 | 23 | 24 | class Sample(Document): 25 | name: str 26 | 27 | 28 | async def init_database(): 29 | cli = Client("LOCAL_DIRECTORY") 30 | db = cli["DATABASE_NAME"] 31 | await init_beanie( 32 | database=db, 33 | document_models=[Sample], 34 | ) 35 | ``` 36 | 37 | ### Not supported 38 | 39 | - Links 40 | - Aggregations 41 | - Union Documents 42 | - other features, that were not implemented in Mongita -------------------------------------------------------------------------------- /tests/odm/operators/find/test_logical.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.logical import And, Not, Nor, Or 2 | from tests.odm.models import Sample 3 | 4 | 5 | async def test_and(): 6 | q = And(Sample.integer == 1) 7 | assert q == {"integer": 1} 8 | 9 | q = And(Sample.integer == 1, Sample.nested.integer > 3) 10 | assert q == {"$and": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} 11 | 12 | 13 | async def test_not(): 14 | q = Not(Sample.integer == 1) 15 | assert q == {"$not": {"integer": 1}} 16 | 17 | 18 | async def test_nor(): 19 | q = Nor(Sample.integer == 1) 20 | assert q == {"$nor": [{"integer": 1}]} 21 | 22 | q = Nor(Sample.integer == 1, Sample.nested.integer > 3) 23 | assert q == {"$nor": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} 24 | 25 | 26 | async def test_or(): 27 | q = Or(Sample.integer == 1) 28 | assert q == {"integer": 1} 29 | 30 | q = Or(Sample.integer == 1, Sample.nested.integer > 3) 31 | assert q == {"$or": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} 32 | -------------------------------------------------------------------------------- /beanita/results.py: -------------------------------------------------------------------------------- 1 | class InsertOneResult: 2 | def __init__(self, inserted_id): 3 | self.acknowledged = True 4 | self.inserted_id = inserted_id 5 | 6 | 7 | class InsertManyResult: 8 | def __init__(self, documents): 9 | self.acknowledged = True 10 | self.inserted_ids = [d["_id"] for d in documents] 11 | 12 | 13 | class UpdateResult: 14 | def __init__( 15 | self, matched_count, modified_count, upserted_id=None, **kwargs 16 | ): 17 | self.acknowledged = True 18 | self.matched_count = matched_count 19 | self.modified_count = modified_count 20 | self.upserted_id = upserted_id 21 | self.updatedExisting = modified_count 22 | 23 | @property 24 | def raw_result(self): 25 | return self.__dict__ 26 | 27 | @classmethod 28 | def from_mongita_result(cls, res): 29 | return cls(**res.__dict__) 30 | 31 | 32 | class DeleteResult: 33 | def __init__(self, deleted_count): 34 | self.acknowledged = True 35 | self.deleted_count = deleted_count 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright © 2022 Roman Right 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /tests/odm/documents/test_inspect.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from beanie.odm.models import InspectionStatuses 3 | from tests.odm.models import DocumentTestModel, DocumentTestModelFailInspection 4 | 5 | 6 | @pytest.mark.skip("Not supported") 7 | async def test_inspect_ok(documents): 8 | await documents(10, "smth") 9 | result = await DocumentTestModel.inspect_collection() 10 | assert result.status == InspectionStatuses.OK 11 | assert result.errors == [] 12 | 13 | 14 | @pytest.mark.skip("Not supported") 15 | async def test_inspect_fail(documents): 16 | await documents(10, "smth") 17 | result = await DocumentTestModelFailInspection.inspect_collection() 18 | assert result.status == InspectionStatuses.FAIL 19 | assert len(result.errors) == 10 20 | assert ( 21 | "1 validation error for DocumentTestModelFailInspection" 22 | in result.errors[0].error 23 | ) 24 | 25 | 26 | @pytest.mark.skip("Not supported") 27 | async def test_inspect_ok_with_session(documents, session): 28 | await documents(10, "smth") 29 | result = await DocumentTestModel.inspect_collection(session=session) 30 | assert result.status == InspectionStatuses.OK 31 | assert result.errors == [] 32 | -------------------------------------------------------------------------------- /tests/odm/documents/test_replace.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import Sample 2 | 3 | 4 | async def test_replace_one(preset_documents): 5 | count_1_before = await Sample.find_many(Sample.integer == 1).count() 6 | count_2_before = await Sample.find_many(Sample.integer == 2).count() 7 | 8 | a_2 = await Sample.find_one(Sample.integer == 2) 9 | await Sample.find_one(Sample.integer == 1).replace_one(a_2) 10 | 11 | count_1_after = await Sample.find_many(Sample.integer == 1).count() 12 | count_2_after = await Sample.find_many(Sample.integer == 2).count() 13 | 14 | assert count_1_after == count_1_before - 1 15 | assert count_2_after == count_2_before + 1 16 | 17 | 18 | async def test_replace_self(preset_documents): 19 | count_1_before = await Sample.find_many(Sample.integer == 1).count() 20 | count_2_before = await Sample.find_many(Sample.integer == 2).count() 21 | 22 | a_1 = await Sample.find_one(Sample.integer == 1) 23 | a_1.integer = 2 24 | await a_1.replace() 25 | 26 | count_1_after = await Sample.find_many(Sample.integer == 1).count() 27 | count_2_after = await Sample.find_many(Sample.integer == 2).count() 28 | 29 | assert count_1_after == count_1_before - 1 30 | assert count_2_after == count_2_before + 1 31 | -------------------------------------------------------------------------------- /tests/odm/documents/test_validation_on_save.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from tests.odm.models import DocumentWithValidationOnSave 5 | 6 | 7 | async def test_validate_on_insert(): 8 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 9 | doc.num_1 = "wrong_value" 10 | with pytest.raises(ValidationError): 11 | await doc.insert() 12 | 13 | 14 | async def test_validate_on_replace(): 15 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 16 | await doc.insert() 17 | doc.num_1 = "wrong_value" 18 | with pytest.raises(ValidationError): 19 | await doc.replace() 20 | 21 | 22 | async def test_validate_on_save_changes(): 23 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 24 | await doc.insert() 25 | doc.num_1 = "wrong_value" 26 | with pytest.raises(ValidationError): 27 | await doc.save_changes() 28 | 29 | 30 | async def test_validate_on_save_action(): 31 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 32 | await doc.insert() 33 | assert doc.num_2 == 3 34 | 35 | 36 | async def test_validate_on_save_skip_action(): 37 | doc = DocumentWithValidationOnSave(num_1=1, num_2=2) 38 | await doc.insert(skip_actions=["num_2_plus_1"]) 39 | assert doc.num_2 == 2 40 | -------------------------------------------------------------------------------- /tests/odm/operators/update/test_general.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.general import ( 2 | Set, 3 | CurrentDate, 4 | Inc, 5 | Unset, 6 | SetOnInsert, 7 | Rename, 8 | Mul, 9 | Max, 10 | Min, 11 | ) 12 | from tests.odm.models import Sample 13 | 14 | 15 | def test_set(): 16 | q = Set({Sample.integer: 2}) 17 | assert q == {"$set": {"integer": 2}} 18 | 19 | 20 | def test_current_date(): 21 | q = CurrentDate({Sample.integer: 2}) 22 | assert q == {"$currentDate": {"integer": 2}} 23 | 24 | 25 | def test_inc(): 26 | q = Inc({Sample.integer: 2}) 27 | assert q == {"$inc": {"integer": 2}} 28 | 29 | 30 | def test_min(): 31 | q = Min({Sample.integer: 2}) 32 | assert q == {"$min": {"integer": 2}} 33 | 34 | 35 | def test_max(): 36 | q = Max({Sample.integer: 2}) 37 | assert q == {"$max": {"integer": 2}} 38 | 39 | 40 | def test_mul(): 41 | q = Mul({Sample.integer: 2}) 42 | assert q == {"$mul": {"integer": 2}} 43 | 44 | 45 | def test_rename(): 46 | q = Rename({Sample.integer: 2}) 47 | assert q == {"$rename": {"integer": 2}} 48 | 49 | 50 | def test_set_on_insert(): 51 | q = SetOnInsert({Sample.integer: 2}) 52 | assert q == {"$setOnInsert": {"integer": 2}} 53 | 54 | 55 | def test_unset(): 56 | q = Unset({Sample.integer: 2}) 57 | assert q == {"$unset": {"integer": 2}} 58 | -------------------------------------------------------------------------------- /beanita/db.py: -------------------------------------------------------------------------------- 1 | from mongita.common import ok_name 2 | from mongita.database import Database as MongitaDatabase 3 | from mongita.errors import InvalidName 4 | 5 | from beanita.collection import Collection 6 | 7 | 8 | class Database(MongitaDatabase): 9 | UNIMPLEMENTED = [ 10 | "aggregate", 11 | "codec_options", 12 | "create_collection", 13 | "dereference", 14 | "get_collection", 15 | "next", 16 | "profiling_info", 17 | "profiling_level", 18 | "read_concern", 19 | "read_preference", 20 | "set_profiling_level", 21 | "validate_collection", 22 | "watch", 23 | "with_options", 24 | "write_concern", 25 | ] 26 | 27 | def __getitem__(self, collection_name): 28 | try: 29 | return self._cache[collection_name] 30 | except KeyError: 31 | if not ok_name(collection_name): 32 | raise InvalidName( 33 | "Collection cannot be named %r." % collection_name 34 | ) 35 | coll = Collection(collection_name, self) 36 | self._cache[collection_name] = coll 37 | return coll 38 | 39 | async def command(self, *args, **kwargs): 40 | return {"version": "4.4"} 41 | 42 | async def list_collection_names(self): 43 | return super().list_collection_names() 44 | -------------------------------------------------------------------------------- /tests/odm/documents/test_distinct.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.odm.models import DocumentTestModel 4 | 5 | 6 | @pytest.mark.skip("Not supported") 7 | async def test_distinct_unique(documents, document_not_inserted): 8 | await documents(1, "uno") 9 | await documents(2, "dos") 10 | await documents(3, "cuatro") 11 | expected_result = ["cuatro", "dos", "uno"] 12 | unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 13 | assert unique_test_strs == expected_result 14 | document_not_inserted.test_str = "uno" 15 | await document_not_inserted.insert() 16 | another_unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 17 | assert another_unique_test_strs == expected_result 18 | 19 | 20 | @pytest.mark.skip("Not supported") 21 | async def test_distinct_different_value(documents, document_not_inserted): 22 | await documents(1, "uno") 23 | await documents(2, "dos") 24 | await documents(3, "cuatro") 25 | expected_result = ["cuatro", "dos", "uno"] 26 | unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 27 | assert unique_test_strs == expected_result 28 | document_not_inserted.test_str = "diff_val" 29 | await document_not_inserted.insert() 30 | another_unique_test_strs = await DocumentTestModel.distinct("test_str", {}) 31 | assert not another_unique_test_strs == expected_result 32 | assert another_unique_test_strs == ["cuatro", "diff_val", "dos", "uno"] 33 | -------------------------------------------------------------------------------- /tests/odm/test_timesries_collection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie import init_beanie 4 | from beanie.exceptions import MongoDBVersionError 5 | from tests.odm.models import DocumentWithTimeseries 6 | 7 | 8 | async def test_timeseries_collection(db): 9 | build_info = await db.command({"buildInfo": 1}) 10 | mongo_version = build_info["version"] 11 | major_version = int(mongo_version.split(".")[0]) 12 | if major_version < 5: 13 | with pytest.raises(MongoDBVersionError): 14 | await init_beanie( 15 | database=db, document_models=[DocumentWithTimeseries] 16 | ) 17 | 18 | if major_version >= 5: 19 | await init_beanie( 20 | database=db, document_models=[DocumentWithTimeseries] 21 | ) 22 | info = await db.command( 23 | { 24 | "listCollections": 1, 25 | "filter": {"name": "DocumentWithTimeseries"}, 26 | } 27 | ) 28 | 29 | assert info["cursor"]["firstBatch"][0] == { 30 | "name": "DocumentWithTimeseries", 31 | "type": "timeseries", 32 | "options": { 33 | "expireAfterSeconds": 2, 34 | "timeseries": { 35 | "timeField": "ts", 36 | "granularity": "seconds", 37 | "bucketMaxSpanSeconds": 3600, 38 | }, 39 | }, 40 | "info": {"readOnly": False}, 41 | } 42 | -------------------------------------------------------------------------------- /tests/odm/test_encoder.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, date 2 | 3 | from bson import Binary 4 | 5 | from beanie.odm.utils.encoder import Encoder 6 | from tests.odm.models import ( 7 | DocumentForEncodingTest, 8 | DocumentForEncodingTestDate, 9 | ) 10 | 11 | 12 | async def test_encode_datetime(): 13 | assert isinstance(Encoder().encode(datetime.now()), datetime) 14 | 15 | doc = DocumentForEncodingTest(datetime_field=datetime.now()) 16 | await doc.insert() 17 | new_doc = await DocumentForEncodingTest.get(doc.id) 18 | assert isinstance(new_doc.datetime_field, datetime) 19 | 20 | 21 | async def test_encode_date(): 22 | assert isinstance(Encoder().encode(datetime.now()), datetime) 23 | 24 | doc = DocumentForEncodingTestDate() 25 | await doc.insert() 26 | new_doc = await DocumentForEncodingTestDate.get(doc.id) 27 | assert new_doc.date_field == doc.date_field 28 | assert isinstance(new_doc.date_field, date) 29 | 30 | 31 | def test_encode_with_custom_encoder(): 32 | assert isinstance( 33 | Encoder(custom_encoders={datetime: str}).encode(datetime.now()), str 34 | ) 35 | 36 | 37 | async def test_bytes(): 38 | encoded_b = Encoder().encode(b"test") 39 | assert isinstance(encoded_b, Binary) 40 | assert encoded_b.subtype == 0 41 | 42 | doc = DocumentForEncodingTest(bytes_field=b"test") 43 | await doc.insert() 44 | new_doc = await DocumentForEncodingTest.get(doc.id) 45 | assert isinstance(new_doc.bytes_field, bytes) 46 | 47 | 48 | async def test_bytes_already_binary(): 49 | b = Binary(b"123", 3) 50 | encoded_b = Encoder().encode(b) 51 | assert isinstance(encoded_b, Binary) 52 | assert encoded_b.subtype == 3 53 | -------------------------------------------------------------------------------- /tests/odm/documents/test_revision.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.exceptions import RevisionIdWasChanged 4 | from tests.odm.models import DocumentWithRevisionTurnedOn 5 | 6 | 7 | @pytest.mark.skip("Not supported") 8 | async def test_replace(): 9 | doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) 10 | await doc.insert() 11 | 12 | doc.num_1 = 2 13 | await doc.replace() 14 | 15 | doc.num_2 = 3 16 | await doc.replace() 17 | 18 | for i in range(5): 19 | found_doc = await DocumentWithRevisionTurnedOn.get(doc.id) 20 | found_doc.num_1 += 1 21 | await found_doc.replace() 22 | 23 | doc._previous_revision_id = "wrong" 24 | doc.num_1 = 4 25 | with pytest.raises(RevisionIdWasChanged): 26 | await doc.replace() 27 | 28 | await doc.replace(ignore_revision=True) 29 | 30 | 31 | async def test_update(): 32 | doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) 33 | await doc.insert() 34 | 35 | doc.num_1 = 2 36 | await doc.save_changes() 37 | 38 | doc.num_2 = 3 39 | await doc.save_changes() 40 | 41 | for i in range(5): 42 | found_doc = await DocumentWithRevisionTurnedOn.get(doc.id) 43 | found_doc.num_1 += 1 44 | await found_doc.save_changes() 45 | 46 | doc._previous_revision_id = "wrong" 47 | doc.num_1 = 4 48 | with pytest.raises(RevisionIdWasChanged): 49 | await doc.save_changes() 50 | 51 | await doc.save_changes(ignore_revision=True) 52 | 53 | 54 | async def test_empty_update(): 55 | doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) 56 | await doc.insert() 57 | 58 | # This fails with RevisionIdWasChanged 59 | await doc.update({"$set": {"num_1": 1}}) 60 | -------------------------------------------------------------------------------- /tests/odm/documents/test_delete.py: -------------------------------------------------------------------------------- 1 | from tests.odm.models import DocumentTestModel 2 | 3 | 4 | async def test_delete_one(documents): 5 | await documents(4, "uno") 6 | await documents(2, "dos") 7 | await documents(1, "cuatro") 8 | await DocumentTestModel.find_one({"test_str": "uno"}).delete() 9 | documents = await DocumentTestModel.find_all().to_list() 10 | assert len(documents) == 6 11 | 12 | 13 | async def test_delete_one_not_found(documents): 14 | await documents(4, "uno") 15 | await documents(2, "dos") 16 | await documents(1, "cuatro") 17 | await DocumentTestModel.find_one({"test_str": "wrong"}).delete() 18 | documents = await DocumentTestModel.find_all().to_list() 19 | assert len(documents) == 7 20 | 21 | 22 | async def test_delete_many(documents): 23 | await documents(4, "uno") 24 | await documents(2, "dos") 25 | await documents(1, "cuatro") 26 | await DocumentTestModel.find_many({"test_str": "uno"}).delete() 27 | documents = await DocumentTestModel.find_all().to_list() 28 | assert len(documents) == 3 29 | 30 | 31 | async def test_delete_many_not_found(documents): 32 | await documents(4, "uno") 33 | await documents(2, "dos") 34 | await documents(1, "cuatro") 35 | await DocumentTestModel.find_many({"test_str": "wrong"}).delete() 36 | documents = await DocumentTestModel.find_all().to_list() 37 | assert len(documents) == 7 38 | 39 | 40 | async def test_delete_all(documents): 41 | await documents(4, "uno") 42 | await documents(2, "dos") 43 | await documents(1, "cuatro") 44 | await DocumentTestModel.delete_all() 45 | documents = await DocumentTestModel.find_all().to_list() 46 | assert len(documents) == 0 47 | 48 | 49 | async def test_delete(document): 50 | doc_id = document.id 51 | await document.delete() 52 | new_document = await DocumentTestModel.get(doc_id) 53 | assert new_document is None 54 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_evaluation.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.evaluation import ( 2 | Expr, 3 | JsonSchema, 4 | Mod, 5 | RegEx, 6 | Text, 7 | Where, 8 | ) 9 | from tests.odm.models import Sample 10 | 11 | 12 | async def test_expr(): 13 | q = Expr({"a": "B"}) 14 | assert q == {"$expr": {"a": "B"}} 15 | 16 | 17 | async def test_json_schema(): 18 | q = JsonSchema({"a": "B"}) 19 | assert q == {"$jsonSchema": {"a": "B"}} 20 | 21 | 22 | async def test_mod(): 23 | q = Mod(Sample.integer, 3, 2) 24 | assert q == {"integer": {"$mod": [3, 2]}} 25 | 26 | 27 | async def test_regex(): 28 | q = RegEx(Sample.integer, "smth") 29 | assert q == {"integer": {"$regex": "smth"}} 30 | 31 | q = RegEx(Sample.integer, "smth", "options") 32 | assert q == {"integer": {"$regex": "smth", "$options": "options"}} 33 | 34 | 35 | async def test_text(): 36 | q = Text("something") 37 | assert q == { 38 | "$text": { 39 | "$search": "something", 40 | "$caseSensitive": False, 41 | "$diacriticSensitive": False, 42 | } 43 | } 44 | q = Text("something", case_sensitive=True) 45 | assert q == { 46 | "$text": { 47 | "$search": "something", 48 | "$caseSensitive": True, 49 | "$diacriticSensitive": False, 50 | } 51 | } 52 | q = Text("something", diacritic_sensitive=True) 53 | assert q == { 54 | "$text": { 55 | "$search": "something", 56 | "$caseSensitive": False, 57 | "$diacriticSensitive": True, 58 | } 59 | } 60 | q = Text("something", language="test") 61 | assert q == { 62 | "$text": { 63 | "$search": "something", 64 | "$caseSensitive": False, 65 | "$diacriticSensitive": False, 66 | "$language": "test", 67 | } 68 | } 69 | 70 | 71 | async def test_where(): 72 | q = Where("test") 73 | assert q == {"$where": "test"} 74 | -------------------------------------------------------------------------------- /tests/odm/documents/test_create.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.odm.fields import PydanticObjectId 4 | from mongita.errors import DuplicateKeyError 5 | 6 | from tests.odm.models import DocumentTestModel 7 | 8 | 9 | async def test_insert_one(document_not_inserted): 10 | result = await DocumentTestModel.insert_one(document_not_inserted) 11 | document = await DocumentTestModel.get(result.id) 12 | assert document is not None 13 | assert document.test_int == document_not_inserted.test_int 14 | assert document.test_list == document_not_inserted.test_list 15 | assert document.test_str == document_not_inserted.test_str 16 | 17 | 18 | async def test_insert_many(documents_not_inserted): 19 | await DocumentTestModel.insert_many(documents_not_inserted(10)) 20 | documents = await DocumentTestModel.find_all().to_list() 21 | assert len(documents) == 10 22 | 23 | 24 | async def test_create(document_not_inserted): 25 | await document_not_inserted.insert() 26 | assert isinstance(document_not_inserted.id, PydanticObjectId) 27 | 28 | 29 | async def test_create_twice(document_not_inserted): 30 | await document_not_inserted.insert() 31 | with pytest.raises(DuplicateKeyError): 32 | await document_not_inserted.insert() 33 | 34 | 35 | async def test_insert_one_with_session(document_not_inserted, session): 36 | result = await DocumentTestModel.insert_one( 37 | document_not_inserted, session=session 38 | ) 39 | document = await DocumentTestModel.get(result.id, session=session) 40 | assert document is not None 41 | assert document.test_int == document_not_inserted.test_int 42 | assert document.test_list == document_not_inserted.test_list 43 | assert document.test_str == document_not_inserted.test_str 44 | 45 | 46 | async def test_insert_many_with_session(documents_not_inserted, session): 47 | await DocumentTestModel.insert_many( 48 | documents_not_inserted(10), session=session 49 | ) 50 | documents = await DocumentTestModel.find_all(session=session).to_list() 51 | assert len(documents) == 10 52 | 53 | 54 | async def test_create_with_session(document_not_inserted, session): 55 | await document_not_inserted.insert(session=session) 56 | assert isinstance(document_not_inserted.id, PydanticObjectId) 57 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_comparison.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.comparison import ( 2 | Eq, 3 | GT, 4 | GTE, 5 | In, 6 | LT, 7 | LTE, 8 | NE, 9 | NotIn, 10 | ) 11 | from tests.odm.models import Sample 12 | 13 | 14 | async def test_eq(): 15 | q = Sample.integer == 1 16 | assert q == {"integer": 1} 17 | 18 | q = Eq(Sample.integer, 1) 19 | assert q == {"integer": 1} 20 | 21 | q = Eq("integer", 1) 22 | assert q == {"integer": 1} 23 | 24 | 25 | async def test_gt(): 26 | q = Sample.integer > 1 27 | assert q == {"integer": {"$gt": 1}} 28 | 29 | q = GT(Sample.integer, 1) 30 | assert q == {"integer": {"$gt": 1}} 31 | 32 | q = GT("integer", 1) 33 | assert q == {"integer": {"$gt": 1}} 34 | 35 | 36 | async def test_gte(): 37 | q = Sample.integer >= 1 38 | assert q == {"integer": {"$gte": 1}} 39 | 40 | q = GTE(Sample.integer, 1) 41 | assert q == {"integer": {"$gte": 1}} 42 | 43 | q = GTE("integer", 1) 44 | assert q == {"integer": {"$gte": 1}} 45 | 46 | 47 | async def test_in(): 48 | q = In(Sample.integer, [1]) 49 | assert q == {"integer": {"$in": [1]}} 50 | 51 | q = In(Sample.integer, [1]) 52 | assert q == {"integer": {"$in": [1]}} 53 | 54 | 55 | async def test_lt(): 56 | q = Sample.integer < 1 57 | assert q == {"integer": {"$lt": 1}} 58 | 59 | q = LT(Sample.integer, 1) 60 | assert q == {"integer": {"$lt": 1}} 61 | 62 | q = LT("integer", 1) 63 | assert q == {"integer": {"$lt": 1}} 64 | 65 | 66 | async def test_lte(): 67 | q = Sample.integer <= 1 68 | assert q == {"integer": {"$lte": 1}} 69 | 70 | q = LTE(Sample.integer, 1) 71 | assert q == {"integer": {"$lte": 1}} 72 | 73 | q = LTE("integer", 1) 74 | assert q == {"integer": {"$lte": 1}} 75 | 76 | 77 | async def test_ne(): 78 | q = Sample.integer != 1 79 | assert q == {"integer": {"$ne": 1}} 80 | 81 | q = NE(Sample.integer, 1) 82 | assert q == {"integer": {"$ne": 1}} 83 | 84 | q = NE("integer", 1) 85 | assert q == {"integer": {"$ne": 1}} 86 | 87 | 88 | async def test_nin(): 89 | q = NotIn(Sample.integer, [1]) 90 | assert q == {"integer": {"$nin": [1]}} 91 | 92 | q = NotIn(Sample.integer, [1]) 93 | assert q == {"integer": {"$nin": [1]}} 94 | -------------------------------------------------------------------------------- /tests/odm/query/test_update_methods.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.update.general import Max 2 | from beanie.odm.queries.update import UpdateQuery, UpdateMany 3 | from tests.odm.models import Sample 4 | 5 | 6 | async def test_set(session): 7 | q = Sample.find_many(Sample.integer == 1).set( 8 | {Sample.integer: 100}, session=session 9 | ) 10 | 11 | assert isinstance(q, UpdateQuery) 12 | assert isinstance(q, UpdateMany) 13 | assert q.session == session 14 | 15 | assert q.update_query == {"$set": {"integer": 100}} 16 | 17 | q = ( 18 | Sample.find_many(Sample.integer == 1) 19 | .update(Max({Sample.integer: 10})) 20 | .set({Sample.integer: 100}) 21 | ) 22 | 23 | assert isinstance(q, UpdateQuery) 24 | assert isinstance(q, UpdateMany) 25 | 26 | assert q.update_query == { 27 | "$max": {"integer": 10}, 28 | "$set": {"integer": 100}, 29 | } 30 | 31 | 32 | async def test_current_date(session): 33 | q = Sample.find_many(Sample.integer == 1).current_date( 34 | {Sample.timestamp: "timestamp"}, session=session 35 | ) 36 | 37 | assert isinstance(q, UpdateQuery) 38 | assert isinstance(q, UpdateMany) 39 | assert q.session == session 40 | 41 | assert q.update_query == {"$currentDate": {"timestamp": "timestamp"}} 42 | 43 | q = ( 44 | Sample.find_many(Sample.integer == 1) 45 | .update(Max({Sample.integer: 10})) 46 | .current_date({Sample.timestamp: "timestamp"}) 47 | ) 48 | 49 | assert isinstance(q, UpdateQuery) 50 | assert isinstance(q, UpdateMany) 51 | 52 | assert q.update_query == { 53 | "$max": {"integer": 10}, 54 | "$currentDate": {"timestamp": "timestamp"}, 55 | } 56 | 57 | 58 | async def test_inc(session): 59 | q = Sample.find_many(Sample.integer == 1).inc( 60 | {Sample.integer: 100}, session=session 61 | ) 62 | 63 | assert isinstance(q, UpdateQuery) 64 | assert isinstance(q, UpdateMany) 65 | assert q.session == session 66 | 67 | assert q.update_query == {"$inc": {"integer": 100}} 68 | 69 | q = ( 70 | Sample.find_many(Sample.integer == 1) 71 | .update(Max({Sample.integer: 10})) 72 | .inc({Sample.integer: 100}) 73 | ) 74 | 75 | assert isinstance(q, UpdateQuery) 76 | assert isinstance(q, UpdateMany) 77 | 78 | assert q.update_query == { 79 | "$max": {"integer": 10}, 80 | "$inc": {"integer": 100}, 81 | } 82 | -------------------------------------------------------------------------------- /tests/odm/test_expression_fields.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.enums import SortDirection 2 | from beanie.odm.operators.find.comparison import In, NotIn 3 | from tests.odm.models import Sample 4 | 5 | 6 | def test_nesting(): 7 | assert Sample.id == "_id" 8 | 9 | q = Sample.find_many(Sample.integer == 1) 10 | assert q.get_filter_query() == {"integer": 1} 11 | assert Sample.integer == "integer" 12 | 13 | q = Sample.find_many(Sample.nested.integer == 1) 14 | assert q.get_filter_query() == {"nested.integer": 1} 15 | assert Sample.nested.integer == "nested.integer" 16 | 17 | q = Sample.find_many(Sample.union.s == "test") 18 | assert q.get_filter_query() == {"union.s": "test"} 19 | assert Sample.union.s == "union.s" 20 | 21 | q = Sample.find_many(Sample.nested.optional == None) # noqa 22 | assert q.get_filter_query() == {"nested.optional": None} 23 | assert Sample.nested.optional == "nested.optional" 24 | 25 | q = Sample.find_many(Sample.nested.integer == 1).find_many( 26 | Sample.nested.union.s == "test" 27 | ) 28 | assert q.get_filter_query() == { 29 | "$and": [{"nested.integer": 1}, {"nested.union.s": "test"}] 30 | } 31 | 32 | 33 | def test_eq(): 34 | q = Sample.find_many(Sample.integer == 1) 35 | assert q.get_filter_query() == {"integer": 1} 36 | 37 | 38 | def test_gt(): 39 | q = Sample.find_many(Sample.integer > 1) 40 | assert q.get_filter_query() == {"integer": {"$gt": 1}} 41 | 42 | 43 | def test_gte(): 44 | q = Sample.find_many(Sample.integer >= 1) 45 | assert q.get_filter_query() == {"integer": {"$gte": 1}} 46 | 47 | 48 | def test_in(): 49 | q = Sample.find_many(In(Sample.integer, [1, 2, 3, 4])) 50 | assert dict(q.get_filter_query()) == {"integer": {"$in": [1, 2, 3, 4]}} 51 | 52 | 53 | def test_lt(): 54 | q = Sample.find_many(Sample.integer < 1) 55 | assert q.get_filter_query() == {"integer": {"$lt": 1}} 56 | 57 | 58 | def test_lte(): 59 | q = Sample.find_many(Sample.integer <= 1) 60 | assert q.get_filter_query() == {"integer": {"$lte": 1}} 61 | 62 | 63 | def test_ne(): 64 | q = Sample.find_many(Sample.integer != 1) 65 | assert q.get_filter_query() == {"integer": {"$ne": 1}} 66 | 67 | 68 | def test_nin(): 69 | q = Sample.find_many(NotIn(Sample.integer, [1, 2, 3, 4])) 70 | assert dict(q.get_filter_query()) == {"integer": {"$nin": [1, 2, 3, 4]}} 71 | 72 | 73 | def test_pos(): 74 | q = +Sample.integer 75 | assert q == ("integer", SortDirection.ASCENDING) 76 | 77 | 78 | def test_neg(): 79 | q = -Sample.integer 80 | assert q == ("integer", SortDirection.DESCENDING) 81 | -------------------------------------------------------------------------------- /tests/odm/documents/test_multi_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.odm.models import ( 4 | DocumentMultiModelOne, 5 | DocumentMultiModelTwo, 6 | DocumentUnion, 7 | ) 8 | 9 | 10 | @pytest.mark.skip("Not supported") 11 | class TestMultiModel: 12 | async def test_multi_model(self): 13 | doc_1 = await DocumentMultiModelOne().insert() 14 | doc_2 = await DocumentMultiModelTwo().insert() 15 | 16 | new_doc_1 = await DocumentMultiModelOne.get(doc_1.id) 17 | new_doc_2 = await DocumentMultiModelTwo.get(doc_2.id) 18 | 19 | assert new_doc_1 is not None 20 | assert new_doc_2 is not None 21 | 22 | new_doc_1 = await DocumentMultiModelTwo.get(doc_1.id) 23 | new_doc_2 = await DocumentMultiModelOne.get(doc_2.id) 24 | 25 | assert new_doc_1 is None 26 | assert new_doc_2 is None 27 | 28 | new_docs_1 = await DocumentMultiModelOne.find({}).to_list() 29 | new_docs_2 = await DocumentMultiModelTwo.find({}).to_list() 30 | 31 | assert len(new_docs_1) == 1 32 | assert len(new_docs_2) == 1 33 | 34 | await DocumentMultiModelOne.update_all({"$set": {"shared": 100}}) 35 | 36 | new_doc_1 = await DocumentMultiModelOne.get(doc_1.id) 37 | new_doc_2 = await DocumentMultiModelTwo.get(doc_2.id) 38 | 39 | assert new_doc_1.shared == 100 40 | assert new_doc_2.shared == 0 41 | 42 | async def test_union_doc(self): 43 | await DocumentMultiModelOne().insert() 44 | await DocumentMultiModelTwo().insert() 45 | await DocumentMultiModelOne().insert() 46 | await DocumentMultiModelTwo().insert() 47 | 48 | docs = await DocumentUnion.all().to_list() 49 | assert isinstance(docs[0], DocumentMultiModelOne) 50 | assert isinstance(docs[1], DocumentMultiModelTwo) 51 | assert isinstance(docs[2], DocumentMultiModelOne) 52 | assert isinstance(docs[3], DocumentMultiModelTwo) 53 | 54 | async def test_union_doc_aggregation(self): 55 | await DocumentMultiModelOne().insert() 56 | await DocumentMultiModelTwo().insert() 57 | await DocumentMultiModelOne().insert() 58 | await DocumentMultiModelTwo().insert() 59 | 60 | docs = await DocumentUnion.aggregate( 61 | [{"$match": {"$expr": {"$eq": ["$int_filed", 0]}}}] 62 | ).to_list() 63 | assert len(docs) == 2 64 | 65 | async def test_union_doc_link(self): 66 | doc_1 = await DocumentMultiModelOne().insert() 67 | await DocumentMultiModelTwo(linked_doc=doc_1).insert() 68 | 69 | docs = await DocumentMultiModelTwo.find({}, fetch_links=True).to_list() 70 | assert isinstance(docs[0].linked_doc, DocumentMultiModelOne) 71 | -------------------------------------------------------------------------------- /tests/odm/documents/test_aggregate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import Field 3 | from pydantic.main import BaseModel 4 | 5 | from tests.odm.models import DocumentTestModel 6 | 7 | 8 | @pytest.mark.skip("Not supported") 9 | async def test_aggregate(documents): 10 | await documents(4, "uno") 11 | await documents(2, "dos") 12 | await documents(1, "cuatro") 13 | result = await DocumentTestModel.aggregate( 14 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 15 | ).to_list() 16 | assert len(result) == 3 17 | assert {"_id": "cuatro", "total": 0} in result 18 | assert {"_id": "dos", "total": 1} in result 19 | assert {"_id": "uno", "total": 6} in result 20 | 21 | 22 | @pytest.mark.skip("Not supported") 23 | async def test_aggregate_with_filter(documents): 24 | await documents(4, "uno") 25 | await documents(2, "dos") 26 | await documents(1, "cuatro") 27 | result = ( 28 | await DocumentTestModel.find(DocumentTestModel.test_int >= 1) 29 | .aggregate( 30 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 31 | ) 32 | .to_list() 33 | ) 34 | assert len(result) == 2 35 | assert {"_id": "dos", "total": 1} in result 36 | assert {"_id": "uno", "total": 6} in result 37 | 38 | 39 | @pytest.mark.skip("Not supported") 40 | async def test_aggregate_with_item_model(documents): 41 | class OutputItem(BaseModel): 42 | id: str = Field(None, alias="_id") 43 | total: int 44 | 45 | await documents(4, "uno") 46 | await documents(2, "dos") 47 | await documents(1, "cuatro") 48 | ids = [] 49 | async for i in DocumentTestModel.aggregate( 50 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}], 51 | projection_model=OutputItem, 52 | ): 53 | if i.id == "cuatro": 54 | assert i.total == 0 55 | elif i.id == "dos": 56 | assert i.total == 1 57 | elif i.id == "uno": 58 | assert i.total == 6 59 | else: 60 | raise KeyError 61 | ids.append(i.id) 62 | assert set(ids) == {"cuatro", "dos", "uno"} 63 | 64 | 65 | @pytest.mark.skip("Not supported") 66 | async def test_aggregate_with_session(documents, session): 67 | await documents(4, "uno") 68 | await documents(2, "dos") 69 | await documents(1, "cuatro") 70 | result = await DocumentTestModel.aggregate( 71 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}], 72 | session=session, 73 | ).to_list() 74 | assert len(result) == 3 75 | assert {"_id": "cuatro", "total": 0} in result 76 | assert {"_id": "dos", "total": 1} in result 77 | assert {"_id": "uno", "total": 6} in result 78 | -------------------------------------------------------------------------------- /tests/odm/test_actions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie import Before, After 4 | 5 | from tests.odm.models import DocumentWithActions, InheritedDocumentWithActions 6 | 7 | 8 | class TestActions: 9 | @pytest.mark.parametrize( 10 | "doc_class", [DocumentWithActions, InheritedDocumentWithActions] 11 | ) 12 | async def test_actions_insert(self, doc_class): 13 | test_name = f"test_actions_insert_{doc_class}" 14 | sample = doc_class(name=test_name) 15 | 16 | await sample.insert() 17 | assert sample.name != test_name 18 | assert sample.name == test_name.capitalize() 19 | assert sample.num_1 == 1 20 | assert sample.num_2 == 9 21 | 22 | @pytest.mark.parametrize( 23 | "doc_class", [DocumentWithActions, InheritedDocumentWithActions] 24 | ) 25 | async def test_actions_replace(self, doc_class): 26 | test_name = f"test_actions_replace_{doc_class}" 27 | sample = doc_class(name=test_name) 28 | 29 | await sample.insert() 30 | 31 | await sample.replace() 32 | assert sample.num_1 == 2 33 | assert sample.num_3 == 99 34 | 35 | @pytest.mark.parametrize( 36 | "doc_class", [DocumentWithActions, InheritedDocumentWithActions] 37 | ) 38 | async def test_skip_actions_insert(self, doc_class): 39 | test_name = f"test_skip_actions_insert_{doc_class}" 40 | sample = doc_class(name=test_name) 41 | 42 | await sample.insert(skip_actions=[After, "capitalize_name"]) 43 | # capitalize_name has been skipped 44 | assert sample.name == test_name 45 | # add_one has not been skipped 46 | assert sample.num_1 == 1 47 | # num_2_change has been skipped 48 | assert sample.num_2 == 10 49 | 50 | @pytest.mark.parametrize( 51 | "doc_class", [DocumentWithActions, InheritedDocumentWithActions] 52 | ) 53 | async def test_skip_actions_replace(self, doc_class): 54 | test_name = f"test_skip_actions_replace{doc_class}" 55 | sample = doc_class(name=test_name) 56 | 57 | await sample.insert() 58 | 59 | await sample.replace(skip_actions=[Before, "num_3_change"]) 60 | # add_one has been skipped 61 | assert sample.num_1 == 1 62 | # num_3_change has been skipped 63 | assert sample.num_3 == 100 64 | 65 | @pytest.mark.parametrize( 66 | "doc_class", [DocumentWithActions, InheritedDocumentWithActions] 67 | ) 68 | async def test_actions_delete(self, doc_class): 69 | test_name = f"test_actions_insert_{doc_class}" 70 | sample = doc_class(name=test_name) 71 | 72 | await sample.delete() 73 | assert sample.Inner.inner_num_1 == 1 74 | assert sample.Inner.inner_num_2 == 2 75 | -------------------------------------------------------------------------------- /tests/odm/query/test_aggregate_methods.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.odm.models import Sample 4 | 5 | 6 | @pytest.mark.skip("Not supported") 7 | async def test_sum(preset_documents, session): 8 | n = await Sample.find_many(Sample.integer == 1).sum(Sample.increment) 9 | 10 | assert n == 12 11 | 12 | n = await Sample.find_many(Sample.integer == 1).sum( 13 | Sample.increment, session=session 14 | ) 15 | 16 | assert n == 12 17 | 18 | 19 | @pytest.mark.skip("Not supported") 20 | async def test_sum_without_docs(session): 21 | n = await Sample.find_many(Sample.integer == 1).sum(Sample.increment) 22 | 23 | assert n is None 24 | 25 | n = await Sample.find_many(Sample.integer == 1).sum( 26 | Sample.increment, session=session 27 | ) 28 | 29 | assert n is None 30 | 31 | 32 | @pytest.mark.skip("Not supported") 33 | async def test_avg(preset_documents, session): 34 | n = await Sample.find_many(Sample.integer == 1).avg(Sample.increment) 35 | 36 | assert n == 4 37 | n = await Sample.find_many(Sample.integer == 1).avg( 38 | Sample.increment, session=session 39 | ) 40 | 41 | assert n == 4 42 | 43 | 44 | @pytest.mark.skip("Not supported") 45 | async def test_avg_without_docs(session): 46 | n = await Sample.find_many(Sample.integer == 1).avg(Sample.increment) 47 | 48 | assert n is None 49 | n = await Sample.find_many(Sample.integer == 1).avg( 50 | Sample.increment, session=session 51 | ) 52 | 53 | assert n is None 54 | 55 | 56 | @pytest.mark.skip("Not supported") 57 | async def test_max(preset_documents, session): 58 | n = await Sample.find_many(Sample.integer == 1).max(Sample.increment) 59 | 60 | assert n == 5 61 | 62 | n = await Sample.find_many(Sample.integer == 1).max( 63 | Sample.increment, session=session 64 | ) 65 | 66 | assert n == 5 67 | 68 | 69 | @pytest.mark.skip("Not supported") 70 | async def test_max_without_docs(session): 71 | n = await Sample.find_many(Sample.integer == 1).max(Sample.increment) 72 | 73 | assert n is None 74 | 75 | n = await Sample.find_many(Sample.integer == 1).max( 76 | Sample.increment, session=session 77 | ) 78 | 79 | assert n is None 80 | 81 | 82 | @pytest.mark.skip("Not supported") 83 | async def test_min(preset_documents, session): 84 | n = await Sample.find_many(Sample.integer == 1).min(Sample.increment) 85 | 86 | assert n == 3 87 | 88 | n = await Sample.find_many(Sample.integer == 1).min( 89 | Sample.increment, session=session 90 | ) 91 | 92 | assert n == 3 93 | 94 | 95 | @pytest.mark.skip("Not supported") 96 | async def test_min_without_docs(session): 97 | n = await Sample.find_many(Sample.integer == 1).min(Sample.increment) 98 | 99 | assert n is None 100 | 101 | n = await Sample.find_many(Sample.integer == 1).min( 102 | Sample.increment, session=session 103 | ) 104 | 105 | assert n is None 106 | -------------------------------------------------------------------------------- /tests/odm/operators/find/test_geospatial.py: -------------------------------------------------------------------------------- 1 | from beanie.odm.operators.find.geospatial import ( 2 | GeoIntersects, 3 | GeoWithin, 4 | Near, 5 | NearSphere, 6 | ) 7 | from tests.odm.models import Sample 8 | 9 | 10 | async def test_geo_intersects(): 11 | q = GeoIntersects( 12 | Sample.geo, geo_type="Polygon", coordinates=[[1, 1], [2, 2], [3, 3]] 13 | ) 14 | assert q == { 15 | "geo": { 16 | "$geoIntersects": { 17 | "$geometry": { 18 | "type": "Polygon", 19 | "coordinates": [[1, 1], [2, 2], [3, 3]], 20 | } 21 | } 22 | } 23 | } 24 | 25 | 26 | async def test_geo_within(): 27 | q = GeoWithin( 28 | Sample.geo, geo_type="Polygon", coordinates=[[1, 1], [2, 2], [3, 3]] 29 | ) 30 | assert q == { 31 | "geo": { 32 | "$geoWithin": { 33 | "$geometry": { 34 | "type": "Polygon", 35 | "coordinates": [[1, 1], [2, 2], [3, 3]], 36 | } 37 | } 38 | } 39 | } 40 | 41 | 42 | async def test_near(): 43 | q = Near(Sample.geo, longitude=1.1, latitude=2.2) 44 | assert q == { 45 | "geo": { 46 | "$near": { 47 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]} 48 | } 49 | } 50 | } 51 | 52 | q = Near(Sample.geo, longitude=1.1, latitude=2.2, max_distance=1) 53 | assert q == { 54 | "geo": { 55 | "$near": { 56 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 57 | "$maxDistance": 1, 58 | } 59 | } 60 | } 61 | 62 | q = Near( 63 | Sample.geo, 64 | longitude=1.1, 65 | latitude=2.2, 66 | max_distance=1, 67 | min_distance=0.5, 68 | ) 69 | assert q == { 70 | "geo": { 71 | "$near": { 72 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 73 | "$maxDistance": 1, 74 | "$minDistance": 0.5, 75 | } 76 | } 77 | } 78 | 79 | 80 | async def test_near_sphere(): 81 | q = NearSphere(Sample.geo, longitude=1.1, latitude=2.2) 82 | assert q == { 83 | "geo": { 84 | "$nearSphere": { 85 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]} 86 | } 87 | } 88 | } 89 | 90 | q = NearSphere(Sample.geo, longitude=1.1, latitude=2.2, max_distance=1) 91 | assert q == { 92 | "geo": { 93 | "$nearSphere": { 94 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 95 | "$maxDistance": 1, 96 | } 97 | } 98 | } 99 | 100 | q = NearSphere( 101 | Sample.geo, 102 | longitude=1.1, 103 | latitude=2.2, 104 | max_distance=1, 105 | min_distance=0.5, 106 | ) 107 | assert q == { 108 | "geo": { 109 | "$nearSphere": { 110 | "$geometry": {"type": "Point", "coordinates": [1.1, 2.2]}, 111 | "$maxDistance": 1, 112 | "$minDistance": 0.5, 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /tests/odm/query/test_delete.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mongita.errors import MongitaError 3 | 4 | from tests.odm.models import Sample 5 | from beanie.odm.queries.delete import DeleteMany 6 | 7 | 8 | async def test_delete_many(preset_documents): 9 | count_before = await Sample.count() 10 | count_find = await Sample.find_many(Sample.integer > 1).count() # noqa 11 | delete_result = await Sample.find_many(Sample.integer > 1).delete() # noqa 12 | count_deleted = delete_result.deleted_count 13 | count_after = await Sample.count() 14 | assert count_before - count_find == count_after 15 | assert count_after + count_deleted == count_before 16 | assert isinstance( 17 | Sample.find_many(Sample.integer > 1).delete_many(), 18 | DeleteMany, 19 | ) # noqa 20 | 21 | 22 | async def test_delete_all(preset_documents): 23 | count_before = await Sample.count() 24 | delete_result = await Sample.delete_all() 25 | count_deleted = delete_result.deleted_count 26 | count_after = await Sample.count() 27 | assert count_after == 0 28 | assert count_after + count_deleted == count_before 29 | 30 | 31 | async def test_delete_self(preset_documents): 32 | count_before = await Sample.count() 33 | result = await Sample.find_many(Sample.integer > 1).to_list() # noqa 34 | a = result[0] 35 | delete_result = await a.delete() 36 | count_deleted = delete_result.deleted_count 37 | count_after = await Sample.count() 38 | assert count_before == count_after + 1 39 | assert count_deleted == 1 40 | 41 | 42 | async def test_delete_one(preset_documents): 43 | count_before = await Sample.count() 44 | delete_result = await Sample.find_one(Sample.integer > 1).delete() # noqa 45 | count_after = await Sample.count() 46 | count_deleted = delete_result.deleted_count 47 | assert count_before == count_after + 1 48 | assert count_deleted == 1 49 | 50 | count_before = await Sample.count() 51 | delete_result = await Sample.find_one( 52 | Sample.integer > 1 53 | ).delete_one() # noqa 54 | count_deleted = delete_result.deleted_count 55 | count_after = await Sample.count() 56 | assert count_before == count_after + 1 57 | assert count_deleted == 1 58 | 59 | 60 | async def test_delete_many_with_session(preset_documents, session): 61 | count_before = await Sample.count() 62 | count_find = await Sample.find_many(Sample.integer > 1).count() # noqa 63 | q = Sample.find_many(Sample.integer > 1).delete(session=session) # noqa 64 | assert q.session == session 65 | 66 | q = ( 67 | Sample.find_many(Sample.integer > 1) 68 | .delete() 69 | .set_session(session=session) 70 | ) # noqa 71 | 72 | assert q.session == session 73 | 74 | delete_result = await q 75 | count_deleted = delete_result.deleted_count 76 | count_after = await Sample.count() 77 | assert count_before - count_find == count_after 78 | assert count_after + count_deleted == count_before 79 | 80 | 81 | @pytest.mark.skip("Not supported") 82 | async def test_delete_pymongo_kwargs(preset_documents): 83 | with pytest.raises(MongitaError): 84 | await Sample.find_many(Sample.increment > 4).delete(wrong="integer_1") 85 | 86 | delete_result = await Sample.find_many(Sample.increment > 4).delete( 87 | hint="integer_1" 88 | ) 89 | assert delete_result is not None 90 | 91 | delete_result = await Sample.find_one(Sample.increment > 4).delete( 92 | hint="integer_1" 93 | ) 94 | assert delete_result is not None 95 | -------------------------------------------------------------------------------- /tests/odm/test_cache.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from tests.odm.models import DocumentTestModel 6 | 7 | 8 | async def test_find_one(documents): 9 | await documents(5) 10 | doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 11 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 1).set( 12 | {DocumentTestModel.test_str: "NEW_VALUE"} 13 | ) 14 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 15 | assert doc == new_doc 16 | 17 | new_doc = await DocumentTestModel.find_one( 18 | DocumentTestModel.test_int == 1, ignore_cache=True 19 | ) 20 | assert doc != new_doc 21 | 22 | await asyncio.sleep(10) 23 | 24 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 25 | assert doc != new_doc 26 | 27 | 28 | async def test_find_many(documents): 29 | await documents(5) 30 | docs = await DocumentTestModel.find( 31 | DocumentTestModel.test_int > 1 32 | ).to_list() 33 | 34 | await DocumentTestModel.find(DocumentTestModel.test_int > 1).set( 35 | {DocumentTestModel.test_str: "NEW_VALUE"} 36 | ) 37 | 38 | new_docs = await DocumentTestModel.find( 39 | DocumentTestModel.test_int > 1 40 | ).to_list() 41 | assert docs == new_docs 42 | 43 | new_docs = await DocumentTestModel.find( 44 | DocumentTestModel.test_int > 1, ignore_cache=True 45 | ).to_list() 46 | assert docs != new_docs 47 | 48 | await asyncio.sleep(10) 49 | 50 | new_docs = await DocumentTestModel.find( 51 | DocumentTestModel.test_int > 1 52 | ).to_list() 53 | assert docs != new_docs 54 | 55 | 56 | @pytest.mark.skip("Not supported") 57 | async def test_aggregation(documents): 58 | await documents(5) 59 | docs = await DocumentTestModel.aggregate( 60 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 61 | ).to_list() 62 | 63 | await DocumentTestModel.find(DocumentTestModel.test_int > 1).set( 64 | {DocumentTestModel.test_str: "NEW_VALUE"} 65 | ) 66 | 67 | new_docs = await DocumentTestModel.aggregate( 68 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 69 | ).to_list() 70 | assert docs == new_docs 71 | 72 | new_docs = await DocumentTestModel.aggregate( 73 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}], 74 | ignore_cache=True, 75 | ).to_list() 76 | assert docs != new_docs 77 | 78 | await asyncio.sleep(10) 79 | 80 | new_docs = await DocumentTestModel.aggregate( 81 | [{"$group": {"_id": "$test_str", "total": {"$sum": "$test_int"}}}] 82 | ).to_list() 83 | assert docs != new_docs 84 | 85 | 86 | async def test_capacity(documents): 87 | await documents(10) 88 | docs = [] 89 | for i in range(10): 90 | docs.append( 91 | await DocumentTestModel.find_one(DocumentTestModel.test_int == i) 92 | ) 93 | 94 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 1).set( 95 | {DocumentTestModel.test_str: "NEW_VALUE"} 96 | ) 97 | await DocumentTestModel.find_one(DocumentTestModel.test_int == 9).set( 98 | {DocumentTestModel.test_str: "NEW_VALUE"} 99 | ) 100 | 101 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 1) 102 | assert docs[1] != new_doc 103 | 104 | new_doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 9) 105 | assert docs[9] == new_doc 106 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | config.cnf 2 | *.pyc 3 | *.iml 4 | */*.pytest* 5 | .rnd 6 | ### Python template 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | .static_storage/ 62 | .media/ 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | ### VirtualEnv template 112 | # Virtualenv 113 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 114 | .Python 115 | [Bb]in 116 | [Ii]nclude 117 | [Ll]ib 118 | [Ll]ib64 119 | [Ll]ocal 120 | pyvenv.cfg 121 | .venv 122 | pip-selfcheck.json 123 | ### JetBrains template 124 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 125 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 126 | 127 | # User-specific stuff: 128 | .idea/**/workspace.xml 129 | .idea/**/tasks.xml 130 | .idea/dictionaries 131 | 132 | # Sensitive or high-churn files: 133 | .idea/**/dataSources/ 134 | .idea/**/dataSources.ids 135 | .idea/**/dataSources.xml 136 | .idea/**/dataSources.local.xml 137 | .idea/**/sqlDataSources.xml 138 | .idea/**/dynamic.xml 139 | .idea/**/uiDesigner.xml 140 | 141 | # Gradle: 142 | .idea/**/gradle.xml 143 | .idea/**/libraries 144 | 145 | # CMake 146 | cmake-build-debug/ 147 | cmake-build-release/ 148 | 149 | # Mongo Explorer plugin: 150 | .idea/**/mongoSettings.xml 151 | 152 | ## File-based project format: 153 | *.iws 154 | 155 | ## Plugin-specific files: 156 | 157 | # IntelliJ 158 | out/ 159 | 160 | # mpeltonen/sbt-idea plugin 161 | .idea_modules/ 162 | 163 | # JIRA plugin 164 | atlassian-ide-plugin.xml 165 | 166 | # Cursive Clojure plugin 167 | .idea/replstate.xml 168 | 169 | # Crashlytics plugin (for Android Studio and IntelliJ) 170 | com_crashlytics_export_strings.xml 171 | crashlytics.properties 172 | crashlytics-build.properties 173 | fabric.properties 174 | 175 | .idea 176 | .pytest_cache 177 | docs/api 178 | docs/_rst 179 | tags 180 | 181 | tests/assets/tmp 182 | src/api_files/storage_dir 183 | docker-compose-aws.yml 184 | tilt_modules 185 | 186 | # Poetry stuff 187 | poetry.lock 188 | 189 | tests/beanita_path/ 190 | .monty.storage* 191 | *.monty.storage.cfg -------------------------------------------------------------------------------- /tests/odm/query/test_aggregate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import Field 3 | from pydantic.main import BaseModel 4 | 5 | from tests.odm.models import Sample 6 | 7 | 8 | @pytest.mark.skip("Not supported") 9 | async def test_aggregate(preset_documents): 10 | q = Sample.aggregate( 11 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] 12 | ) 13 | assert q.get_aggregation_pipeline() == [ 14 | {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}} 15 | ] 16 | result = await q.to_list() 17 | assert len(result) == 4 18 | assert {"_id": "test_3", "total": 3} in result 19 | assert {"_id": "test_1", "total": 3} in result 20 | assert {"_id": "test_0", "total": 0} in result 21 | assert {"_id": "test_2", "total": 6} in result 22 | 23 | 24 | @pytest.mark.skip("Not supported") 25 | async def test_aggregate_with_filter(preset_documents): 26 | q = Sample.find(Sample.increment >= 4).aggregate( 27 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] 28 | ) 29 | assert q.get_aggregation_pipeline() == [ 30 | {"$match": {"increment": {"$gte": 4}}}, 31 | {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, 32 | ] 33 | result = await q.to_list() 34 | assert len(result) == 3 35 | assert {"_id": "test_1", "total": 2} in result 36 | assert {"_id": "test_2", "total": 6} in result 37 | assert {"_id": "test_3", "total": 3} in result 38 | 39 | 40 | @pytest.mark.skip("Not supported") 41 | async def test_aggregate_with_projection_model(preset_documents): 42 | class OutputItem(BaseModel): 43 | id: str = Field(None, alias="_id") 44 | total: int 45 | 46 | ids = [] 47 | q = Sample.find(Sample.increment >= 4).aggregate( 48 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}], 49 | projection_model=OutputItem, 50 | ) 51 | assert q.get_aggregation_pipeline() == [ 52 | {"$match": {"increment": {"$gte": 4}}}, 53 | {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, 54 | {"$project": {"_id": 1, "total": 1}}, 55 | ] 56 | async for i in q: 57 | if i.id == "test_1": 58 | assert i.total == 2 59 | elif i.id == "test_2": 60 | assert i.total == 6 61 | elif i.id == "test_3": 62 | assert i.total == 3 63 | else: 64 | raise KeyError 65 | ids.append(i.id) 66 | assert set(ids) == {"test_1", "test_2", "test_3"} 67 | 68 | 69 | @pytest.mark.skip("Not supported") 70 | async def test_aggregate_with_session(preset_documents, session): 71 | q = Sample.find(Sample.increment >= 4).aggregate( 72 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}], 73 | session=session, 74 | ) 75 | assert q.session == session 76 | 77 | q = Sample.find(Sample.increment >= 4, session=session).aggregate( 78 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] 79 | ) 80 | assert q.session == session 81 | 82 | result = await q.to_list() 83 | 84 | assert len(result) == 3 85 | assert {"_id": "test_1", "total": 2} in result 86 | assert {"_id": "test_2", "total": 6} in result 87 | assert {"_id": "test_3", "total": 3} in result 88 | 89 | 90 | @pytest.mark.skip("Not supported") 91 | async def test_aggregate_pymongo_kwargs(preset_documents): 92 | with pytest.raises(TypeError): 93 | await Sample.find(Sample.increment >= 4).aggregate( 94 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}], 95 | wrong=True, 96 | ) 97 | 98 | with pytest.raises(TypeError): 99 | await Sample.find(Sample.increment >= 4).aggregate( 100 | [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}], 101 | hint="integer_1", 102 | ) 103 | -------------------------------------------------------------------------------- /beanita/collection.py: -------------------------------------------------------------------------------- 1 | from mongita.collection import ( 2 | Collection as MongitaCollection, 3 | _validate_filter, 4 | _get_item_from_doc, 5 | ) 6 | from mongita.common import support_alert 7 | from mongita.errors import MongitaError 8 | 9 | from beanita.cursor import Cursor 10 | from beanita.results import UpdateResult 11 | 12 | 13 | class Collection(MongitaCollection): 14 | UNIMPLEMENTED = [ 15 | "aggregate", 16 | "aggregate_raw_batches", 17 | "bulk_write", 18 | "codec_options", 19 | "ensure_index", 20 | "estimated_document_count", 21 | "find_one_and_delete", 22 | "find_one_and_replace", 23 | "find_one_and_update", 24 | "find_raw_batches", 25 | "inline_map_reduce", 26 | "list_indexes", 27 | "map_reduce", 28 | "next", 29 | "options", 30 | "read_concern", 31 | "read_preference", 32 | "rename", 33 | "watch", 34 | ] 35 | 36 | @support_alert 37 | async def index_information(self): 38 | ret = {"_id_": {"key": [("_id", 1)]}} 39 | metadata = self.__get_metadata() 40 | for idx in metadata.get("indexes", {}).values(): 41 | ret[idx["_id"]] = {"key": [(idx["key_str"], idx["direction"])]} 42 | return ret 43 | 44 | async def create_indexes(self, indexes): 45 | return [] # TODO create indexes here 46 | 47 | @support_alert 48 | async def insert_one(self, document, session): 49 | return super().insert_one(document) 50 | 51 | @support_alert 52 | async def insert_many(self, documents, ordered=True, session=None): 53 | return super().insert_many(documents, ordered) 54 | 55 | @support_alert 56 | async def find_one( 57 | self, filter=None, sort=None, skip=None, session=None, projection=None 58 | ): 59 | return super().find_one(filter, sort, skip) 60 | 61 | @support_alert 62 | def find( 63 | self, 64 | filter=None, 65 | sort=None, 66 | limit=None, 67 | skip=None, 68 | projection=None, 69 | session=None, 70 | ): 71 | print(filter) 72 | res = super().find(filter, sort, limit, skip) 73 | print(res) 74 | return Cursor(res) 75 | 76 | @support_alert 77 | async def update_one(self, filter, update, upsert=False, session=None): 78 | return super().update_one(filter, update, upsert) 79 | 80 | @support_alert 81 | async def update_many(self, filter, update, upsert=False, session=None): 82 | return super().update_many(filter, update, upsert) 83 | 84 | @support_alert 85 | async def replace_one( 86 | self, filter, replacement, upsert=False, session=None 87 | ): 88 | res = super().replace_one(filter, replacement, upsert) 89 | return UpdateResult.from_mongita_result(res) 90 | 91 | @support_alert 92 | async def delete_one(self, filter, session=None): 93 | return super().delete_one(filter) 94 | 95 | @support_alert 96 | async def delete_many(self, filter, session=None): 97 | return super().delete_many(filter) 98 | 99 | @support_alert 100 | async def count_documents(self, filter): 101 | return super().count_documents(filter) 102 | 103 | @support_alert 104 | async def distinct(self, key, filter=None, session=None): 105 | """ 106 | Given a key, return all distinct documents matching the key 107 | 108 | :param key str: 109 | :param filter dict|None: 110 | :rtype: list[str] 111 | """ 112 | 113 | if not isinstance(key, str): 114 | raise MongitaError("The 'key' parameter must be a string") 115 | filter = filter or {} 116 | _validate_filter(filter) 117 | uniq = set() 118 | for doc in await self.find(filter).to_list(): 119 | uniq.add(_get_item_from_doc(doc, key)) 120 | uniq.discard(None) 121 | return list(uniq) 122 | 123 | async def drop(self): 124 | await self.delete_many({}) 125 | 126 | async def drop_indexes(self): 127 | pass # TODO add index dropping 128 | # for index_name in (await self.index_information()).keys(): 129 | # self.drop_index(index_name) 130 | -------------------------------------------------------------------------------- /tests/odm/test_fields.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | from typing import Mapping, AbstractSet 4 | import pytest 5 | from pydantic import BaseModel, ValidationError 6 | 7 | from beanie.odm.fields import PydanticObjectId 8 | from beanie.odm.utils.dump import get_dict 9 | from beanie.odm.utils.encoder import Encoder 10 | from tests.odm.models import ( 11 | DocumentWithCustomFiledsTypes, 12 | DocumentWithBsonEncodersFiledsTypes, 13 | DocumentTestModel, 14 | Sample, 15 | ) 16 | 17 | 18 | class M(BaseModel): 19 | p: PydanticObjectId 20 | 21 | 22 | def test_pydantic_object_id_wrong_input(): 23 | with pytest.raises(ValidationError): 24 | M(p="test") 25 | 26 | 27 | def test_pydantic_object_id_bytes_input(): 28 | p = PydanticObjectId() 29 | m = M(p=str(p).encode("utf-8")) 30 | assert m.p == p 31 | with pytest.raises(ValidationError): 32 | M(p=b"test") 33 | 34 | 35 | async def test_bson_encoders_filed_types(): 36 | custom = DocumentWithBsonEncodersFiledsTypes( 37 | color="7fffd4", timestamp=datetime.datetime.utcnow() 38 | ) 39 | encoded = get_dict(custom) 40 | assert isinstance(encoded["timestamp"], str) 41 | c = await custom.insert() 42 | c_fromdb = await DocumentWithBsonEncodersFiledsTypes.get(c.id) 43 | assert c_fromdb.color.as_hex() == c.color.as_hex() 44 | assert isinstance(c_fromdb.timestamp, datetime.datetime) 45 | assert c_fromdb.timestamp, custom.timestamp 46 | 47 | 48 | async def test_custom_filed_types(): 49 | custom1 = DocumentWithCustomFiledsTypes( 50 | color="#753c38", 51 | decimal=500, 52 | secret_bytes=b"secret_bytes", 53 | secret_string="super_secret_password", 54 | ipv4address="127.0.0.1", 55 | ipv4interface="192.0.2.5/24", 56 | ipv4network="192.0.2.0/24", 57 | ipv6address="::abc:7:def", 58 | ipv6interface="2001:db00::2/24", 59 | ipv6network="2001:db00::0/24", 60 | timedelta=4782453, 61 | set_type={"one", "two", "three"}, 62 | tuple_type=tuple([3, "string"]), 63 | path="/etc/hosts", 64 | ) 65 | custom2 = DocumentWithCustomFiledsTypes( 66 | color="magenta", 67 | decimal=500.213, 68 | secret_bytes=b"secret_bytes", 69 | secret_string="super_secret_password", 70 | ipv4address="127.0.0.1", 71 | ipv4interface="192.0.2.5/24", 72 | ipv4network="192.0.2.0/24", 73 | ipv6address="::abc:7:def", 74 | ipv6interface="2001:db00::2/24", 75 | ipv6network="2001:db00::0/24", 76 | timedelta=4782453, 77 | set_type=["one", "two", "three"], 78 | tuple_type=[3, "three"], 79 | path=Path("C:\\Windows"), 80 | ) 81 | c1 = await custom1.insert() 82 | c2 = await custom2.insert() 83 | c1_fromdb = await DocumentWithCustomFiledsTypes.get(c1.id) 84 | c2_fromdb = await DocumentWithCustomFiledsTypes.get(c2.id) 85 | assert set(c1_fromdb.set_type) == set(c1.set_type) 86 | assert set(c2_fromdb.set_type) == set(c2.set_type) 87 | c1_fromdb.set_type = c2_fromdb.set_type = c1.set_type = c2.set_type = None 88 | c1_fromdb.revision_id = None 89 | c2_fromdb.revision_id = None 90 | assert Encoder().encode(c1_fromdb) == Encoder().encode(c1) 91 | assert Encoder().encode(c2_fromdb) == Encoder().encode(c2) 92 | 93 | 94 | async def test_hidden(document): 95 | document = await DocumentTestModel.find_one() 96 | 97 | assert "test_list" not in document.dict() 98 | 99 | 100 | @pytest.mark.parametrize("exclude", [{"test_int"}, {"test_doc": {"test_int"}}]) 101 | async def test_param_exclude(document, exclude): 102 | document = await DocumentTestModel.find_one() 103 | 104 | doc_dict = document.dict(exclude=exclude) 105 | if isinstance(exclude, AbstractSet): 106 | for k in exclude: 107 | assert k not in doc_dict 108 | elif isinstance(exclude, Mapping): 109 | for k, v in exclude.items(): 110 | if isinstance(v, bool) and v: 111 | assert k not in doc_dict 112 | elif isinstance(v, AbstractSet): 113 | for another_k in v: 114 | assert another_k not in doc_dict[k] 115 | 116 | 117 | def test_expression_fields(): 118 | assert Sample.nested.integer == "nested.integer" 119 | assert Sample.nested["integer"] == "nested.integer" 120 | -------------------------------------------------------------------------------- /tests/odm/test_state_management.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from bson import ObjectId 3 | 4 | from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved 5 | from tests.odm.models import ( 6 | DocumentWithTurnedOnStateManagement, 7 | DocumentWithTurnedOnReplaceObjects, 8 | DocumentWithTurnedOffStateManagement, 9 | InternalDoc, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def state(): 15 | return { 16 | "num_1": 1, 17 | "num_2": 2, 18 | "_id": ObjectId(), 19 | "internal": InternalDoc(), 20 | } 21 | 22 | 23 | @pytest.fixture 24 | def doc_default(state): 25 | return DocumentWithTurnedOnStateManagement._parse_obj_saving_state(state) 26 | 27 | 28 | @pytest.fixture 29 | def doc_replace(state): 30 | return DocumentWithTurnedOnReplaceObjects._parse_obj_saving_state(state) 31 | 32 | 33 | @pytest.fixture 34 | async def saved_doc_default(doc_default): 35 | await doc_default.insert() 36 | return doc_default 37 | 38 | 39 | def test_use_state_management_property(): 40 | assert DocumentWithTurnedOnStateManagement.use_state_management() is True 41 | assert DocumentWithTurnedOffStateManagement.use_state_management() is False 42 | 43 | 44 | def test_save_state(): 45 | doc = DocumentWithTurnedOnStateManagement( 46 | num_1=1, num_2=2, internal=InternalDoc(num=1, string="s") 47 | ) 48 | assert doc.get_saved_state() is None 49 | doc._save_state() 50 | assert doc.get_saved_state() == { 51 | "num_1": 1, 52 | "num_2": 2, 53 | "internal": {"num": 1, "string": "s", "lst": [1, 2, 3, 4, 5]}, 54 | } 55 | 56 | 57 | def test_parse_object_with_saving_state(): 58 | obj = { 59 | "num_1": 1, 60 | "num_2": 2, 61 | "_id": ObjectId(), 62 | "internal": InternalDoc(), 63 | } 64 | doc = DocumentWithTurnedOnStateManagement._parse_obj_saving_state(obj) 65 | assert doc.get_saved_state() == obj 66 | 67 | 68 | def test_saved_state_needed(): 69 | doc_1 = DocumentWithTurnedOffStateManagement(num_1=1, num_2=2) 70 | with pytest.raises(StateManagementIsTurnedOff): 71 | doc_1.is_changed 72 | 73 | doc_2 = DocumentWithTurnedOnStateManagement( 74 | num_1=1, num_2=2, internal=InternalDoc() 75 | ) 76 | with pytest.raises(StateNotSaved): 77 | doc_2.is_changed 78 | 79 | 80 | def test_if_changed(doc_default): 81 | assert doc_default.is_changed is False 82 | doc_default.num_1 = 10 83 | assert doc_default.is_changed is True 84 | 85 | 86 | def test_get_changes_default(doc_default): 87 | doc_default.internal.num = 1000 88 | doc_default.internal.string = "new_value" 89 | doc_default.internal.lst.append(100) 90 | assert doc_default.get_changes() == { 91 | "internal.num": 1000, 92 | "internal.string": "new_value", 93 | "internal.lst": [1, 2, 3, 4, 5, 100], 94 | } 95 | 96 | 97 | def test_get_changes_default_whole(doc_default): 98 | doc_default.internal = {"num": 1000, "string": "new_value"} 99 | assert doc_default.get_changes() == { 100 | "internal.num": 1000, 101 | "internal.string": "new_value", 102 | } 103 | 104 | 105 | def test_get_changes_replace(doc_replace): 106 | doc_replace.internal.num = 1000 107 | doc_replace.internal.string = "new_value" 108 | assert doc_replace.get_changes() == { 109 | "internal": { 110 | "num": 1000, 111 | "string": "new_value", 112 | "lst": [1, 2, 3, 4, 5], 113 | } 114 | } 115 | 116 | 117 | def test_get_changes_replace_whole(doc_replace): 118 | doc_replace.internal = {"num": 1000, "string": "new_value"} 119 | assert doc_replace.get_changes() == { 120 | "internal": { 121 | "num": 1000, 122 | "string": "new_value", 123 | } 124 | } 125 | 126 | 127 | async def test_save_changes(saved_doc_default): 128 | saved_doc_default.internal.num = 10000 129 | await saved_doc_default.save_changes() 130 | assert saved_doc_default.get_saved_state()["internal"]["num"] == 10000 131 | 132 | new_doc = await DocumentWithTurnedOnStateManagement.get( 133 | saved_doc_default.id 134 | ) 135 | assert new_doc.internal.num == 10000 136 | 137 | 138 | async def test_find_one(saved_doc_default, state): 139 | new_doc = await DocumentWithTurnedOnStateManagement.get( 140 | saved_doc_default.id 141 | ) 142 | assert new_doc.get_saved_state() == state 143 | 144 | new_doc = await DocumentWithTurnedOnStateManagement.find_one( 145 | DocumentWithTurnedOnStateManagement.id == saved_doc_default.id 146 | ) 147 | assert new_doc.get_saved_state() == state 148 | 149 | 150 | async def test_find_many(): 151 | docs = [] 152 | for i in range(10): 153 | docs.append( 154 | DocumentWithTurnedOnStateManagement( 155 | num_1=i, num_2=i + 1, internal=InternalDoc() 156 | ) 157 | ) 158 | await DocumentWithTurnedOnStateManagement.insert_many(docs) 159 | 160 | found_docs = await DocumentWithTurnedOnStateManagement.find( 161 | DocumentWithTurnedOnStateManagement.num_1 > 4 162 | ).to_list() 163 | 164 | for doc in found_docs: 165 | assert doc.get_saved_state() is not None 166 | 167 | 168 | async def test_insert(state): 169 | doc = DocumentWithTurnedOnStateManagement.parse_obj(state) 170 | assert doc.get_saved_state() is None 171 | await doc.insert() 172 | assert doc.get_saved_state() == state 173 | 174 | 175 | async def test_replace(saved_doc_default): 176 | saved_doc_default.num_1 = 100 177 | await saved_doc_default.replace() 178 | assert saved_doc_default.get_saved_state()["num_1"] == 100 179 | 180 | 181 | async def test_save_chages(saved_doc_default): 182 | saved_doc_default.num_1 = 100 183 | await saved_doc_default.save_changes() 184 | assert saved_doc_default.get_saved_state()["num_1"] == 100 185 | 186 | 187 | async def test_rollback(doc_default, state): 188 | doc_default.num_1 = 100 189 | doc_default.rollback() 190 | assert doc_default.num_1 == state["num_1"] 191 | -------------------------------------------------------------------------------- /tests/odm/documents/test_find.py: -------------------------------------------------------------------------------- 1 | import pymongo 2 | 3 | from beanie.odm.fields import PydanticObjectId 4 | from tests.odm.models import DocumentTestModel 5 | 6 | 7 | async def test_get(document): 8 | new_document = await DocumentTestModel.get(document.id) 9 | assert new_document == document 10 | 11 | 12 | async def test_get_not_found(document): 13 | new_document = await DocumentTestModel.get(PydanticObjectId()) 14 | assert new_document is None 15 | 16 | 17 | async def test_find_one(documents): 18 | inserted_one = await documents(1, "kipasa") 19 | await documents(10, "smthe else") 20 | expected_doc_id = PydanticObjectId(inserted_one[0]) 21 | new_document = await DocumentTestModel.find_one({"test_str": "kipasa"}) 22 | assert new_document.id == expected_doc_id 23 | 24 | 25 | async def test_find_one_not_found(documents): 26 | await documents(10, "smthe else") 27 | 28 | new_document = await DocumentTestModel.find_one({"test_str": "wrong"}) 29 | assert new_document is None 30 | 31 | 32 | async def test_find_one_more_than_one_found(documents): 33 | await documents(10, "one") 34 | await documents(10, "two") 35 | new_document = await DocumentTestModel.find_one({"test_str": "one"}) 36 | assert new_document.test_str == "one" 37 | 38 | 39 | async def test_find_all(documents): 40 | await documents(4, "uno") 41 | await documents(2, "dos") 42 | await documents(1, "cuatro") 43 | result = await DocumentTestModel.find_all().to_list() 44 | assert len(result) == 7 45 | 46 | 47 | async def test_find_all_limit(documents): 48 | await documents(4, "uno") 49 | await documents(2, "dos") 50 | await documents(1, "cuatro") 51 | result = await DocumentTestModel.find_all(limit=5).to_list() 52 | assert len(result) == 5 53 | 54 | 55 | async def test_find_all_skip(documents): 56 | await documents(4, "uno") 57 | await documents(2, "dos") 58 | await documents(1, "cuatro") 59 | result = await DocumentTestModel.find_all(skip=1).to_list() 60 | assert len(result) == 6 61 | 62 | 63 | async def test_find_all_sort(documents): 64 | await documents(4, "uno", True) 65 | await documents(2, "dos", True) 66 | await documents(1, "cuatro", True) 67 | result = await DocumentTestModel.find_all( 68 | sort=[ 69 | ("test_str", pymongo.ASCENDING), 70 | ("test_int", pymongo.DESCENDING), 71 | ] 72 | ).to_list() 73 | assert result[0].test_str == "cuatro" 74 | assert result[1].test_str == result[2].test_str == "dos" 75 | assert ( 76 | result[3].test_str 77 | == result[4].test_str 78 | == result[5].test_str 79 | == result[5].test_str 80 | == "uno" 81 | ) 82 | 83 | assert result[1].test_int >= result[2].test_int 84 | assert ( 85 | result[3].test_int 86 | >= result[4].test_int 87 | >= result[5].test_int 88 | >= result[6].test_int 89 | ) 90 | 91 | 92 | async def test_find_many(documents): 93 | await documents(4, "uno") 94 | await documents(2, "dos") 95 | await documents(1, "cuatro") 96 | result = await DocumentTestModel.find_many( 97 | DocumentTestModel.test_str == "uno" 98 | ).to_list() 99 | assert len(result) == 4 100 | 101 | 102 | async def test_find_many_limit(documents): 103 | await documents(4, "uno") 104 | await documents(2, "dos") 105 | await documents(1, "cuatro") 106 | result = await DocumentTestModel.find_many( 107 | {"test_str": "uno"}, limit=2 108 | ).to_list() 109 | assert len(result) == 2 110 | 111 | 112 | async def test_find_many_skip(documents): 113 | await documents(4, "uno") 114 | await documents(2, "dos") 115 | await documents(1, "cuatro") 116 | result = await DocumentTestModel.find_many( 117 | {"test_str": "uno"}, skip=1 118 | ).to_list() 119 | assert len(result) == 3 120 | 121 | 122 | async def test_find_many_sort(documents): 123 | await documents(4, "uno", True) 124 | await documents(2, "dos", True) 125 | await documents(1, "cuatro", True) 126 | result = await DocumentTestModel.find_many( 127 | {"test_str": "uno"}, sort="test_int" 128 | ).to_list() 129 | assert ( 130 | result[0].test_int 131 | <= result[1].test_int 132 | <= result[2].test_int 133 | <= result[3].test_int 134 | ) 135 | 136 | result = await DocumentTestModel.find_many( 137 | {"test_str": "uno"}, sort=[("test_int", pymongo.DESCENDING)] 138 | ).to_list() 139 | assert ( 140 | result[0].test_int 141 | >= result[1].test_int 142 | >= result[2].test_int 143 | >= result[3].test_int 144 | ) 145 | 146 | 147 | async def test_find_many_not_found(documents): 148 | await documents(4, "uno") 149 | await documents(2, "dos") 150 | await documents(1, "cuatro") 151 | result = await DocumentTestModel.find_many({"test_str": "wrong"}).to_list() 152 | assert len(result) == 0 153 | 154 | 155 | async def test_get_with_session(document, session): 156 | new_document = await DocumentTestModel.get(document.id, session=session) 157 | assert new_document == document 158 | 159 | 160 | async def test_find_one_with_session(documents, session): 161 | inserted_one = await documents(1, "kipasa") 162 | await documents(10, "smthe else") 163 | 164 | expected_doc_id = PydanticObjectId(inserted_one[0]) 165 | 166 | new_document = await DocumentTestModel.find_one( 167 | {"test_str": "kipasa"}, session=session 168 | ) 169 | assert new_document.id == expected_doc_id 170 | 171 | 172 | async def test_find_all_with_session(documents, session): 173 | await documents(4, "uno") 174 | await documents(2, "dos") 175 | await documents(1, "cuatro") 176 | result = await DocumentTestModel.find_all(session=session).to_list() 177 | assert len(result) == 7 178 | 179 | 180 | async def test_find_many_with_session(documents, session): 181 | await documents(4, "uno") 182 | await documents(2, "dos") 183 | await documents(1, "cuatro") 184 | result = await DocumentTestModel.find_many( 185 | {"test_str": "uno"}, session=session 186 | ).to_list() 187 | assert len(result) == 4 188 | -------------------------------------------------------------------------------- /tests/odm/conftest.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from random import randint 3 | from typing import List 4 | 5 | import pytest 6 | 7 | from beanie.odm.utils.general import init_beanie 8 | from tests.odm.models import ( 9 | DocumentTestModel, 10 | DocumentTestModelWithIndexFlagsAliases, 11 | SubDocument, 12 | DocumentTestModelWithCustomCollectionName, 13 | DocumentTestModelWithSimpleIndex, 14 | DocumentTestModelWithIndexFlags, 15 | DocumentTestModelWithComplexIndex, 16 | DocumentTestModelFailInspection, 17 | DocumentWithCustomIdUUID, 18 | DocumentWithCustomIdInt, 19 | DocumentWithCustomFiledsTypes, 20 | DocumentWithBsonEncodersFiledsTypes, 21 | DocumentWithActions, 22 | DocumentWithTurnedOnStateManagement, 23 | DocumentWithTurnedOnReplaceObjects, 24 | DocumentWithTurnedOffStateManagement, 25 | DocumentWithValidationOnSave, 26 | DocumentWithRevisionTurnedOn, 27 | DocumentWithPydanticConfig, 28 | DocumentWithExtras, 29 | House, 30 | Window, 31 | Door, 32 | Roof, 33 | InheritedDocumentWithActions, 34 | DocumentForEncodingTest, 35 | DocumentForEncodingTestDate, 36 | DocumentMultiModelOne, 37 | DocumentMultiModelTwo, 38 | DocumentUnion, 39 | ) 40 | from tests.odm.views import TestView 41 | from tests.odm.models import ( 42 | Sample, 43 | Nested, 44 | Option2, 45 | Option1, 46 | GeoObject, 47 | ) 48 | 49 | 50 | @pytest.fixture 51 | def point(): 52 | return { 53 | "longitude": 13.404954, 54 | "latitude": 52.520008, 55 | } 56 | 57 | 58 | @pytest.fixture 59 | async def preset_documents(point): 60 | docs = [] 61 | for i in range(10): 62 | timestamp = datetime.utcnow() - timedelta(days=i) 63 | integer_1: int = i // 3 64 | integer_2: int = i // 2 65 | float_num = integer_1 + 0.3 66 | string: str = f"test_{integer_1}" 67 | option_1 = Option1(s="TEST") 68 | option_2 = Option2(f=3.14) 69 | union = option_1 if i % 2 else option_2 70 | optional = option_2 if not i % 3 else None 71 | geo = GeoObject( 72 | coordinates=[ 73 | point["longitude"] + i / 10, 74 | point["latitude"] + i / 10, 75 | ] 76 | ) 77 | nested = Nested( 78 | integer=integer_2, 79 | option_1=option_1, 80 | union=union, 81 | optional=optional, 82 | ) 83 | 84 | sample = Sample( 85 | timestamp=timestamp, 86 | increment=i, 87 | integer=integer_1, 88 | float_num=float_num, 89 | string=string, 90 | nested=nested, 91 | optional=optional, 92 | union=union, 93 | geo=geo, 94 | ) 95 | docs.append(sample) 96 | await Sample.insert_many(documents=docs) 97 | 98 | 99 | @pytest.fixture() 100 | def sample_doc_not_saved(point): 101 | nested = Nested( 102 | integer=0, 103 | option_1=Option1(s="TEST"), 104 | union=Option1(s="TEST"), 105 | optional=None, 106 | ) 107 | geo = GeoObject( 108 | coordinates=[ 109 | point["longitude"], 110 | point["latitude"], 111 | ] 112 | ) 113 | return Sample( 114 | timestamp=datetime.utcnow(), 115 | increment=0, 116 | integer=0, 117 | float_num=0, 118 | string="TEST_NOT_SAVED", 119 | nested=nested, 120 | optional=None, 121 | union=Option1(s="TEST"), 122 | geo=geo, 123 | ) 124 | 125 | 126 | @pytest.fixture() 127 | async def session(cli, loop): 128 | # s = await cli.start_session() 129 | # yield s 130 | # await s.end_session() 131 | return None 132 | 133 | 134 | @pytest.fixture(autouse=True) 135 | async def init(loop, db): 136 | models = [ 137 | DocumentWithExtras, 138 | DocumentWithPydanticConfig, 139 | DocumentTestModel, 140 | DocumentTestModelWithCustomCollectionName, 141 | DocumentTestModelWithSimpleIndex, 142 | DocumentTestModelWithIndexFlags, 143 | DocumentTestModelWithIndexFlagsAliases, 144 | DocumentTestModelWithComplexIndex, 145 | DocumentTestModelFailInspection, 146 | DocumentWithBsonEncodersFiledsTypes, 147 | DocumentWithCustomFiledsTypes, 148 | DocumentWithCustomIdUUID, 149 | DocumentWithCustomIdInt, 150 | Sample, 151 | DocumentWithActions, 152 | DocumentWithTurnedOnStateManagement, 153 | DocumentWithTurnedOnReplaceObjects, 154 | DocumentWithTurnedOffStateManagement, 155 | DocumentWithValidationOnSave, 156 | DocumentWithRevisionTurnedOn, 157 | House, 158 | Window, 159 | Door, 160 | Roof, 161 | InheritedDocumentWithActions, 162 | DocumentForEncodingTest, 163 | DocumentForEncodingTestDate, 164 | TestView, 165 | DocumentMultiModelOne, 166 | DocumentMultiModelTwo, 167 | DocumentUnion, 168 | ] 169 | await init_beanie( 170 | database=db, 171 | document_models=models, 172 | ) 173 | yield None 174 | 175 | for model in models: 176 | await model.get_motor_collection().drop() 177 | await model.get_motor_collection().drop_indexes() 178 | 179 | 180 | @pytest.fixture 181 | def document_not_inserted(): 182 | return DocumentTestModel( 183 | test_int=42, 184 | test_list=[SubDocument(test_str="foo"), SubDocument(test_str="bar")], 185 | test_doc=SubDocument(test_str="foobar"), 186 | test_str="kipasa", 187 | ) 188 | 189 | 190 | @pytest.fixture 191 | def documents_not_inserted(): 192 | def generate_documents( 193 | number: int, test_str: str = None, random: bool = False 194 | ) -> List[DocumentTestModel]: 195 | return [ 196 | DocumentTestModel( 197 | test_int=randint(0, 1000000) if random else i, 198 | test_list=[ 199 | SubDocument(test_str="foo"), 200 | SubDocument(test_str="bar"), 201 | ], 202 | test_doc=SubDocument(test_str="foobar"), 203 | test_str="kipasa" if test_str is None else test_str, 204 | ) 205 | for i in range(number) 206 | ] 207 | 208 | return generate_documents 209 | 210 | 211 | @pytest.fixture 212 | async def document(document_not_inserted, loop) -> DocumentTestModel: 213 | return await document_not_inserted.insert() 214 | 215 | 216 | @pytest.fixture 217 | def documents(documents_not_inserted): 218 | async def generate_documents( 219 | number: int, test_str: str = None, random: bool = False 220 | ): 221 | result = await DocumentTestModel.insert_many( 222 | documents_not_inserted(number, test_str, random) 223 | ) 224 | return result.inserted_ids 225 | 226 | return generate_documents 227 | -------------------------------------------------------------------------------- /tests/odm/query/test_update.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from beanie.odm.operators.update.general import Set, Max 6 | from tests.odm.models import Sample 7 | 8 | 9 | async def test_update_query(): 10 | q = ( 11 | Sample.find_many(Sample.integer == 1) 12 | .update(Set({Sample.integer: 10})) 13 | .update_query 14 | ) 15 | assert q == {"$set": {"integer": 10}} 16 | 17 | q = ( 18 | Sample.find_many(Sample.integer == 1) 19 | .update(Max({Sample.integer: 10}), Set({Sample.optional: None})) 20 | .update_query 21 | ) 22 | assert q == {"$max": {"integer": 10}, "$set": {"optional": None}} 23 | 24 | q = ( 25 | Sample.find_many(Sample.integer == 1) 26 | .update(Set({Sample.integer: 10}), Set({Sample.optional: None})) 27 | .update_query 28 | ) 29 | assert q == {"$set": {"optional": None}} 30 | 31 | q = ( 32 | Sample.find_many(Sample.integer == 1) 33 | .update(Max({Sample.integer: 10})) 34 | .update(Set({Sample.optional: None})) 35 | .update_query 36 | ) 37 | assert q == {"$max": {"integer": 10}, "$set": {"optional": None}} 38 | 39 | q = ( 40 | Sample.find_many(Sample.integer == 1) 41 | .update(Set({Sample.integer: 10})) 42 | .update(Set({Sample.optional: None})) 43 | .update_query 44 | ) 45 | assert q == {"$set": {"optional": None}} 46 | 47 | with pytest.raises(TypeError): 48 | Sample.find_many(Sample.integer == 1).update(40).update_query 49 | 50 | 51 | async def test_update_many(preset_documents): 52 | await Sample.find_many(Sample.increment > 4).update( 53 | Set({Sample.increment: 100}) 54 | ) # noqa 55 | result = await Sample.find_many(Sample.increment == 100).to_list() 56 | assert len(result) == 5 57 | for sample in result: 58 | assert sample.increment == 100 59 | 60 | 61 | async def test_update_many_linked_method(preset_documents): 62 | await Sample.find_many(Sample.increment > 4).update_many( 63 | Set({Sample.increment: 100}) 64 | ) # noqa 65 | result = await Sample.find_many(Sample.increment == 100).to_list() 66 | assert len(result) == 5 67 | for sample in result: 68 | assert sample.increment == 100 69 | 70 | 71 | async def test_update_all(preset_documents): 72 | await Sample.update_all(Set({Sample.integer: 100})) 73 | result = await Sample.find_all().to_list() 74 | for sample in result: 75 | assert sample.integer == 100 76 | 77 | await Sample.find_all().update(Set({Sample.integer: 101})) 78 | result = await Sample.find_all().to_list() 79 | for sample in result: 80 | assert sample.integer == 101 81 | 82 | 83 | async def test_update_one(preset_documents): 84 | await Sample.find_one(Sample.integer == 1).update( 85 | Set({Sample.integer: 100}) 86 | ) 87 | result = await Sample.find_many(Sample.integer == 100).to_list() 88 | assert len(result) == 1 89 | assert result[0].integer == 100 90 | 91 | await Sample.find_one(Sample.integer == 1).update_one( 92 | Set({Sample.integer: 101}) 93 | ) 94 | result = await Sample.find_many(Sample.integer == 101).to_list() 95 | assert len(result) == 1 96 | assert result[0].integer == 101 97 | 98 | 99 | async def test_update_self(preset_documents): 100 | sample = await Sample.find_one(Sample.integer == 1) 101 | await sample.update(Set({Sample.integer: 100})) 102 | assert sample.integer == 100 103 | 104 | result = await Sample.find_many(Sample.integer == 100).to_list() 105 | assert len(result) == 1 106 | assert result[0].integer == 100 107 | 108 | 109 | async def test_update_many_with_session(preset_documents, session): 110 | q = ( 111 | Sample.find_many(Sample.increment > 4) 112 | .update(Set({Sample.increment: 100})) 113 | .set_session(session=session) 114 | ) 115 | assert q.session == session 116 | 117 | q = Sample.find_many(Sample.increment > 4).update( 118 | Set({Sample.increment: 100}), session=session 119 | ) 120 | assert q.session == session 121 | 122 | q = Sample.find_many(Sample.increment > 4).update( 123 | Set({Sample.increment: 100}) 124 | ) 125 | assert q.session == session 126 | 127 | await q # noqa 128 | result = await Sample.find_many(Sample.increment == 100).to_list() 129 | assert len(result) == 5 130 | for sample in result: 131 | assert sample.increment == 100 132 | 133 | 134 | async def test_update_many_upsert_with_insert( 135 | preset_documents, sample_doc_not_saved 136 | ): 137 | await Sample.find_many(Sample.integer > 100000).upsert( 138 | Set({Sample.integer: 100}), on_insert=sample_doc_not_saved 139 | ) 140 | await asyncio.sleep(2) 141 | new_docs = await Sample.find_many( 142 | Sample.string == sample_doc_not_saved.string 143 | ).to_list() 144 | assert len(new_docs) == 1 145 | doc = new_docs[0] 146 | assert doc.integer == sample_doc_not_saved.integer 147 | 148 | 149 | async def test_update_many_upsert_without_insert( 150 | preset_documents, sample_doc_not_saved 151 | ): 152 | await Sample.find_many(Sample.integer > 1).upsert( 153 | Set({Sample.integer: 100}), on_insert=sample_doc_not_saved 154 | ) 155 | await asyncio.sleep(2) 156 | new_docs = await Sample.find_many( 157 | Sample.string == sample_doc_not_saved.string 158 | ).to_list() 159 | assert len(new_docs) == 0 160 | 161 | 162 | async def test_update_one_upsert_with_insert( 163 | preset_documents, sample_doc_not_saved 164 | ): 165 | await Sample.find_one(Sample.integer > 100000).upsert( 166 | Set({Sample.integer: 100}), on_insert=sample_doc_not_saved 167 | ) 168 | await asyncio.sleep(2) 169 | new_docs = await Sample.find_many( 170 | Sample.string == sample_doc_not_saved.string 171 | ).to_list() 172 | assert len(new_docs) == 1 173 | doc = new_docs[0] 174 | assert doc.integer == sample_doc_not_saved.integer 175 | 176 | 177 | async def test_update_one_upsert_without_insert( 178 | preset_documents, sample_doc_not_saved 179 | ): 180 | await Sample.find_one(Sample.integer > 1).upsert( 181 | Set({Sample.integer: 100}), on_insert=sample_doc_not_saved 182 | ) 183 | await asyncio.sleep(2) 184 | new_docs = await Sample.find_many( 185 | Sample.string == sample_doc_not_saved.string 186 | ).to_list() 187 | assert len(new_docs) == 0 188 | 189 | 190 | @pytest.mark.skip("Not supported") 191 | async def test_update_pymongo_kwargs(preset_documents): 192 | with pytest.raises(TypeError): 193 | await Sample.find_many(Sample.increment > 4).update( 194 | Set({Sample.increment: 100}), wrong="integer_1" 195 | ) 196 | 197 | await Sample.find_many(Sample.increment > 4).update( 198 | Set({Sample.increment: 100}), hint="integer_1" 199 | ) 200 | 201 | await Sample.find_one(Sample.increment > 4).update( 202 | Set({Sample.increment: 100}), hint="integer_1" 203 | ) 204 | -------------------------------------------------------------------------------- /tests/odm/documents/test_bulk_write.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pymongo.errors import BulkWriteError 3 | 4 | from beanie.odm.bulk import BulkWriter 5 | from beanie.odm.operators.update.general import Set 6 | from tests.odm.models import DocumentTestModel 7 | 8 | 9 | @pytest.mark.skip("Not supported") 10 | async def test_insert(documents_not_inserted): 11 | documents = documents_not_inserted(2) 12 | async with BulkWriter() as bulk_writer: 13 | await DocumentTestModel.insert_one( 14 | documents[0], bulk_writer=bulk_writer 15 | ) 16 | await DocumentTestModel.insert_one( 17 | documents[1], bulk_writer=bulk_writer 18 | ) 19 | 20 | new_documents = await DocumentTestModel.find_all().to_list() 21 | assert len(new_documents) == 2 22 | 23 | 24 | @pytest.mark.skip("Not supported") 25 | async def test_update(documents, document_not_inserted): 26 | await documents(5) 27 | doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 0) 28 | doc.test_int = 100 29 | async with BulkWriter() as bulk_writer: 30 | await doc.save_changes(bulk_writer=bulk_writer) 31 | await DocumentTestModel.find_one( 32 | DocumentTestModel.test_int == 1 33 | ).update( 34 | Set({DocumentTestModel.test_int: 1000}), bulk_writer=bulk_writer 35 | ) 36 | await DocumentTestModel.find(DocumentTestModel.test_int < 100).update( 37 | Set({DocumentTestModel.test_int: 2000}), bulk_writer=bulk_writer 38 | ) 39 | 40 | assert len(await DocumentTestModel.find_all().to_list()) == 5 41 | assert ( 42 | len( 43 | await DocumentTestModel.find( 44 | DocumentTestModel.test_int == 100 45 | ).to_list() 46 | ) 47 | == 1 48 | ) 49 | assert ( 50 | len( 51 | await DocumentTestModel.find( 52 | DocumentTestModel.test_int == 1000 53 | ).to_list() 54 | ) 55 | == 1 56 | ) 57 | assert ( 58 | len( 59 | await DocumentTestModel.find( 60 | DocumentTestModel.test_int == 2000 61 | ).to_list() 62 | ) 63 | == 3 64 | ) 65 | 66 | 67 | @pytest.mark.skip("Not supported") 68 | async def test_delete(documents, document_not_inserted): 69 | await documents(5) 70 | doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 0) 71 | async with BulkWriter() as bulk_writer: 72 | await doc.delete(bulk_writer=bulk_writer) 73 | await DocumentTestModel.find_one( 74 | DocumentTestModel.test_int == 1 75 | ).delete(bulk_writer=bulk_writer) 76 | await DocumentTestModel.find(DocumentTestModel.test_int < 4).delete( 77 | bulk_writer=bulk_writer 78 | ) 79 | 80 | assert len(await DocumentTestModel.find_all().to_list()) == 1 81 | 82 | 83 | @pytest.mark.skip("Not supported") 84 | async def test_replace(documents, document_not_inserted): 85 | await documents(5) 86 | doc = await DocumentTestModel.find_one(DocumentTestModel.test_int == 0) 87 | doc.test_int = 100 88 | async with BulkWriter() as bulk_writer: 89 | await doc.replace(bulk_writer=bulk_writer) 90 | 91 | document_not_inserted.test_int = 100 92 | 93 | await DocumentTestModel.find_one( 94 | DocumentTestModel.test_int == 1 95 | ).replace_one(document_not_inserted, bulk_writer=bulk_writer) 96 | 97 | assert len(await DocumentTestModel.find_all().to_list()) == 5 98 | assert ( 99 | len( 100 | await DocumentTestModel.find( 101 | DocumentTestModel.test_int == 100 102 | ).to_list() 103 | ) 104 | == 2 105 | ) 106 | 107 | 108 | @pytest.mark.skip("Not supported") 109 | async def test_upsert_find_many_not_found(documents, document_not_inserted): 110 | await documents(5) 111 | document_not_inserted.test_int = -10000 112 | async with BulkWriter() as bulk_writer: 113 | await DocumentTestModel.find( 114 | DocumentTestModel.test_int < -1000 115 | ).upsert( 116 | {"$set": {DocumentTestModel.test_int: 0}}, 117 | on_insert=document_not_inserted, 118 | ) 119 | 120 | await bulk_writer.commit() 121 | 122 | assert len(await DocumentTestModel.find_all().to_list()) == 6 123 | assert ( 124 | len( 125 | await DocumentTestModel.find( 126 | DocumentTestModel.test_int == -10000 127 | ).to_list() 128 | ) 129 | == 1 130 | ) 131 | 132 | 133 | @pytest.mark.skip("Not supported") 134 | async def test_upsert_find_one_not_found(documents, document_not_inserted): 135 | await documents(5) 136 | document_not_inserted.test_int = -10000 137 | async with BulkWriter() as bulk_writer: 138 | await DocumentTestModel.find_one( 139 | DocumentTestModel.test_int < -1000 140 | ).upsert( 141 | {"$set": {DocumentTestModel.test_int: 0}}, 142 | on_insert=document_not_inserted, 143 | ) 144 | 145 | await bulk_writer.commit() 146 | 147 | assert len(await DocumentTestModel.find_all().to_list()) == 6 148 | assert ( 149 | len( 150 | await DocumentTestModel.find( 151 | DocumentTestModel.test_int == -10000 152 | ).to_list() 153 | ) 154 | == 1 155 | ) 156 | 157 | 158 | @pytest.mark.skip("Not supported") 159 | async def test_upsert_find_many_found(documents, document_not_inserted): 160 | await documents(5) 161 | async with BulkWriter() as bulk_writer: 162 | await DocumentTestModel.find(DocumentTestModel.test_int == 1).upsert( 163 | {"$set": {DocumentTestModel.test_int: -10000}}, 164 | on_insert=document_not_inserted, 165 | ) 166 | 167 | await bulk_writer.commit() 168 | 169 | assert len(await DocumentTestModel.find_all().to_list()) == 5 170 | assert ( 171 | len( 172 | await DocumentTestModel.find( 173 | DocumentTestModel.test_int == -10000 174 | ).to_list() 175 | ) 176 | == 1 177 | ) 178 | 179 | 180 | @pytest.mark.skip("Not supported") 181 | async def test_upsert_find_one_found(documents, document_not_inserted): 182 | await documents(5) 183 | async with BulkWriter() as bulk_writer: 184 | await DocumentTestModel.find_one( 185 | DocumentTestModel.test_int == 1 186 | ).upsert( 187 | {"$set": {DocumentTestModel.test_int: -10000}}, 188 | on_insert=document_not_inserted, 189 | ) 190 | 191 | await bulk_writer.commit() 192 | 193 | assert len(await DocumentTestModel.find_all().to_list()) == 5 194 | assert ( 195 | len( 196 | await DocumentTestModel.find( 197 | DocumentTestModel.test_int == -10000 198 | ).to_list() 199 | ) 200 | == 1 201 | ) 202 | 203 | 204 | @pytest.mark.skip("Not supported") 205 | async def test_internal_error(document): 206 | with pytest.raises(BulkWriteError): 207 | async with BulkWriter() as bulk_writer: 208 | await DocumentTestModel.insert_one( 209 | document, bulk_writer=bulk_writer 210 | ) 211 | -------------------------------------------------------------------------------- /tests/odm/test_relations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.exceptions import DocumentWasNotSaved 4 | from beanie.odm.fields import WriteRules, Link, DeleteRules 5 | from tests.odm.models import Window, Door, House, Roof 6 | 7 | 8 | @pytest.fixture 9 | def windows_not_inserted(): 10 | return [Window(x=10, y=10), Window(x=11, y=11)] 11 | 12 | 13 | @pytest.fixture 14 | def door_not_inserted(): 15 | return Door(t=10) 16 | 17 | 18 | @pytest.fixture 19 | def house_not_inserted(windows_not_inserted, door_not_inserted): 20 | return House( 21 | windows=windows_not_inserted, door=door_not_inserted, name="test" 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | async def house(house_not_inserted): 27 | return await house_not_inserted.insert(link_rule=WriteRules.WRITE) 28 | 29 | 30 | @pytest.fixture 31 | async def houses(): 32 | for i in range(10): 33 | roof = Roof() if i % 2 == 0 else None 34 | house = await House( 35 | door=Door(t=i), 36 | windows=[Window(x=10, y=10 + i), Window(x=11, y=11 + i)], 37 | roof=roof, 38 | name="test", 39 | height=i, 40 | ).insert(link_rule=WriteRules.WRITE) 41 | if i == 9: 42 | await house.windows[0].delete() 43 | await house.door.delete() 44 | 45 | 46 | @pytest.mark.skip("Not supported") 47 | class TestInsert: 48 | async def test_rule_do_nothing(self, house_not_inserted): 49 | with pytest.raises(DocumentWasNotSaved): 50 | await house_not_inserted.insert() 51 | 52 | async def test_rule_write(self, house_not_inserted): 53 | await house_not_inserted.insert(link_rule=WriteRules.WRITE) 54 | windows = await Window.all().to_list() 55 | assert len(windows) == 2 56 | doors = await Door.all().to_list() 57 | assert len(doors) == 1 58 | houses = await House.all().to_list() 59 | assert len(houses) == 1 60 | 61 | async def test_insert_with_link( 62 | self, house_not_inserted, door_not_inserted 63 | ): 64 | door = await door_not_inserted.insert() 65 | link = Door.link_from_id(door.id) 66 | house_not_inserted.door = link 67 | house = House.parse_obj(house_not_inserted) 68 | await house.insert(link_rule=WriteRules.WRITE) 69 | house.json() 70 | 71 | 72 | @pytest.mark.skip("Not supported") 73 | class TestFind: 74 | async def test_prefetch_find_many(self, houses): 75 | items = await House.find(House.height > 2).sort(House.height).to_list() 76 | assert len(items) == 7 77 | for window in items[0].windows: 78 | assert isinstance(window, Link) 79 | assert isinstance(items[0].door, Link) 80 | assert items[0].roof is None 81 | assert isinstance(items[1].roof, Link) 82 | 83 | items = ( 84 | await House.find(House.height > 2, fetch_links=True) 85 | .sort(House.height) 86 | .to_list() 87 | ) 88 | assert len(items) == 7 89 | for window in items[0].windows: 90 | assert isinstance(window, Window) 91 | assert isinstance(items[0].door, Door) 92 | assert items[0].roof is None 93 | assert isinstance(items[1].roof, Roof) 94 | 95 | houses = await House.find_many( 96 | House.height == 9, fetch_links=True 97 | ).to_list() 98 | assert len(houses[0].windows) == 1 99 | assert isinstance(houses[0].door, Link) 100 | await houses[0].fetch_link(House.door) 101 | assert isinstance(houses[0].door, Link) 102 | 103 | houses = await House.find_many( 104 | House.door.t > 5, fetch_links=True 105 | ).to_list() 106 | 107 | assert len(houses) == 3 108 | 109 | houses = await House.find_many( 110 | House.windows.y == 15, fetch_links=True 111 | ).to_list() 112 | 113 | assert len(houses) == 2 114 | 115 | houses = await House.find_many( 116 | House.height > 5, limit=3, fetch_links=True 117 | ).to_list() 118 | 119 | assert len(houses) == 3 120 | 121 | async def test_prefetch_find_one(self, house): 122 | house = await House.find_one(House.name == "test") 123 | for window in house.windows: 124 | assert isinstance(window, Link) 125 | assert isinstance(house.door, Link) 126 | 127 | house = await House.find_one(House.name == "test", fetch_links=True) 128 | for window in house.windows: 129 | assert isinstance(window, Window) 130 | assert isinstance(house.door, Door) 131 | 132 | house = await House.get(house.id, fetch_links=True) 133 | for window in house.windows: 134 | assert isinstance(window, Window) 135 | assert isinstance(house.door, Door) 136 | 137 | async def test_fetch(self, house): 138 | house = await House.find_one(House.name == "test") 139 | for window in house.windows: 140 | assert isinstance(window, Link) 141 | assert isinstance(house.door, Link) 142 | 143 | await house.fetch_all_links() 144 | for window in house.windows: 145 | assert isinstance(window, Window) 146 | assert isinstance(house.door, Door) 147 | 148 | house = await House.find_one(House.name == "test") 149 | assert isinstance(house.door, Link) 150 | await house.fetch_link(House.door) 151 | assert isinstance(house.door, Door) 152 | 153 | for window in house.windows: 154 | assert isinstance(window, Link) 155 | await house.fetch_link(House.windows) 156 | for window in house.windows: 157 | assert isinstance(window, Window) 158 | 159 | async def test_find_by_id_of_the_linked_docs(self, house): 160 | house_lst_1 = await House.find( 161 | House.door.id == house.door.id 162 | ).to_list() 163 | house_lst_2 = await House.find( 164 | House.door.id == house.door.id, fetch_links=True 165 | ).to_list() 166 | assert len(house_lst_1) == 1 167 | assert len(house_lst_2) == 1 168 | 169 | house_1 = await House.find_one(House.door.id == house.door.id) 170 | house_2 = await House.find_one( 171 | House.door.id == house.door.id, fetch_links=True 172 | ) 173 | assert house_1 is not None 174 | assert house_2 is not None 175 | 176 | 177 | @pytest.mark.skip("Not supported") 178 | class TestReplace: 179 | async def test_do_nothing(self, house): 180 | house.door.t = 100 181 | await house.replace() 182 | new_house = await House.get(house.id, fetch_links=True) 183 | assert new_house.door.t == 10 184 | 185 | async def test_write(self, house): 186 | house.door.t = 100 187 | await house.replace(link_rule=WriteRules.WRITE) 188 | new_house = await House.get(house.id, fetch_links=True) 189 | assert new_house.door.t == 100 190 | 191 | 192 | @pytest.mark.skip("Not supported") 193 | class TestSave: 194 | async def test_do_nothing(self, house): 195 | house.door.t = 100 196 | await house.save() 197 | new_house = await House.get(house.id, fetch_links=True) 198 | assert new_house.door.t == 10 199 | 200 | async def test_write(self, house): 201 | house.door.t = 100 202 | house.windows = [Window(x=100, y=100)] 203 | await house.save(link_rule=WriteRules.WRITE) 204 | new_house = await House.get(house.id, fetch_links=True) 205 | assert new_house.door.t == 100 206 | for window in new_house.windows: 207 | assert window.x == 100 208 | assert window.y == 100 209 | 210 | 211 | @pytest.mark.skip("Not supported") 212 | class TestDelete: 213 | async def test_do_nothing(self, house): 214 | await house.delete() 215 | door = await Door.get(house.door.id) 216 | assert door is not None 217 | 218 | windows = await Window.all().to_list() 219 | assert windows is not None 220 | 221 | async def test_delete_links(self, house): 222 | await house.delete(link_rule=DeleteRules.DELETE_LINKS) 223 | door = await Door.get(house.door.id) 224 | assert door is None 225 | 226 | windows = await Window.all().to_list() 227 | assert windows == [] 228 | -------------------------------------------------------------------------------- /tests/odm/documents/test_init.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from motor.motor_asyncio import AsyncIOMotorCollection 3 | from yarl import URL 4 | 5 | from beanie import Document, init_beanie 6 | from beanie.exceptions import CollectionWasNotInitialized 7 | from beanie.odm.utils.projection import get_projection 8 | from tests.odm.models import ( 9 | DocumentTestModel, 10 | DocumentTestModelWithCustomCollectionName, 11 | DocumentTestModelWithIndexFlagsAliases, 12 | DocumentTestModelWithSimpleIndex, 13 | DocumentTestModelWithIndexFlags, 14 | DocumentTestModelWithComplexIndex, 15 | DocumentTestModelStringImport, 16 | DocumentTestModelWithDroppedIndex, 17 | ) 18 | 19 | 20 | @pytest.mark.skip("Not supported") 21 | async def test_init_collection_was_not_initialized(): 22 | class NewDocument(Document): 23 | test_str: str 24 | 25 | with pytest.raises(CollectionWasNotInitialized): 26 | NewDocument(test_str="test") 27 | 28 | 29 | @pytest.mark.skip("Not supported") 30 | async def test_init_connection_string(settings): 31 | class NewDocumentCS(Document): 32 | test_str: str 33 | 34 | await init_beanie( 35 | connection_string=settings.mongodb_dsn, document_models=[NewDocumentCS] 36 | ) 37 | assert ( 38 | NewDocumentCS.get_motor_collection().database.name 39 | == URL(settings.mongodb_dsn).path[1:] 40 | ) 41 | 42 | 43 | @pytest.mark.skip("Not supported") 44 | async def test_init_wrong_params(settings, db): 45 | class NewDocumentCS(Document): 46 | test_str: str 47 | 48 | with pytest.raises(ValueError): 49 | await init_beanie( 50 | database=db, 51 | connection_string=settings.mongodb_dsn, 52 | document_models=[NewDocumentCS], 53 | ) 54 | 55 | with pytest.raises(ValueError): 56 | await init_beanie(document_models=[NewDocumentCS]) 57 | 58 | with pytest.raises(ValueError): 59 | await init_beanie(connection_string=settings.mongodb_dsn) 60 | 61 | 62 | @pytest.mark.skip("Not supported") 63 | async def test_collection_with_custom_name(): 64 | collection: AsyncIOMotorCollection = ( 65 | DocumentTestModelWithCustomCollectionName.get_motor_collection() 66 | ) 67 | assert collection.name == "custom" 68 | 69 | 70 | @pytest.mark.skip("Not supported") 71 | async def test_simple_index_creation(): 72 | collection: AsyncIOMotorCollection = ( 73 | DocumentTestModelWithSimpleIndex.get_motor_collection() 74 | ) 75 | index_info = await collection.index_information() 76 | assert index_info["test_int_1"] == {"key": [("test_int", 1)], "v": 2} 77 | assert index_info["test_str_text"]["key"] == [ 78 | ("_fts", "text"), 79 | ("_ftsx", 1), 80 | ] 81 | 82 | 83 | @pytest.mark.skip("Not supported") 84 | async def test_flagged_index_creation(): 85 | collection: AsyncIOMotorCollection = ( 86 | DocumentTestModelWithIndexFlags.get_motor_collection() 87 | ) 88 | index_info = await collection.index_information() 89 | assert index_info["test_int_1"] == { 90 | "key": [("test_int", 1)], 91 | "sparse": True, 92 | "v": 2, 93 | } 94 | assert index_info["test_str_-1"] == { 95 | "key": [("test_str", -1)], 96 | "unique": True, 97 | "v": 2, 98 | } 99 | 100 | 101 | @pytest.mark.skip("Not supported") 102 | async def test_flagged_index_creation_with_alias(): 103 | collection: AsyncIOMotorCollection = ( 104 | DocumentTestModelWithIndexFlagsAliases.get_motor_collection() 105 | ) 106 | index_info = await collection.index_information() 107 | assert index_info["testInt_1"] == { 108 | "key": [("testInt", 1)], 109 | "sparse": True, 110 | "v": 2, 111 | } 112 | assert index_info["testStr_-1"] == { 113 | "key": [("testStr", -1)], 114 | "unique": True, 115 | "v": 2, 116 | } 117 | 118 | 119 | @pytest.mark.skip("Not supported") 120 | async def test_complex_index_creation(): 121 | collection: AsyncIOMotorCollection = ( 122 | DocumentTestModelWithComplexIndex.get_motor_collection() 123 | ) 124 | index_info = await collection.index_information() 125 | assert index_info == { 126 | "_id_": {"key": [("_id", 1)], "v": 2}, 127 | "test_int_1": {"key": [("test_int", 1)], "v": 2}, 128 | "test_int_1_test_str_-1": { 129 | "key": [("test_int", 1), ("test_str", -1)], 130 | "v": 2, 131 | }, 132 | "test_string_index_DESCENDING": {"key": [("test_str", -1)], "v": 2}, 133 | } 134 | 135 | 136 | @pytest.mark.skip("Not supported") 137 | async def test_index_dropping_is_allowed(db): 138 | await init_beanie( 139 | database=db, document_models=[DocumentTestModelWithComplexIndex] 140 | ) 141 | await init_beanie( 142 | database=db, 143 | document_models=[DocumentTestModelWithDroppedIndex], 144 | allow_index_dropping=True, 145 | ) 146 | 147 | collection: AsyncIOMotorCollection = ( 148 | DocumentTestModelWithComplexIndex.get_motor_collection() 149 | ) 150 | index_info = await collection.index_information() 151 | assert index_info == { 152 | "_id_": {"key": [("_id", 1)], "v": 2}, 153 | "test_int_1": {"key": [("test_int", 1)], "v": 2}, 154 | } 155 | 156 | 157 | @pytest.mark.skip("Not supported") 158 | async def test_index_dropping_is_not_allowed(db): 159 | await init_beanie( 160 | database=db, document_models=[DocumentTestModelWithComplexIndex] 161 | ) 162 | await init_beanie( 163 | database=db, 164 | document_models=[DocumentTestModelWithDroppedIndex], 165 | allow_index_dropping=False, 166 | ) 167 | 168 | collection: AsyncIOMotorCollection = ( 169 | DocumentTestModelWithComplexIndex.get_motor_collection() 170 | ) 171 | index_info = await collection.index_information() 172 | assert index_info == { 173 | "_id_": {"key": [("_id", 1)], "v": 2}, 174 | "test_int_1": {"key": [("test_int", 1)], "v": 2}, 175 | "test_int_1_test_str_-1": { 176 | "key": [("test_int", 1), ("test_str", -1)], 177 | "v": 2, 178 | }, 179 | "test_string_index_DESCENDING": {"key": [("test_str", -1)], "v": 2}, 180 | } 181 | 182 | 183 | @pytest.mark.skip("Not supported") 184 | async def test_index_dropping_is_not_allowed_as_default(db): 185 | await init_beanie( 186 | database=db, document_models=[DocumentTestModelWithComplexIndex] 187 | ) 188 | await init_beanie( 189 | database=db, 190 | document_models=[DocumentTestModelWithDroppedIndex], 191 | ) 192 | 193 | collection: AsyncIOMotorCollection = ( 194 | DocumentTestModelWithComplexIndex.get_motor_collection() 195 | ) 196 | index_info = await collection.index_information() 197 | assert index_info == { 198 | "_id_": {"key": [("_id", 1)], "v": 2}, 199 | "test_int_1": {"key": [("test_int", 1)], "v": 2}, 200 | "test_int_1_test_str_-1": { 201 | "key": [("test_int", 1), ("test_str", -1)], 202 | "v": 2, 203 | }, 204 | "test_string_index_DESCENDING": {"key": [("test_str", -1)], "v": 2}, 205 | } 206 | 207 | 208 | @pytest.mark.skip("Not supported") 209 | async def test_document_string_import(db): 210 | await init_beanie( 211 | database=db, 212 | document_models=[ 213 | "tests.odm.models.DocumentTestModelStringImport", 214 | ], 215 | ) 216 | document = DocumentTestModelStringImport(test_int=1) 217 | assert document.id is None 218 | await document.insert() 219 | assert document.id is not None 220 | 221 | with pytest.raises(ValueError): 222 | await init_beanie( 223 | database=db, 224 | document_models=[ 225 | "tests", 226 | ], 227 | ) 228 | 229 | with pytest.raises(AttributeError): 230 | await init_beanie( 231 | database=db, 232 | document_models=[ 233 | "tests.wrong", 234 | ], 235 | ) 236 | 237 | 238 | @pytest.mark.skip("Not supported") 239 | async def test_projection(): 240 | projection = get_projection(DocumentTestModel) 241 | assert projection == { 242 | "_id": 1, 243 | "test_int": 1, 244 | "test_list": 1, 245 | "test_str": 1, 246 | "test_doc": 1, 247 | "revision_id": 1, 248 | } 249 | -------------------------------------------------------------------------------- /tests/odm/documents/test_update.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beanie.exceptions import ( 4 | DocumentNotFound, 5 | ReplaceError, 6 | ) 7 | from beanie.odm.fields import PydanticObjectId 8 | from tests.odm.models import DocumentTestModel 9 | 10 | 11 | # REPLACE 12 | 13 | # 14 | # async def test_replace_one(document): 15 | # new_doc = DocumentTestModel( 16 | # test_int=0, test_str="REPLACED_VALUE", test_list=[] 17 | # ) 18 | # await DocumentTestModel.replace_one({"_id": document.id}, new_doc) 19 | # new_document = await DocumentTestModel.get(document.id) 20 | # assert new_document.test_str == "REPLACED_VALUE" 21 | 22 | 23 | async def test_replace_many(documents): 24 | await documents(10, "foo") 25 | created_documents = await DocumentTestModel.find_many( 26 | {"test_str": "foo"} 27 | ).to_list() 28 | to_replace = [] 29 | for document in created_documents[:5]: 30 | document.test_str = "REPLACED_VALUE" 31 | to_replace.append(document) 32 | await DocumentTestModel.replace_many(to_replace) 33 | 34 | replaced_documetns = await DocumentTestModel.find_many( 35 | {"test_str": "REPLACED_VALUE"} 36 | ).to_list() 37 | assert len(replaced_documetns) == 5 38 | 39 | 40 | async def test_replace_many_not_all_the_docs_found(documents): 41 | await documents(10, "foo") 42 | created_documents = await DocumentTestModel.find_many( 43 | {"test_str": "foo"} 44 | ).to_list() 45 | to_replace = [] 46 | created_documents[0].id = PydanticObjectId() 47 | for document in created_documents[:5]: 48 | document.test_str = "REPLACED_VALUE" 49 | to_replace.append(document) 50 | with pytest.raises(ReplaceError): 51 | await DocumentTestModel.replace_many(to_replace) 52 | 53 | 54 | async def test_replace(document): 55 | update_data = {"test_str": "REPLACED_VALUE"} 56 | new_doc = document.copy(update=update_data) 57 | # document.test_str = "REPLACED_VALUE" 58 | await new_doc.replace() 59 | new_document = await DocumentTestModel.get(document.id) 60 | assert new_document.test_str == "REPLACED_VALUE" 61 | 62 | 63 | async def test_replace_not_saved(document_not_inserted): 64 | with pytest.raises(ValueError): 65 | await document_not_inserted.replace() 66 | 67 | 68 | async def test_replace_not_found(document_not_inserted): 69 | document_not_inserted.id = PydanticObjectId() 70 | with pytest.raises(DocumentNotFound): 71 | await document_not_inserted.replace() 72 | 73 | 74 | # SAVE 75 | async def test_save(document): 76 | update_data = {"test_str": "REPLACED_VALUE"} 77 | new_doc = document.copy(update=update_data) 78 | # document.test_str = "REPLACED_VALUE" 79 | await new_doc.save() 80 | new_document = await DocumentTestModel.get(document.id) 81 | assert new_document.test_str == "REPLACED_VALUE" 82 | 83 | 84 | async def test_save_not_saved(document_not_inserted): 85 | await document_not_inserted.save() 86 | assert ( 87 | hasattr(document_not_inserted, "id") 88 | and document_not_inserted.id is not None 89 | ) 90 | from_db = await DocumentTestModel.get(document_not_inserted.id) 91 | assert from_db == document_not_inserted 92 | 93 | 94 | async def test_save_not_found(document_not_inserted): 95 | document_not_inserted.id = PydanticObjectId() 96 | await document_not_inserted.save() 97 | assert ( 98 | hasattr(document_not_inserted, "id") 99 | and document_not_inserted.id is not None 100 | ) 101 | from_db = await DocumentTestModel.get(document_not_inserted.id) 102 | assert from_db == document_not_inserted 103 | 104 | 105 | # UPDATE 106 | 107 | 108 | @pytest.mark.skip("Not supported") 109 | async def test_update_one(document): 110 | await DocumentTestModel.find_one( 111 | {"_id": document.id, "test_list.test_str": "foo"} 112 | ).update({"$set": {"test_list.$.test_str": "foo_foo"}}) 113 | new_document = await DocumentTestModel.get(document.id) 114 | assert new_document.test_list[0].test_str == "foo_foo" 115 | 116 | 117 | async def test_update_many(documents): 118 | await documents(10, "foo") 119 | await documents(7, "bar") 120 | await DocumentTestModel.find_many({"test_str": "foo"}).update( 121 | {"$set": {"test_str": "bar"}} 122 | ) 123 | bar_documetns = await DocumentTestModel.find_many( 124 | {"test_str": "bar"} 125 | ).to_list() 126 | assert len(bar_documetns) == 17 127 | foo_documetns = await DocumentTestModel.find_many( 128 | {"test_str": "foo"} 129 | ).to_list() 130 | assert len(foo_documetns) == 0 131 | 132 | 133 | async def test_update_all(documents): 134 | await documents(10, "foo") 135 | await documents(7, "bar") 136 | await DocumentTestModel.update_all( 137 | {"$set": {"test_str": "smth_else"}}, 138 | ) 139 | bar_documetns = await DocumentTestModel.find_many( 140 | {"test_str": "bar"} 141 | ).to_list() 142 | assert len(bar_documetns) == 0 143 | foo_documetns = await DocumentTestModel.find_many( 144 | {"test_str": "foo"} 145 | ).to_list() 146 | assert len(foo_documetns) == 0 147 | smth_else_documetns = await DocumentTestModel.find_many( 148 | {"test_str": "smth_else"} 149 | ).to_list() 150 | assert len(smth_else_documetns) == 17 151 | 152 | 153 | # WITH SESSION 154 | 155 | 156 | # async def test_update_with_session(document: DocumentTestModel, session): 157 | # buf_len = len(document.test_list) 158 | # to_insert = SubDocument(test_str="test") 159 | # await document.update( 160 | # update_query={"$push": {"test_list": to_insert.dict()}}, 161 | # session=session, 162 | # ) 163 | # new_document = await DocumentTestModel.get(document.id, session=session) 164 | # assert len(new_document.test_list) == buf_len + 1 165 | # 166 | # 167 | # async def test_replace_one_with_session(document, session): 168 | # new_doc = DocumentTestModel( 169 | # test_int=0, test_str="REPLACED_VALUE", test_list=[] 170 | # ) 171 | # await DocumentTestModel.replace_one( 172 | # {"_id": document.id}, new_doc, session=session 173 | # ) 174 | # new_document = await DocumentTestModel.get(document.id, session=session) 175 | # assert new_document.test_str == "REPLACED_VALUE" 176 | # 177 | # 178 | # async def test_replace_with_session(document, session): 179 | # update_data = {"test_str": "REPLACED_VALUE"} 180 | # new_doc: DocumentTestModel = document.copy(update=update_data) 181 | # # document.test_str = "REPLACED_VALUE" 182 | # await new_doc.replace(session=session) 183 | # new_document = await DocumentTestModel.get(document.id, session=session) 184 | # assert new_document.test_str == "REPLACED_VALUE" 185 | # 186 | # 187 | # async def test_update_one_with_session(document, session): 188 | # await DocumentTestModel.update_one( 189 | # update_query={"$set": {"test_list.$.test_str": "foo_foo"}}, 190 | # filter_query={"_id": document.id, "test_list.test_str": "foo"}, 191 | # session=session, 192 | # ) 193 | # new_document = await DocumentTestModel.get(document.id, session=session) 194 | # assert new_document.test_list[0].test_str == "foo_foo" 195 | # 196 | # 197 | # async def test_update_many_with_session(documents, session): 198 | # await documents(10, "foo") 199 | # await documents(7, "bar") 200 | # await DocumentTestModel.update_many( 201 | # update_query={"$set": {"test_str": "bar"}}, 202 | # filter_query={"test_str": "foo"}, 203 | # session=session, 204 | # ) 205 | # bar_documetns = await DocumentTestModel.find_many( 206 | # {"test_str": "bar"}, session=session 207 | # ).to_list() 208 | # assert len(bar_documetns) == 17 209 | # foo_documetns = await DocumentTestModel.find_many( 210 | # {"test_str": "foo"}, session=session 211 | # ).to_list() 212 | # assert len(foo_documetns) == 0 213 | # 214 | # 215 | # async def test_update_all_with_session(documents, session): 216 | # await documents(10, "foo") 217 | # await documents(7, "bar") 218 | # await DocumentTestModel.update_all( 219 | # update_query={"$set": {"test_str": "smth_else"}}, session=session 220 | # ) 221 | # bar_documetns = await DocumentTestModel.find_many( 222 | # {"test_str": "bar"}, session=session 223 | # ).to_list() 224 | # assert len(bar_documetns) == 0 225 | # foo_documetns = await DocumentTestModel.find_many( 226 | # {"test_str": "foo"}, session=session 227 | # ).to_list() 228 | # assert len(foo_documetns) == 0 229 | # smth_else_documetns = await DocumentTestModel.find_many( 230 | # {"test_str": "smth_else"}, session=session 231 | # ).to_list() 232 | # assert len(smth_else_documetns) == 17 233 | -------------------------------------------------------------------------------- /tests/odm/query/test_find.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | from pydantic.color import Color 6 | 7 | from beanie.odm.enums import SortDirection 8 | from tests.odm.models import Sample, DocumentWithBsonEncodersFiledsTypes, House 9 | 10 | 11 | async def test_find_query(): 12 | q = Sample.find_many(Sample.integer == 1).get_filter_query() 13 | assert q == {"integer": 1} 14 | 15 | q = Sample.find_many( 16 | Sample.integer == 1, Sample.nested.integer >= 2 17 | ).get_filter_query() 18 | assert q == {"$and": [{"integer": 1}, {"nested.integer": {"$gte": 2}}]} 19 | 20 | q = ( 21 | Sample.find_many(Sample.integer == 1) 22 | .find_many(Sample.nested.integer >= 2) 23 | .get_filter_query() 24 | ) 25 | assert q == {"$and": [{"integer": 1}, {"nested.integer": {"$gte": 2}}]} 26 | 27 | q = Sample.find().get_filter_query() 28 | assert q == {} 29 | 30 | 31 | async def test_find_many(preset_documents): 32 | result = await Sample.find_many(Sample.integer > 1).to_list() # noqa 33 | assert len(result) == 4 34 | for a in result: 35 | assert a.integer > 1 36 | 37 | len_result = 0 38 | async for a in Sample.find_many(Sample.integer > 1): # noqa 39 | assert a in result 40 | len_result += 1 41 | 42 | assert len_result == len(result) 43 | 44 | 45 | async def test_find_many_skip(preset_documents): 46 | q = Sample.find_many(Sample.integer > 1, skip=2) 47 | assert q.skip_number == 2 48 | 49 | q = Sample.find_many(Sample.integer > 1).skip(2) 50 | assert q.skip_number == 2 51 | 52 | result = await Sample.find_many(Sample.increment > 2).skip(1).to_list() 53 | assert len(result) == 7 54 | for sample in result: 55 | assert sample.increment > 2 56 | 57 | len_result = 0 58 | async for sample in Sample.find_many(Sample.increment > 2).skip(1): # noqa 59 | assert sample in result 60 | len_result += 1 61 | 62 | assert len_result == len(result) 63 | 64 | 65 | async def test_find_many_limit(preset_documents): 66 | q = Sample.find_many(Sample.integer > 1, limit=2) 67 | assert q.limit_number == 2 68 | 69 | q = Sample.find_many(Sample.integer > 1).limit(2) 70 | assert q.limit_number == 2 71 | 72 | result = ( 73 | await Sample.find_many(Sample.increment > 2) 74 | .sort(Sample.increment) 75 | .limit(2) 76 | .to_list() 77 | ) # noqa 78 | assert len(result) == 2 79 | for a in result: 80 | assert a.increment > 2 81 | 82 | len_result = 0 83 | async for a in Sample.find_many(Sample.increment > 2).sort( 84 | Sample.increment 85 | ).limit( 86 | 2 87 | ): # noqa 88 | assert a in result 89 | len_result += 1 90 | 91 | assert len_result == len(result) 92 | 93 | 94 | async def test_find_all(preset_documents): 95 | result = await Sample.find_all().to_list() 96 | assert len(result) == 10 97 | 98 | len_result = 0 99 | async for a in Sample.find_all(): 100 | assert a in result 101 | len_result += 1 102 | 103 | assert len_result == len(result) 104 | 105 | 106 | async def test_find_one(preset_documents): 107 | a = await Sample.find_one(Sample.integer > 1) # noqa 108 | assert a.integer > 1 109 | 110 | a = await Sample.find_one(Sample.integer > 100) # noqa 111 | assert a is None 112 | 113 | 114 | async def test_get(preset_documents): 115 | a = await Sample.find_one(Sample.integer > 1) # noqa 116 | assert a.integer > 1 117 | 118 | new_a = await Sample.get(a.id) 119 | assert new_a == a 120 | 121 | # check for another type 122 | new_a = await Sample.get(str(a.id)) 123 | assert new_a == a 124 | 125 | 126 | async def test_sort(preset_documents): 127 | q = Sample.find_many(Sample.integer > 1, sort="-integer") 128 | assert q.sort_expressions == [("integer", SortDirection.DESCENDING)] 129 | 130 | q = Sample.find_many(Sample.integer > 1, sort="integer") 131 | assert q.sort_expressions == [("integer", SortDirection.ASCENDING)] 132 | 133 | q = Sample.find_many(Sample.integer > 1).sort("-integer") 134 | assert q.sort_expressions == [("integer", SortDirection.DESCENDING)] 135 | 136 | q = ( 137 | Sample.find_many(Sample.integer > 1) 138 | .find_many(Sample.integer < 100) 139 | .sort("-integer") 140 | ) 141 | assert q.sort_expressions == [("integer", SortDirection.DESCENDING)] 142 | 143 | result = await Sample.find_many( 144 | Sample.integer > 1, sort="-integer" 145 | ).to_list() 146 | i_buf = None 147 | for a in result: 148 | if i_buf is None: 149 | i_buf = a.integer 150 | assert i_buf >= a.integer 151 | i_buf = a.integer 152 | 153 | result = await Sample.find_many( 154 | Sample.integer > 1, sort="+integer" 155 | ).to_list() 156 | i_buf = None 157 | for a in result: 158 | if i_buf is None: 159 | i_buf = a.integer 160 | assert i_buf <= a.integer 161 | i_buf = a.integer 162 | 163 | result = await Sample.find_many( 164 | Sample.integer > 1, sort="integer" 165 | ).to_list() 166 | i_buf = None 167 | for a in result: 168 | if i_buf is None: 169 | i_buf = a.integer 170 | assert i_buf <= a.integer 171 | i_buf = a.integer 172 | 173 | result = await Sample.find_many( 174 | Sample.integer > 1, sort=-Sample.integer 175 | ).to_list() 176 | i_buf = None 177 | for a in result: 178 | if i_buf is None: 179 | i_buf = a.integer 180 | assert i_buf >= a.integer 181 | i_buf = a.integer 182 | 183 | with pytest.raises(TypeError): 184 | Sample.find_many(Sample.integer > 1, sort=1) 185 | 186 | 187 | async def test_find_many_with_projection(preset_documents): 188 | class SampleProjection(BaseModel): 189 | string: str 190 | integer: int 191 | 192 | result = ( 193 | await Sample.find_many(Sample.integer > 1) 194 | .project(projection_model=SampleProjection) 195 | .to_list() 196 | ) 197 | assert len(result) == 4 198 | 199 | result = ( 200 | await Sample.find_many(Sample.integer > 1) 201 | .find_many(projection_model=SampleProjection) 202 | .to_list() 203 | ) 204 | assert isinstance(result[0], SampleProjection) 205 | 206 | 207 | @pytest.mark.skip("Not supported") 208 | async def test_find_many_with_custom_projection(preset_documents): 209 | class SampleProjection(BaseModel): 210 | string: str 211 | i: int 212 | 213 | class Settings: 214 | projection = {"string": 1, "i": "$nested.integer"} 215 | 216 | result = ( 217 | await Sample.find_many(Sample.integer > 1) 218 | .project(projection_model=SampleProjection) 219 | .to_list() 220 | ) 221 | assert result == [ 222 | SampleProjection(string="test_2", i=3), 223 | SampleProjection(string="test_2", i=4), 224 | ] 225 | 226 | 227 | async def test_find_many_with_session(preset_documents, session): 228 | q_1 = Sample.find_many(Sample.integer > 1).set_session(session) 229 | assert q_1.session == session 230 | 231 | q_2 = Sample.find_many(Sample.integer > 1) 232 | assert q_2.session == session 233 | 234 | result = await q_2.to_list() 235 | 236 | assert len(result) == 4 237 | for a in result: 238 | assert a.integer > 1 239 | 240 | len_result = 0 241 | async for a in Sample.find_many(Sample.integer > 1): 242 | assert a in result 243 | len_result += 1 244 | 245 | assert len_result == len(result) 246 | 247 | 248 | async def test_bson_encoders_filed_types(): 249 | custom = DocumentWithBsonEncodersFiledsTypes( 250 | color="7fffd4", timestamp=datetime.datetime.utcnow() 251 | ) 252 | c = await custom.insert() 253 | c_fromdb = await DocumentWithBsonEncodersFiledsTypes.find_one( 254 | DocumentWithBsonEncodersFiledsTypes.color == Color("7fffd4") 255 | ) 256 | assert c_fromdb.color.as_hex() == c.color.as_hex() 257 | 258 | 259 | @pytest.mark.skip("Not supported") 260 | async def test_find_by_datetime(preset_documents): 261 | datetime_1 = datetime.datetime.utcnow() - datetime.timedelta(days=7) 262 | datetime_2 = datetime.datetime.utcnow() - datetime.timedelta(days=2) 263 | docs = await Sample.find( 264 | Sample.timestamp >= datetime_1, 265 | Sample.timestamp <= datetime_2, 266 | ).to_list() 267 | assert len(docs) == 5 268 | 269 | 270 | async def test_find_first_or_none(preset_documents): 271 | doc = ( 272 | await Sample.find(Sample.increment > 1) 273 | .sort(-Sample.increment) 274 | .first_or_none() 275 | ) 276 | assert doc.increment == 9 277 | 278 | doc = ( 279 | await Sample.find(Sample.increment > 9) 280 | .sort(-Sample.increment) 281 | .first_or_none() 282 | ) 283 | assert doc is None 284 | 285 | 286 | @pytest.mark.skip("Not supported") 287 | async def test_find_pymongo_kwargs(preset_documents): 288 | with pytest.raises(TypeError): 289 | await Sample.find_many(Sample.increment > 1, wrong=100).to_list() 290 | 291 | await Sample.find_many( 292 | Sample.increment > 1, Sample.integer > 1, allow_disk_use=True 293 | ).to_list() 294 | 295 | await Sample.find_many( 296 | Sample.increment > 1, Sample.integer > 1, hint="integer_1" 297 | ).to_list() 298 | 299 | await House.find_many( 300 | House.height > 1, fetch_links=True, hint="height_1" 301 | ).to_list() 302 | 303 | await House.find_many( 304 | House.height > 1, fetch_links=True, allowDiskUse=True 305 | ).to_list() 306 | 307 | await Sample.find_one( 308 | Sample.increment > 1, Sample.integer > 1, hint="integer_1" 309 | ) 310 | 311 | await House.find_one(House.height > 1, fetch_links=True, hint="height_1") 312 | -------------------------------------------------------------------------------- /tests/odm/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from decimal import Decimal 3 | from ipaddress import ( 4 | IPv4Address, 5 | IPv4Interface, 6 | IPv4Network, 7 | IPv6Address, 8 | IPv6Interface, 9 | IPv6Network, 10 | ) 11 | from pathlib import Path 12 | from typing import List, Optional, Set, Tuple, Union 13 | from uuid import UUID, uuid4 14 | 15 | import pymongo 16 | from pydantic import SecretBytes, SecretStr, Extra 17 | from pydantic.color import Color 18 | from pydantic import BaseModel, Field 19 | from pymongo import IndexModel 20 | 21 | from beanie import Document, Indexed, Insert, Replace, ValidateOnSave 22 | from beanie.odm.actions import before_event, after_event, Delete 23 | from beanie.odm.fields import Link 24 | from beanie.odm.settings.timeseries import TimeSeriesConfig 25 | from beanie.odm.union_doc import UnionDoc 26 | 27 | 28 | class Option2(BaseModel): 29 | f: float 30 | 31 | 32 | class Option1(BaseModel): 33 | s: str 34 | 35 | 36 | class Nested(BaseModel): 37 | integer: int 38 | option_1: Option1 39 | union: Union[Option1, Option2] 40 | optional: Optional[Option2] 41 | 42 | 43 | class GeoObject(BaseModel): 44 | type: str = "Point" 45 | coordinates: Tuple[float, float] 46 | 47 | 48 | class Sample(Document): 49 | timestamp: datetime.datetime 50 | increment: Indexed(int) 51 | integer: Indexed(int) 52 | float_num: float 53 | string: str 54 | nested: Nested 55 | optional: Optional[Option2] 56 | union: Union[Option1, Option2] 57 | geo: GeoObject 58 | 59 | 60 | class SubDocument(BaseModel): 61 | test_str: str 62 | test_int: int = 42 63 | 64 | 65 | class DocumentTestModel(Document): 66 | test_int: int 67 | test_list: List[SubDocument] = Field(hidden=True) 68 | test_doc: SubDocument 69 | test_str: str 70 | 71 | class Settings: 72 | use_cache = True 73 | cache_expiration_time = datetime.timedelta(seconds=10) 74 | cache_capacity = 5 75 | use_state_management = True 76 | 77 | 78 | class DocumentTestModelWithCustomCollectionName(Document): 79 | test_int: int 80 | test_list: List[SubDocument] 81 | test_str: str 82 | 83 | class Collection: 84 | name = "custom" 85 | 86 | # class Settings: 87 | # name = "custom" 88 | 89 | 90 | class DocumentTestModelWithSimpleIndex(Document): 91 | test_int: Indexed(int) 92 | test_list: List[SubDocument] 93 | test_str: Indexed(str, index_type=pymongo.TEXT) 94 | 95 | 96 | class DocumentTestModelWithIndexFlags(Document): 97 | test_int: Indexed(int, sparse=True) 98 | test_str: Indexed(str, index_type=pymongo.DESCENDING, unique=True) 99 | 100 | 101 | class DocumentTestModelWithIndexFlagsAliases(Document): 102 | test_int: Indexed(int, sparse=True) = Field(alias="testInt") 103 | test_str: Indexed(str, index_type=pymongo.DESCENDING, unique=True) = Field( 104 | alias="testStr" 105 | ) 106 | 107 | 108 | class DocumentTestModelWithComplexIndex(Document): 109 | test_int: int 110 | test_list: List[SubDocument] 111 | test_str: str 112 | 113 | class Collection: 114 | name = "docs_with_index" 115 | indexes = [ 116 | "test_int", 117 | [ 118 | ("test_int", pymongo.ASCENDING), 119 | ("test_str", pymongo.DESCENDING), 120 | ], 121 | IndexModel( 122 | [("test_str", pymongo.DESCENDING)], 123 | name="test_string_index_DESCENDING", 124 | ), 125 | ] 126 | 127 | # class Settings: 128 | # name = "docs_with_index" 129 | # indexes = [ 130 | # "test_int", 131 | # [ 132 | # ("test_int", pymongo.ASCENDING), 133 | # ("test_str", pymongo.DESCENDING), 134 | # ], 135 | # IndexModel( 136 | # [("test_str", pymongo.DESCENDING)], 137 | # name="test_string_index_DESCENDING", 138 | # ), 139 | # ] 140 | 141 | 142 | class DocumentTestModelWithDroppedIndex(Document): 143 | test_int: int 144 | test_list: List[SubDocument] 145 | test_str: str 146 | 147 | class Collection: 148 | name = "docs_with_index" 149 | indexes = [ 150 | "test_int", 151 | ] 152 | 153 | class Settings: 154 | name = "docs_with_index" 155 | indexes = [ 156 | "test_int", 157 | ] 158 | 159 | 160 | class DocumentTestModelStringImport(Document): 161 | test_int: int 162 | 163 | 164 | class DocumentTestModelFailInspection(Document): 165 | test_int_2: int 166 | 167 | class Collection: 168 | name = "DocumentTestModel" 169 | 170 | class Settings: 171 | name = "DocumentTestModel" 172 | 173 | 174 | class DocumentWithCustomIdUUID(Document): 175 | id: UUID = Field(default_factory=uuid4) 176 | name: str 177 | 178 | 179 | class DocumentWithCustomIdInt(Document): 180 | id: int 181 | name: str 182 | 183 | 184 | class DocumentWithCustomFiledsTypes(Document): 185 | color: Color 186 | decimal: Decimal 187 | secret_bytes: SecretBytes 188 | secret_string: SecretStr 189 | ipv4address: IPv4Address 190 | ipv4interface: IPv4Interface 191 | ipv4network: IPv4Network 192 | ipv6address: IPv6Address 193 | ipv6interface: IPv6Interface 194 | ipv6network: IPv6Network 195 | timedelta: datetime.timedelta 196 | set_type: Set[str] 197 | tuple_type: Tuple[int, str] 198 | path: Path 199 | 200 | 201 | class DocumentWithBsonEncodersFiledsTypes(Document): 202 | color: Color 203 | timestamp: datetime.datetime 204 | 205 | class Settings: 206 | bson_encoders = { 207 | Color: lambda c: c.as_rgb(), 208 | datetime.datetime: lambda o: o.isoformat(timespec="microseconds"), 209 | } 210 | 211 | 212 | class DocumentWithActions(Document): 213 | name: str 214 | num_1: int = 0 215 | num_2: int = 10 216 | num_3: int = 100 217 | 218 | class Inner: 219 | inner_num_1 = 0 220 | inner_num_2 = 0 221 | 222 | @before_event(Insert) 223 | def capitalize_name(self): 224 | self.name = self.name.capitalize() 225 | 226 | @before_event([Insert, Replace]) 227 | async def add_one(self): 228 | self.num_1 += 1 229 | 230 | @after_event(Insert) 231 | def num_2_change(self): 232 | self.num_2 -= 1 233 | 234 | @after_event(Replace) 235 | def num_3_change(self): 236 | self.num_3 -= 1 237 | 238 | @before_event(Delete) 239 | def inner_num_to_one(self): 240 | self.Inner.inner_num_1 = 1 241 | 242 | @after_event(Delete) 243 | def inner_num_to_two(self): 244 | self.Inner.inner_num_2 = 2 245 | 246 | 247 | class InheritedDocumentWithActions(DocumentWithActions): 248 | ... 249 | 250 | 251 | class InternalDoc(BaseModel): 252 | num: int = 100 253 | string: str = "test" 254 | lst: List[int] = [1, 2, 3, 4, 5] 255 | 256 | 257 | class DocumentWithTurnedOnStateManagement(Document): 258 | num_1: int 259 | num_2: int 260 | internal: InternalDoc 261 | 262 | class Settings: 263 | use_state_management = True 264 | 265 | 266 | class DocumentWithTurnedOnReplaceObjects(Document): 267 | num_1: int 268 | num_2: int 269 | internal: InternalDoc 270 | 271 | class Settings: 272 | use_state_management = True 273 | state_management_replace_objects = True 274 | 275 | 276 | class DocumentWithTurnedOffStateManagement(Document): 277 | num_1: int 278 | num_2: int 279 | 280 | 281 | class DocumentWithValidationOnSave(Document): 282 | num_1: int 283 | num_2: int 284 | 285 | @after_event(ValidateOnSave) 286 | def num_2_plus_1(self): 287 | self.num_2 += 1 288 | 289 | class Settings: 290 | validate_on_save = True 291 | use_state_management = True 292 | 293 | 294 | class DocumentWithRevisionTurnedOn(Document): 295 | num_1: int 296 | num_2: int 297 | 298 | class Settings: 299 | use_revision = True 300 | use_state_management = True 301 | 302 | 303 | class DocumentWithPydanticConfig(Document): 304 | num_1: int 305 | 306 | class Config(Document.Config): 307 | validate_assignment = True 308 | 309 | 310 | class DocumentWithExtras(Document): 311 | num_1: int 312 | 313 | class Config(Document.Config): 314 | extra = Extra.allow 315 | 316 | 317 | class DocumentWithExtrasKw(Document, extra=Extra.allow): 318 | num_1: int 319 | 320 | 321 | class Window(Document): 322 | x: int 323 | y: int 324 | 325 | 326 | class Door(Document): 327 | t: int = 10 328 | 329 | 330 | class Roof(Document): 331 | r: int = 100 332 | 333 | 334 | class House(Document): 335 | windows: List[Link[Window]] 336 | door: Link[Door] 337 | roof: Optional[Link[Roof]] 338 | name: Indexed(str) = Field(hidden=True) 339 | height: Indexed(int) = 2 340 | 341 | 342 | class DocumentForEncodingTest(Document): 343 | bytes_field: Optional[bytes] 344 | datetime_field: Optional[datetime.datetime] 345 | 346 | 347 | class DocumentWithTimeseries(Document): 348 | ts: datetime.datetime = Field(default_factory=datetime.datetime.now) 349 | 350 | class Settings: 351 | timeseries = TimeSeriesConfig(time_field="ts", expire_after_seconds=2) 352 | 353 | 354 | class DocumentForEncodingTestDate(Document): 355 | date_field: datetime.date = Field(default_factory=datetime.date.today) 356 | 357 | class Settings: 358 | name = "test_date" 359 | bson_encoders = { 360 | datetime.date: lambda dt: datetime.datetime( 361 | year=dt.year, 362 | month=dt.month, 363 | day=dt.day, 364 | hour=0, 365 | minute=0, 366 | second=0, 367 | ) 368 | } 369 | 370 | 371 | class DocumentUnion(UnionDoc): 372 | class Settings: 373 | name = "multi_model" 374 | 375 | 376 | class DocumentMultiModelOne(Document): 377 | int_filed: int = 0 378 | shared: int = 0 379 | 380 | class Settings: 381 | union_doc = DocumentUnion 382 | 383 | 384 | class DocumentMultiModelTwo(Document): 385 | str_filed: str = "test" 386 | shared: int = 0 387 | linked_doc: Optional[Link[DocumentMultiModelOne]] = None 388 | 389 | class Settings: 390 | union_doc = DocumentUnion 391 | --------------------------------------------------------------------------------