├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .isort.cfg ├── CHANGELOG.md ├── LICENSE ├── README.md ├── atomdb ├── __init__.py ├── base.py ├── nosql.py ├── py.typed └── sql.py ├── makefile ├── pytest.ini ├── setup.py └── tests ├── test_base.py ├── test_benchmark.py ├── test_json.py ├── test_nosql.py ├── test_sql.py └── test_sql_benchmark.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [push] 3 | jobs: 4 | test-postgres: 5 | runs-on: ubuntu-latest 6 | env: 7 | DATABASE_URL: 'postgres://user:password@localhost:5432/test_atomdb' 8 | strategy: 9 | fail-fast: false 10 | matrix: 11 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 12 | services: 13 | postgres: 14 | image: postgres 15 | ports: 16 | - 5432:5432 17 | env: 18 | POSTGRES_USER: user 19 | POSTGRES_PASSWORD: password 20 | POSTGRES_DB: test_atomdb 21 | # Set health checks to wait until postgres has started 22 | options: >- 23 | --health-cmd pg_isready 24 | --health-interval 10s 25 | --health-timeout 5s 26 | --health-retries 5 27 | steps: 28 | - uses: actions/checkout@v3 29 | - name: Setup python ${{ matrix.python-version}} 30 | uses: actions/setup-python@v3 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | - name: Install dependencies 34 | run: pip install -U aiopg 'sqlalchemy<1.5' codecov pytest pytest-benchmark pytest-cov pytest-asyncio 35 | - name: Install atom-db 36 | run: pip install -e ./ 37 | - name: Run tests 38 | run: pytest -v tests --cov atomdb --cov-report xml --asyncio-mode auto 39 | - name: Coverage 40 | run: codecov 41 | test-mysql: 42 | runs-on: ubuntu-latest 43 | env: 44 | DATABASE_URL: 'mysql://user:password@localhost:3306/test_atomdb' 45 | strategy: 46 | fail-fast: false 47 | matrix: 48 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 49 | services: 50 | mariadb: 51 | image: mariadb:latest 52 | ports: 53 | - 3306:3306 54 | env: 55 | MYSQL_USER: user 56 | MYSQL_PASSWORD: password 57 | MYSQL_DATABASE: test_atomdb 58 | MYSQL_ROOT_PASSWORD: root 59 | options: >- 60 | --health-cmd="healthcheck.sh --connect --innodb_initialized" 61 | --health-interval=10s 62 | --health-timeout=5s 63 | --health-retries=3 64 | steps: 65 | - uses: actions/checkout@v3 66 | - name: Setup python ${{ matrix.python-version}} 67 | uses: actions/setup-python@v3 68 | with: 69 | python-version: ${{ matrix.python-version }} 70 | - name: Install dependencies 71 | run: pip install -U aiomysql 'sqlalchemy<1.4' codecov pytest pytest-benchmark pytest-cov pytest-asyncio 72 | - name: Install atom-db 73 | run: pip install -e ./ 74 | - name: Run tests 75 | run: pytest -v tests --cov atomdb --cov-report xml --asyncio-mode auto 76 | - name: Coverage 77 | run: codecov 78 | test-mongo: 79 | runs-on: ubuntu-latest 80 | env: 81 | MONGO_URL: 'mongodb://localhost:27017' 82 | strategy: 83 | fail-fast: false 84 | matrix: 85 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 86 | 87 | services: 88 | mongodb: 89 | image: mongo 90 | ports: 91 | - 27017:27017 92 | steps: 93 | - uses: actions/checkout@v3 94 | - name: Setup python ${{ matrix.python-version}} 95 | uses: actions/setup-python@v3 96 | with: 97 | python-version: ${{ matrix.python-version }} 98 | - name: Install dependencies 99 | run: pip install -U motor codecov pytest pytest-benchmark pytest-cov pytest-asyncio 100 | - name: Install atom-db 101 | run: pip install -e ./ 102 | - name: Run tests 103 | run: pytest -v tests --cov atomdb --cov-report xml --asyncio-mode auto 104 | - name: Coverage 105 | run: codecov 106 | check-code: 107 | runs-on: ubuntu-latest 108 | strategy: 109 | matrix: 110 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 111 | steps: 112 | - uses: actions/checkout@v3 113 | - name: Setup python ${{ matrix.python-version}} 114 | uses: actions/setup-python@v3 115 | with: 116 | python-version: ${{ matrix.python-version }} 117 | - name: Install dependencies 118 | run: pip install -U motor aiopg aiomysql 'sqlalchemy<2' mypy black isort flake8 119 | - name: Run checks 120 | run: | 121 | isort atomdb tests --check --diff 122 | black atomdb tests --check --diff 123 | mypy atomdb --ignore-missing-imports 124 | flake8 --ignore=E501,W503 atomdb tests 125 | 126 | 127 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # kde 107 | *.kate-swp 108 | *.kdev4 109 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.8.1 2 | 3 | - Rebase changes from 0.7.10 and 0.7.11 4 | - Rework Relation & RelatedList so the return value is still a list instance 5 | - Flatten builtin enum.Enum types to their value 6 | - Add first, last, earliest, latest to QuerySet and support Meta get_latest_by 7 | - Support using `_id` as alias to the primary key when doing filtering 8 | 9 | # 0.8.0 10 | 11 | - **breaking** Make `Relation` return a RelatedList with `save` and `load` methods. 12 | - Don't rewrite bytecode 13 | - Pass onclause when using join to prevent sqlalchemy from picking incorrect relation 14 | - Fix select_related with duplicate joins (eg `select_related('a', 'a__b')`) 15 | - Change Enum database name to include the table name 16 | - Add builtin set / tuple support for JSONModel 17 | 18 | # 0.7.11 19 | 20 | - Support doing an or filter using django style queries by passing a dict arg 21 | 22 | # 0.7.10 23 | 24 | - Add group by 25 | 26 | 27 | # 0.7.9 28 | 29 | - Fix error with 3.11 30 | 31 | # 0.7.8 32 | 33 | - Fix problem with flatten/unflatten creating invalid columns. 34 | 35 | # 0.7.7 36 | 37 | - Return from bulk_create if values list is empty 38 | - Fix problem with order_by not having __bool__ defined 39 | 40 | # 0.7.6 41 | 42 | - Change internal data types to set to speed up building queries and allow caching 43 | - Add a `bulk_create` method to the model manager. 44 | - Add `py.typed` to package 45 | 46 | # 0.7.5 47 | 48 | - Fix a bug preventing select_related on multiple fields from working 49 | - Add django style `exclude` to filter 50 | 51 | # 0.7.4 52 | 53 | - Do not save Property members by default 54 | - Annotate Model `objects` and `serializer` with `ClassVar` 55 | - Change import sorting and cleanup errors found with flake8 56 | 57 | # 0.7.3 58 | 59 | - Revert force restore items from direct query even if in the cache. 60 | Now queries can accept a `force_restore=True` to do this. 61 | See https://en.wikipedia.org/wiki/Isolation_%28database_systems%29 62 | 63 | # 0.7.2 64 | 65 | - Support prefetching of one to one "Related" members. 66 | - Remove _id field in base Model and JSONModel as it has no meaning there 67 | 68 | # 0.7.1 69 | 70 | - Always force restore items from direct query even if in the cache 71 | - Make prefetch use parent queries connection 72 | 73 | # 0.7.0 74 | 75 | - Use generated functions to speed up save and restore 76 | - BREAKING: To save memory (by avoiding overriding members) the `_id` and `__ref__` fields 77 | were changed to an `Int`. 78 | 79 | # 0.6.4 80 | 81 | - Fix queries joining through multiple tables 82 | - Add initial implementation of prefetch_related 83 | 84 | # 0.6.3 85 | 86 | - Add workaround for subclassed pk handling 87 | 88 | # 0.6.2 89 | 90 | - Add support for using multiple databases 91 | - Fix non-abstract subclasses throwing multiple primary key error. 92 | - Make update work with renamed fields 93 | 94 | # 0.6.1 95 | 96 | - Merge `composite_indexes` and typing branches 97 | 98 | # 0.6.0 99 | 100 | - Add type hints 101 | - Drop python 3.6 support 102 | - Fix bug with renamed fields excluded fields 103 | 104 | # 0.5.8 105 | 106 | - Add `composite_indexes` to Model Meta. 107 | 108 | # 0.5.7 109 | 110 | - Add `distinct` to queryset 111 | 112 | # 0.5.6 113 | 114 | - Add `outer_join` to queryset to allow using a left outer join with select related 115 | 116 | # 0.5.5 117 | 118 | - Add builtin JSON serializer for `UUID` 119 | 120 | 121 | # 0.5.4 122 | 123 | - Add builtin JSON serializer for `Decimal` 124 | 125 | # 0.5.3 126 | 127 | - Add field types for `Decimal` and `timedelta` 128 | - Fix bug with enum field name on postgres 129 | - Fix array field with instance child types 130 | - Add support for `update_fields` on save to only fields specified 131 | - Add support for `fields` on load to only load fields specified 132 | 133 | 134 | # 0.5.2 135 | 136 | - Add support for table database `triggers`. See https://docs.sqlalchemy.org/en/14/core/ddl.html 137 | - Fix bug in create_table where errors are not raised 138 | 139 | # 0.5.1 140 | 141 | - Add `update` method using `Model.objects.update(**values)` 142 | 143 | # 0.5.0 144 | 145 | - Replace usage of Unicode with Str to support atom 0.6.0 146 | 147 | # 0.4.1 148 | 149 | - Change order by to use `-` as desc instead of `~` 150 | - Add default constraint naming conventions https://alembic.sqlalchemy.org/en/latest/naming.html#the-importance-of-naming-constraints 151 | - Allow setting a `constraints` list on the Model `Meta` class 152 | - Fix issue with `connection` arg not working properly when filtering 153 | 154 | # 0.4.0 155 | 156 | - Refactor SQL queries so they can be chained 157 | ex `Model.objects.filter(name="Something").filter(age__gt=18)` 158 | - Add `order_by`, `limit`, and `offset`, slicing, and `exists` 159 | - Support filtering using django-style reverse foreign key lookups, 160 | ex `Model.objects.filter(groups_in=[group1, group2])` 161 | - Refactor count to support counting over joins 162 | 163 | # 0.3.11 164 | 165 | - Let a member be tagged with a custom `flatten` function 166 | 167 | # 0.3.10 168 | 169 | - Fix bug in SQLModel load using `_id` which is not a valid field 170 | 171 | # 0.3.9 172 | 173 | - Let a member be tagged with a custom `unflatten` function 174 | 175 | # 0.3.8 176 | - Properly restore JSONModel instances that have no `__model__` in the dict 177 | when migrated from a dict or regular JSON field. 178 | 179 | # 0.3.7 180 | 181 | - Add a `__restored__` member to models and a `load` method so foreign keys 182 | do not restore as None if not in the cache. 183 | - Update to support atom 0.5.0 184 | 185 | 186 | # 0.3.6 187 | 188 | - Add `cache` option to SQLModelManager to determine if restoring 189 | should always be done even if the object is in the cache. 190 | 191 | # 0.3.5 192 | 193 | - Set column type to json if the type is a JSONModel subclass 194 | 195 | # 0.3.4 196 | 197 | - Fix bug when saving using a generated id 198 | 199 | # 0.3.3 200 | 201 | - Change __setstate__ to __restorestate__ to not conflict with normal pickleing 202 | 203 | # 0.3.2 204 | 205 | - Support lookups on foreign key fields 206 | - Add ability to specify `get_column` and `get_column_type` to let `atom.catom.Member` 207 | subclasses use custom sql columns if needed. 208 | 209 | # 0.3.1 210 | 211 | - Support lookups using renamed column fields 212 | 213 | # 0.3.0 214 | 215 | - The create and drop have been renamed to `create_table` and `drop_table` respectively. 216 | - Add a shortcut `SomeModel.object.create(**state)` method 217 | - Allow passing a db connection to manager methods ( 218 | get, get_or_create, filter, delete, etc...) to better support transactions 219 | 220 | 221 | # 0.2.4 222 | 223 | - Fix the nosql serialization registry not being loaded properly 224 | 225 | # 0.2.3 226 | 227 | - Fix packaging issue with 0.2.2 228 | 229 | # 0.2.2 230 | 231 | - Fix bug with fk types 232 | - Allow passing Model instances as filter parameters #8 by @youpsla 233 | 234 | # 0.2.1 235 | 236 | - Add a JSONModel that simply can be serialized and restored using JSON. 237 | 238 | # 0.2.1 239 | 240 | - Add ability to set an SQL model as `abstract` so that no sqlalchemy table is 241 | created. 242 | 243 | # 0.1.0 244 | 245 | - Initial release 246 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 CodeLV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![status](https://github.com/codelv/atom-db/actions/workflows/ci.yml/badge.svg)](https://github.com/codelv/atom-db/actions) 2 | [![codecov](https://codecov.io/gh/codelv/atom-db/branch/master/graph/badge.svg)](https://codecov.io/gh/codelv/atom-db) 3 | 4 | atom-db is a database abstraction layer for the 5 | [atom](https://github.com/nucleic/atom) framework. This package provides api's for 6 | seamlessly saving and restoring atom objects from json based document databases 7 | and SQL databases supported by sqlalchemy. 8 | 9 | 10 | The main reason for building this is to make it easier have database integration 11 | with [enaml](https://github.com/nucleic/enaml) applications so a separate 12 | framework is not needed to define database models. 13 | 14 | This was originally a part of [enaml-web](https://github.com/codelv/enaml-web) 15 | but has been pulled out to a separate package. 16 | 17 | 18 | ### Overview 19 | 20 | - Supports MySQL and Postgres 21 | - Uses django like queries or raw sqlalchemy queries 22 | - Works with alembic database migrations 23 | - Supports MongoDB using motor 24 | 25 | ### Structure 26 | 27 | The design is based somewhat on django. 28 | 29 | There is a "manager" called `Model.objects` to do queries on the database table 30 | created for each subclass. 31 | 32 | Serialization and deserialization is done with `Model.serializer`. 33 | 34 | > Note: As of 0.3.11 serialization can be customizer per member by tagging the 35 | member with a `flatten` or `unflatten` which should be a async callable which 36 | accepts the value and scope. 37 | 38 | Each `Model` has async `save`, `delete`, and `restore` methods to interact with 39 | the database. This can be customized if needed using 40 | `__restorestate__` and `__getstate__`. 41 | 42 | 43 | # MySQL and Postgres support 44 | 45 | You can use atom-db to save and restore atom subclasses to MySQL and Postgres. 46 | 47 | Just define models using atom members, but subclass the `SQLModel` and atom-db 48 | will convert the builtin atom members of your model to sqlalchemy table columns 49 | and create a `sqlalchemy.Table` for your model. 50 | 51 | 52 | ### Customizing table creation 53 | 54 | To customize how table columns are created you can tag members with information 55 | needed for sqlalchemy columns, ex `Str().tag(length=40)` will make a `sa.String(40)`. 56 | See https://docs.sqlalchemy.org/en/latest/core/type_basics.html. Tagging any 57 | member with `store=False` will make the member be excluded from the db. 58 | 59 | atomdb will attempt to determine the proper column type, but if you need more 60 | control, you can tag the member to specify the column type with 61 | `type=sa.` or specify the full column definition with 62 | `column=sa.Column(...)`. 63 | 64 | If you have a custom member, you can define a `def get_column(self, model)` 65 | or `def get_column_type(self, model)` method to create the table column for the 66 | given model. 67 | 68 | 69 | ##### Primary keys 70 | 71 | You can tag a member with `primary_key=True` to make it the pk. If no member 72 | is tagged with `primary_key` it will create and use `_id` as the primary key. 73 | The`_id` member will be always alias to the actual primary key. Use the `__pk__` 74 | attribute of the class to get the name of the primary key member. 75 | 76 | ##### Table metadata 77 | 78 | Like in Django a nested `Meta` class can be added to specify the `db_name`, 79 | `unique_together`, `composite indexes` and `constraints`. 80 | 81 | If no `db_name` is specified on a Meta class, the table name defaults the what 82 | is set in the `__model__` member. This defaults to the qualname of the class, 83 | eg `myapp.SomeModel`. 84 | 85 | 86 | ```python 87 | 88 | class SomeModel(SQLModel): 89 | # ... 90 | 91 | class Meta: 92 | db_table = 'custom_table_name' 93 | 94 | ``` 95 | 96 | 97 | 98 | `composite indexes` must be a list of each composite index description. See [sqlachemy's Index](https://docs.sqlalchemy.org/en/14/core/constraints.html#sqlalchemy.schema.Index) for index tuple description. 99 | 100 | First element is the index name. If `None`, the name will be auto-generated according the convention. Followings elements are columns table. Order of columns matters. 101 | 102 | ```python 103 | 104 | class SomeModel(SQLModel): 105 | # ... 106 | 107 | class Meta: 108 | composite_indexes = [(None, 'a', 'b'), ('ìndex_b_c', 'b', 'c')] 109 | 110 | ``` 111 | 112 | In the exemple above, two composite indexes are created. The first is on columns 'a' and 'b' with name `ix_SomeModel_a_b`. The second one on columns 'b' and 'c' with name `ìndex_b_c`. 113 | 114 | 115 | ##### Table creation / dropping 116 | 117 | Once your tables are defined as atom models, create and drop tables using 118 | `create_table` and `drop_table` of `Model.objects` respectively For example: 119 | 120 | ```python 121 | 122 | from atomdb.sql import SQLModel, SQLModelManager 123 | 124 | # Call create_tables to create sqlalchemy tables. This does NOT write them to 125 | # the db but ensures that all ForeignKey relations are created 126 | mgr = SQLModelManager.instance() 127 | mgr.create_tables() 128 | 129 | # Now actually drop/create for each of your models 130 | 131 | # Drop the table for this model (will raise sqlalchemy's error if it doesn't exist) 132 | await User.objects.drop_table() 133 | 134 | # Create the user table 135 | await User.objects.create_table() 136 | 137 | 138 | ``` 139 | 140 | The `mgr.create_tables()` method will create the sqlalchemy tables for each 141 | imported SQLModel subclass (anything in the manager's `registry` dict). This 142 | should be called after all of your models are imported so sqlalchemy can 143 | properly setup any foreign key relations. 144 | 145 | The manager also has a `metadata` member which holds the `sqlalchemy.MetaData` 146 | needed for migrations. 147 | 148 | Once the tables are created, they are accessible via `Model.objects.table`. 149 | 150 | > Note: The sqlachemy table is also assigned to the `__table__` attribute of 151 | each model class, however this will not be defined until the manager has 152 | created it. 153 | 154 | 155 | #### Database setup 156 | 157 | Before accessing the DB you must assign a "database engine" to the manager's 158 | `database` member. 159 | 160 | > Note: As of `0.6.2` you can also specify this as a dictionary to use multiple 161 | databases. 162 | 163 | ```python 164 | import os 165 | import re 166 | from aiomysql.sa import create_engine 167 | from atomdb.sql import SQLModelManager 168 | 169 | DATABASE_URL = os.environ.get('MYSQL_URL') 170 | 171 | # Parse the DB url 172 | m = re.match(r'mysql://(.+):(.*)@(.+):(\d+)/(.+)', DATABASE_URL) 173 | user, pwd, host, port, db = m.groups() 174 | 175 | # Create the engine 176 | engine = await create_engine( 177 | db=db, user=user, password=pwd, host=host, port=port) 178 | 179 | # Assign it to the manager 180 | mgr = SQLModelManager.instance() 181 | mgr.database = engine 182 | 183 | 184 | ``` 185 | 186 | This engine will then be used by the manager to execute queries. You can 187 | retrieve the database engine from any Model by using `Model.objects.engine`. 188 | 189 | 190 | ###### Multiple database 191 | 192 | If you need to use more than one database it looks like this. 193 | 194 | ```python 195 | 196 | # Multiple databases 197 | mgr = SQLModelManager.instance() 198 | mgr.database = { 199 | 'default': await create_engine(**default_db_params), 200 | 'other': await create_engine(**other_db_params), 201 | } 202 | 203 | ``` 204 | 205 | To specify which database is used either using the `__database__` class field 206 | or specify it as the `db_name` on the model Meta. 207 | 208 | ```python 209 | 210 | class ExternalData(SQLModel): 211 | 212 | # ... fields 213 | class Meta: 214 | db_name = 'other' 215 | db_table = 'external_data' 216 | 217 | 218 | ``` 219 | 220 | 221 | #### Django style queries 222 | 223 | Only very basic ORM style queries are implemented for common use cases. These 224 | are `get`, `get_or_create`, `filter`, and `all`. These all accept 225 | "django style" queries using `=` or `__=`. 226 | 227 | For example: 228 | 229 | ```python 230 | 231 | john, created = await User.objects.get_or_create( 232 | name="John Doe", email="jon@example.com", age=21, active=True) 233 | assert created 234 | 235 | jane, created = await User.objects.get_or_create( 236 | name="Jane Doe", email="jane@example.com", age=48, active=False, 237 | rating=10.0) 238 | assert created 239 | 240 | # Startswith 241 | u = await User.objects.get(name__startswith="John") 242 | assert u.name == john.name 243 | 244 | # In query 245 | users = await User.objects.filter(name__in=[john.name, jane.name]) 246 | assert len(users) == 2 247 | 248 | # Is query 249 | users = await User.objects.filter(active__is=False) 250 | assert len(users) == 1 and users[0].active == False 251 | 252 | ``` 253 | 254 | See [sqlachemy's ColumnElement](https://docs.sqlalchemy.org/en/latest/core/sqlelement.html?highlight=column#sqlalchemy.sql.expression.ColumnElement) 255 | for which queries can be used in this way. Also the tests check that these 256 | actually work as intended. 257 | 258 | > Note: As of `0.4.0` you can pass sqlalchemy filters as non-keyword arguments 259 | directly to the filter method. 260 | 261 | 262 | ###### Caching, select related, and prefetch related 263 | 264 | Foreign key relations can automatically be loaded using `select_related` and 265 | `prefetch_related`. Select related will perform a 266 | join while prefetch related does a separate query. 267 | 268 | Each Model has a cache available at `Model.objects.cache` which uses weakrefs to 269 | ensure the same object is returned each time. You can manually prefetch objects 270 | and atom-db will pull them from it's internal cache when restoring objects. 271 | 272 | For example with a simple many to one relationship like this: 273 | 274 | ```python 275 | 276 | class Category(SQLModel): 277 | name = Str() 278 | products = Relation(lambda: Product) 279 | 280 | class Product(SQLModel): 281 | title = Str() 282 | category = Typed(Category) 283 | 284 | category = await Category.objects.create(name="PCB") 285 | await Product.objects.create(title="Stepper driver", category=category) 286 | 287 | ``` 288 | 289 | Use select related to load the product's category foreign key automatically. 290 | 291 | ```python 292 | # In this case the category of each product will automatically be loaded 293 | products = await Product.objects.select_related('category').filter(title__icontains="driver") 294 | # The __restored__ flag can be used check if the model has been loaded 295 | assert products[0].category.name == "PCB" 296 | ``` 297 | 298 | > If a foreign key relation is NOT in the cache or in the state from a joined row 299 | it will create an "unloaded" model with only the primary key populated. In this 300 | case the `__restored__` flag will be set to `False`. 301 | 302 | From the other direction use prefetch related. 303 | 304 | ```python 305 | category = await Category.objects.prefetch_related('products').get(name="PCB") 306 | assert category.products[0].title == "Stepper driver" 307 | ``` 308 | 309 | > Note: prefetch_related does not apply a limit. If the query has a lot of rows 310 | this may be a problem. 311 | 312 | Alternatively you can prefetch the related objects and they will be 313 | automatically pulled from the internal cache (eg `TheModel.objects.cache`). 314 | 315 | ```python 316 | all_categories = await Category.objects.all() 317 | products = await Product.objects.filter(title__icontains="driver") 318 | assert products[0].category in all_categories 319 | ``` 320 | 321 | 322 | #### Advanced / raw sqlalchemy queries 323 | 324 | For more advanced queries using joins, etc.. you must build the query with 325 | sqlalchemy then execute it. The `sa.Table` for an atom model can be retrieved 326 | using `Model.objects.table` on which you can use select, where, etc... to build 327 | up whatever query you need. 328 | 329 | Then use `fetchall`, `fetchone`, `fetchmany`, or `execute` to do these queries. 330 | 331 | These methods do NOT return an object but the row from the database so they 332 | must manually be restored. 333 | 334 | When joining you'll usually want to pass `use_labels=True`. For example: 335 | 336 | ```python 337 | 338 | q = Job.objects.table.join(JobRole.objects.table).select(use_labels=True) 339 | 340 | for row in await Job.objects.fetchall(q): 341 | # Restore each manually, it handles pulling out the fields that are it's own 342 | job = await Job.restore(row) 343 | role = await JobRole.restore(row) 344 | 345 | ``` 346 | 347 | Depending on the relationships, you may need to then post-process these so they 348 | can be accessed in a more pythonic way. This is trade off between complexity 349 | and ease of use. 350 | 351 | 352 | ### Connections and Transactions 353 | 354 | A connection can be retrieved using `Model.objects.connection()` and used 355 | like normal aiomysql / aiopg connection. A transaction is done in the same way 356 | as defined in the docs for those libraries eg. 357 | 358 | ```python 359 | 360 | async with Job.objects.connection() as conn: 361 | trans = await conn.begin() 362 | try: 363 | # Do your queries here and pass the `connection` to each 364 | job, created = await Job.objects.get_or_create(connection=conn, **state) 365 | except: 366 | await trans.rollback() 367 | raise 368 | else: 369 | await trans.commit() 370 | 371 | ``` 372 | 373 | When using a transaction you need to pass the active connection to 374 | each call or it will use a different connection outside of the transaction! 375 | 376 | The connection argument is removed from the filters/state. If your model happens 377 | to have a member named `connection` you can rename the connection argument by 378 | with `Model.object.connection_kwarg = 'connection_'` or whatever name you like. 379 | 380 | ### Migrations 381 | 382 | Migrations work using [alembic](https://alembic.sqlalchemy.org/en/latest/autogenerate.html). The metadata needed 383 | to autogenerate migrations can be retrieved from `SQLModelManager.instance().metadata` so add the following 384 | in your alembic's env.py: 385 | 386 | ```python 387 | # Import your db models first 388 | from myapp.models import * 389 | 390 | from atomdb.sql import SQLModelManager 391 | manager = SQLModelManager.instance() 392 | manager.create_tables() # Create sa tables 393 | target_metadata = manager.metadata 394 | 395 | ``` 396 | 397 | The rest is handled by alembic. 398 | 399 | 400 | > Note: As of 0.4.1 the constraint naming conventions can be set using 401 | manager.constraints, this must be done before any tables are imported. 402 | 403 | 404 | 405 | # NoSQL support 406 | 407 | You can also use atom-db to save and restore atom subclasses to MongoDB. 408 | 409 | The NoSQL version is very basic as mongo is much more relaxed. No restriction 410 | is imposed on what type of manager is used, leaving that to whichever database 411 | library is preferred but it's tested (and currently used) with [motor](https://motor.readthedocs.io/en/stable/) 412 | and [tornado](https://www.tornadoweb.org/en/stable/index.html). 413 | 414 | Just define models using atom members, but subclass the `NoSQLModel`. 415 | 416 | ```python 417 | 418 | from atom.api import Unicode, Int, Instance, List 419 | from atomdb.nosql import NoSQLModel, NoSQLModelManager 420 | from motor.motor_asyncio import AsyncIOMotorClient 421 | 422 | # Set DB 423 | client = AsyncIOMotorClient() 424 | mgr = NoSQLModelManager.instance() 425 | mgr.database = client.test_db 426 | 427 | 428 | class Group(NoSQLModel): 429 | name = Unicode() 430 | 431 | class User(NoSQLModel): 432 | name = Unicode() 433 | age = Int() 434 | groups = List(Group) 435 | 436 | 437 | ``` 438 | 439 | Then we can create an instance and save it. It will perform an upsert or replace 440 | the existing entry. 441 | 442 | ```python 443 | 444 | admins = Group(name="Admins") 445 | await admins.save() 446 | 447 | # It will save admins using it's ObjectID 448 | bob = User(name="Bob", age=32, groups=[admins]) 449 | await bob.save() 450 | 451 | tom = User(name="Tom", age=34, groups=[admins]) 452 | await tom.save() 453 | 454 | ``` 455 | 456 | To fetch from the DB each model has a `ModelManager` called `objects` that will 457 | simply return the collection for the model type. For example. 458 | 459 | ```python 460 | 461 | # Fetch from db, you can use any MongoDB queries here 462 | state = await User.objects.find_one({'name': "James"}) 463 | if state: 464 | james = await User.restore(state) 465 | 466 | # etc... 467 | ``` 468 | 469 | Restoring is async because it will automatically fetch any related objects 470 | (ex the groups in this case). It saves objects using the ObjectID when present. 471 | 472 | And finally you can either delete using queries on the manager directly or 473 | call on the object. 474 | 475 | ```python 476 | await tom.delete() 477 | assert not await User.objects.find_one({'name': "Tom"}) 478 | 479 | ``` 480 | 481 | You can exclude members from being saved to the DB by tagging them 482 | with `.tag(store=False)`. 483 | 484 | 485 | ## Contributing 486 | 487 | This is currently used in a few projects but not considered mature by 488 | any means. 489 | 490 | Pull requests and feature requests are welcome! 491 | -------------------------------------------------------------------------------- /atomdb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codelv/atom-db/2928b0d90a991d5ca5995b72868cf2c2cd2a4364/atomdb/__init__.py -------------------------------------------------------------------------------- /atomdb/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018-2022, Jairus Martin. 3 | 4 | Distributed under the terms of the MIT License. 5 | 6 | The full license is in the file LICENSE.text, distributed with this software. 7 | 8 | Created on Jun 12, 2018 9 | """ 10 | 11 | import asyncio 12 | import enum 13 | import logging 14 | from base64 import b64decode, b64encode 15 | from collections.abc import MutableMapping 16 | from datetime import date, datetime, time 17 | from decimal import Decimal 18 | from pprint import pformat 19 | from random import getrandbits 20 | from typing import Any, Callable, ClassVar 21 | from typing import Dict as DictType 22 | from typing import List as ListType 23 | from typing import Optional 24 | from typing import Tuple as TupleType 25 | from typing import Type, TypeVar 26 | from uuid import UUID 27 | 28 | from atom.api import ( 29 | Atom, 30 | AtomMeta, 31 | Bool, 32 | Coerced, 33 | Dict, 34 | Float, 35 | Instance, 36 | Int, 37 | List, 38 | Member, 39 | Property, 40 | Str, 41 | Typed, 42 | Value, 43 | set_default, 44 | ) 45 | 46 | T = TypeVar("T") 47 | M = TypeVar("M", bound="Model") 48 | ScopeType = DictType[int, Any] 49 | StateType = DictType[str, Any] 50 | GetStateFn = Callable[[M, Optional[ScopeType]], StateType] 51 | RestoreStateFn = Callable[[M, StateType, Optional[ScopeType]], None] 52 | log = logging.getLogger("atomdb") 53 | 54 | 55 | def find_subclasses(cls: Type[T]) -> ListType[Type[T]]: 56 | """Finds subclasses of the given class""" 57 | classes = [] 58 | for subclass in cls.__subclasses__(): 59 | classes.append(subclass) 60 | classes.extend(find_subclasses(subclass)) 61 | return classes 62 | 63 | 64 | def is_db_field(m: Member) -> bool: 65 | """Check if the member should be saved into the database. Any member that 66 | does not start with an underscore, is not a Property, and is not tagged 67 | with `store=False` is considered to be field to save into the database. 68 | 69 | Parameters 70 | ---------- 71 | m: Member 72 | The atom member to check. 73 | 74 | Returns 75 | ------- 76 | result: bool 77 | Whether the member should be saved into the database. 78 | 79 | """ 80 | metadata = m.metadata 81 | default = not m.name.startswith("_") 82 | if metadata is not None: 83 | return metadata.get("store", default) 84 | if isinstance(m, Property): 85 | return False # Users can override this by tagging it with store=True 86 | return default 87 | 88 | 89 | def is_primitive_member(m: Member) -> Optional[bool]: 90 | """Check if the member can be serialized without calling flatten. If the 91 | member references a field that is not yet resolved it returns None 92 | indicating that it cannot determine whether it is primitive yet. 93 | 94 | Parameters 95 | ---------- 96 | m: Member 97 | The atom member to check. 98 | 99 | Returns 100 | ------- 101 | result: Optional[bool] 102 | Whether the member is a primitive type that can be intrinsicly 103 | converted. 104 | 105 | """ 106 | if isinstance(m, (Bool, Str, Int, Float)): 107 | return True 108 | if hasattr(m, "resolve"): 109 | # These cannot be resolved until their dependencies are available 110 | return None 111 | if isinstance(m, (List, Typed, Instance, Dict, Coerced)): 112 | try: 113 | types = resolve_member_types(m, resolve=False) 114 | except UnresolvableError: 115 | return None 116 | if types is None: 117 | return False # Value can be any type 118 | if types and all(t in (int, float, bool, str) for t in types): 119 | return True 120 | return False 121 | 122 | 123 | def resolve_member_types( 124 | member: Member, resolve: bool = True 125 | ) -> Optional[TupleType[type, ...]]: 126 | """Determine the validation types specified on a member. 127 | 128 | Parameters 129 | ---------- 130 | member: Member 131 | The member to retrieve the type from 132 | resolve: bool 133 | Whether to resolve "Forward" members. 134 | Returns 135 | ------- 136 | types: Optional[Tuple[Model|Member|type, ..]] 137 | The member types. If types is `None` then the member does not do any 138 | type validation. 139 | 140 | Raises 141 | ------ 142 | UnresolveableError 143 | If `resolve=False` and the member has a nested forwarded member this 144 | will raise an UnresolvableError with the unresolved member. 145 | 146 | """ 147 | # TODO: This should really use the validate mode... 148 | if hasattr(member, "resolve"): 149 | if not resolve: 150 | raise UnresolvableError(member) # Do not resolve now 151 | types = member.resolve() # type: ignore 152 | elif isinstance(member, Coerced): 153 | types = member.validate_mode[-1][0] 154 | else: 155 | types = member.validate_mode[-1] 156 | if types is None: 157 | return None 158 | if isinstance(types, tuple): 159 | # Dict may have an member in the types list, so walk the types 160 | # and resolve all of those. 161 | resolved: ListType[type] = [] 162 | for t in types: 163 | if isinstance(t, Member): 164 | r = resolve_member_types(t, resolve) 165 | if r is None: 166 | # TODO: Think about whether this is correct to bail out here 167 | return None 168 | resolved.extend(r) 169 | else: 170 | resolved.append(t) 171 | return tuple(resolved) 172 | if isinstance(types, Member): 173 | # Follow the chain. For example if the member is defined 174 | # as `List(Tuple(float)))` lookup the types of the nested Tuple(). 175 | return resolve_member_types(types, resolve) 176 | if isinstance(types, str): 177 | return None # Custom validation method 178 | return (types,) 179 | 180 | 181 | class UnresolvableError(Exception): 182 | """Error raised when a Forwarded Member cannot be resolved at the time 183 | when the resolve_member_types is called. 184 | 185 | """ 186 | 187 | def __init__(self, member): 188 | self.member = member 189 | super().__init__(f"Cannot resolve {member}") 190 | 191 | 192 | class ModelSerializer(Atom): 193 | """Handles serializing and deserializing of Model subclasses. It 194 | will automatically save and restore references where present. 195 | 196 | """ 197 | 198 | #: Hold one instance per subclass for easy reuse 199 | _instances: ClassVar[DictType[Type["ModelSerializer"], "ModelSerializer"]] = {} 200 | 201 | #: Store all registered models 202 | registry = Dict() 203 | 204 | #: Mapping of type name to coercer function 205 | coercers = Dict( 206 | default={ 207 | "datetime.date": lambda v, scope: date(**v), 208 | "datetime.datetime": lambda v, scope: datetime(**v), 209 | "datetime.time": lambda v, scope: time(**v), 210 | "bytes": lambda v, scope: b64decode(v["bytes"]), 211 | "decimal": lambda v, scope: Decimal(v["value"]), 212 | "uuid": lambda v, scope: UUID(v["id"]), 213 | } 214 | ) 215 | 216 | @classmethod 217 | def instance(cls: Type["ModelSerializer"]) -> "ModelSerializer": 218 | if cls not in ModelSerializer._instances: 219 | ModelSerializer._instances[cls] = cls() 220 | return ModelSerializer._instances[cls] 221 | 222 | def flatten(self, v: Any, scope: Optional[ScopeType] = None) -> Any: 223 | """Convert Model objects to a dict 224 | 225 | Parameters 226 | ---------- 227 | v: Object 228 | The object to flatten 229 | scope: Dict 230 | The scope of references available for circular lookups 231 | 232 | Returns 233 | ------- 234 | result: Object 235 | The flattened object 236 | 237 | """ 238 | flatten = self.flatten 239 | scope = scope or {} 240 | 241 | # Handle circular reference 242 | if isinstance(v, Model): 243 | return v.serializer.flatten_object(v, scope) 244 | elif isinstance(v, (list, tuple, set)): 245 | return [flatten(item, scope) for item in v] 246 | elif isinstance(v, (dict, MutableMapping)): 247 | return {k: flatten(item, scope) for k, item in v.items()} 248 | elif isinstance(v, enum.Enum): 249 | return v.value 250 | # TODO: Handle other object types 251 | return v 252 | 253 | def flatten_object(self, obj: "Model", scope: ScopeType) -> Any: 254 | """Serialize a model for entering into the database 255 | 256 | Parameters 257 | ---------- 258 | obj: Model 259 | The object to unflatten 260 | scope: Dict 261 | The scope of references available for circular lookups 262 | 263 | Returns 264 | ------- 265 | result: Object 266 | The flattened object 267 | 268 | """ 269 | raise NotImplementedError 270 | 271 | async def unflatten(self, v: Any, scope: Optional[ScopeType] = None) -> Any: 272 | """Convert dict or list to Models 273 | 274 | Parameters 275 | ---------- 276 | v: Dict or List 277 | The object(s) to unflatten 278 | scope: Dict 279 | The scope of references available for circular lookups 280 | 281 | Returns 282 | ------- 283 | result: Object 284 | The unflattened object 285 | 286 | """ 287 | if isinstance(v, dict): 288 | unflatten = self.unflatten 289 | # Circular reference 290 | if scope and "__ref__" in v: 291 | ref = v["__ref__"] 292 | if ref in scope: 293 | return scope[ref] 294 | 295 | # Create the object 296 | if "__model__" in v: 297 | cls = self.registry[v["__model__"]] 298 | return await cls.serializer.unflatten_object(cls, v, scope) 299 | 300 | # Convert py types 301 | if "__py__" in v: 302 | py_type = v.pop("__py__") 303 | coercer = self.coercers.get(py_type) 304 | if coercer: 305 | if asyncio.iscoroutinefunction(coercer): 306 | return await coercer(v, scope) 307 | return coercer(v, scope) 308 | elif py_type == "set" or py_type == "atomset": 309 | return {await unflatten(i) for i in v["values"]} 310 | elif py_type == "tuple": 311 | return tuple([await unflatten(i) for i in v["values"]]) 312 | return {k: await unflatten(i, scope) for k, i in v.items()} 313 | elif isinstance(v, list): 314 | unflatten = self.unflatten 315 | return [await unflatten(item, scope) for item in v] 316 | return v 317 | 318 | async def unflatten_object( 319 | self, cls: Type["Model"], state: StateType, scope: ScopeType 320 | ) -> Optional["Model"]: 321 | """Restore the object for the given class, state, and scope. 322 | If a reference is given the scope should be updated with the newly 323 | created object using the given ref. 324 | 325 | Parameters 326 | ---------- 327 | cls: Class 328 | The type of object expected 329 | state: Dict 330 | The state of the object to restore 331 | 332 | Returns 333 | ------- 334 | result: object or None 335 | A the newly created object (or an existing object if using a cache) 336 | or None if this object does not exist in the database. 337 | """ 338 | _id = state.get("_id") 339 | 340 | # Get the object for this id, retrieve from cache if needed 341 | obj, created = await self.get_or_create(cls, state, scope) 342 | 343 | # Lookup the object if needed 344 | if created and _id is not None: 345 | # If a new object was created lookup the state for that object 346 | state = await self.get_object_state(obj, state, scope) 347 | if state is None: 348 | return None 349 | 350 | # Child objects may have circular references to this object 351 | # so we must update the scope with this reference to handle this 352 | # before restoring any children 353 | if scope and "__ref__" in state: 354 | scope[state["__ref__"]] = obj 355 | 356 | # If not restoring from cache update the state 357 | if created: 358 | await obj.__restorestate__(state, scope) 359 | return obj 360 | 361 | async def get_or_create( 362 | self, cls: Type["Model"], state: Any, scope: ScopeType 363 | ) -> TupleType["Model", bool]: 364 | """Get a cached object for this _id or create a new one. Subclasses 365 | should override this as needed to provide object caching if desired. 366 | 367 | Parameters 368 | ---------- 369 | cls: Class 370 | The type of object expected 371 | state: Dict 372 | Unflattened state of object to restore 373 | scope: Dict 374 | Scope of objects available when flattened 375 | 376 | Returns 377 | ------- 378 | result: Tuple[object, bool] 379 | A tuple of the object and a flag stating if it was created or not. 380 | 381 | """ 382 | return (cls.__new__(cls), True) 383 | 384 | async def get_object_state(self, obj: "Model", state: Any, scope: ScopeType) -> Any: 385 | """Lookup the state needed to restore the given object id and class. 386 | 387 | Parameters 388 | ---------- 389 | obj: Model 390 | The object created by `get_or_create` 391 | state: Dict 392 | Unflattened state of object to restore 393 | scope: Dict 394 | Scope of objects available when flattened 395 | 396 | Returns 397 | ------- 398 | result: Any 399 | The model state needed to restore this object 400 | 401 | """ 402 | raise NotImplementedError 403 | 404 | 405 | class ModelManager(Atom): 406 | """A descriptor so you can use this somewhat like Django's models. 407 | Assuming your using motor. 408 | 409 | Examples 410 | -------- 411 | MyModel.objects.find_one({'_id':'someid}) 412 | 413 | """ 414 | 415 | #: Stores instances of each class so we can easily reuse them if desired 416 | _instances: ClassVar[DictType[Type["ModelManager"], "ModelManager"]] = {} 417 | 418 | @classmethod 419 | def instance(cls) -> "ModelManager": 420 | if cls not in ModelManager._instances: 421 | ModelManager._instances[cls] = cls() 422 | return ModelManager._instances[cls] 423 | 424 | #: Used to access the database 425 | database = Value() 426 | 427 | def _default_database(self) -> Any: 428 | raise NotImplementedError 429 | 430 | def __get__(self, obj: T, cls: Optional[Type[T]] = None): 431 | """Handle objects from the class that oType[wns the manager. Subclasses 432 | should override this as needed. 433 | 434 | """ 435 | raise NotImplementedError 436 | 437 | 438 | def generate_getstate(cls: Type["Model"]) -> GetStateFn: 439 | """Generate an optimized __getstate__ function for the given model. 440 | 441 | Parameters 442 | ---------- 443 | cls: Type[Model] 444 | The clase to generate a getstate function for. 445 | 446 | Returns 447 | ------- 448 | result: GetStateFn 449 | A function optimized to generate the state for the given model class. 450 | 451 | """ 452 | template = [ 453 | "def __getstate__(self, scope=None):", 454 | "scope = scope or {}", 455 | "scope[self.__ref__] = self", 456 | "state = {", 457 | ] 458 | default_flatten = cls.serializer.flatten 459 | members = cls.members() 460 | namespace = { 461 | "default_flatten": default_flatten, 462 | } 463 | for f in cls.__fields__: 464 | # Since f is potentially an untrusted input, make sure it is a valid 465 | # python identifier to prevent unintended code being generated. 466 | if not f.isidentifier(): 467 | raise ValueError(f"Field '{f}' cannot be used for code generation") 468 | m = members[f] 469 | meta = m.metadata or {} 470 | flatten = meta.get("flatten", default_flatten) 471 | if flatten is default_flatten: 472 | if is_primitive_member(m): 473 | expr = f"self.{f}" 474 | else: 475 | expr = f"default_flatten(self.{f}, scope)" 476 | else: 477 | namespace[f"flatten_{f}"] = flatten 478 | expr = f"flatten_{f}(self.{f}, scope)" 479 | template.append(f' "{f}": {expr},') 480 | 481 | template.append(' "__model__": self.__model__,') 482 | template.append(' "__ref__": self.__ref__,') 483 | template.append("}") 484 | if "_id" in members: 485 | template.append("if self._id:") 486 | template.append(' state["_id"] = self._id') 487 | template.append("return state") 488 | source = "\n ".join(template) 489 | return generate_function(source, namespace, "__getstate__") 490 | 491 | 492 | def generate_restorestate(cls: Type["Model"]) -> RestoreStateFn: 493 | """Generate an optimized __restorestate__ function for the given model. 494 | 495 | Parameters 496 | ---------- 497 | cls: Type[Model] 498 | The clase to generate a getstate function for. 499 | 500 | Returns 501 | ------- 502 | result: RestoreStateFn 503 | A function optimized to restore the state for the given model class. 504 | 505 | """ 506 | # Python must do some caching because using key in state and state[key] 507 | # seems to be faster than using get 508 | template = [ 509 | "async def __restorestate__(self, state, scope=None):", 510 | "if '__model__' in state and state['__model__'] != self.__model__:", 511 | " name = state['__model__']", 512 | " raise ValueError(", 513 | " f'Trying to use {name} state for {self.__model__} object'", 514 | " )", 515 | "if '__ref__' in state and state['__ref__'] is not None:", 516 | " scope = scope or {}", 517 | " scope[state['__ref__']] = self", 518 | ] 519 | 520 | default_unflatten = cls.serializer.unflatten 521 | members = cls.members() 522 | excluded = ( 523 | "__ref__", 524 | "__restored__", 525 | ) 526 | setters = [] 527 | for f, m in members.items(): 528 | if f in excluded: 529 | continue 530 | meta = m.metadata or {} 531 | order = meta.get("setstate_order", 1000) 532 | 533 | # Allow tagging a custom unflatten fn 534 | unflatten = meta.get("unflatten", default_unflatten) 535 | 536 | setters.append((order, f, m, unflatten)) 537 | setters.sort(key=lambda it: it[0]) 538 | 539 | on_error = cls.__on_error__ 540 | 541 | namespace: DictType[str, Any] = { 542 | "default_unflatten": default_unflatten, 543 | } 544 | for order, f, m, unflatten in setters: 545 | # Since f is potentially an untrusted input, make sure it is a valid 546 | # python identifier to prevent unintended code being generated. 547 | if not f.isidentifier(): 548 | raise ValueError(f"Field '{f}' cannot be used for code generation") 549 | 550 | template.append(f"if '{f}' in state:") 551 | 552 | # Determine the expresion to unflatten the value 553 | if unflatten is default_unflatten: 554 | RelModel = None 555 | # If the member is typed we can shortcut looking up the __model__ 556 | # type from the state and restore it directly. 557 | # Note that this does not work for instances. 558 | if isinstance(m, Typed): 559 | types = resolve_member_types(m, resolve=False) 560 | if types and len(types) == 1 and issubclass(types[0], Model): 561 | RelModel = types[0] 562 | if RelModel is not None: 563 | namespace[f"rel_model_{f}"] = RelModel 564 | expr = f"await rel_model_{f}.restore(state['{f}'])" 565 | elif is_primitive_member(m): 566 | # Direct assignment 567 | expr = f"state['{f}']" 568 | else: 569 | # Default flatten 570 | expr = f"await default_unflatten(state['{f}'], scope)" 571 | else: 572 | namespace[f"unflatten_{f}"] = unflatten 573 | if asyncio.iscoroutinefunction(unflatten): 574 | expr = f"await unflatten_{f}(state['{f}'], scope)" 575 | else: 576 | expr = f"unflatten_{f}(state['{f}'], scope)" 577 | 578 | # Do the assignment 579 | if on_error == "raise": 580 | template.append(f" self.{f} = {expr}") 581 | else: 582 | if on_error == "log": 583 | handler = f"self.__log_restore_error__(e, '{f}', state, scope)" 584 | else: 585 | handler = "pass" 586 | template.extend( 587 | [ 588 | " try:", 589 | f" self.{f} = {expr}", 590 | " except Exception as e:", 591 | f" {handler}", 592 | ] 593 | ) 594 | 595 | # Update restored state 596 | template.append("self.__restored__ = True") 597 | source = "\n ".join(template) 598 | return generate_function(source, namespace, "__restorestate__") 599 | 600 | 601 | def generate_function( 602 | source: str, 603 | namespace: DictType[str, Any], 604 | fn_name: str, 605 | ) -> Callable[..., Any]: 606 | """Generate an optimized function 607 | 608 | Parameters 609 | ---------- 610 | source: str 611 | The function source code 612 | namespaced: dict 613 | Namespace available to the function 614 | fn_name: str 615 | The name of the generated function. 616 | 617 | Returns 618 | ------- 619 | fn: function 620 | The function generated. 621 | 622 | """ 623 | # print(source) 624 | try: 625 | assert source.startswith(f"def {fn_name}") or source.startswith( 626 | f"async def {fn_name}" 627 | ) 628 | code = compile(source, __name__, "exec", optimize=1) 629 | except Exception as e: 630 | raise RuntimeError(f"Could not generate code: {e}:\n{source}") 631 | 632 | result: DictType[str, Any] = {} 633 | exec(code, namespace, result) 634 | 635 | # Optimize global access 636 | fn = result[fn_name] 637 | return fn 638 | 639 | 640 | class ModelMeta(AtomMeta): 641 | def __new__(meta, name, bases, dct): 642 | cls = AtomMeta.__new__(meta, name, bases, dct) 643 | 644 | # Fields that are saved in the db. By default it uses all atom members 645 | # that don't start with an underscore and are not taged with store. 646 | if "__fields__" not in dct: 647 | cls.__fields__ = [ 648 | name for name, m in cls.members().items() if is_db_field(m) 649 | ] 650 | 651 | # Model name used so the serializer knows what class to recreate 652 | # when restoring 653 | if "__model__" not in dct: 654 | cls.__model__ = f"{cls.__module__}.{cls.__name__}" 655 | 656 | # Generate optimized get and restore functions 657 | # Some general testing indicates this improves getstate by about 2x 658 | # and restorestate by about 20% but it depends on the model. 659 | if "__generated_getstate__" not in dct: 660 | cls.__generated_getstate__ = generate_getstate(cls) 661 | 662 | if "__generated_restorestate__" not in dct: 663 | cls.__generated_restorestate__ = generate_restorestate(cls) 664 | 665 | return cls 666 | 667 | 668 | class Model(Atom, metaclass=ModelMeta): 669 | """An atom model that can be serialized and deserialized to and from 670 | a database. 671 | 672 | """ 673 | 674 | # -------------------------------------------------------------------------- 675 | # Class attributes 676 | # -------------------------------------------------------------------------- 677 | __slots__ = "__weakref__" 678 | 679 | #: List of database field member names 680 | __fields__: ClassVar[ListType[str]] 681 | 682 | #: Table name used when saving into the database 683 | __model__: ClassVar[str] 684 | 685 | #: Error handling 686 | __on_error__: ClassVar[str] = "log" # "ignore" or "raise" 687 | 688 | # -------------------------------------------------------------------------- 689 | # Internal model members 690 | # -------------------------------------------------------------------------- 691 | 692 | #: A unique ID used to handle cyclical serialization and deserialization 693 | __ref__ = Int(factory=lambda: getrandbits(32)) 694 | 695 | #: Flag to indicate if this model has been restored or saved 696 | __restored__ = Bool().tag(store=False) 697 | 698 | # -------------------------------------------------------------------------- 699 | # Serialization API 700 | # -------------------------------------------------------------------------- 701 | 702 | #: Handles encoding and decoding. Subclasses should redefine this to a 703 | #: subclass of ModelSerializer 704 | serializer: ClassVar[ModelSerializer] = ModelSerializer.instance() 705 | 706 | #: Optimized serialize functions. These are generated by the metaclass. 707 | __generated_getstate__: ClassVar[GetStateFn] 708 | __generated_restorestate__: ClassVar[RestoreStateFn] 709 | 710 | def __getstate__(self, scope: Optional[ScopeType] = None) -> StateType: 711 | """Get the serialized model state. By default this delegates to an 712 | optimized function generated by the ModelMeta class. 713 | 714 | Parameters 715 | ---------- 716 | scope: Optionl[ScopeType 717 | The scope to lookup circular references. 718 | 719 | Returns 720 | ------- 721 | state: StateType 722 | The state of the object. 723 | 724 | """ 725 | return self.__generated_getstate__(scope) 726 | 727 | async def __restorestate__( 728 | self, state: StateType, scope: Optional[ScopeType] = None 729 | ): 730 | """Restore an object from the a state from the database. This is 731 | async as it will lookup any referenced objects from the DB. 732 | 733 | State is restored by calling setattr(k, v) for every item in the state 734 | that has an associated atom member. Members can be tagged with a 735 | `setstate_order=` to define the order of setattr calls. Errors 736 | from setattr are caught and logged instead of raised. 737 | 738 | Parameters 739 | ---------- 740 | state: Dict 741 | A dictionary of state keys and values 742 | scope: Dict or None 743 | A namespace to use to resolve any possible circular references. 744 | The __ref__ value is used as the keys. 745 | 746 | """ 747 | await self.__generated_restorestate__(state, scope) # type: ignore 748 | 749 | def __log_restore_error__( 750 | self, e: Exception, k: str, state: StateType, scope: Optional[ScopeType] 751 | ): 752 | """Log details when restoring a member fails. This typically only will 753 | occur if the state has data from an old model after a schema change. 754 | 755 | """ 756 | obj = state.get(k) 757 | log.warning( 758 | f"Error loading state:" 759 | f"{self.__model__}.{k} = {pformat(obj)}:" 760 | f"\nRef: {self.__ref__}" 761 | f"\nScope: {pformat(scope)}" 762 | f"\nState: {pformat(state)}" 763 | f"\n{e}" 764 | ) 765 | 766 | # -------------------------------------------------------------------------- 767 | # Database API 768 | # -------------------------------------------------------------------------- 769 | 770 | #: Handles database access. Subclasses should redefine this. 771 | objects: ClassVar[ModelManager] = ModelManager() 772 | 773 | @classmethod 774 | async def restore(cls: Type[M], state: StateType, **kwargs: Any) -> M: 775 | """Restore an object from the database state""" 776 | obj = cls.__new__(cls) 777 | await obj.__restorestate__(state) 778 | return obj 779 | 780 | async def load(self): 781 | """Alias to load this object from the database""" 782 | raise NotImplementedError 783 | 784 | async def save(self): 785 | """Alias to delete this object to the database""" 786 | raise NotImplementedError 787 | 788 | async def delete(self): 789 | """Alias to delete this object in the database""" 790 | raise NotImplementedError 791 | 792 | 793 | class JSONSerializer(ModelSerializer): 794 | def flatten(self, v: Any, scope: Optional[ScopeType] = None): 795 | """Flatten date, datetime, time, decimal, and bytes as a dict with 796 | a __py__ field and arguments to reconstruct it. Also see the coercers 797 | 798 | """ 799 | if isinstance(v, (date, datetime, time)): 800 | # This is inefficient space wise but still allows queries 801 | s: DictType[str, Any] = { 802 | "__py__": f"{v.__class__.__module__}.{v.__class__.__name__}" 803 | } 804 | if isinstance(v, (date, datetime)): 805 | s.update({"year": v.year, "month": v.month, "day": v.day}) 806 | if isinstance(v, (time, datetime)): 807 | s.update( 808 | { 809 | "hour": v.hour, 810 | "minute": v.minute, 811 | "second": v.second, 812 | "microsecond": v.microsecond, 813 | # TODO: Timezones 814 | } 815 | ) 816 | return s 817 | if isinstance(v, bytes): 818 | return {"__py__": "bytes", "bytes": b64encode(v).decode()} 819 | if isinstance(v, Decimal): 820 | return {"__py__": "decimal", "value": str(v)} 821 | if isinstance(v, UUID): 822 | return {"__py__": "uuid", "id": str(v)} 823 | if isinstance(v, (tuple, set)): 824 | flatten = self.flatten 825 | type_name = v.__class__.__name__ 826 | return {"__py__": type_name, "values": [flatten(it) for it in v]} 827 | return super().flatten(v, scope) 828 | 829 | def flatten_object(self, obj: Model, scope: ScopeType) -> DictType[str, Any]: 830 | """Flatten to just json but add in keys to know how to restore it.""" 831 | ref = obj.__ref__ 832 | if ref in scope: 833 | return {"__ref__": ref, "__model__": obj.__model__} 834 | else: 835 | scope[ref] = obj 836 | return obj.__getstate__(scope) 837 | 838 | async def get_object_state(self, obj: Any, state: StateType, scope: ScopeType): 839 | """State should be contained in the dict""" 840 | return state 841 | 842 | def _default_registry(self) -> DictType[str, Type[Model]]: 843 | return {m.__model__: m for m in find_subclasses(JSONModel)} 844 | 845 | 846 | class JSONModel(Model): 847 | """A simple model that can be serialized to json. Useful for embedding 848 | within other models. 849 | 850 | """ 851 | 852 | serializer = JSONSerializer.instance() 853 | __restored__ = set_default(True) # type: ignore 854 | -------------------------------------------------------------------------------- /atomdb/nosql.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018-2022, Jairus Martin. 3 | 4 | Distributed under the terms of the MIT License. 5 | 6 | The full license is in the file LICENSE.text, distributed with this software. 7 | 8 | Created on Jun 12, 2018 9 | """ 10 | 11 | import weakref 12 | 13 | import bson 14 | from atom.api import Atom, Dict, Instance, Typed, Value 15 | 16 | from .base import JSONSerializer, Model, ModelManager, ModelSerializer, find_subclasses 17 | 18 | 19 | class NoSQLModelSerializer(ModelSerializer): 20 | """Handles serializing and deserializing of Model subclasses. It 21 | will automatically save and restore references where present. 22 | 23 | """ 24 | 25 | async def get_or_create(self, cls, state, scope): 26 | """Restore an object from the database. If the object is cached, 27 | use that instead. 28 | """ 29 | # Check if this is in the cache 30 | pk = state.get("_id") 31 | cache = cls.objects.cache 32 | if pk is not None: 33 | obj = cache.get(pk) 34 | else: 35 | obj = None 36 | if obj is None: 37 | # Create and cache it 38 | obj = cls.__new__(cls) 39 | if pk is not None: 40 | cache[pk] = obj 41 | 42 | # This ideally should only be done if created 43 | return (obj, True) 44 | return (obj, False) 45 | 46 | async def get_object_state(self, obj, state, scope): 47 | ModelType = obj.__class__ 48 | return await ModelType.objects.find_one({"_id": state["_id"]}) 49 | 50 | def flatten_object(self, obj, scope): 51 | ref = obj.__ref__ 52 | if ref in scope: 53 | return {"__ref__": ref, "__model__": obj.__model__} 54 | scope[ref] = obj 55 | state = obj.__getstate__(scope) 56 | _id = state.get("_id") 57 | if _id is None: 58 | return state 59 | return {"_id": _id, "__ref__": ref, "__model__": obj.__model__} 60 | 61 | def _default_registry(self): 62 | """Add all nosql and json models to the registry""" 63 | registry = JSONSerializer.instance().registry.copy() 64 | registry.update({m.__model__: m for m in find_subclasses(NoSQLModel)}) 65 | return registry 66 | 67 | 68 | class NoSQLDatabaseProxy(Atom): 69 | """A proxy to the collection which holds a cache of model objects.""" 70 | 71 | #: Object cache 72 | cache = Typed(weakref.WeakValueDictionary, ()) 73 | 74 | #: Database handle 75 | table = Value() 76 | 77 | def __getattr__(self, name): 78 | return getattr(self.table, name) 79 | 80 | 81 | class NoSQLModelManager(ModelManager): 82 | """A descriptor so you can use this somewhat like Django's models. 83 | Assuming your using motor or txmongo. 84 | 85 | Examples 86 | -------- 87 | MyModel.objects.find_one({'_id':'someid}) 88 | 89 | """ 90 | 91 | #: Table proxy cache 92 | proxies = Dict() 93 | 94 | def __get__(self, obj, cls=None): 95 | """Handle objects from the class that owns the manager""" 96 | cls = cls or obj.__class__ 97 | if not issubclass(cls, Model): 98 | return self # Only return the collection when used from a Model 99 | proxy = self.proxies.get(cls) 100 | if proxy is None: 101 | proxy = self.proxies[cls] = NoSQLDatabaseProxy( 102 | table=self.database[cls.__model__] 103 | ) 104 | return proxy 105 | 106 | def _default_database(self): 107 | raise EnvironmentError( 108 | "No database has been set. Use " 109 | "NoSQLModelManager.instance().database = " 110 | ) 111 | 112 | 113 | class NoSQLModel(Model): 114 | """An atom model that can be serialized and deserialized to and from 115 | MongoDB. 116 | 117 | """ 118 | 119 | #: ID of this object in the database 120 | _id = Instance(bson.ObjectId) # type: ignore 121 | 122 | #: Handles encoding and decoding 123 | serializer = NoSQLModelSerializer.instance() 124 | 125 | #: Handles database access 126 | objects = NoSQLModelManager.instance() 127 | 128 | @classmethod 129 | async def restore(cls, state, force=False): 130 | """Restore an object from the database. If the object is cached, 131 | use that instead. 132 | """ 133 | pk = state["_id"] 134 | if pk: 135 | # Check if this is in the cache 136 | cache = cls.objects.cache 137 | obj = cache.get(pk) 138 | else: 139 | obj = None 140 | 141 | # Restore 142 | if obj is None: 143 | # Create and cache it 144 | obj = cls.__new__(cls) 145 | if pk: 146 | cache[pk] = obj 147 | restore = True 148 | else: 149 | restore = force 150 | 151 | if restore: 152 | await obj.__restorestate__(state) 153 | 154 | return obj 155 | 156 | async def load(self): 157 | """Alias to load this object from the database""" 158 | pk = self._id 159 | if self.__restored__ or pk is None: 160 | return # Already loaded or nothing to load 161 | state = await self.objects.find_one({"_id": pk}) 162 | if state is not None: 163 | await self.__restorestate__(state) 164 | 165 | async def save(self): 166 | """Alias to delete this object to the database""" 167 | db = self.objects 168 | state = self.__getstate__() 169 | if self._id is None: 170 | r = await db.insert_one(state) 171 | self._id = r.inserted_id 172 | db.cache[self._id] = self 173 | else: 174 | r = await db.replace_one({"_id": self._id}, state, upsert=True) 175 | self.__restored__ = True 176 | return r 177 | 178 | async def delete(self): 179 | """Alias to delete this object in the database""" 180 | db = self.objects 181 | pk = self._id 182 | if pk: 183 | r = await db.delete_one({"_id": pk}) 184 | del db.cache[pk] 185 | del self._id 186 | return r 187 | -------------------------------------------------------------------------------- /atomdb/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codelv/atom-db/2928b0d90a991d5ca5995b72868cf2c2cd2a4364/atomdb/py.typed -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | docs: 2 | cd docs 3 | make html 4 | isort: 5 | isort atomdb 6 | isort tests 7 | typecheck: 8 | mypy atomdb --ignore-missing-imports 9 | lintcheck: 10 | flake8 --ignore=E501,W503 atomdb tests 11 | reformat: 12 | black atomdb tests 13 | test: 14 | pytest -v tests --cov atomdb --cov-report xml --asyncio-mode auto 15 | 16 | precommit: isort reformat typecheck lintcheck 17 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # pytest.ini 2 | [pytest] 3 | asyncio_mode = auto 4 | 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-2021, Jairus Martin. 3 | 4 | Distributed under the terms of the MIT License. 5 | 6 | The full license is in the file LICENSE.text, distributed with this software. 7 | 8 | Created on Feb 21, 2019 9 | 10 | @author: jrm 11 | """ 12 | from setuptools import setup, find_packages 13 | 14 | setup( 15 | name="atom-db", 16 | version="0.8.1", 17 | author="CodeLV", 18 | author_email="frmdstryr@gmail.com", 19 | url="https://github.com/codelv/atom-db", 20 | description="Database abstraction layer for atom objects", 21 | license="MIT", 22 | long_description=open("README.md").read(), 23 | long_description_content_type="text/markdown", 24 | requires=["atom"], 25 | python_requires=">=3.7", 26 | install_requires=["atom>=0.7.0"], 27 | optional_requires=[ 28 | "sqlalchemy<2", 29 | "aiomysql", 30 | "aiopg", 31 | "aiosqlite", # sql database support 32 | "motor", # nosql database support 33 | ], 34 | packages=find_packages(), 35 | package_data={'atomdb': ["py.typed"]} 36 | ) 37 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | from atom.api import ( 5 | Bool, 6 | Coerced, 7 | Dict, 8 | ForwardInstance, 9 | ForwardTyped, 10 | Instance, 11 | Int, 12 | List, 13 | Property, 14 | Set, 15 | Tuple, 16 | Typed, 17 | ) 18 | 19 | from atomdb.base import ( 20 | Model, 21 | ModelManager, 22 | ModelSerializer, 23 | generate_function, 24 | is_db_field, 25 | is_primitive_member, 26 | ) 27 | 28 | 29 | class AbstractModel(Model): 30 | objects = ModelManager.instance() 31 | serializer = ModelSerializer.instance() 32 | 33 | rating = Int() 34 | 35 | 36 | class Dummy(Model): 37 | _private = Int() 38 | computed = Int().tag(store=False) 39 | id = Int() 40 | enabled = Bool() 41 | string = Bool() 42 | list_of_int = List(int) 43 | list_of_str = List(str) 44 | list_of_any = List() 45 | list_of_tuple = List(Tuple()) 46 | list_of_tuple_of_float = List(Tuple(float)) 47 | tuple_of_any = Tuple() 48 | tuple_of_number = Tuple((float, int)) 49 | tuple_of_int_or_model = Tuple((int, Model)) 50 | tuple_of_forwarded = Tuple(ForwardTyped(lambda: NotYetDefined)) 51 | set_of_any = Tuple() 52 | set_of_number = Set(float) 53 | set_of_model = Set(AbstractModel) 54 | dict_of_any = Dict() 55 | dict_of_str_any = Dict(str) 56 | dict_of_str_int = Dict(str, int) 57 | typed_int = Typed(int) 58 | typed_dict = Typed(dict) 59 | instance_of_model = Instance(AbstractModel) 60 | forwarded_instance = ForwardInstance(lambda: NotYetDefined) 61 | coerced_int = Coerced(int) 62 | prop = Property(lambda self: True) 63 | tagged_prop = Property(lambda self: 0).tag(store=True) 64 | 65 | 66 | class NotYetDefined: 67 | pass 68 | 69 | 70 | async def test_manager(): 71 | mgr = ModelManager.instance() 72 | 73 | # Not implemented for abstract manager 74 | with pytest.raises(NotImplementedError): 75 | mgr.database 76 | 77 | # Not implemented for abstract manager 78 | with pytest.raises(NotImplementedError): 79 | AbstractModel.objects 80 | 81 | 82 | async def test_serializer(): 83 | m = AbstractModel() 84 | ser = ModelSerializer.instance() 85 | with pytest.raises(NotImplementedError): 86 | await ser.get_object_state(m, {}, {}) 87 | 88 | with pytest.raises(NotImplementedError): 89 | await ser.flatten_object(m, {}) 90 | 91 | 92 | async def test_model(): 93 | m = AbstractModel() 94 | 95 | # Not implemented for abstract models 96 | with pytest.raises(NotImplementedError): 97 | await m.load() 98 | 99 | with pytest.raises(NotImplementedError): 100 | await m.save() 101 | 102 | with pytest.raises(NotImplementedError): 103 | await m.delete() 104 | 105 | with pytest.raises(ValueError): 106 | state = {"__model__": "not.this.Model"} 107 | await AbstractModel.restore(state) 108 | 109 | # Old state fields do not blow up 110 | state = m.__getstate__() 111 | state["removed_field"] = "no-longer-exists" 112 | state["rating"] = 3.5 # Type changed 113 | obj = await AbstractModel.restore(state) 114 | assert obj.rating == 0 115 | 116 | 117 | @pytest.mark.parametrize( 118 | "attr, expected", 119 | ( 120 | ("id", True), 121 | ("_private", False), 122 | ("computed", False), 123 | ("prop", False), 124 | ("tagged_prop", True), 125 | ), 126 | ) 127 | def test_is_db_field(attr, expected): 128 | member = Dummy.members()[attr] 129 | assert is_db_field(member) == expected 130 | 131 | 132 | @pytest.mark.parametrize( 133 | "attr, expected", 134 | ( 135 | ("id", True), 136 | ("enabled", True), 137 | ("string", True), 138 | ("list_of_int", True), 139 | ("list_of_any", False), 140 | ("list_of_str", True), 141 | ("list_of_tuple", False), 142 | ("list_of_tuple_of_float", True), 143 | ("tuple_of_any", False), 144 | ("tuple_of_number", False), 145 | ("tuple_of_int_or_model", False), 146 | ("tuple_of_forwarded", False), 147 | ("set_of_any", False), 148 | ("set_of_number", False), 149 | ("set_of_model", False), 150 | ("typed_int", True), 151 | ("typed_dict", False), 152 | ("instance_of_model", False), 153 | ("forwarded_instance", None), 154 | ("dict_of_any", False), 155 | ("dict_of_str_any", False), 156 | ("dict_of_str_int", True), 157 | ("coerced_int", True), 158 | ("prop", False), 159 | ("tagged_prop", False), 160 | ), 161 | ) 162 | def test_is_primitive_member(attr, expected): 163 | member = Dummy.members()[attr] 164 | assert is_primitive_member(member) == expected 165 | 166 | 167 | def test_gen_fn(): 168 | fn = generate_function( 169 | "\n".join(("def foo(v):", " return str(v)")), 170 | {"str": str}, 171 | "foo", 172 | ) 173 | assert callable(fn) 174 | assert fn(1) == "1" 175 | 176 | # Not a fn 177 | with pytest.raises(RuntimeError): 178 | generate_function('__import__("os").path.exists()', {}, "__import__") 179 | 180 | 181 | async def test_on_error_raise(): 182 | """When __on_error__ is raise any old data in the state will make the 183 | restore fail. 184 | """ 185 | 186 | class A(Model): 187 | __on_error__ = "raise" 188 | value = Int() 189 | 190 | with pytest.raises(TypeError): 191 | await A.restore({"value": "str"}) 192 | 193 | 194 | async def test_on_error_ignore(): 195 | """When __on_error__ is "ignore" and setattr fails the error is discarded""" 196 | 197 | class B(Model): 198 | __on_error__ = "ignore" 199 | old_field = Int() 200 | new_field = Int() 201 | 202 | b = await B.restore({"old_field": "str", "new_field": 1}) 203 | assert b.old_field == 0 204 | assert b.new_field == 1 205 | 206 | 207 | async def test_on_error_log(caplog): 208 | """When __on_error__ is "log" (the default) and setattr fails the error 209 | is logged. 210 | """ 211 | 212 | class C(Model): 213 | old_field = Int() 214 | new_field = Int() 215 | 216 | with caplog.at_level(logging.DEBUG): 217 | c = await C.restore({"old_field": "str", "new_field": 1}) 218 | assert c.old_field == 0 219 | assert c.new_field == 1 220 | assert "Error loading state:" in caplog.text 221 | assert f"{C.__model__}.old_field" in caplog.text 222 | assert "object must be of type 'int'" in caplog.text 223 | -------------------------------------------------------------------------------- /tests/test_benchmark.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | from atom.api import Bool, Dict, Float, Int, List, Str, Typed 5 | 6 | from atomdb.base import JSONModel, Model 7 | 8 | NOW = datetime.now() 9 | 10 | flat_state = dict( 11 | title="This is a test", 12 | desc="""Lorem ipsum dolor sit amet, consectetur adipiscing elit, 13 | sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. 14 | Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris 15 | nisi ut aliquip ex ea commodo consequat. 16 | Duis aute irure dolor in reprehenderit in voluptate velit esse cillum 17 | """, 18 | enabled=True, 19 | rating=8.3, 20 | sku=4567899, 21 | datetime=NOW.timestamp(), 22 | tags=["electronics", "laptop"], 23 | meta={"views": 0}, 24 | ) 25 | 26 | nested_state = { 27 | "__model__": "test_benchmark.Page", 28 | "blocks": [ 29 | {"__model__": "test_benchmark.HeadingBlock", "text": "Main", "width": 0}, 30 | { 31 | "__model__": "test_benchmark.MarkdownBlock", 32 | "content": "![Home](/)", 33 | "width": 0, 34 | }, 35 | ], 36 | "enabled": True, 37 | "settings": { 38 | "__model__": "test_benchmark.PageSettings", 39 | "count": 100, 40 | "meta": "Lorem ipsum dolor", 41 | }, 42 | "title": "Hello world", 43 | } 44 | 45 | 46 | class Product(Model): 47 | title = Str() 48 | desc = Str() 49 | enabled = Bool() 50 | rating = Float() 51 | sku = Int() 52 | tags = List(str) 53 | meta = Dict() 54 | created = Typed(datetime, factory=datetime.now).tag( 55 | flatten=lambda v, scope: v.timestamp() if v else None, 56 | unflatten=lambda v, scope: datetime.fromtimestamp(v) if v else None, 57 | ) 58 | 59 | 60 | class Block(JSONModel): 61 | width = Int() 62 | 63 | 64 | class HeadingBlock(Block): 65 | text = Str() 66 | 67 | 68 | class MarkdownBlock(Block): 69 | content = Str() 70 | 71 | 72 | class PageSettings(JSONModel): 73 | meta = Str() 74 | count = Int() 75 | 76 | 77 | class Page(JSONModel): 78 | title = Str() 79 | enabled = Bool() 80 | blocks = List(Block) 81 | settings = Typed(PageSettings, ()) 82 | 83 | 84 | @pytest.mark.benchmark(group="base") 85 | def test_serialize_flat(benchmark): 86 | product = Product( 87 | title="This is a test", 88 | desc="""Lorem ipsum dolor sit amet, consectetur adipiscing elit, 89 | sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. 90 | Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris 91 | nisi ut aliquip ex ea commodo consequat. 92 | Duis aute irure dolor in reprehenderit in voluptate velit esse cillum 93 | """, 94 | enabled=True, 95 | rating=8.3, 96 | created=NOW, 97 | sku=4567899, 98 | tags=["electronics", "laptop"], 99 | meta={"views": 0}, 100 | ) 101 | benchmark(product.__getstate__) 102 | 103 | 104 | @pytest.mark.benchmark(group="base") 105 | def test_restore_flat(benchmark, event_loop): 106 | def run(): 107 | event_loop.run_until_complete(Product.restore(flat_state)) 108 | 109 | benchmark(run) 110 | 111 | 112 | @pytest.mark.benchmark(group="base") 113 | def test_serialize_nested(benchmark): 114 | page = Page( 115 | title="Hello world", 116 | enabled=True, 117 | blocks=[HeadingBlock(text="Main"), MarkdownBlock(content="![Home](/)")], 118 | settings=PageSettings(meta="Lorem ipsum dolor", count=100), 119 | ) 120 | benchmark(page.__getstate__) 121 | 122 | 123 | @pytest.mark.benchmark(group="base") 124 | def test_restore_nested(benchmark, event_loop): 125 | def run(): 126 | event_loop.run_until_complete(Page.restore(nested_state)) 127 | 128 | benchmark(run) 129 | -------------------------------------------------------------------------------- /tests/test_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from datetime import date, datetime, time 4 | from decimal import Decimal 5 | 6 | from atom.api import Bool, Bytes, ForwardInstance, Instance, List, Set, Str, Tuple 7 | 8 | from atomdb.base import JSONModel 9 | 10 | 11 | class Dates(JSONModel): 12 | d = Instance(date) 13 | t = Instance(time) 14 | dt = Instance(datetime) 15 | 16 | 17 | class Options(JSONModel): 18 | a = Bool() 19 | b = Str() 20 | 21 | 22 | class User(JSONModel): 23 | options = Instance(Options) 24 | 25 | 26 | class File(JSONModel): 27 | id = Instance(uuid.UUID, factory=uuid.uuid4) 28 | name = Str() 29 | data = Bytes() 30 | 31 | 32 | class Page(JSONModel): 33 | files = List(File) 34 | created = Instance(datetime).tag( 35 | flatten=lambda d, scope: d.timestamp(), 36 | unflatten=lambda v, scope: datetime.fromtimestamp(v) if v else None, 37 | ) 38 | 39 | 40 | class Tree(JSONModel): 41 | name = Str() 42 | related = ForwardInstance(lambda: Tree) 43 | 44 | 45 | class Amount(JSONModel): 46 | total = Instance(Decimal) 47 | 48 | 49 | class ImageExtra(JSONModel): 50 | name = Str() 51 | enabled = Bool() 52 | 53 | 54 | class Image(JSONModel): 55 | tags = Set(str) 56 | extras = Set(ImageExtra) 57 | 58 | 59 | class Point(JSONModel): 60 | position = Tuple(float) 61 | 62 | 63 | async def test_json_dates(): 64 | now = datetime.now() 65 | obj = Dates(d=now.date(), t=now.time(), dt=now) 66 | 67 | state = obj.__getstate__() 68 | data = json.dumps(state) 69 | r = await Dates.restore(json.loads(data)) 70 | assert r.d == obj.d and r.t == obj.t and r.dt == r.dt 71 | 72 | 73 | async def test_json_decimal(): 74 | d = Decimal("3.9") 75 | obj = Amount(total=d) 76 | state = obj.__getstate__() 77 | data = json.dumps(state) 78 | r = await Amount.restore(json.loads(data)) 79 | assert r.total == d 80 | 81 | 82 | async def test_json_nested(): 83 | obj = User(options=Options(a=True, b="Yes")) 84 | state = obj.__getstate__() 85 | data = json.dumps(state) 86 | r = await User.restore(json.loads(data)) 87 | assert r.options.a == obj.options.a and r.options.b == obj.options.b 88 | 89 | 90 | async def test_json_bytes(): 91 | obj = File(name="test.png", data=b"abc") 92 | state = obj.__getstate__() 93 | data = json.dumps(state) 94 | r = await File.restore(json.loads(data)) 95 | assert r.name == obj.name and r.data == obj.data 96 | 97 | 98 | async def test_json_list(): 99 | f1 = File(name="test.png", data=b"abc") 100 | f2 = File(name="blueberry.jpg", data=b"123") 101 | now = datetime.now() 102 | obj = Page(files=[f1, f2], created=now) 103 | state = obj.__getstate__() 104 | assert isinstance(state["created"], float) # Make sure conversion occurred 105 | data = json.dumps(state) 106 | r = await Page.restore(json.loads(data)) 107 | assert r.created == now 108 | assert len(r.files) == 2 109 | assert r.files[0].name == f1.name and r.files[0].data == f1.data 110 | assert r.files[1].name == f2.name and r.files[1].data == f2.data 111 | assert r.files[1].id == f2.id 112 | 113 | 114 | async def test_json_set(): 115 | obj = Image(tags={"cat", "dog"}) 116 | state = obj.__getstate__() 117 | data = json.dumps(state) 118 | r = await Image.restore(json.loads(data)) 119 | assert r.tags == {"cat", "dog"} 120 | 121 | 122 | async def test_json_set_nested(): 123 | crop = ImageExtra(name="crop", enabled=True) 124 | obj = Image(extras={crop}) 125 | state = obj.__getstate__() 126 | data = json.dumps(state) 127 | r = await Image.restore(json.loads(data)) 128 | assert len(r.extras) == 1 129 | for it in r.extras: 130 | assert it.name == crop.name 131 | assert it.enabled == crop.enabled 132 | 133 | 134 | async def test_json_tuple(): 135 | obj = Point(position=(2.1, 3.3)) 136 | state = obj.__getstate__() 137 | data = json.dumps(state) 138 | r = await Point.restore(json.loads(data)) 139 | assert r.position == (2.1, 3.3) 140 | 141 | 142 | async def test_json_cyclical(): 143 | b = Tree(name="b") 144 | a = Tree(name="a", related=b) 145 | 146 | # Create a cyclical ref 147 | b.related = a 148 | 149 | obj = a 150 | state = obj.__getstate__() 151 | data = json.dumps(state) 152 | print(data) 153 | r = await Tree.restore(json.loads(data)) 154 | assert r.name == "a" 155 | assert r.related.name == b.name 156 | assert r.related.related == r 157 | -------------------------------------------------------------------------------- /tests/test_nosql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from pprint import pprint 4 | 5 | import pytest 6 | from atom.api import Atom, Bool, Dict, Enum, ForwardInstance, Instance, List, Str 7 | 8 | try: 9 | from motor.motor_asyncio import AsyncIOMotorClient 10 | 11 | from atomdb.nosql import NoSQLModel, NoSQLModelManager 12 | except ImportError: 13 | pytest.skip("mongo/motor is not available", allow_module_level=True) 14 | 15 | 16 | class User(NoSQLModel): 17 | name = Str() 18 | email = Str() 19 | active = Bool() 20 | settings = Dict() 21 | 22 | def _default_settings(self): 23 | return {"font-size": 16, "theme": "maroon"} 24 | 25 | 26 | class Image(NoSQLModel): 27 | name = Str() 28 | path = Str() 29 | 30 | 31 | class Page(NoSQLModel): 32 | title = Str() 33 | status = Enum("preview", "live") 34 | body = Str() 35 | author = Instance(User) 36 | images = List(Image) 37 | related = List(ForwardInstance(lambda: Page)) 38 | 39 | 40 | class Comment(NoSQLModel): 41 | page = Instance(Page) 42 | author = Instance(User) 43 | status = Enum("pending", "approved") 44 | body = Str() 45 | reply_to = ForwardInstance(lambda: Comment) 46 | 47 | 48 | @pytest.fixture 49 | def db(event_loop): 50 | MONGO_URL = os.environ.get("MONGO_URL", None) 51 | if MONGO_URL: 52 | client = AsyncIOMotorClient(MONGO_URL, io_loop=event_loop) 53 | else: 54 | client = AsyncIOMotorClient(io_loop=event_loop) 55 | db = client.enaml_web_test_db 56 | mgr = NoSQLModelManager.instance() 57 | mgr.database = db 58 | mgr.proxies = {} # Flush the db cache 59 | yield db 60 | 61 | 62 | async def test_db_manager(db): 63 | mgr = NoSQLModelManager.instance() 64 | 65 | # Check non-model access, it should not return the collection 66 | class NotModel(Atom): 67 | objects = mgr 68 | 69 | assert NotModel.objects == mgr 70 | 71 | # Now change it 72 | del mgr.database 73 | with pytest.raises(EnvironmentError): 74 | await User.objects.find().to_list(length=10) 75 | 76 | # And restore 77 | mgr.database = db 78 | await User.objects.find().to_list(length=10) 79 | 80 | 81 | async def test_simple_save_restore_delete(db): 82 | await User.objects.drop() 83 | 84 | # Save 85 | user = User(name="name", email="name@ex.com", active=True) 86 | await user.save() 87 | assert user._id is not None 88 | 89 | # Restore 90 | state = await User.objects.find_one({"name": user.name}) 91 | assert state 92 | 93 | u = await User.restore(state) 94 | assert u is user # No cached 95 | assert u._id == user._id 96 | assert u.name == user.name 97 | assert u.email == user.email 98 | assert u.active == user.active 99 | 100 | # Update 101 | user.active = False 102 | await user.save() 103 | 104 | state = await User.objects.find_one({"name": user.name}) 105 | assert state 106 | u = await User.restore(state) 107 | assert not u.active 108 | 109 | # Create second user 110 | another_user = User(name="other", email="other@ex.com", active=True) 111 | await another_user.save() 112 | 113 | # Delete 114 | await user.delete() 115 | state = await User.objects.find_one({"name": user.name}) 116 | assert not state 117 | 118 | # Make sure second user still exists 119 | state = await User.objects.find_one({"name": another_user.name}) 120 | assert state 121 | 122 | 123 | async def test_nested_save_restore(db): 124 | await Image.objects.drop() 125 | await User.objects.drop() 126 | await Page.objects.drop() 127 | await Comment.objects.drop() 128 | 129 | authors = [User(name=f"User{i}", active=True) for i in range(2)] 130 | for a in authors: 131 | await a.save() 132 | 133 | images = [Image(name=f"Img{i}", path=f"/app/{i}") for i in range(10)] 134 | 135 | # Only save the first few, it should serialize the others 136 | for i in range(3): 137 | await images[i].save() 138 | 139 | pages = [ 140 | Page( 141 | title=f"Page{i}", 142 | body=f"Content{i}", 143 | author=author, 144 | images=[random.choice(images) for j in range(random.randint(0, 2))], 145 | status=random.choice(Page.status.items), 146 | ) 147 | for i in range(4) 148 | for author in authors 149 | ] 150 | for p in pages: 151 | await p.save() 152 | 153 | # Generate comments 154 | comments = [] 155 | for i in range(random.randint(1, 10)): 156 | commentor = User(name=f"User{i}") 157 | await commentor.save() 158 | comment = Comment( 159 | author=commentor, 160 | page=p, 161 | status=random.choice(Comment.status.items), 162 | reply_to=random.choice([None] + comments), 163 | body=f"Body{i}", 164 | ) 165 | comments.append(comment) 166 | await comment.save() 167 | 168 | for p in pages: 169 | # Find in db 170 | state = await Page.objects.find_one( 171 | {"author._id": p.author._id, "title": p.title} 172 | ) 173 | assert state, f"Couldnt find page by {p.title} by {p.author.name}" 174 | r = await Page.restore(state) 175 | assert p._id == r._id 176 | assert p.author._id == r.author._id 177 | assert p.title == r.title 178 | assert p.body == r.body 179 | for img_1, img_2 in zip(p.images, r.images): 180 | assert img_1.path == img_2.path 181 | 182 | async for state in Comment.objects.find({"page._id": p._id}): 183 | comment = await Comment.restore(state) 184 | assert comment.page._id == p._id 185 | async for state in Comment.objects.find({"reply_to._id": comment._id}): 186 | reply = await Comment.restore(state) 187 | assert reply.page._id == p._id 188 | 189 | 190 | async def test_circular(db): 191 | # Test that a circular reference is properly stored as a reference 192 | # and doesn't create an infinite loop 193 | await Page.objects.drop() 194 | 195 | p = Page(title="Home", body="HomeBody") 196 | related_page = Page(title="Other", body="OtherBody", related=[p]) 197 | 198 | # Create a circular reference 199 | p.related = [related_page] 200 | await p.save() 201 | 202 | # Make sure it restores properly 203 | state = await Page.objects.find_one({"_id": p._id}) 204 | pprint(state) 205 | r = await Page.restore(state) 206 | assert r.title == p.title 207 | assert r.related[0].title == related_page.title 208 | assert r.related[0].related[0] == r 209 | 210 | 211 | async def test_load(db): 212 | # That an object can be loaded by setting the ID and calling load. 213 | await User.objects.drop() 214 | 215 | authors = [User(name=f"User{i}", active=True) for i in range(2)] 216 | for a in authors: 217 | await a.save() 218 | 219 | user = User(_id=authors[0]._id) 220 | assert not user.name and not user.__restored__ 221 | 222 | # Load should do nothing if already restored (which is faked here) 223 | user.__restored__ = True 224 | await user.load() 225 | assert not user.name 226 | 227 | # Now ensure a normal load works 228 | user.__restored__ = False 229 | await user.load() 230 | assert user.name == authors[0].name 231 | -------------------------------------------------------------------------------- /tests/test_sql.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import gc 3 | import logging 4 | import os 5 | import random 6 | import re 7 | from datetime import date, datetime, time, timedelta 8 | from decimal import Decimal 9 | 10 | import pytest 11 | from atom.api import ( 12 | Bool, 13 | Bytes, 14 | Dict, 15 | Enum, 16 | Float, 17 | ForwardInstance, 18 | Instance, 19 | Int, 20 | List, 21 | Range, 22 | Str, 23 | Typed, 24 | ) 25 | 26 | if "DATABASE_URL" not in os.environ: 27 | os.environ["DATABASE_URL"] = ( 28 | "postgres://postgres:postgres@127.0.0.1:5432/test_atomdb" 29 | ) 30 | 31 | DATABASE_URL = os.environ["DATABASE_URL"] 32 | 33 | IS_MYSQL = DATABASE_URL.startswith("mysql") 34 | IS_SQLITE = DATABASE_URL.startswith("sqlite") 35 | 36 | try: 37 | import sqlalchemy as sa 38 | 39 | if IS_MYSQL: 40 | from pymysql.err import IntegrityError 41 | elif IS_SQLITE: 42 | # logging.getLogger("aiosqlite").setLevel(logging.DEBUG) 43 | logging.getLogger("aiosqlite.sa").setLevel(logging.DEBUG) 44 | from aiosqlite import IntegrityError 45 | else: 46 | from psycopg2.errors import UniqueViolation as IntegrityError 47 | except ImportError as e: 48 | pytest.skip( 49 | f"aiomysql, aisosqlite, aiopg not available {e}", allow_module_level=True 50 | ) 51 | 52 | from atomdb.sql import ( # noqa: E402 53 | JSONModel, 54 | RelatedInstance, 55 | Relation, 56 | SQLModel, 57 | SQLModelManager, 58 | ) 59 | 60 | 61 | class AbstractUser(SQLModel): 62 | email = Str().tag(length=64) 63 | hashed_password = Bytes() 64 | 65 | class Meta: 66 | abstract = True 67 | 68 | 69 | class User(AbstractUser): 70 | id = Typed(int).tag(primary_key=True, name="user_id") 71 | name = Str().tag(length=200) 72 | active = Bool() 73 | age = Int() 74 | settings = Dict() 75 | rating = Instance(float).tag(nullable=True) 76 | 77 | 78 | class Job(SQLModel): 79 | name = Str().tag(length=64, unique=True) 80 | enabled = Bool(True) 81 | roles = Relation(lambda: JobRole) 82 | duration = Instance(timedelta) 83 | manager = Instance(User) 84 | # FIXME: lead = Instance(User) 85 | 86 | 87 | class JobSkill(SQLModel): 88 | name = Str().tag(length=64, unique=True) 89 | 90 | 91 | class JobRole(SQLModel): 92 | name = Str().tag(length=64) 93 | default = Bool() 94 | job = Instance(Job) 95 | skill = Instance(JobSkill) 96 | tasks = Relation(lambda: JobTask) 97 | 98 | check_one_default = sa.schema.DDL( 99 | """ 100 | CREATE OR REPLACE FUNCTION check_one_default() RETURNS TRIGGER 101 | LANGUAGE plpgsql 102 | AS $$ 103 | BEGIN 104 | IF EXISTS (SELECT * from "test_sql.JobRole" 105 | WHERE "default" = true AND "job" = NEW."job") THEN 106 | RAISE EXCEPTION 'A default aleady exists'; 107 | END IF; 108 | RETURN NEW; 109 | END; 110 | $$;""" 111 | ) 112 | 113 | trigger = sa.schema.DDL( 114 | """ 115 | CREATE CONSTRAINT TRIGGER check_default_role AFTER INSERT OR UPDATE 116 | OF "default" ON "test_sql.JobRole" 117 | FOR EACH ROW EXECUTE PROCEDURE check_one_default();""" 118 | ) 119 | 120 | class Meta: 121 | triggers = [ 122 | ( 123 | "after_create", 124 | lambda: JobRole.check_one_default.execute_if(dialect="postgresql"), 125 | ), 126 | ("after_create", lambda: JobRole.trigger.execute_if(dialect="postgresql")), 127 | ] 128 | 129 | 130 | class JobTask(SQLModel): 131 | id = Int().tag(primary_key=True) 132 | role = Instance(JobRole) 133 | desc = Str().tag(length=20) 134 | 135 | 136 | class ImageInfo(JSONModel): 137 | depth = Int() 138 | 139 | 140 | async def unflatten_image_info(v, scope): 141 | if v is None: 142 | return ImageInfo() 143 | return await ImageInfo.restore(v) 144 | 145 | 146 | class Image(SQLModel): 147 | name = Str().tag(length=100) 148 | path = Str().tag(length=200) 149 | metadata = Typed(dict).tag(nullable=True) 150 | alpha = Range(low=0, high=255) 151 | data = Instance(bytes).tag(nullable=True) 152 | 153 | # Maps to sa.ARRAY, must include the item_type tag 154 | # size = Tuple(int).tag(nullable=True) 155 | 156 | #: Maps to sa.JSON 157 | info = Instance(ImageInfo, ()).tag(unflatten=unflatten_image_info) 158 | 159 | 160 | class BigInt(Int): 161 | """Custom member which defines the sa column type using get_column method""" 162 | 163 | def get_column(self, model): 164 | return sa.Column(self.name, sa.BigInteger()) 165 | 166 | 167 | class Page(SQLModel): 168 | title = Str().tag(length=60) 169 | status = Enum("preview", "live") 170 | body = Str().tag(type=sa.UnicodeText()) 171 | author = Instance(User) 172 | if DATABASE_URL.startswith("postgres"): 173 | images = List(Instance(Image)) 174 | related = List(ForwardInstance(lambda: Page)).tag(nullable=True) 175 | tags = List(str) 176 | 177 | visits = BigInt() 178 | date = Instance(date) 179 | last_updated = Instance(datetime) 180 | rating = Instance(Decimal) 181 | ranking = Float().tag(name="order") 182 | 183 | # A bit verbose but provides a custom column specification 184 | data = Instance(object).tag(column=sa.Column("data", sa.LargeBinary())) 185 | 186 | class Meta: 187 | get_latest_by = "date" 188 | 189 | 190 | class PageImage(SQLModel): 191 | # Example through table for job role 192 | page = Instance(Page).tag(nullable=False) 193 | image = Instance(Image).tag(nullable=False) 194 | 195 | class Meta: 196 | db_table = "page_image_m2m" 197 | unique_together = ("page", "image") 198 | 199 | 200 | class Comment(SQLModel): 201 | page = Instance(Page) 202 | author = Instance(User) 203 | status = Enum("pending", "approved") 204 | body = Str().tag(type=sa.UnicodeText()) 205 | reply_to = ForwardInstance(lambda: Comment).tag(nullable=True) 206 | when = Instance(time) 207 | 208 | 209 | class Email(SQLModel): 210 | id = Int().tag(name="email_id", primary_key=True) 211 | to = Str().tag(length=120) 212 | from_ = Str().tag(name="from").tag(length=120) 213 | body = Str().tag(length=1024) 214 | attachments = Relation(lambda: Attachment) 215 | tags = Relation(lambda: Tag, through=lambda: EmailTag) 216 | 217 | 218 | class Tag(SQLModel): 219 | name = Str().tag(length=100) 220 | 221 | 222 | class EmailTag(SQLModel): 223 | tag = Instance(Tag).tag(nullable=False) 224 | email = Instance(Email).tag(nullable=False) 225 | 226 | 227 | class Attachment(SQLModel): 228 | id = Int().tag(name="attachment_id", primary_key=True) 229 | email = Instance(Email).tag(name="email_id", nullable=False) 230 | name = Str().tag(length=100) 231 | size = Int() 232 | data = Bytes() 233 | 234 | 235 | class Ticket(SQLModel): 236 | code = Str().tag(length=64, primary_key=True) 237 | desc = Str().tag(length=500) 238 | 239 | 240 | class ImportedTicket(Ticket): 241 | meta = Dict() 242 | 243 | 244 | class Document(SQLModel): 245 | name = Str().tag(length=32) 246 | uuid = Str().tag(length=64, primary_key=True) 247 | 248 | #: Reference to the project that is not included in the state 249 | #: You can also use ForwardInstance().tag(store=False) 250 | project = RelatedInstance(lambda: Project) 251 | 252 | 253 | class Project(SQLModel): 254 | title = Str().tag(length=32) 255 | doc = Instance(Document) 256 | 257 | 258 | class Node(SQLModel): 259 | id = Int().tag(primary_key=True) 260 | name = Str().tag(length=10) 261 | type = ForwardInstance(lambda: NodeType).tag(nullable=False, ondelete="CASCADE") 262 | 263 | 264 | class NodeType(SQLModel): 265 | id = Int().tag(primary_key=True) 266 | name = Str().tag(length=10) 267 | # This creates a cyclical FK 268 | default_node = Instance(Node).tag(use_alter=True, ondelete="SET NULL") 269 | 270 | 271 | def test_build_tables(): 272 | # Trigger table creation 273 | SQLModelManager.instance().create_tables() 274 | 275 | 276 | def test_custom_table_name(): 277 | table_name = "some_table.test" 278 | 279 | class Test(SQLModel): 280 | __model__ = table_name 281 | 282 | assert Test.objects.table.name == table_name 283 | 284 | 285 | def test_sanity_pk_and_fields(): 286 | class A(SQLModel): 287 | foo = Str() 288 | 289 | assert A.__pk__ == "_id" 290 | assert A.__fields__ == ["_id", "foo"] 291 | 292 | 293 | def test_sanity_pk_override(): 294 | class A(SQLModel): 295 | id = Int().tag(primary_key=True) 296 | foo = Str() 297 | 298 | assert A._id is A.id 299 | assert A.__pk__ == "id" 300 | assert A.__fields__ == ["id", "foo"] 301 | 302 | 303 | def test_sanity_pk_renamed(): 304 | class A(SQLModel): 305 | id = Int().tag(primary_key=True, name="table_id") 306 | foo = Str() 307 | 308 | assert A._id is A.id 309 | assert A.__pk__ == "table_id" 310 | assert A.__fields__ == ["id", "foo"] 311 | 312 | 313 | def test_sanity_relation_exluded(): 314 | class Child(SQLModel): 315 | pass 316 | 317 | class Parent(SQLModel): 318 | children = Relation(lambda: Child) 319 | 320 | assert "children" in Parent.__excluded_fields__ 321 | 322 | 323 | async def test_sanity_flatten_unflatten(): 324 | async def unflatten_date(v: str, scope=None): 325 | return datetime.strptime(v, "%Y-%m-%d").date() 326 | 327 | def flatten_date(v: date, scope=None): 328 | return v.strftime("%Y-%m-%d") 329 | 330 | class TableOfUnformattedGarbage(SQLModel): 331 | created = Instance(date).tag( 332 | type=sa.String(length=10), flatten=flatten_date, unflatten=unflatten_date 333 | ) 334 | 335 | r = await TableOfUnformattedGarbage.restore( 336 | { 337 | "__model__": TableOfUnformattedGarbage.__model__, 338 | "_id": 1, 339 | "created": "2020-10-28", 340 | } 341 | ) 342 | assert r.created == date(2020, 10, 28) 343 | assert r.__getstate__()["created"] == "2020-10-28" 344 | 345 | 346 | def test_sanity_renamed_fields(): 347 | class A(SQLModel): 348 | some_field = Str().tag(name="SomeField") 349 | 350 | class B(SQLModel): 351 | other_field = Str() 352 | 353 | A.__renamed_fields__ == {"some_field": "SomeField"} 354 | B.__renamed_fields__ == {"some_field": "SomeField"} 355 | 356 | 357 | def test_table_subclass(): 358 | # Test that a non-abstract table can be subclassed 359 | class Base(SQLModel): 360 | class Meta: 361 | abstract = True 362 | 363 | class A(Base): 364 | id = Int().tag(primary_key=True) 365 | 366 | class Meta: 367 | db_table = "test_a" 368 | 369 | assert A.__pk__ == "id" 370 | assert A.__model__ == "test_a" 371 | assert A.__fields__ == ["id"] 372 | 373 | class B(A): 374 | extra_col = Dict() 375 | 376 | class Meta: 377 | db_table = "test_b" 378 | 379 | assert B.__pk__ == "id" 380 | assert B.__model__ == "test_b" 381 | assert B.__fields__ == ["id", "extra_col"] 382 | 383 | 384 | async def reset_tables(*models): 385 | ignore_list = ("Unknown table", "does not exist", "no such table", "doesn't exist") 386 | for Model in models: 387 | try: 388 | await Model.objects.drop_alter_foreign_keys() 389 | except Exception as e: 390 | msg = str(e) 391 | if not any(it in msg for it in ignore_list): 392 | raise # Unexpected error 393 | for Model in models: 394 | try: 395 | await Model.objects.drop_table() 396 | except Exception as e: 397 | msg = str(e) 398 | if not any(it in msg for it in ignore_list): 399 | raise # Unexpected error 400 | await Model.objects.create_table() 401 | for Model in models: 402 | await Model.objects.create_alter_foreign_keys() 403 | 404 | 405 | @pytest.fixture 406 | async def db(): 407 | if DATABASE_URL.startswith("sqlite"): 408 | m = re.match(r"(.+)://(.+)", DATABASE_URL) 409 | assert m, "DATABASE_URL is an invalid format" 410 | schema, db = m.groups() 411 | params = dict(database=db) 412 | else: 413 | m = re.match(r"(.+)://(.+):(.*)@(.+):(\d+)/(.+)", DATABASE_URL) 414 | assert m, "DATABASE_URL is an invalid format" 415 | schema, user, pwd, host, port, db = m.groups() 416 | params = dict(host=host, port=int(port), user=user, password=pwd) 417 | 418 | if schema == "mysql": 419 | from aiomysql import connect 420 | from aiomysql.sa import create_engine 421 | elif schema == "postgres": 422 | from aiopg import connect 423 | from aiopg.sa import create_engine 424 | elif schema == "sqlite": 425 | from aiosqlite import connect 426 | from aiosqlite.sa import create_engine 427 | else: 428 | raise ValueError("Unsupported database schema: %s" % schema) 429 | 430 | if schema == "mysql": 431 | params["autocommit"] = True 432 | 433 | params["loop"] = asyncio.get_running_loop() 434 | 435 | if schema == "sqlite": 436 | params["isolation_level"] = None # autocommit 437 | if os.path.exists(db): 438 | os.remove(db) 439 | else: 440 | if schema == "postgres": 441 | params["database"] = "postgres" 442 | 443 | async with connect(**params) as conn: 444 | async with conn.cursor() as c: 445 | # WARNING: Not safe 446 | await c.execute("DROP DATABASE IF EXISTS %s;" % db) 447 | await c.execute("CREATE DATABASE %s;" % db) 448 | 449 | if schema == "mysql": 450 | params["db"] = db 451 | elif schema == "postgres": 452 | params["database"] = db 453 | 454 | if os.environ.get("ECHO", "").lower() == "true": 455 | params["echo"] = True 456 | 457 | async with create_engine(**params) as engine: 458 | mgr = SQLModelManager.instance() 459 | mgr.database = {"default": engine} 460 | yield engine 461 | 462 | 463 | def test_query_ops_valid(): 464 | """Test that operators are all valid""" 465 | from sqlalchemy.sql.expression import ColumnElement 466 | 467 | from atomdb.sql import QUERY_OPS 468 | 469 | for k, v in QUERY_OPS.items(): 470 | assert hasattr(ColumnElement, v) 471 | 472 | 473 | async def test_drop_create_table(db): 474 | await reset_tables(User) 475 | 476 | 477 | async def test_simple_save_restore_delete(db): 478 | await reset_tables(User) 479 | 480 | # Save 481 | user = User(name="Bob", email="bob@example.com", active=True) 482 | await user.delete() # Deleting unsaved item does nothing 483 | await user.save() 484 | assert user._id is not None 485 | 486 | # Restore 487 | u = await User.objects.get(name=user.name) 488 | assert u 489 | assert u._id == user._id 490 | assert u.name == user.name 491 | assert u.email == user.email 492 | assert u.active == user.active 493 | 494 | # Update 495 | user.active = False 496 | await user.save() 497 | 498 | u = await User.objects.get(name=user.name) 499 | assert u 500 | assert not u.active 501 | 502 | # Create second user 503 | another_user = User(name="Jill", email="jill@example.com", active=True) 504 | await another_user.save() 505 | 506 | # Delete 507 | await user.delete() 508 | assert await User.objects.get(name=user.name) is None 509 | 510 | # Make sure second user still exists 511 | assert await User.objects.get(name=another_user.name) is not None 512 | 513 | 514 | async def test_query(db): 515 | await reset_tables(User) 516 | 517 | # Create second user 518 | for i in range(10): 519 | user = User( 520 | name=f"name-{i}", email="email-{i}@example.com", age=20, active=True 521 | ) 522 | await user.save() 523 | 524 | for user in await User.objects.all(): 525 | print(user) 526 | 527 | for user in await User.objects.filter(name=user.name): 528 | print(user) 529 | 530 | assert await User.objects.filter(name=user.name).exists() 531 | assert not await User.objects.filter(name="I DO NOT EXIST").exists() 532 | 533 | # Delete one 534 | await User.objects.delete(name=user.name) 535 | assert len(await User.objects.all()) == 9 536 | 537 | # Test mapping _id to pk 538 | u = await User.objects.first() 539 | assert u.id 540 | await User.objects.filter(_id=u.id).count() == 1 541 | 542 | # Delete them all 543 | await User.objects.delete(active=True) 544 | assert len(await User.objects.all()) == 0 545 | 546 | 547 | async def test_query_related(db): 548 | await reset_tables(User, Job, JobSkill, JobRole) 549 | 550 | job = await Job.objects.create(name="Chef") 551 | job1 = await Job.objects.create(name="Waitress") 552 | job2 = await Job.objects.create(name="Manager") 553 | 554 | await JobRole.objects.create(job=job, name="Cooking") 555 | await JobRole.objects.create(job=job, name="Grilling") 556 | 557 | await JobRole.objects.create(job=job1, name="Serving") 558 | role2 = await JobRole.objects.create(job=job2, name="Managing") 559 | 560 | roles = await JobRole.objects.filter(job__name__in=[job.name, job2.name]) 561 | assert len(roles) == 3 562 | assert await JobRole.objects.count(job__name__in=[job.name, job2.name]) == 3 563 | 564 | roles = await JobRole.objects.filter(job__name=job2.name) 565 | assert len(roles) == 1 566 | assert await JobRole.objects.count(job__name=job2.name) == 1 567 | 568 | roles = await JobRole.objects.filter(job=job2) 569 | assert len(roles) == 1 and roles[0] == role2 570 | 571 | roles = await JobRole.objects.filter(job__in=[job2]) 572 | assert len(roles) == 1 and roles[0] == role2 573 | 574 | roles = await JobRole.objects.filter(job__name__not="none of the above") 575 | assert len(roles) == 4 576 | 577 | # Test related list 578 | assert len(job.roles) == 0 # Not loaded 579 | await job.roles.load() 580 | assert len(job.roles) == 2 581 | job.roles.append(JobRole(name="Baking", job=job)) 582 | job.roles.sort(key=lambda it: it.name) 583 | assert [it.name for it in job.roles] == ["Baking", "Cooking", "Grilling"] 584 | await job.roles.save() 585 | assert await JobRole.objects.filter(job__name=job.name).count() == 3 586 | 587 | # Cant do multiple joins 588 | with pytest.raises(ValueError): 589 | roles = await JobRole.objects.get(job__name__other=1) 590 | 591 | 592 | async def test_query_related_reverse(db): 593 | await reset_tables(User, Job, JobSkill, JobRole) 594 | 595 | job = await Job.objects.create(name="Chef") 596 | job1 = await Job.objects.create(name="Waitress") 597 | job2 = await Job.objects.create(name="Manager") 598 | 599 | role = await JobRole.objects.create(job=job, name="Cooking") 600 | role1 = await JobRole.objects.create(job=job1, name="Serving") 601 | role2 = await JobRole.objects.create(job=job2, name="Managing") 602 | 603 | jobs = await Job.objects.filter(roles__name=role1.name) 604 | assert jobs == [job1] 605 | 606 | jobs = await Job.objects.filter(roles__in=[role, role2]) 607 | assert jobs == [job, job2] or jobs == [job2, job] 608 | 609 | assert await Job.objects.filter(roles__in=[role, role2]).count() == 2 610 | 611 | 612 | async def test_query_related_renamed(db): 613 | await reset_tables(Email, Attachment) 614 | email = await Email.objects.create( 615 | to="alice@example.com", 616 | from_="bob@example.com", 617 | ) 618 | await Attachment.objects.create(email=email, size=50) 619 | email2 = await Email.objects.create( 620 | to="alice@example.com", 621 | from_="jill@example.com", 622 | ) 623 | await Attachment.objects.create(email=email2, size=100) 624 | 625 | # from_ is renamed to from 626 | r = await Attachment.objects.get(email__from_="jill@example.com") 627 | assert r.size == 100 628 | 629 | 630 | async def test_query_order_by(db): 631 | await reset_tables(User) 632 | # Create second user 633 | users = [] 634 | for i in range(3): 635 | user = User(name=f"Name{i}", email=f"{i}@a.com", age=i, active=True) 636 | await user.save() 637 | users.append(user) 638 | 639 | users.sort(key=lambda it: it.name) 640 | assert await User.objects.order_by("name").all() == users 641 | 642 | assert await User.objects.order_by("name").first() == users[0] 643 | assert await User.objects.order_by("name").last() == users[-1] 644 | 645 | users.reverse() 646 | assert await User.objects.order_by("-name").all() == users 647 | 648 | assert await User.objects.order_by("-name").first() == users[0] 649 | assert await User.objects.order_by("-name").last() == users[-1] 650 | 651 | with pytest.raises(ValueError): 652 | assert await User.objects.last() # Cannot do this on un-ordered query 653 | 654 | 655 | async def test_query_order_by_latest_earliest(db): 656 | """Make sure update takes account for any renamed columns""" 657 | await reset_tables(User, Page) 658 | 659 | p1 = await Page.objects.create( 660 | title="Test1", 661 | ranking=20, 662 | date=date(2024, 10, 1), 663 | last_updated=datetime(2024, 10, 4, 12, 00), 664 | ) 665 | p2 = await Page.objects.create( 666 | title="Test2", 667 | date=date(2024, 10, 2), 668 | last_updated=datetime(2024, 10, 4, 12, 30), 669 | ) 670 | p3 = await Page.objects.create( 671 | title="Test3", 672 | date=date(2024, 10, 3), 673 | last_updated=datetime(2024, 10, 4, 11, 30), 674 | ) 675 | 676 | assert await Page.objects.earliest() == p1 677 | assert await Page.objects.latest() == p3 678 | assert await Page.objects.latest("last_updated") == p2 679 | assert await Page.objects.earliest("last_updated") == p3 680 | 681 | with pytest.raises(TypeError): 682 | await User.objects.earliest() # Does not have a get latest by defined on meta 683 | 684 | 685 | async def test_query_limit(db): 686 | await reset_tables(User) 687 | # Create second user 688 | users = [] 689 | for i in range(3): 690 | user = User(name=f"Name{i}", email=f"{i}@a.com", age=i, active=True) 691 | await user.save() 692 | users.append(user) 693 | 694 | assert len(await User.objects.limit(2).all()) == 2 695 | assert len(await User.objects.offset(2).all()) == 1 696 | 697 | assert len(await User.objects.filter()[1:2].all()) == 1 698 | assert len(await User.objects.filter()[1:].all()) == 2 699 | assert len(await User.objects.filter()[0].all()) == 1 700 | 701 | # Keys must be integers 702 | with pytest.raises(TypeError): 703 | User.objects.filter()[1.2] 704 | 705 | with pytest.raises(TypeError): 706 | User.objects.filter()[1.2:3] 707 | 708 | # No negative offests 709 | with pytest.raises(ValueError): 710 | User.objects.filter()[-1] 711 | 712 | # No negative limits 713 | with pytest.raises(ValueError): 714 | User.objects.filter()[0:-1] 715 | 716 | 717 | async def test_query_pk(db): 718 | await reset_tables(Ticket) 719 | t = await Ticket.objects.create(code="special") 720 | assert await Ticket.objects.get(code="special") is t 721 | 722 | 723 | async def test_query_subclassed_pk(db): 724 | await reset_tables(ImportedTicket) 725 | t = await ImportedTicket.objects.create(code="special", meta={"source": "db"}) 726 | assert await ImportedTicket.objects.get(code="special") is t 727 | 728 | 729 | async def test_query_renamed_pk(db): 730 | await reset_tables(Email) 731 | email = await Email.objects.create( 732 | to="bob@example.com", from_="alice@example.com", body="Hello ;)" 733 | ) 734 | email_id = email._id 735 | assert await Email.objects.get(id=email_id) is email 736 | del email 737 | gc.collect() 738 | 739 | # Make sure renamed field is restored 740 | email = await Email.objects.get(id=email_id) 741 | assert email.from_ == "alice@example.com" 742 | 743 | 744 | async def test_requery_update_not_force_restored(db): 745 | await reset_tables(Ticket) 746 | a = await Ticket.objects.create(code="a", desc="In progress") 747 | b = await Ticket.objects.create(code="b", desc="In progress") 748 | c = await Ticket.objects.create(code="c", desc="Fixed") 749 | results = await Ticket.objects.order_by("code").all() 750 | assert results == [a, b, c] 751 | 752 | # This bypasses updating the object in memory 753 | await Ticket.objects.filter(desc="In progress").update(desc="Fixed") 754 | 755 | # The objects remain the same, the cached values are still kept 756 | updated_results = await Ticket.objects.order_by("code").all() 757 | assert updated_results == [a, b, c] 758 | assert a.desc == "In progress" and b.desc == "In progress" 759 | 760 | 761 | async def test_requery_update_force_restored(db): 762 | await reset_tables(Ticket) 763 | a = await Ticket.objects.create(code="a", desc="In progress") 764 | b = await Ticket.objects.create(code="b", desc="In progress") 765 | c = await Ticket.objects.create(code="c", desc="Fixed") 766 | results = await Ticket.objects.order_by("code").all() 767 | assert results == [a, b, c] 768 | 769 | # This bypasses updating the object in memory 770 | await Ticket.objects.filter(desc="In progress").update(desc="Fixed") 771 | 772 | # The objects remain the same but his force restores any updated fields 773 | updated_results = await Ticket.objects.order_by("code").all(force_restore=True) 774 | assert updated_results == [a, b, c] 775 | assert a.desc == "Fixed" and b.desc == "Fixed" 776 | 777 | 778 | async def test_query_bad_column_name(db): 779 | await reset_tables(Ticket) 780 | await Ticket.objects.create(code="special") 781 | with pytest.raises(ValueError): 782 | await Ticket.objects.get(unknown="special") 783 | 784 | 785 | async def test_query_select_related(db): 786 | await reset_tables(User, Job, JobSkill, JobRole) 787 | # Create second user 788 | job = await Job.objects.create(name="Chef") 789 | await JobRole.objects.create(job=job, name="Cooking") 790 | await JobRole.objects.create(name="Accounting") 791 | del job 792 | 793 | Job.objects.cache.clear() 794 | JobRole.objects.cache.clear() 795 | 796 | # Without select related it only has the Job with it's pk 797 | roles = await JobRole.objects.all() 798 | assert len(roles) == 2 799 | assert roles[0].job.__restored__ is False 800 | del roles 801 | 802 | # TODO: Shouldn't have to do this here... 803 | Job.objects.cache.clear() 804 | JobRole.objects.cache.clear() 805 | 806 | # With select related the job is fully loaded 807 | # since the second role does not set a job it is excluded due to the 808 | # default inner join 809 | roles = await JobRole.objects.select_related("job").all() 810 | assert len(roles) == 1 811 | assert roles[0].job.__restored__ is True 812 | 813 | # Using outer join includes related fields that are null 814 | roles = await JobRole.objects.select_related("job", outer_join=True).all() 815 | assert len(set(roles)) == 2 816 | assert roles[0].job.__restored__ is True 817 | assert roles[1].job is None 818 | 819 | 820 | async def test_query_select_related_multiple(db): 821 | await reset_tables(User, Job, JobSkill, JobRole) 822 | await JobRole.objects.create( 823 | job=await Job.objects.create(name="Manager"), 824 | skill=await JobSkill.objects.create(name="Excel"), 825 | name="Sr Manager", 826 | ) 827 | await JobRole.objects.create( 828 | job=await Job.objects.create(name="Dev"), 829 | skill=await JobSkill.objects.create(name="Python"), 830 | name="Sr Dev", 831 | ) 832 | 833 | Job.objects.cache.clear() 834 | JobRole.objects.cache.clear() 835 | JobSkill.objects.cache.clear() 836 | 837 | # With select related the job is fully loaded 838 | # since the second role does not set a job it is excluded due to the 839 | # default inner join 840 | q = JobRole.objects.select_related("job", "skill").order_by("name").all() 841 | roles = await q 842 | # for role in roles: 843 | # print((role.name, role.job.name, role.skill.name)) 844 | assert len(roles) == 2 845 | assert roles[0].name == "Sr Dev" 846 | assert roles[0].job.__restored__ is True 847 | assert roles[0].job.name == "Dev" 848 | assert roles[0].skill.__restored__ is True 849 | assert roles[0].skill.name == "Python" 850 | 851 | assert roles[1].name == "Sr Manager" 852 | assert roles[1].job.__restored__ is True 853 | assert roles[1].job.name == "Manager" 854 | assert roles[1].skill.__restored__ is True 855 | assert roles[1].skill.name == "Excel" 856 | 857 | 858 | async def test_query_select_related_filter(db): 859 | await reset_tables(User, Job, JobSkill, JobRole) 860 | 861 | boss = await User.objects.create(name="Boss man") 862 | monkey = await User.objects.create(name="Code monkey") 863 | 864 | await JobRole.objects.create( 865 | job=await Job.objects.create(name="Manager", manager=boss), 866 | skill=await JobSkill.objects.create(name="Excel"), 867 | name="Sr Manager", 868 | ) 869 | dev = await Job.objects.create(name="Dev", manager=boss) # FIXME:, lead=monkey) 870 | await JobRole.objects.create( 871 | job=dev, 872 | skill=await JobSkill.objects.create(name="C++"), 873 | name="Sr Dev", 874 | ) 875 | await JobRole.objects.create( 876 | job=dev, 877 | skill=await JobSkill.objects.create(name="Python"), 878 | name="Jr Dev", 879 | ) 880 | del dev, boss, monkey 881 | Job.objects.cache.clear() 882 | JobRole.objects.cache.clear() 883 | JobSkill.objects.cache.clear() 884 | User.objects.cache.clear() 885 | 886 | r = await JobRole.objects.select_related("job", "skill").get( 887 | job__enabled=True, skill__name__startswith="C" 888 | ) 889 | assert r.name == "Sr Dev" 890 | assert r.job.name == "Dev" 891 | assert not r.job.manager.__restored__ 892 | assert r.skill.name == "C++" 893 | 894 | del r 895 | Job.objects.cache.clear() 896 | JobRole.objects.cache.clear() 897 | JobSkill.objects.cache.clear() 898 | User.objects.cache.clear() 899 | 900 | # Test that duplicate select on job does not lead to an error 901 | r = await JobRole.objects.select_related("job", "job__manager", "skill").get( 902 | job__manager__name__contains="Boss", 903 | # FIXME: job__lead__name__contains="monkey", 904 | skill__name__startswith="P", 905 | ) 906 | assert r.name == "Jr Dev" 907 | assert r.job.name == "Dev" 908 | assert r.job.manager.name == "Boss man" 909 | # FIXME: assert r.job.lead.name == "Code monkey" 910 | assert r.skill.name == "Python" 911 | 912 | 913 | async def test_query_prefetch_related_invalid(db): 914 | await reset_tables(Email, Attachment) 915 | with pytest.raises(ValueError): 916 | await Email.objects.prefetch_related("comments").all() 917 | 918 | 919 | async def test_query_prefetch_related_instance(db): 920 | await reset_tables(Document, Project) 921 | doc1 = await Document.objects.create(name="first", uuid="1") 922 | doc2 = await Document.objects.create(name="second", uuid="2") 923 | await Project.objects.create(title="pack", doc=doc1) 924 | await Project.objects.create(title="ship", doc=doc2) 925 | 926 | del doc1, doc2 927 | Document.objects.cache.clear() 928 | Project.objects.cache.clear() 929 | gc.collect() 930 | 931 | # Related instances are not populated without prefetch 932 | docs = await Document.objects.all() 933 | assert len(docs) == 2 934 | assert all(doc.project is None for doc in docs) 935 | 936 | del docs 937 | Document.objects.cache.clear() 938 | Project.objects.cache.clear() 939 | gc.collect() 940 | 941 | docs = await Document.objects.prefetch_related("project").all() 942 | assert len(docs) == 2 943 | assert all(doc.project.__restored__ for doc in docs) 944 | assert docs[0].project.title == "pack" 945 | assert docs[1].project.title == "ship" 946 | 947 | 948 | async def test_query_prefetch_related_list(db): 949 | await reset_tables(Email, Attachment) 950 | 951 | email = await Email.objects.create( 952 | to="alice@example.com", 953 | from_="bob@example.com", 954 | body="Please checkout this project", 955 | ) 956 | await Attachment.objects.create(email=email, name="a.txt", data=b"a") 957 | await Attachment.objects.create(email=email, name="b.txt", data=b"b") 958 | 959 | email = await Email.objects.create( 960 | to="bob@example.com", from_="alice@example.com", body="Cat pictures!" 961 | ) 962 | await Attachment.objects.create(email=email, name="new.jpg", data=b"photo") 963 | 964 | # Purge cache 965 | del email 966 | Email.objects.cache.clear() 967 | Attachment.objects.cache.clear() 968 | gc.collect() 969 | 970 | # No prefetch 971 | emails = await Email.objects.all() 972 | assert len(emails) == 2 973 | for email in emails: 974 | assert len(email.attachments) == 0 975 | 976 | # Purge cache 977 | del email, emails 978 | Email.objects.cache.clear() 979 | Attachment.objects.cache.clear() 980 | gc.collect() 981 | 982 | emails = await Email.objects.prefetch_related("attachments").all() 983 | assert len(emails) == 2 984 | 985 | email = emails[0] 986 | assert len(email.attachments) == 2 987 | attachment = email.attachments[0] 988 | assert attachment.name == "a.txt" 989 | assert attachment.data == b"a" 990 | assert attachment.email is email 991 | attachment = email.attachments[1] 992 | assert attachment.name == "b.txt" 993 | assert attachment.data == b"b" 994 | assert attachment.email is email 995 | 996 | email = emails[1] 997 | assert len(email.attachments) == 1 998 | attachment = email.attachments[0] 999 | assert attachment.name == "new.jpg" 1000 | assert attachment.data == b"photo" 1001 | assert attachment.email is email 1002 | 1003 | email = await Email.objects.prefetch_related("attachments").get( 1004 | to="bob@example.com" 1005 | ) 1006 | assert len(email.attachments) == 1 1007 | attachment = email.attachments[0] 1008 | assert attachment.name == "new.jpg" 1009 | assert attachment.data == b"photo" 1010 | assert attachment.email is email 1011 | 1012 | emails = await Email.objects.prefetch_related("attachments").filter( 1013 | body__contains="pictures" 1014 | ) 1015 | assert len(emails) == 1 1016 | 1017 | email = emails[0] 1018 | assert len(email.attachments) == 1 1019 | attachment = email.attachments[0] 1020 | assert attachment.name == "new.jpg" 1021 | assert attachment.data == b"photo" 1022 | assert attachment.email is email 1023 | 1024 | 1025 | async def test_query_prefetch_related_updates(db): 1026 | await reset_tables(Email, Attachment) 1027 | 1028 | email = await Email.objects.create( 1029 | to="alice@example.com", 1030 | from_="bob@example.com", 1031 | body="Please checkout this project", 1032 | ) 1033 | await Attachment.objects.create(email=email, name="a.txt", data=b"a") 1034 | await Attachment.objects.create(email=email, name="b.txt", data=b"b") 1035 | 1036 | email = await Email.objects.prefetch_related("attachments").get( 1037 | to="alice@example.com", 1038 | force_restore=True, 1039 | ) 1040 | assert len(email.attachments) == 2 1041 | 1042 | await Attachment.objects.create(email=email, name="c.txt", data=b"c") 1043 | email = await Email.objects.prefetch_related("attachments").get( 1044 | to="alice@example.com", 1045 | force_restore=True, 1046 | ) 1047 | assert len(email.attachments) == 3 1048 | 1049 | 1050 | async def test_query_values(db): 1051 | await reset_tables(User) 1052 | # Create second user 1053 | user = User(name="Bob", email="bob@email.com", age=40, active=True) 1054 | await user.save() 1055 | 1056 | user1 = User(name="Jack", email="jack@ex.com", age=30, active=False) 1057 | await user1.save() 1058 | 1059 | user2 = User(name="Bob", email="bob@other.com", age=20, active=False) 1060 | await user2.save() 1061 | 1062 | vals = await User.objects.filter(active=True).values() 1063 | assert len(vals) == 1 and vals[0]["email"] == user.email 1064 | 1065 | assert await User.objects.order_by("name").values("name", distinct=True) == [ 1066 | ("Bob",), 1067 | ("Jack",), 1068 | ] 1069 | 1070 | assert await User.objects.order_by("age").values("age", flat=True) == [20, 30, 40] 1071 | 1072 | assert await User.objects.filter(active=True).values("age", flat=True) == [40] 1073 | 1074 | # Cannot use flat with multiple values 1075 | with pytest.raises(ValueError): 1076 | await User.objects.values("name", "age", flat=True) 1077 | 1078 | 1079 | @pytest.mark.skipif(IS_MYSQL, reason="Distinct and count doesn't work") 1080 | async def test_query_distinct(db): 1081 | await reset_tables(User) 1082 | # Create second user 1083 | user = User(name="Bob", email="bob@email.com", age=40, active=True) 1084 | await user.save() 1085 | 1086 | user1 = User(name="Jack", email="jack@ex.com", age=30, active=False) 1087 | await user1.save() 1088 | 1089 | user2 = User(name="Bob", email="bob@other.com", age=20, active=False) 1090 | await user2.save() 1091 | 1092 | num_names = await User.objects.distinct("name").count() 1093 | assert num_names == 2 1094 | distinct_names = ( 1095 | await User.objects.distinct("name").order_by("name").values("name", flat=True) 1096 | ) 1097 | assert distinct_names == ["Bob", "Jack"] 1098 | 1099 | num_ages = await User.objects.distinct("age").count() 1100 | assert num_ages == 3 1101 | num_ages = await User.objects.filter(age__gt=25).distinct("age").count() 1102 | assert num_ages == 2 1103 | 1104 | 1105 | async def test_get_or_create(db): 1106 | await reset_tables(User, Job, JobSkill, JobRole) 1107 | 1108 | name = "Bob" 1109 | email = "bob@example.com" 1110 | 1111 | user, created = await User.objects.get_or_create(name=name, email=email) 1112 | assert created 1113 | assert user._id and user.name == name and user.email == user.email 1114 | 1115 | u, created = await User.objects.get_or_create(name=user.name, email=user.email) 1116 | assert u._id == user._id 1117 | assert not created and user.name == name and user.email == user.email 1118 | 1119 | # Test passing model 1120 | job, created = await Job.objects.get_or_create(name="Accountant") 1121 | assert job and created 1122 | 1123 | role, created = await JobRole.objects.get_or_create(job=job, name="Accounting") 1124 | assert role and created and role.job._id == job._id 1125 | 1126 | role_check, created = await JobRole.objects.get_or_create(job=job, name=role.name) 1127 | assert role_check._id == role._id and not created 1128 | 1129 | 1130 | async def test_create(db): 1131 | await reset_tables(User, Job, JobSkill, JobRole) 1132 | 1133 | job = await Job.objects.create(name="Chef") 1134 | assert job and job._id 1135 | 1136 | # DB should enforce unique ness 1137 | with pytest.raises(IntegrityError): 1138 | await Job.objects.create(name=job.name) 1139 | 1140 | 1141 | async def test_bulk_create(db): 1142 | await reset_tables(User) 1143 | assert await User.objects.count() == 0 1144 | # TODO: Get the id's of the rows inserted? 1145 | users = await User.objects.bulk_create([User(name=f"user-{i}") for i in range(10)]) 1146 | for u in users: 1147 | if not IS_MYSQL: 1148 | assert u._id 1149 | assert await User.objects.count() == 10 1150 | 1151 | 1152 | async def test_transaction_rollback(db): 1153 | await reset_tables(User, Job, JobSkill, JobRole) 1154 | 1155 | with pytest.raises(ValueError): 1156 | async with Job.objects.connection() as conn: 1157 | trans = await conn.begin() 1158 | try: 1159 | # Must pass in the connection parameter for transactions 1160 | job = await Job.objects.create(name="Job", connection=conn) 1161 | assert job._id is not None 1162 | for i in range(3): 1163 | role = await JobRole.objects.create( 1164 | job=job, name=f"Role{i}", connection=conn 1165 | ) 1166 | assert role._id is not None 1167 | complete = True 1168 | raise ValueError("Oh crap, I didn't want to do that") 1169 | except Exception: 1170 | await trans.rollback() 1171 | rollback = True 1172 | raise 1173 | else: 1174 | rollback = False 1175 | await trans.commit() 1176 | 1177 | assert complete and rollback 1178 | assert len(await Job.objects.all()) == 0 1179 | assert len(await JobRole.objects.all()) == 0 1180 | 1181 | 1182 | async def test_transaction_commit(db): 1183 | await reset_tables(User, Job, JobSkill, JobRole) 1184 | 1185 | async with Job.objects.connection() as conn: 1186 | trans = await conn.begin() 1187 | try: 1188 | # Must pass in the connection parameter for transactions 1189 | job = await Job.objects.create(name="Job", connection=conn) 1190 | assert job._id is not None 1191 | for i in range(3): 1192 | role = await JobRole.objects.create( 1193 | job=job, name=f"Role{i}", connection=conn 1194 | ) 1195 | assert role._id is not None 1196 | except Exception: 1197 | await trans.rollback() 1198 | raise 1199 | else: 1200 | await trans.commit() 1201 | 1202 | assert len(await Job.objects.all()) == 1 1203 | assert len(await JobRole.objects.all()) == 3 1204 | 1205 | 1206 | async def test_transaction_delete(db): 1207 | await reset_tables(User) 1208 | 1209 | name = "Name" 1210 | async with User.objects.connection() as conn: 1211 | trans = await conn.begin() 1212 | try: 1213 | # Must pass in the connection parameter for transactions 1214 | user = await User.objects.create( 1215 | name=name, email="test@ex.com", age=20, active=True, connection=conn 1216 | ) 1217 | assert user._id is not None 1218 | await User.objects.delete(name=name, connection=conn) 1219 | except Exception: 1220 | await trans.rollback() 1221 | raise 1222 | else: 1223 | await trans.commit() 1224 | 1225 | assert not await User.objects.exists(name=name) 1226 | 1227 | 1228 | async def test_filters(db): 1229 | await reset_tables(User) 1230 | 1231 | user, created = await User.objects.get_or_create( 1232 | name="Bob", email="bob@ex.com", age=21, active=True 1233 | ) 1234 | assert created 1235 | 1236 | user2, created = await User.objects.get_or_create( 1237 | name="Tom", email="tom@ex.com", age=48, active=False, rating=10.0 1238 | ) 1239 | assert created 1240 | 1241 | # Startswith 1242 | u = await User.objects.get(name__startswith="B") 1243 | assert u.name == user.name 1244 | assert u is user # Now cached 1245 | 1246 | # In query 1247 | users = await User.objects.filter(name__in=[user.name, user2.name]) 1248 | assert len(users) == 2 1249 | 1250 | # Test use of count 1251 | assert await User.objects.count(name__in=[user.name, user2.name]) == 2 1252 | 1253 | # Is query 1254 | users = await User.objects.filter(active__is=False) 1255 | assert len(users) == 1 and users[0].active is False 1256 | assert users[0] is user2 # Now cached 1257 | 1258 | # Not query 1259 | users = await User.objects.filter(rating__isnot=None) 1260 | assert len(users) == 1 and users[0].rating is not None 1261 | 1262 | # Lt query 1263 | users = await User.objects.filter(age__lt=30) 1264 | assert len(users) == 1 and users[0].age == user.age 1265 | 1266 | users = await User.objects.exclude(age=21) 1267 | assert len(users) == 1 and users[0].age == 48 1268 | 1269 | # Or query 1270 | users = await User.objects.filter(dict(age__lt=18, age__gt=40)) 1271 | assert len(users) == 1 and users[0].age == 48 1272 | 1273 | # Exclude or 1274 | users = await User.objects.exclude(dict(age__lt=18, age__gt=40)) 1275 | assert len(users) == 1 and users[0].age == 21 1276 | 1277 | # Not supported 1278 | with pytest.raises(ValueError): 1279 | users = await User.objects.filter(age__xor=1) 1280 | 1281 | # Missing op 1282 | with pytest.raises(ValueError): 1283 | users = await User.objects.filter(age__=1) 1284 | 1285 | # Invalid name 1286 | with pytest.raises(ValueError): 1287 | users = await User.objects.filter(does_not_exist=True) 1288 | 1289 | 1290 | async def test_filter_exclude(db): 1291 | await reset_tables(User) 1292 | # Create second user 1293 | await User.objects.create(name="Bob", email="bob@other.com", age=40, active=True) 1294 | await User.objects.create( 1295 | name="Jack", email="jack@company.com", age=30, active=False 1296 | ) 1297 | await User.objects.create(name="Bob", email="bob@company.com", age=20, active=False) 1298 | 1299 | users = await User.objects.filter(name__startswith="B").exclude( 1300 | email__endswith="other.com" 1301 | ) 1302 | assert len(users) == 1 and users[0].email == "bob@company.com" 1303 | 1304 | users = await User.objects.exclude(active=True, age__lt=25) 1305 | assert len(users) == 1 and users[0].name == "Jack" 1306 | 1307 | users = await User.objects.exclude(name="Bob") 1308 | assert len(users) == 1 and users[0].name == "Jack" 1309 | 1310 | 1311 | async def test_update(db): 1312 | await reset_tables(User) 1313 | # Create second user 1314 | user = User(name="Bob", email="bob@ex.com", age=40, active=True) 1315 | await user.save() 1316 | 1317 | user1 = User(name="Jack", email="jack@ex.com", age=30, active=False) 1318 | await user1.save() 1319 | 1320 | user2 = User(name="Bob", email="bob@other.com", age=20, active=False) 1321 | await user2.save() 1322 | 1323 | assert await User.objects.filter(age=20).exists() 1324 | await User.objects.filter(age=20).update(age=25) 1325 | assert not await User.objects.filter(age=20).exists() 1326 | 1327 | assert await User.objects.filter(active=False).exists() 1328 | await User.objects.update(active=True) 1329 | assert not await User.objects.filter(active=False).exists() 1330 | 1331 | assert await User.objects.filter(active=False).count() == 0 1332 | await User.objects.filter(name="Bob").update(active=False) 1333 | assert await User.objects.filter(active=False).count() == 2 1334 | 1335 | 1336 | async def test_update_renamed(db): 1337 | """Make sure update takes account for any renamed columns""" 1338 | await reset_tables(User, Page) 1339 | 1340 | await Page.objects.create(title="Test1", status="live", ranking=100) 1341 | await Page.objects.create(title="Test2", status="live", ranking=100) 1342 | await Page.objects.create(title="Test3", status="live", ranking=1) 1343 | 1344 | assert await Page.objects.filter(ranking=100).count() == 2 1345 | await Page.objects.filter(ranking=100).update(ranking=3) 1346 | assert await Page.objects.filter(ranking=100).count() == 0 1347 | assert await Page.objects.filter(ranking=3).count() == 2 1348 | 1349 | 1350 | async def test_column_rename(db): 1351 | """Columns can be tagged with custom names. Verify that it works.""" 1352 | await reset_tables(Email) 1353 | 1354 | e = Email(from_="jack@ex.com", to="jill@ex.com", body="Did you see this?") 1355 | await e.save() 1356 | 1357 | # Check without use labels 1358 | table = Email.objects.table 1359 | q = table.select().where(table.c.to == e.to) 1360 | row = await Email.objects.fetchone(q) 1361 | assert row["from"] == e.from_, "Column rename failed" 1362 | 1363 | # Check with use labels 1364 | q = table.select(use_labels=True).where(table.c.to == e.to) 1365 | row = await Email.objects.fetchone(q) 1366 | assert row[f"{table.name}_from"] == e.from_, "Column rename failed" 1367 | 1368 | # Restoring a renamed column needs to work 1369 | restored = await Email.objects.get(to=e.to) 1370 | restored.from_ == e.from_ 1371 | 1372 | 1373 | async def test_query_many_to_one(db): 1374 | await reset_tables(User, Job, JobSkill, JobRole) 1375 | 1376 | jobs = [] 1377 | 1378 | for i in range(5): 1379 | job = Job(name=f"Job{i}") 1380 | await job.save() 1381 | jobs.append(job) 1382 | 1383 | for i in range(random.randint(1, 5)): 1384 | role = JobRole(name=f"Role{i}", job=job) 1385 | await role.save() 1386 | 1387 | loaded = [] 1388 | q = Job.objects.table.join(JobRole.objects.table).select(use_labels=True) 1389 | 1390 | print(q) 1391 | 1392 | r = await Job.objects.execute(q) 1393 | assert r.returns_rows 1394 | 1395 | for row in await JobRole.objects.fetchall(q): 1396 | #: TODO: combine the joins back up 1397 | role = await JobRole.restore(row) 1398 | 1399 | # Job should be restored from the cache 1400 | assert role.job is not None 1401 | assert role.job.__restored__ is True 1402 | # for role in job.roles: 1403 | # assert role.job == job 1404 | loaded.append(job) 1405 | 1406 | assert len(await Job.objects.fetchmany(q, size=2)) == 2 1407 | 1408 | # Make sure they pull from cache 1409 | roles = await JobRole.objects.all() 1410 | for role in roles: 1411 | assert role.job is not None 1412 | assert role.job.__restored__ is True 1413 | 1414 | # Clear cache and ensure it doesn't pull from cache now 1415 | Job.objects.cache.clear() 1416 | JobRole.objects.cache.clear() 1417 | 1418 | roles = await JobRole.objects.all() 1419 | used = set() 1420 | for role in roles: 1421 | assert role.job is not None 1422 | if role.job not in used: 1423 | assert role.job.__restored__ is False 1424 | used.add(role.job) 1425 | await role.job.load() 1426 | assert role.job.__restored__ is True 1427 | 1428 | 1429 | async def test_query_multiple_joins(db): 1430 | await reset_tables(User, Job, JobSkill, JobRole, JobTask) 1431 | 1432 | ceo = await Job.objects.create(name="CEO") 1433 | cfo = await Job.objects.create(name="CFO") 1434 | swe = await Job.objects.create(name="SWE") 1435 | 1436 | ceo_role = await JobRole.objects.create(name="CEO", job=ceo) 1437 | cfo_role = await JobRole.objects.create(name="CFO", job=cfo) 1438 | swe_role = await JobRole.objects.create(name="SWE", job=swe) 1439 | 1440 | await JobTask.objects.create(desc="Code", role=swe_role) 1441 | await JobTask.objects.create(desc="Hire", role=ceo_role) 1442 | await JobTask.objects.create(desc="Fire", role=ceo_role) 1443 | await JobTask.objects.create(desc="Account", role=cfo_role) 1444 | 1445 | jobs = await Job.objects.filter(roles__tasks__desc="Fire") 1446 | assert jobs == [ceo] 1447 | 1448 | jobs = await Job.objects.order_by("name").filter( 1449 | roles__tasks__desc__notin=["Hire", "Fire"] 1450 | ) 1451 | assert jobs == [cfo, swe] 1452 | 1453 | 1454 | async def test_save_update_fields(db): 1455 | """Test that using save with update_fields only updates the fields 1456 | specified 1457 | 1458 | """ 1459 | await reset_tables(User) 1460 | await reset_tables(Page) 1461 | 1462 | page = await Page.objects.create( 1463 | title="Test", body="This is only a test", status="live" 1464 | ) 1465 | assert page.visits == 0 1466 | page.visits += 1 1467 | page.body = "New body" 1468 | await page.save(update_fields=["visits"]) 1469 | 1470 | del Page.objects.cache[page._id] 1471 | page = await Page.objects.get(title="Test") 1472 | 1473 | # This field should not be saved 1474 | assert page.body == "This is only a test" 1475 | # But this should be saved 1476 | assert page.visits == 1 1477 | 1478 | 1479 | async def test_load_fields(db): 1480 | """Test that using load with fields only loads the given field""" 1481 | await reset_tables(User) 1482 | await reset_tables(Page) 1483 | 1484 | page = await Page.objects.create( 1485 | title="Test", body="This is only a test", status="live" 1486 | ) 1487 | assert page.visits == 0 1488 | 1489 | # Update outside the orm 1490 | t = Page.objects.table 1491 | q = t.update().where(t.c._id == page._id).values(visits=1288821, title="New title") 1492 | async with Page.objects.connection() as conn: 1493 | await conn.execute(q) 1494 | 1495 | page.body = "This has changed" 1496 | 1497 | # Reload the visits 1498 | await page.load(fields=["visits"]) 1499 | 1500 | # This should be the only field that updates 1501 | assert page.visits == 1288821 1502 | 1503 | # This should not change 1504 | assert page.title == "Test" 1505 | assert page.body == "This has changed" 1506 | 1507 | # Reload the title 1508 | await page.load(fields=["title", "visits"]) 1509 | assert page.title == "New title" 1510 | 1511 | 1512 | async def test_save_errors(db): 1513 | await reset_tables(User) 1514 | 1515 | u = User() 1516 | with pytest.raises(ValueError): 1517 | # Cant do both 1518 | await u.save(force_insert=True, force_update=True) 1519 | 1520 | # Updating unsaved will not work 1521 | r = await u.save(force_update=True) 1522 | assert r.rowcount == 0 1523 | 1524 | 1525 | async def test_object_caching(db): 1526 | await reset_tables(Email) 1527 | 1528 | e = Email(from_="a", to="b", body="c") 1529 | await e.save() 1530 | pk = e._id 1531 | aref = Email.objects.cache.get(pk) 1532 | assert aref is e, "Cached object is invalid" 1533 | 1534 | # Delete 1535 | del e 1536 | del aref 1537 | 1538 | gc.collect() 1539 | 1540 | # Make sure cache was cleaned up 1541 | aref = Email.objects.cache.get(pk) 1542 | assert aref is None, "Cached object was not released" 1543 | 1544 | 1545 | async def test_fk_custom_type(db): 1546 | await reset_tables(Document, Project) 1547 | doc = await Document.objects.create(uuid="foo") 1548 | await Project.objects.create(doc=doc) 1549 | col = Project.objects.table.columns["doc"] 1550 | assert isinstance(col.type, sa.String) 1551 | 1552 | 1553 | async def test_relation_many_to_one_save(db): 1554 | await reset_tables(Email, Attachment) 1555 | email = await Email.objects.create( 1556 | to="alice@example.com", 1557 | from_="bob@example.com", 1558 | ) 1559 | email.attachments = [ 1560 | Attachment(email=email, name="test.pdf"), 1561 | Attachment(email=email, name="funny.jpg"), 1562 | ] 1563 | assert isinstance(email.attachments, list) 1564 | await email.attachments.save() 1565 | assert (await Attachment.objects.filter(email=email).count()) == 2 1566 | 1567 | all_attachments = email.attachments + [ 1568 | Attachment(email=email, name="new.jpg"), 1569 | ] 1570 | email.attachments = all_attachments 1571 | await email.attachments.save() 1572 | assert (await Attachment.objects.filter(email=email).count()) == 3 1573 | 1574 | email.attachments.pop() 1575 | email.attachments.pop() 1576 | await email.attachments.save() 1577 | assert (await Attachment.objects.filter(email=email).count()) == 1 1578 | 1579 | # Check RelatedList 1580 | # Check iter 1581 | assert [a.email is email for a in email.attachments] 1582 | # Check getitem 1583 | assert email.attachments[0].name == "test.pdf" 1584 | assert len(email.attachments) == 1 1585 | 1586 | a = email.attachments[0] 1587 | assert a in email.attachments 1588 | 1589 | email.attachments.insert(0, Attachment(email=email, name="new.docx")) 1590 | await email.attachments.save() 1591 | assert (await Attachment.objects.filter(email=email).count()) == 2 1592 | 1593 | email.attachments = email.attachments[-1:] 1594 | await email.attachments.save() 1595 | assert (await Attachment.objects.filter(email=email).count()) == 1 1596 | 1597 | # Make sure errors still work 1598 | with pytest.raises(TypeError): 1599 | email.attachments.append(Image()) 1600 | with pytest.raises(TypeError): 1601 | email.attachments = [Image()] 1602 | 1603 | 1604 | async def test_relation_many_to_many_save(db): 1605 | await reset_tables(Email, Tag, EmailTag) 1606 | email = await Email.objects.create( 1607 | to="alice@example.com", 1608 | from_="bob@example.com", 1609 | ) 1610 | inbox = await Tag.objects.create(name="Inbox") 1611 | starred = await Tag.objects.create(name="Starred") 1612 | draft = await Tag.objects.create(name="Draft") 1613 | 1614 | email.tags = [inbox, starred] 1615 | await email.tags.save() 1616 | assert (await EmailTag.objects.count()) == 2 1617 | email.tags = [inbox] 1618 | await email.tags.save() 1619 | assert (await EmailTag.objects.count()) == 1 1620 | email.tags = [starred, draft] 1621 | await email.tags.save() 1622 | assert (await EmailTag.objects.count()) == 2 1623 | 1624 | 1625 | async def test_cyclical_foreign_keys(db): 1626 | await reset_tables(NodeType, Node) 1627 | 1628 | link_node_type = await NodeType.objects.create( 1629 | name="link", 1630 | ) 1631 | web_node = await Node.objects.create( 1632 | name="web", 1633 | type=link_node_type, 1634 | ) 1635 | await Node.objects.create( 1636 | name="file", 1637 | type=link_node_type, 1638 | ) 1639 | link_node_type.default_node = web_node 1640 | await link_node_type.save() 1641 | del link_node_type 1642 | 1643 | NodeType.objects.cache.clear() 1644 | Node.objects.cache.clear() 1645 | assert (await Node.objects.filter(type__name="link").count()) == 2 1646 | link_node = await NodeType.objects.select_related("default_node").get(name="link") 1647 | assert link_node.default_node.name == "web" 1648 | 1649 | # Check ondelete="SET NULL" 1650 | await web_node.delete() 1651 | del link_node 1652 | NodeType.objects.cache.clear() 1653 | Node.objects.cache.clear() 1654 | link_node = await NodeType.objects.select_related( 1655 | "default_node", outer_join=True 1656 | ).get(name="link") 1657 | assert link_node.default_node is None 1658 | 1659 | # Check ondelete="CASCADE" 1660 | assert (await Node.objects.count()) == 1 1661 | await link_node.delete() 1662 | assert (await Node.objects.count()) == 0 1663 | 1664 | 1665 | def test_invalid_meta_field(): 1666 | with pytest.raises(TypeError): 1667 | 1668 | class TestTable(SQLModel): 1669 | id = Int().tag(primary_key=True) 1670 | 1671 | class Meta: 1672 | # table_name is invalid, use db_table 1673 | table_name = "use db_table" 1674 | 1675 | 1676 | def test_invalid_multiple_pk(): 1677 | with pytest.raises(NotImplementedError): 1678 | 1679 | class TestTable(SQLModel): 1680 | id = Int().tag(primary_key=True) 1681 | id2 = Int().tag(primary_key=True) 1682 | 1683 | 1684 | def test_abstract_tables(): 1685 | class AbstractUser(SQLModel): 1686 | name = Str().tag(length=60) 1687 | 1688 | class Meta: 1689 | abstract = True 1690 | 1691 | class CustomUser(AbstractUser): 1692 | data = Dict() 1693 | 1694 | class CustomUserWithMeta(AbstractUser): 1695 | data = Dict() 1696 | 1697 | class Meta: 1698 | db_table = "custom_user" 1699 | 1700 | class AbstractCustomUser(AbstractUser): 1701 | data = Dict() 1702 | 1703 | class Meta(AbstractUser.Meta): 1704 | db_table = "custom_user2" 1705 | 1706 | class CustomUser2(AbstractCustomUser): 1707 | pass 1708 | 1709 | class CustomUser3(AbstractCustomUser): 1710 | class Meta(AbstractCustomUser.Meta): 1711 | abstract = False 1712 | 1713 | # Attempts to invoke create_table on abstract models should fail 1714 | with pytest.raises(NotImplementedError): 1715 | AbstractUser.objects 1716 | 1717 | # Subclasses of abstract models become concrete so this is ok 1718 | assert CustomUser.objects 1719 | 1720 | # Subclasses of abstract models become with Meta concrete so this is ok 1721 | assert CustomUserWithMeta.objects 1722 | 1723 | # Subclasses that inherit Meta, inherit Meta :) 1724 | with pytest.raises(NotImplementedError): 1725 | AbstractCustomUser.objects # Abstract is inherited in this case 1726 | 1727 | # This is okay too 1728 | CustomUser2.objects 1729 | CustomUser3.objects 1730 | -------------------------------------------------------------------------------- /tests/test_sql_benchmark.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from test_sql import Image, Page, db, reset_tables 3 | 4 | assert db # fix flake8 5 | 6 | 7 | @pytest.mark.benchmark(group="sql-create") 8 | def test_benchmark_create(db, event_loop, benchmark): 9 | event_loop.run_until_complete(reset_tables(Image)) 10 | 11 | async def create_images(): 12 | for i in range(100): 13 | await Image.objects.create( 14 | name=f"Image {i}", 15 | path=f"/media/some/path/{i}", 16 | alpha=i % 255, 17 | # size=(320, 240), 18 | data=b"12345678", 19 | metadata={"tag": "sunset"}, 20 | ) 21 | 22 | def run(): 23 | event_loop.run_until_complete(create_images()) 24 | 25 | benchmark(run) 26 | 27 | 28 | def prepare_benchmark(event_loop, n: int): 29 | event_loop.run_until_complete(reset_tables(Image)) 30 | images = [] 31 | 32 | async def make(): 33 | images.append( 34 | await Image.objects.create( 35 | name=f"Image {i}", 36 | path=f"/media/some/path/{i}", 37 | alpha=i % 255, 38 | # size=(320, 240), 39 | data=b"12345678", 40 | metadata={"tag": "sunset"}, 41 | ) 42 | ) 43 | 44 | for i in range(n): 45 | event_loop.run_until_complete(make()) 46 | 47 | return images 48 | 49 | 50 | @pytest.mark.benchmark(group="sql-get") 51 | def test_benchmark_get(db, event_loop, benchmark): 52 | """Do a normal get query where the item is not in the cache""" 53 | prepare_benchmark(event_loop, n=1) 54 | Image.objects.cache.clear() 55 | 56 | async def task(): 57 | image = await Image.objects.get(name="Image 0") 58 | assert image.name == "Image 0" 59 | 60 | benchmark(lambda: event_loop.run_until_complete(task())) 61 | 62 | 63 | @pytest.mark.benchmark(group="sql-get") 64 | def test_benchmark_get_cached(db, event_loop, benchmark): 65 | """Do a normal get query where the item is in the cache""" 66 | images = prepare_benchmark(event_loop, n=1) 67 | assert images 68 | 69 | async def task(): 70 | image = await Image.objects.get(name="Image 0") 71 | assert image.name == "Image 0" 72 | 73 | benchmark(lambda: event_loop.run_until_complete(task())) 74 | 75 | 76 | @pytest.mark.benchmark(group="sql-get") 77 | def test_benchmark_get_raw(db, event_loop, benchmark): 78 | """Do a prebuilt get query with restoring without cache""" 79 | prepare_benchmark(event_loop, n=1) 80 | q = Image.objects.filter(name="Image 0").query("select") 81 | Image.objects.cache.clear() 82 | 83 | async def task(): 84 | async with Image.objects.connection() as conn: 85 | cursor = await conn.execute(q) 86 | row = await cursor.fetchone() 87 | image = await Image.restore(row) 88 | assert image.name == "Image 0" 89 | 90 | benchmark(lambda: event_loop.run_until_complete(task())) 91 | 92 | 93 | @pytest.mark.benchmark(group="sql-get") 94 | def test_benchmark_get_raw_str(db, event_loop, benchmark): 95 | """Do a prebuilt get query with restoring without cache""" 96 | images = prepare_benchmark(event_loop, n=1) 97 | assert images 98 | q = str(Image.objects.filter(name="Image 0").query("select")) 99 | 100 | async def task(): 101 | async with Image.objects.connection() as conn: 102 | cursor = await conn.execute(q, {"name_1": "Image 0"}) 103 | row = await cursor.fetchone() 104 | image = await Image.restore(row) 105 | assert image.name == "Image 0" 106 | 107 | benchmark(lambda: event_loop.run_until_complete(task())) 108 | 109 | 110 | @pytest.mark.benchmark(group="sql-get") 111 | def test_benchmark_get_raw_cached(db, event_loop, benchmark): 112 | """Do a prebuilt get query with restoring with cache""" 113 | images = prepare_benchmark(event_loop, n=1) 114 | assert images 115 | q = Image.objects.filter(name="Image 0").query("select") 116 | # Image.objects.cache.clear() 117 | 118 | async def task(): 119 | async with Image.objects.connection() as conn: 120 | cursor = await conn.execute(q) 121 | row = await cursor.fetchone() 122 | image = await Image.restore(row) 123 | assert image.name == "Image 0" 124 | 125 | benchmark(lambda: event_loop.run_until_complete(task())) 126 | 127 | 128 | @pytest.mark.benchmark(group="sql-get") 129 | def test_benchmark_get_raw_row(db, event_loop, benchmark): 130 | """Do a prebuilt get query without restoring""" 131 | prepare_benchmark(event_loop, n=1) 132 | q = Image.objects.filter(name="Image 0").query("select") 133 | 134 | async def task(): 135 | async with Image.objects.connection() as conn: 136 | cursor = await conn.execute(q) 137 | row = await cursor.fetchone() 138 | assert row["name"] == "Image 0" 139 | # No restore 140 | 141 | benchmark(lambda: event_loop.run_until_complete(task())) 142 | 143 | 144 | @pytest.mark.benchmark(group="sql-filter") 145 | def test_benchmark_filter(db, event_loop, benchmark): 146 | """Do a filter query where no items are in the cache""" 147 | prepare_benchmark(event_loop, n=1000) 148 | Image.objects.cache.clear() 149 | 150 | async def task(): 151 | results = await Image.objects.filter(alpha__ne=0) 152 | assert len(results) == 996 153 | 154 | benchmark(lambda: event_loop.run_until_complete(task())) 155 | 156 | 157 | @pytest.mark.benchmark(group="sql-filter") 158 | def test_benchmark_filter_cached(db, event_loop, benchmark): 159 | """Do a filter query where all items are in the cache""" 160 | images = prepare_benchmark(event_loop, n=1000) 161 | assert images 162 | 163 | async def task(): 164 | results = await Image.objects.filter(alpha__ne=0) 165 | assert len(results) == 996 166 | 167 | benchmark(lambda: event_loop.run_until_complete(task())) 168 | 169 | 170 | @pytest.mark.benchmark(group="sql-filter") 171 | def test_benchmark_filter_raw(db, event_loop, benchmark): 172 | """Do a raw filter query where no items are in the cache""" 173 | prepare_benchmark(event_loop, n=1000) 174 | Image.objects.cache.clear() 175 | q = Image.objects.filter(alpha__ne=0).query("select") 176 | 177 | async def task(): 178 | async with Image.objects.connection() as conn: 179 | cursor = await conn.execute(q) 180 | results = [await Image.restore(row) for row in await cursor.fetchall()] 181 | assert len(results) == 996 182 | 183 | benchmark(lambda: event_loop.run_until_complete(task())) 184 | 185 | 186 | @pytest.mark.benchmark(group="sql-filter") 187 | def test_benchmark_filter_raw_cached(db, event_loop, benchmark): 188 | """Do a raw filter query where all items are in the cache""" 189 | images = prepare_benchmark(event_loop, n=1000) 190 | assert images 191 | # Image.objects.cache.clear() 192 | q = Image.objects.filter(alpha__ne=0).query("select") 193 | 194 | async def task(): 195 | async with Image.objects.connection() as conn: 196 | cursor = await conn.execute(q) 197 | results = [await Image.restore(row) for row in await cursor.fetchall()] 198 | assert len(results) == 996 199 | 200 | benchmark(lambda: event_loop.run_until_complete(task())) 201 | 202 | 203 | @pytest.mark.benchmark(group="sql-filter") 204 | def test_benchmark_filter_raw_row(db, event_loop, benchmark): 205 | """Do a raw filter query without restoring item from rows""" 206 | prepare_benchmark(event_loop, n=1000) 207 | # Image.objects.cache.clear() 208 | q = Image.objects.filter(alpha__ne=0).query("select") 209 | 210 | async def task(): 211 | async with Image.objects.connection() as conn: 212 | cursor = await conn.execute(q) 213 | # No restore 214 | results = [row for row in await cursor.fetchall()] 215 | assert len(results) == 996 216 | 217 | benchmark(lambda: event_loop.run_until_complete(task())) 218 | 219 | 220 | @pytest.mark.benchmark(group="sql-build-query") 221 | def test_benchmark_filter_related_query(db, benchmark): 222 | def query(): 223 | Page.objects.filter(author__name="Tom", status="live") 224 | 225 | benchmark(query) 226 | 227 | 228 | @pytest.mark.benchmark(group="sql-build-query") 229 | def test_benchmark_filter_query(db, benchmark): 230 | def query(): 231 | Page.objects.filter(status="live") 232 | 233 | benchmark(query) 234 | 235 | 236 | @pytest.mark.benchmark(group="sql-build-query") 237 | def test_benchmark_filter_query_ordered(db, benchmark): 238 | def query(): 239 | Page.objects.filter(status="live").order_by("last_updated") 240 | 241 | benchmark(query) 242 | --------------------------------------------------------------------------------