├── .codecov.yml ├── .github ├── ISSUE_TEMPLATE │ ├── 2-bug-report.md │ ├── 3-feature-request.md │ └── config.yml └── workflows │ ├── publish.yml │ └── test-suite.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── docs ├── declaring_models.md ├── index.md ├── making_queries.md └── relationships.md ├── mkdocs.yml ├── orm ├── __init__.py ├── constants.py ├── exceptions.py ├── fields.py ├── models.py └── sqlalchemy_fields.py ├── requirements.txt ├── scripts ├── README.md ├── build ├── check ├── clean ├── coverage ├── docs ├── install ├── lint ├── publish └── test ├── setup.cfg ├── setup.py └── tests ├── conftest.py ├── settings.py ├── test_columns.py ├── test_foreignkey.py └── test_models.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: down 4 | range: "80...100" 5 | 6 | status: 7 | project: yes 8 | patch: yes 9 | changes: yes 10 | 11 | comment: off 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report a bug to help improve this project 4 | --- 5 | 6 | ### Checklist 7 | 8 | 9 | 10 | - [ ] The bug is reproducible against the latest release and/or `master`. 11 | - [ ] There are no similar issues or pull requests to fix it yet. 12 | 13 | ### Describe the bug 14 | 15 | 16 | 17 | ### To reproduce 18 | 19 | 25 | 26 | ### Expected behavior 27 | 28 | 29 | 30 | ### Actual behavior 31 | 32 | 33 | 34 | ### Debugging material 35 | 36 | 42 | 43 | ### Environment 44 | 45 | - OS: 46 | - Python version: 47 | - ORM version: 48 | 49 | ### Additional context 50 | 51 | 54 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3-feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project. 4 | --- 5 | 6 | ### Checklist 7 | 8 | 9 | 10 | - [ ] There are no similar issues or pull requests for this yet. 11 | - [ ] I discussed this idea on the [community chat](https://gitter.im/encode/community) and feedback is positive. 12 | 13 | ### Is your feature related to a problem? Please describe. 14 | 15 | 18 | 19 | ## Describe the solution you would like. 20 | 21 | 25 | 26 | ## Describe alternatives you considered 27 | 28 | 30 | 31 | ## Additional context 32 | 33 | 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | # Ref: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository#configuring-the-template-chooser 2 | blank_issues_enabled: true 3 | contact_links: 4 | - name: Question 5 | url: https://gitter.im/encode/community 6 | about: > 7 | Ask a question 8 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Publish 3 | 4 | on: 5 | push: 6 | tags: 7 | - '*' 8 | 9 | jobs: 10 | publish: 11 | name: "Publish release" 12 | runs-on: "ubuntu-latest" 13 | 14 | steps: 15 | - uses: "actions/checkout@v2" 16 | - uses: "actions/setup-python@v2" 17 | with: 18 | python-version: 3.7 19 | - name: "Install dependencies" 20 | run: "scripts/install" 21 | - name: "Build package & docs" 22 | run: "scripts/build" 23 | - name: "Publish to PyPI & deploy docs" 24 | run: "scripts/publish" 25 | env: 26 | TWINE_USERNAME: __token__ 27 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 28 | -------------------------------------------------------------------------------- /.github/workflows/test-suite.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Test Suite 3 | 4 | on: 5 | push: 6 | branches: ["master"] 7 | pull_request: 8 | branches: ["master"] 9 | 10 | jobs: 11 | tests: 12 | name: "Python ${{ matrix.python-version }}" 13 | runs-on: "ubuntu-latest" 14 | 15 | strategy: 16 | matrix: 17 | python-version: ["3.7", "3.8", "3.9", "3.10"] 18 | 19 | services: 20 | mysql: 21 | image: mysql:5.7 22 | env: 23 | MYSQL_USER: username 24 | MYSQL_PASSWORD: password 25 | MYSQL_ROOT_PASSWORD: password 26 | MYSQL_DATABASE: testsuite 27 | ports: 28 | - 3306:3306 29 | options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 30 | 31 | postgres: 32 | image: postgres:10.8 33 | env: 34 | POSTGRES_USER: username 35 | POSTGRES_PASSWORD: password 36 | POSTGRES_DB: testsuite 37 | ports: 38 | - 5432:5432 39 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 40 | 41 | steps: 42 | - uses: "actions/checkout@v2" 43 | - uses: "actions/setup-python@v2" 44 | with: 45 | python-version: "${{ matrix.python-version }}" 46 | - name: "Install dependencies" 47 | run: "scripts/install" 48 | - name: "Run linting checks" 49 | run: "scripts/check" 50 | - name: "Build package & docs" 51 | run: "scripts/build" 52 | - name: "Run tests with PostgreSQL" 53 | env: 54 | TEST_DATABASE_URL: "postgresql://username:password@localhost:5432/testsuite" 55 | run: "scripts/test" 56 | - name: "Run tests with MySQL" 57 | env: 58 | TEST_DATABASE_URL: "mysql://username:password@localhost:3306/testsuite" 59 | run: "scripts/test" 60 | - name: "Run tests with SQLite" 61 | env: 62 | TEST_DATABASE_URL: "sqlite:///testsuite" 63 | run: "scripts/test" 64 | - name: "Enforce coverage" 65 | run: "scripts/coverage" 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .coverage 3 | .pytest_cache/ 4 | .mypy_cache/ 5 | *.egg-info/ 6 | htmlcov/ 7 | venv/ 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/). 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ORM 2 | 3 |

4 | 5 | Build Status 6 | 7 | 8 | Coverage 9 | 10 | 11 | Package version 12 | 13 |

14 | 15 | The `orm` package is an async ORM for Python, with support for Postgres, 16 | MySQL, and SQLite. ORM is built with: 17 | 18 | * [SQLAlchemy core][sqlalchemy-core] for query building. 19 | * [`databases`][databases] for cross-database async support. 20 | * [`typesystem`][typesystem] for data validation. 21 | 22 | Because ORM is built on SQLAlchemy core, you can use Alembic to provide 23 | database migrations. 24 | 25 | --- 26 | 27 | **Documentation**: [https://www.encode.io/orm](https://www.encode.io/orm) 28 | 29 | --- 30 | 31 | ## Installation 32 | 33 | ```shell 34 | $ pip install orm 35 | ``` 36 | 37 | You can install the required database drivers with: 38 | 39 | ```shell 40 | $ pip install orm[postgresql] 41 | $ pip install orm[mysql] 42 | $ pip install orm[sqlite] 43 | ``` 44 | 45 | Driver support is provided using one of [asyncpg][asyncpg], [aiomysql][aiomysql], or [aiosqlite][aiosqlite]. 46 | 47 | --- 48 | 49 | ## Quickstart 50 | 51 | **Note**: Use `ipython` to try this from the console, since it supports `await`. 52 | 53 | ```python 54 | import databases 55 | import orm 56 | 57 | database = databases.Database("sqlite:///db.sqlite") 58 | models = orm.ModelRegistry(database=database) 59 | 60 | 61 | class Note(orm.Model): 62 | tablename = "notes" 63 | registry = models 64 | fields = { 65 | "id": orm.Integer(primary_key=True), 66 | "text": orm.String(max_length=100), 67 | "completed": orm.Boolean(default=False), 68 | } 69 | 70 | # Create the tables 71 | await models.create_all() 72 | 73 | await Note.objects.create(text="Buy the groceries.", completed=False) 74 | 75 | note = await Note.objects.get(id=1) 76 | print(note) 77 | # Note(id=1) 78 | ``` 79 | 80 |

— 🗃 —

81 |

ORM is BSD licensed code. Designed & built in Brighton, England.

82 | 83 | [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ 84 | [asyncpg]: https://github.com/MagicStack/asyncpg 85 | [aiomysql]: https://github.com/aio-libs/aiomysql 86 | [aiosqlite]: https://github.com/jreese/aiosqlite 87 | 88 | [databases]: https://github.com/encode/databases 89 | [typesystem]: https://github.com/encode/typesystem 90 | [typesystem-fields]: https://www.encode.io/typesystem/fields/ 91 | -------------------------------------------------------------------------------- /docs/declaring_models.md: -------------------------------------------------------------------------------- 1 | ## Declaring models 2 | 3 | You can define models by inheriting from `orm.Model` and 4 | defining model fields in the `fields` attribute. 5 | For each defined model you need to set two special variables: 6 | 7 | * `registry` an instance of `orm.ModelRegistry` 8 | * `fields` a `dict` of `orm` fields 9 | 10 | You can also specify the table name in database by setting `tablename` attribute. 11 | 12 | ```python 13 | import databases 14 | import orm 15 | 16 | database = databases.Database("sqlite:///db.sqlite") 17 | models = orm.ModelRegistry(database=database) 18 | 19 | 20 | class Note(orm.Model): 21 | tablename = "notes" 22 | registry = models 23 | fields = { 24 | "id": orm.Integer(primary_key=True), 25 | "text": orm.String(max_length=100), 26 | "completed": orm.Boolean(default=False), 27 | } 28 | ``` 29 | 30 | ORM can create or drop database and tables from models using SQLAlchemy. 31 | You can use the following methods: 32 | 33 | ```python 34 | await models.create_all() 35 | 36 | await models.drop_all() 37 | ``` 38 | 39 | ## Data types 40 | 41 | The following keyword arguments are supported on all field types. 42 | 43 | * `primary_key` - A boolean. Determine if column is primary key. 44 | * `allow_null` - A boolean. Determine if column is nullable. 45 | * `default` - A value or a callable (function). 46 | * `index` - A boolean. Determine if database indexes should be created. 47 | * `unique` - A boolean. Determine if unique constraint should be created. 48 | 49 | All fields are required unless one of the following is set: 50 | 51 | * `allow_null` - A boolean. Determine if column is nullable. Sets the default to `None`. 52 | * `allow_blank` - A boolean. Determine if empty strings are allowed. Sets the default to `""`. 53 | * `default` - A value or a callable (function). 54 | 55 | Special keyword arguments for `DateTime` and `Date` fields: 56 | 57 | * `auto_now` - Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. 58 | * `auto_now_add` - Automatically set the field to now when the object is first created. Useful for creation of timestamps. 59 | 60 | Default=`datetime.date.today()` for `DateField` and `datetime.datetime.now()` for `DateTimeField`. 61 | 62 | !!! note 63 | Setting `auto_now` or `auto_now_add` to True will cause the field to be read_only. 64 | 65 | The following column types are supported. 66 | See `TypeSystem` for [type-specific validation keyword arguments][typesystem-fields]. 67 | 68 | * `orm.BigInteger()` 69 | * `orm.Boolean()` 70 | * `orm.Date(auto_now, auto_now_add)` 71 | * `orm.DateTime(auto_now, auto_now_add)` 72 | * `orm.Decimal()` 73 | * `orm.Email(max_length)` 74 | * `orm.Enum()` 75 | * `orm.Float()` 76 | * `orm.Integer()` 77 | * `orm.IPAddress()` 78 | * `orm.String(max_length)` 79 | * `orm.Text()` 80 | * `orm.Time()` 81 | * `orm.URL(max_length)` 82 | * `orm.UUID()` 83 | * `orm.JSON()` 84 | 85 | [typesystem-fields]: https://www.encode.io/typesystem/fields/ 86 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # ORM 2 | 3 |

4 | 5 | Build Status 6 | 7 | 8 | Coverage 9 | 10 | 11 | Package version 12 | 13 |

14 | 15 | The `orm` package is an async ORM for Python, with support for Postgres, 16 | MySQL, and SQLite. ORM is built with: 17 | 18 | * [SQLAlchemy core][sqlalchemy-core] for query building. 19 | * [`databases`][databases] for cross-database async support. 20 | * [`typesystem`][typesystem] for data validation. 21 | 22 | Because ORM is built on SQLAlchemy core, you can use Alembic to provide 23 | database migrations. 24 | 25 | **ORM is still under development: We recommend pinning any dependencies with `orm~=0.2`** 26 | 27 | --- 28 | 29 | ## Installation 30 | 31 | ```shell 32 | $ pip install orm 33 | ``` 34 | 35 | You can install the required database drivers with: 36 | 37 | ```shell 38 | $ pip install orm[postgresql] 39 | $ pip install orm[mysql] 40 | $ pip install orm[sqlite] 41 | ``` 42 | 43 | Driver support is provided using one of [asyncpg][asyncpg], [aiomysql][aiomysql], or [aiosqlite][aiosqlite]. 44 | 45 | --- 46 | 47 | ## Quickstart 48 | 49 | **Note**: Use `ipython` to try this from the console, since it supports `await`. 50 | 51 | ```python 52 | import databases 53 | import orm 54 | 55 | database = databases.Database("sqlite:///db.sqlite") 56 | models = orm.ModelRegistry(database=database) 57 | 58 | 59 | class Note(orm.Model): 60 | tablename = "notes" 61 | registry = models 62 | fields = { 63 | "id": orm.Integer(primary_key=True), 64 | "text": orm.String(max_length=100), 65 | "completed": orm.Boolean(default=False), 66 | } 67 | 68 | # Create the database and tables 69 | await models.create_all() 70 | 71 | await Note.objects.create(text="Buy the groceries.", completed=False) 72 | 73 | note = await Note.objects.get(id=1) 74 | print(note) 75 | # Note(id=1) 76 | ``` 77 | 78 | [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ 79 | [asyncpg]: https://github.com/MagicStack/asyncpg 80 | [aiomysql]: https://github.com/aio-libs/aiomysql 81 | [aiosqlite]: https://github.com/jreese/aiosqlite 82 | 83 | [databases]: https://github.com/encode/databases 84 | [typesystem]: https://github.com/encode/typesystem 85 | [typesystem-fields]: https://www.encode.io/typesystem/fields/ 86 | -------------------------------------------------------------------------------- /docs/making_queries.md: -------------------------------------------------------------------------------- 1 | Let's say you have the following model defined: 2 | 3 | ```python 4 | import databases 5 | import orm 6 | 7 | database = databases.Database("sqlite:///db.sqlite") 8 | models = orm.ModelRegistry(database=database) 9 | 10 | 11 | class Note(orm.Model): 12 | tablename = "notes" 13 | registry = models 14 | fields = { 15 | "id": orm.Integer(primary_key=True), 16 | "text": orm.String(max_length=100), 17 | "completed": orm.Boolean(default=False), 18 | } 19 | ``` 20 | 21 | ORM supports two types of queryset methods. 22 | Some queryset methods return another queryset and can be chained together like `.filter()` and `order_by`: 23 | 24 | ```python 25 | Note.objects.filter(completed=True).order_by("id") 26 | ``` 27 | 28 | Other queryset methods return results and should be used as final method on the queryset like `.all()`: 29 | 30 | ```python 31 | Note.objects.filter(completed=True).all() 32 | ``` 33 | 34 | ## Returning Querysets 35 | 36 | ### .exclude() 37 | 38 | To exclude instances: 39 | 40 | ```python 41 | notes = await Note.objects.exclude(completed=False).all() 42 | ``` 43 | 44 | ### .filter() 45 | 46 | #### Django-style lookup 47 | 48 | To filter instances: 49 | 50 | ```python 51 | notes = await Note.objects.filter(completed=True).all() 52 | ``` 53 | 54 | There are some special operators defined automatically on every column: 55 | 56 | * `in` - SQL `IN` operator. 57 | * `exact` - filter instances matching exact value. 58 | * `iexact` - same as `exact` but case-insensitive. 59 | * `contains` - filter instances containing value. 60 | * `icontains` - same as `contains` but case-insensitive. 61 | * `lt` - filter instances having value `Less Than`. 62 | * `lte` - filter instances having value `Less Than Equal`. 63 | * `gt` - filter instances having value `Greater Than`. 64 | * `gte` - filter instances having value `Greater Than Equal`. 65 | 66 | Example usage: 67 | 68 | ```python 69 | notes = await Note.objects.filter(text__icontains="mum").all() 70 | 71 | notes = await Note.objects.filter(id__in=[1, 2, 3]).all() 72 | ``` 73 | 74 | #### SQLAlchemy filter operators 75 | 76 | The `filter` method also accepts SQLAlchemy filter operators: 77 | 78 | ```python 79 | notes = await Note.objects.filter(Note.columns.text.contains("mum")).all() 80 | 81 | notes = await Note.objects.filter(Note.columns.id.in_([1, 2, 3])).all() 82 | ``` 83 | 84 | Here `Note.columns` refers to the columns of the underlying SQLAlchemy table. 85 | 86 | !!! note 87 | Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields. 88 | 89 | ### .limit() 90 | 91 | To limit number of results: 92 | 93 | ```python 94 | await Note.objects.limit(1).all() 95 | ``` 96 | 97 | ### .offset() 98 | 99 | To apply offset to query results: 100 | 101 | ```python 102 | await Note.objects.offset(1).all() 103 | ``` 104 | 105 | As mentioned before, you can chain multiple queryset methods together to form a query. 106 | As an exmaple: 107 | 108 | ```python 109 | await Note.objects.order_by("id").limit(1).offset(1).all() 110 | await Note.objects.filter(text__icontains="mum").limit(2).all() 111 | ``` 112 | 113 | ### .order_by() 114 | 115 | To order query results: 116 | 117 | ```python 118 | notes = await Note.objects.order_by("text", "-id").all() 119 | ``` 120 | 121 | !!! note 122 | This will sort by ascending `text` and descending `id`. 123 | 124 | ## Returning results 125 | 126 | ### .all() 127 | 128 | To retrieve all the instances: 129 | 130 | ```python 131 | notes = await Note.objects.all() 132 | ``` 133 | 134 | ### .create() 135 | 136 | You need to pass the required model attributes and values to the `.create()` method: 137 | 138 | ```python 139 | await Note.objects.create(text="Buy the groceries.", completed=False) 140 | await Note.objects.create(text="Call Mum.", completed=True) 141 | await Note.objects.create(text="Send invoices.", completed=True) 142 | ``` 143 | 144 | ### .bulk_create() 145 | 146 | You need to pass a list of dictionaries of required fields to create multiple objects: 147 | 148 | ```python 149 | await Product.objects.bulk_create( 150 | [ 151 | {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, 152 | {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, 153 | 154 | ] 155 | ) 156 | ``` 157 | 158 | ### .delete() 159 | 160 | You can `delete` instances by calling `.delete()` on a queryset: 161 | 162 | ```python 163 | await Note.objects.filter(completed=True).delete() 164 | ``` 165 | 166 | It's not very common, but to delete all rows in a table: 167 | 168 | ```python 169 | await Note.objects.delete() 170 | ``` 171 | 172 | You can also call `.delete()` on a queried instance: 173 | 174 | ```python 175 | note = await Note.objects.first() 176 | 177 | await note.delete() 178 | ``` 179 | 180 | ### .exists() 181 | 182 | To check if any instances matching the query exist. Returns `True` or `False`. 183 | 184 | ```python 185 | await Note.objects.filter(completed=True).exists() 186 | ``` 187 | 188 | ### .first() 189 | 190 | This will return the first instance or `None`: 191 | 192 | ```python 193 | note = await Note.objects.filter(completed=True).first() 194 | ``` 195 | 196 | `pk` always refers to the model's primary key field: 197 | 198 | ```python 199 | note = await Note.objects.get(pk=2) 200 | note.pk # 2 201 | ``` 202 | 203 | ### .get() 204 | 205 | To get only one instance: 206 | 207 | ```python 208 | note = await Note.objects.get(id=1) 209 | ``` 210 | 211 | !!! note 212 | `.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`. 213 | 214 | ### .update() 215 | 216 | You can update instances by calling `.update()` on a queryset: 217 | 218 | ```python 219 | await Note.objects.filter(completed=True).update(completed=False) 220 | ``` 221 | 222 | It's not very common, but to update all rows in a table: 223 | 224 | ```python 225 | await Note.objects.update(completed=False) 226 | ``` 227 | 228 | You can also call `.update()` on a queried instance: 229 | 230 | ```python 231 | note = await Note.objects.first() 232 | 233 | await note.update(completed=True) 234 | ``` 235 | 236 | ## Convenience Methods 237 | 238 | ### .get_or_create() 239 | 240 | To get an existing instance matching the query, or create a new one. 241 | This will return a tuple of `instance` and `created`. 242 | 243 | ```python 244 | note, created = await Note.objects.get_or_create( 245 | text="Going to car wash", defaults={"completed": False} 246 | ) 247 | ``` 248 | 249 | This will query a `Note` with `text` as `"Going to car wash"`, 250 | if it doesn't exist, it will use `defaults` argument to create the new instance. 251 | 252 | !!! note 253 | Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. 254 | 255 | 256 | ### .update_or_create() 257 | 258 | To update an existing instance matching the query, or create a new one. 259 | This will return a tuple of `instance` and `created`. 260 | 261 | ```python 262 | note, created = await Note.objects.update_or_create( 263 | text="Going to car wash", defaults={"completed": True} 264 | ) 265 | ``` 266 | 267 | This will query a `Note` with `text` as `"Going to car wash"`, 268 | if an instance is found, it will use the `defaults` argument to update the instance. 269 | If it matches no records, it will use the combination of arguments to create the new instance. 270 | 271 | !!! note 272 | Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. 273 | -------------------------------------------------------------------------------- /docs/relationships.md: -------------------------------------------------------------------------------- 1 | ## ForeignKey 2 | 3 | ### Defining and querying relationships 4 | 5 | ORM supports loading and filtering across foreign keys. 6 | 7 | Let's say you have the following models defined: 8 | 9 | ```python 10 | import databases 11 | import orm 12 | 13 | database = databases.Database("sqlite:///db.sqlite") 14 | models = orm.ModelRegistry(database=database) 15 | 16 | 17 | class Album(orm.Model): 18 | tablename = "albums" 19 | registry = models 20 | fields = { 21 | "id": orm.Integer(primary_key=True), 22 | "name": orm.String(max_length=100), 23 | } 24 | 25 | 26 | class Track(orm.Model): 27 | tablename = "tracks" 28 | registry = models 29 | fields = { 30 | "id": orm.Integer(primary_key=True), 31 | "album": orm.ForeignKey(Album), 32 | "title": orm.String(max_length=100), 33 | "position": orm.Integer(), 34 | } 35 | ``` 36 | 37 | You can create some `Album` and `Track` instances: 38 | 39 | ```python 40 | malibu = await Album.objects.create(name="Malibu") 41 | await Track.objects.create(album=malibu, title="The Bird", position=1) 42 | await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) 43 | await Track.objects.create(album=malibu, title="The Waters", position=3) 44 | 45 | fantasies = await Album.objects.create(name="Fantasies") 46 | await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) 47 | await Track.objects.create(album=fantasies, title="Sick Muse", position=2) 48 | ``` 49 | 50 | To fetch an instance, without loading a foreign key relationship on it: 51 | 52 | ```python 53 | track = await Track.objects.get(title="The Bird") 54 | 55 | # We have an album instance, but it only has the primary key populated 56 | print(track.album) # Album(id=1) [sparse] 57 | print(track.album.pk) # 1 58 | print(track.album.name) # Raises AttributeError 59 | ``` 60 | 61 | You can load the relationship from the database: 62 | 63 | ```python 64 | await track.album.load() 65 | assert track.album.name == "Malibu" 66 | ``` 67 | 68 | You can also fetch an instance, loading the foreign key relationship with it: 69 | 70 | ```python 71 | track = await Track.objects.select_related("album").get(title="The Bird") 72 | assert track.album.name == "Malibu" 73 | ``` 74 | 75 | To fetch an instance, filtering across a foreign key relationship: 76 | 77 | ```python 78 | tracks = Track.objects.filter(album__name="Fantasies") 79 | assert len(tracks) == 2 80 | 81 | tracks = Track.objects.filter(album__name__iexact="fantasies") 82 | assert len(tracks) == 2 83 | ``` 84 | 85 | ### ForeignKey constraints 86 | 87 | `ForeigknKey` supports specifying a constraint through `on_delete` argument. 88 | 89 | This will result in a SQL `ON DELETE` query being generated when the referenced object is removed. 90 | 91 | With the following definition: 92 | 93 | ```python 94 | class Album(orm.Model): 95 | tablename = "albums" 96 | registry = models 97 | fields = { 98 | "id": orm.Integer(primary_key=True), 99 | "name": orm.String(max_length=100), 100 | } 101 | 102 | 103 | class Track(orm.Model): 104 | tablename = "tracks" 105 | registry = models 106 | fields = { 107 | "id": orm.Integer(primary_key=True), 108 | "album": orm.ForeignKey(Album, on_delete=orm.CASCADE), 109 | "title": orm.String(max_length=100), 110 | } 111 | ``` 112 | 113 | `Track` model defines `orm.ForeignKey(Album, on_delete=orm.CASCADE)` so whenever an `Album` object is removed, 114 | all `Track` objects referencing that `Album` will also be removed. 115 | 116 | Available options for `on_delete` are: 117 | 118 | * `CASCADE` 119 | 120 | This will remove all referencing objects. 121 | 122 | * `RESTRICT` 123 | 124 | This will restrict removing referenced object, if there are objects referencing it. 125 | A database driver exception will be raised. 126 | 127 | * `SET NULL` 128 | 129 | This will set referencing objects `ForeignKey` column to `NULL`. 130 | The `ForeignKey` defined here should also have `allow_null=True`. 131 | 132 | 133 | ## OneToOne 134 | 135 | Creating a `OneToOne` relationship between models, this is basically 136 | the same as `ForeignKey` but it uses `unique=True` on the ForeignKey column: 137 | 138 | ```python 139 | class Profile(orm.Model): 140 | registry = models 141 | fields = { 142 | "id": orm.Integer(primary_key=True), 143 | "website": orm.String(max_length=100), 144 | } 145 | 146 | 147 | class Person(orm.Model): 148 | registry = models 149 | fields = { 150 | "id": orm.Integer(primary_key=True), 151 | "email": orm.String(max_length=100), 152 | "profile": orm.OneToOne(Profile), 153 | } 154 | ``` 155 | 156 | You can create a `Profile` and `Person` instance: 157 | 158 | ```python 159 | profile = await Profile.objects.create(website="https://encode.io") 160 | await Person.objects.create(email="info@encode.io", profile=profile) 161 | ``` 162 | 163 | Now creating another `Person` using the same `profile` will fail 164 | and will raise an exception: 165 | 166 | ```python 167 | await Person.objects.create(email="info@encode.io", profile=profile) 168 | ``` 169 | 170 | `OneToOne` accepts the same `on_delete` parameters as `ForeignKey` which is 171 | described [here](#foreignkey-constraints). 172 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: ORM 2 | site_description: An async ORM. 3 | 4 | theme: 5 | name: 'material' 6 | 7 | repo_name: encode/orm 8 | repo_url: https://github.com/encode/orm 9 | # edit_uri: "" 10 | 11 | nav: 12 | - Introduction: 'index.md' 13 | - Declaring Models: 'declaring_models.md' 14 | - Making Queries: 'making_queries.md' 15 | - Relationships: 'relationships.md' 16 | 17 | markdown_extensions: 18 | - mkautodoc 19 | - admonition 20 | - pymdownx.highlight 21 | - pymdownx.superfences 22 | -------------------------------------------------------------------------------- /orm/__init__.py: -------------------------------------------------------------------------------- 1 | from orm.constants import CASCADE, RESTRICT, SET_NULL 2 | from orm.exceptions import MultipleMatches, NoMatch 3 | from orm.fields import ( 4 | JSON, 5 | URL, 6 | UUID, 7 | BigInteger, 8 | Boolean, 9 | Date, 10 | DateTime, 11 | Decimal, 12 | Email, 13 | Enum, 14 | Float, 15 | ForeignKey, 16 | Integer, 17 | IPAddress, 18 | OneToOne, 19 | String, 20 | Text, 21 | Time, 22 | ) 23 | from orm.models import Model, ModelRegistry 24 | 25 | __version__ = "0.3.1" 26 | __all__ = [ 27 | "CASCADE", 28 | "RESTRICT", 29 | "SET_NULL", 30 | "NoMatch", 31 | "MultipleMatches", 32 | "BigInteger", 33 | "Boolean", 34 | "Date", 35 | "DateTime", 36 | "Decimal", 37 | "Email", 38 | "Enum", 39 | "Float", 40 | "ForeignKey", 41 | "Integer", 42 | "IPAddress", 43 | "JSON", 44 | "OneToOne", 45 | "String", 46 | "Text", 47 | "Time", 48 | "URL", 49 | "UUID", 50 | "Model", 51 | "ModelRegistry", 52 | ] 53 | -------------------------------------------------------------------------------- /orm/constants.py: -------------------------------------------------------------------------------- 1 | CASCADE = "CASCADE" 2 | RESTRICT = "RESTRICT" 3 | SET_NULL = "SET NULL" 4 | -------------------------------------------------------------------------------- /orm/exceptions.py: -------------------------------------------------------------------------------- 1 | class NoMatch(Exception): 2 | pass 3 | 4 | 5 | class MultipleMatches(Exception): 6 | pass 7 | -------------------------------------------------------------------------------- /orm/fields.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from datetime import date, datetime 3 | 4 | import sqlalchemy 5 | import typesystem 6 | 7 | from orm.sqlalchemy_fields import GUID, GenericIP 8 | 9 | 10 | class ModelField: 11 | def __init__( 12 | self, 13 | primary_key: bool = False, 14 | index: bool = False, 15 | unique: bool = False, 16 | **kwargs: typing.Any, 17 | ) -> None: 18 | if primary_key: 19 | kwargs["read_only"] = True 20 | self.allow_null = kwargs.get("allow_null", False) 21 | self.primary_key = primary_key 22 | self.index = index 23 | self.unique = unique 24 | self.validator = self.get_validator(**kwargs) 25 | 26 | def get_column(self, name: str) -> sqlalchemy.Column: 27 | column_type = self.get_column_type() 28 | constraints = self.get_constraints() 29 | return sqlalchemy.Column( 30 | name, 31 | column_type, 32 | *constraints, 33 | primary_key=self.primary_key, 34 | nullable=self.allow_null and not self.primary_key, 35 | index=self.index, 36 | unique=self.unique, 37 | ) 38 | 39 | def get_validator(self, **kwargs) -> typesystem.Field: 40 | raise NotImplementedError() # pragma: no cover 41 | 42 | def get_column_type(self) -> sqlalchemy.types.TypeEngine: 43 | raise NotImplementedError() # pragma: no cover 44 | 45 | def get_constraints(self): 46 | return [] 47 | 48 | def expand_relationship(self, value): 49 | return value 50 | 51 | 52 | class String(ModelField): 53 | def __init__(self, **kwargs): 54 | assert "max_length" in kwargs, "max_length is required" 55 | super().__init__(**kwargs) 56 | 57 | def get_validator(self, **kwargs) -> typesystem.Field: 58 | return typesystem.String(**kwargs) 59 | 60 | def get_column_type(self): 61 | return sqlalchemy.String(length=self.validator.max_length) 62 | 63 | 64 | class Text(ModelField): 65 | def get_validator(self, **kwargs) -> typesystem.Field: 66 | return typesystem.Text(**kwargs) 67 | 68 | def get_column_type(self): 69 | return sqlalchemy.Text() 70 | 71 | 72 | class Integer(ModelField): 73 | def get_validator(self, **kwargs) -> typesystem.Field: 74 | return typesystem.Integer(**kwargs) 75 | 76 | def get_column_type(self): 77 | return sqlalchemy.Integer() 78 | 79 | 80 | class Float(ModelField): 81 | def get_validator(self, **kwargs) -> typesystem.Field: 82 | return typesystem.Float(**kwargs) 83 | 84 | def get_column_type(self): 85 | return sqlalchemy.Float() 86 | 87 | 88 | class BigInteger(ModelField): 89 | def get_validator(self, **kwargs) -> typesystem.Field: 90 | return typesystem.Integer(**kwargs) 91 | 92 | def get_column_type(self): 93 | return sqlalchemy.BigInteger() 94 | 95 | 96 | class Boolean(ModelField): 97 | def get_validator(self, **kwargs) -> typesystem.Field: 98 | return typesystem.Boolean(**kwargs) 99 | 100 | def get_column_type(self): 101 | return sqlalchemy.Boolean() 102 | 103 | 104 | class AutoNowMixin(ModelField): 105 | def __init__(self, auto_now=False, auto_now_add=False, **kwargs): 106 | self.auto_now = auto_now 107 | self.auto_now_add = auto_now_add 108 | if auto_now_add and auto_now: 109 | raise ValueError("auto_now and auto_now_add cannot be both True") 110 | if auto_now_add or auto_now: 111 | kwargs["read_only"] = True 112 | super().__init__(**kwargs) 113 | 114 | 115 | class DateTime(AutoNowMixin): 116 | def get_validator(self, **kwargs) -> typesystem.Field: 117 | if self.auto_now_add or self.auto_now: 118 | kwargs["default"] = datetime.now 119 | return typesystem.DateTime(**kwargs) 120 | 121 | def get_column_type(self): 122 | return sqlalchemy.DateTime() 123 | 124 | 125 | class Date(AutoNowMixin): 126 | def get_validator(self, **kwargs) -> typesystem.Field: 127 | if self.auto_now_add or self.auto_now: 128 | kwargs["default"] = date.today 129 | return typesystem.Date(**kwargs) 130 | 131 | def get_column_type(self): 132 | return sqlalchemy.Date() 133 | 134 | 135 | class Time(ModelField): 136 | def get_validator(self, **kwargs) -> typesystem.Field: 137 | return typesystem.Time(**kwargs) 138 | 139 | def get_column_type(self): 140 | return sqlalchemy.Time() 141 | 142 | 143 | class JSON(ModelField): 144 | def get_validator(self, **kwargs) -> typesystem.Field: 145 | return typesystem.Any(**kwargs) 146 | 147 | def get_column_type(self): 148 | return sqlalchemy.JSON() 149 | 150 | 151 | class ForeignKey(ModelField): 152 | class ForeignKeyValidator(typesystem.Field): 153 | def validate(self, value): 154 | return value.pk 155 | 156 | def __init__( 157 | self, to, allow_null: bool = False, on_delete: typing.Optional[str] = None 158 | ): 159 | super().__init__(allow_null=allow_null) 160 | self.to = to 161 | self.on_delete = on_delete 162 | 163 | @property 164 | def target(self): 165 | if not hasattr(self, "_target"): 166 | if isinstance(self.to, str): 167 | self._target = self.registry.models[self.to] 168 | else: 169 | self._target = self.to 170 | return self._target 171 | 172 | def get_validator(self, **kwargs) -> typesystem.Field: 173 | return self.ForeignKeyValidator(**kwargs) 174 | 175 | def get_column(self, name: str) -> sqlalchemy.Column: 176 | target = self.target 177 | to_field = target.fields[target.pkname] 178 | 179 | column_type = to_field.get_column_type() 180 | constraints = [ 181 | sqlalchemy.schema.ForeignKey( 182 | f"{target.tablename}.{target.pkname}", ondelete=self.on_delete 183 | ) 184 | ] 185 | return sqlalchemy.Column( 186 | name, 187 | column_type, 188 | *constraints, 189 | nullable=self.allow_null, 190 | ) 191 | 192 | def expand_relationship(self, value): 193 | target = self.target 194 | if isinstance(value, target): 195 | return value 196 | return target(pk=value) 197 | 198 | 199 | class OneToOne(ForeignKey): 200 | def get_column(self, name: str) -> sqlalchemy.Column: 201 | target = self.target 202 | to_field = target.fields[target.pkname] 203 | 204 | column_type = to_field.get_column_type() 205 | constraints = [ 206 | sqlalchemy.schema.ForeignKey( 207 | f"{target.tablename}.{target.pkname}", ondelete=self.on_delete 208 | ), 209 | ] 210 | 211 | return sqlalchemy.Column( 212 | name, 213 | column_type, 214 | *constraints, 215 | nullable=self.allow_null, 216 | unique=True, 217 | ) 218 | 219 | 220 | class Enum(ModelField): 221 | def __init__(self, enum, **kwargs): 222 | super().__init__(**kwargs) 223 | self.enum = enum 224 | 225 | def get_validator(self, **kwargs) -> typesystem.Field: 226 | return typesystem.Any(**kwargs) 227 | 228 | def get_column_type(self): 229 | return sqlalchemy.Enum(self.enum) 230 | 231 | 232 | class Decimal(ModelField): 233 | def __init__(self, max_digits: int, decimal_places: int, **kwargs): 234 | assert max_digits, "max_digits is required" 235 | assert decimal_places, "decimal_places is required" 236 | self.max_digits = max_digits 237 | self.decimal_places = decimal_places 238 | super().__init__(**kwargs) 239 | 240 | def get_validator(self, **kwargs) -> typesystem.Field: 241 | return typesystem.Decimal(**kwargs) 242 | 243 | def get_column_type(self): 244 | return sqlalchemy.Numeric(precision=self.max_digits, scale=self.decimal_places) 245 | 246 | 247 | class UUID(ModelField): 248 | def get_validator(self, **kwargs) -> typesystem.Field: 249 | return typesystem.UUID(**kwargs) 250 | 251 | def get_column_type(self): 252 | return GUID() 253 | 254 | 255 | class Email(String): 256 | def get_validator(self, **kwargs) -> typesystem.Field: 257 | return typesystem.Email(**kwargs) 258 | 259 | def get_column_type(self): 260 | return sqlalchemy.String(length=self.validator.max_length) 261 | 262 | 263 | class IPAddress(ModelField): 264 | def get_validator(self, **kwargs) -> typesystem.Field: 265 | return typesystem.IPAddress(**kwargs) 266 | 267 | def get_column_type(self): 268 | return GenericIP() 269 | 270 | 271 | class URL(String): 272 | def get_validator(self, **kwargs) -> typesystem.Field: 273 | return typesystem.URL(**kwargs) 274 | 275 | def get_column_type(self): 276 | return sqlalchemy.String(length=self.validator.max_length) 277 | -------------------------------------------------------------------------------- /orm/models.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import databases 4 | import sqlalchemy 5 | import typesystem 6 | from sqlalchemy.ext.asyncio import create_async_engine 7 | 8 | from orm.exceptions import MultipleMatches, NoMatch 9 | from orm.fields import Date, DateTime, String, Text 10 | 11 | FILTER_OPERATORS = { 12 | "exact": "__eq__", 13 | "iexact": "ilike", 14 | "contains": "like", 15 | "icontains": "ilike", 16 | "in": "in_", 17 | "gt": "__gt__", 18 | "gte": "__ge__", 19 | "lt": "__lt__", 20 | "lte": "__le__", 21 | } 22 | 23 | 24 | def _update_auto_now_fields(values, fields): 25 | for key, value in fields.items(): 26 | if isinstance(value, (DateTime, Date)) and value.auto_now: 27 | values[key] = value.validator.get_default_value() 28 | return values 29 | 30 | 31 | class ModelRegistry: 32 | def __init__(self, database: databases.Database) -> None: 33 | self.database = database 34 | self.models = {} 35 | self._metadata = sqlalchemy.MetaData() 36 | 37 | @property 38 | def metadata(self): 39 | for model_cls in self.models.values(): 40 | model_cls.build_table() 41 | return self._metadata 42 | 43 | async def create_all(self): 44 | url = self._get_database_url() 45 | engine = create_async_engine(url) 46 | 47 | async with self.database: 48 | async with engine.begin() as conn: 49 | await conn.run_sync(self.metadata.create_all) 50 | 51 | await engine.dispose() 52 | 53 | async def drop_all(self): 54 | url = self._get_database_url() 55 | engine = create_async_engine(url) 56 | 57 | async with self.database: 58 | async with engine.begin() as conn: 59 | await conn.run_sync(self.metadata.drop_all) 60 | 61 | await engine.dispose() 62 | 63 | def _get_database_url(self) -> str: 64 | url = self.database.url 65 | if not url.driver: 66 | if url.dialect == "postgresql": 67 | url = url.replace(driver="asyncpg") 68 | elif url.dialect == "mysql": 69 | url = url.replace(driver="aiomysql") 70 | elif url.dialect == "sqlite": 71 | url = url.replace(driver="aiosqlite") 72 | return str(url) 73 | 74 | 75 | class ModelMeta(type): 76 | def __new__(cls, name, bases, attrs): 77 | model_class = super().__new__(cls, name, bases, attrs) 78 | 79 | if "registry" in attrs: 80 | model_class.database = attrs["registry"].database 81 | attrs["registry"].models[name] = model_class 82 | 83 | if "tablename" not in attrs: 84 | setattr(model_class, "tablename", name.lower()) 85 | 86 | for name, field in attrs.get("fields", {}).items(): 87 | setattr(field, "registry", attrs.get("registry")) 88 | if field.primary_key: 89 | model_class.pkname = name 90 | 91 | return model_class 92 | 93 | @property 94 | def table(cls): 95 | if not hasattr(cls, "_table"): 96 | cls._table = cls.build_table() 97 | return cls._table 98 | 99 | @property 100 | def columns(cls) -> sqlalchemy.sql.ColumnCollection: 101 | return cls._table.columns 102 | 103 | 104 | class QuerySet: 105 | ESCAPE_CHARACTERS = ["%", "_"] 106 | 107 | def __init__( 108 | self, 109 | model_cls=None, 110 | filter_clauses=None, 111 | select_related=None, 112 | limit_count=None, 113 | offset=None, 114 | order_by=None, 115 | ): 116 | self.model_cls = model_cls 117 | self.filter_clauses = [] if filter_clauses is None else filter_clauses 118 | self._select_related = [] if select_related is None else select_related 119 | self.limit_count = limit_count 120 | self.query_offset = offset 121 | self._order_by = [] if order_by is None else order_by 122 | 123 | def __get__(self, instance, owner): 124 | return self.__class__(model_cls=owner) 125 | 126 | @property 127 | def database(self): 128 | return self.model_cls.registry.database 129 | 130 | @property 131 | def table(self) -> sqlalchemy.Table: 132 | return self.model_cls.table 133 | 134 | @property 135 | def schema(self): 136 | fields = {key: field.validator for key, field in self.model_cls.fields.items()} 137 | return typesystem.Schema(fields=fields) 138 | 139 | @property 140 | def pkname(self): 141 | return self.model_cls.pkname 142 | 143 | def _build_select_expression(self): 144 | tables = [self.table] 145 | select_from = self.table 146 | 147 | for item in self._select_related: 148 | model_cls = self.model_cls 149 | select_from = self.table 150 | for part in item.split("__"): 151 | model_cls = model_cls.fields[part].target 152 | table = model_cls.table 153 | select_from = sqlalchemy.sql.join(select_from, table) 154 | tables.append(table) 155 | 156 | expr = sqlalchemy.sql.select(tables) 157 | expr = expr.select_from(select_from) 158 | 159 | if self.filter_clauses: 160 | if len(self.filter_clauses) == 1: 161 | clause = self.filter_clauses[0] 162 | else: 163 | clause = sqlalchemy.sql.and_(*self.filter_clauses) 164 | expr = expr.where(clause) 165 | 166 | if self._order_by: 167 | order_by = list(map(self._prepare_order_by, self._order_by)) 168 | expr = expr.order_by(*order_by) 169 | 170 | if self.limit_count: 171 | expr = expr.limit(self.limit_count) 172 | 173 | if self.query_offset: 174 | expr = expr.offset(self.query_offset) 175 | 176 | return expr 177 | 178 | def filter( 179 | self, 180 | clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None, 181 | **kwargs: typing.Any, 182 | ): 183 | if clause is not None: 184 | self.filter_clauses.append(clause) 185 | return self 186 | else: 187 | return self._filter_query(**kwargs) 188 | 189 | def exclude( 190 | self, 191 | clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None, 192 | **kwargs: typing.Any, 193 | ): 194 | if clause is not None: 195 | self.filter_clauses.append(clause) 196 | return self 197 | else: 198 | return self._filter_query(_exclude=True, **kwargs) 199 | 200 | def _filter_query(self, _exclude: bool = False, **kwargs): 201 | clauses = [] 202 | filter_clauses = self.filter_clauses 203 | select_related = list(self._select_related) 204 | 205 | if kwargs.get("pk"): 206 | pk_name = self.model_cls.pkname 207 | kwargs[pk_name] = kwargs.pop("pk") 208 | 209 | for key, value in kwargs.items(): 210 | if "__" in key: 211 | parts = key.split("__") 212 | 213 | # Determine if we should treat the final part as a 214 | # filter operator or as a related field. 215 | if parts[-1] in FILTER_OPERATORS: 216 | op = parts[-1] 217 | field_name = parts[-2] 218 | related_parts = parts[:-2] 219 | else: 220 | op = "exact" 221 | field_name = parts[-1] 222 | related_parts = parts[:-1] 223 | 224 | model_cls = self.model_cls 225 | if related_parts: 226 | # Add any implied select_related 227 | related_str = "__".join(related_parts) 228 | if related_str not in select_related: 229 | select_related.append(related_str) 230 | 231 | # Walk the relationships to the actual model class 232 | # against which the comparison is being made. 233 | for part in related_parts: 234 | model_cls = model_cls.fields[part].target 235 | 236 | column = model_cls.table.columns[field_name] 237 | 238 | else: 239 | op = "exact" 240 | column = self.table.columns[key] 241 | 242 | # Map the operation code onto SQLAlchemy's ColumnElement 243 | # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement 244 | op_attr = FILTER_OPERATORS[op] 245 | has_escaped_character = False 246 | 247 | if op in ["contains", "icontains"]: 248 | has_escaped_character = any( 249 | c for c in self.ESCAPE_CHARACTERS if c in value 250 | ) 251 | if has_escaped_character: 252 | # enable escape modifier 253 | for char in self.ESCAPE_CHARACTERS: 254 | value = value.replace(char, f"\\{char}") 255 | value = f"%{value}%" 256 | 257 | if isinstance(value, Model): 258 | value = value.pk 259 | 260 | clause = getattr(column, op_attr)(value) 261 | clause.modifiers["escape"] = "\\" if has_escaped_character else None 262 | 263 | clauses.append(clause) 264 | 265 | if _exclude: 266 | filter_clauses.append(sqlalchemy.not_(sqlalchemy.sql.and_(*clauses))) 267 | else: 268 | filter_clauses += clauses 269 | 270 | return self.__class__( 271 | model_cls=self.model_cls, 272 | filter_clauses=filter_clauses, 273 | select_related=select_related, 274 | limit_count=self.limit_count, 275 | offset=self.query_offset, 276 | order_by=self._order_by, 277 | ) 278 | 279 | def search(self, term: typing.Any): 280 | if not term: 281 | return self 282 | 283 | filter_clauses = list(self.filter_clauses) 284 | value = f"%{term}%" 285 | 286 | # has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in term) 287 | # if has_escaped_character: 288 | # # enable escape modifier 289 | # for char in self.ESCAPE_CHARACTERS: 290 | # term = term.replace(char, f'\\{char}') 291 | # term = f"%{value}%" 292 | # 293 | # clause.modifiers['escape'] = '\\' if has_escaped_character else None 294 | 295 | search_fields = [ 296 | name 297 | for name, field in self.model_cls.fields.items() 298 | if isinstance(field, (String, Text)) 299 | ] 300 | search_clauses = [ 301 | self.table.columns[name].ilike(value) for name in search_fields 302 | ] 303 | 304 | if len(search_clauses) > 1: 305 | filter_clauses.append(sqlalchemy.sql.or_(*search_clauses)) 306 | else: 307 | filter_clauses.extend(search_clauses) 308 | 309 | return self.__class__( 310 | model_cls=self.model_cls, 311 | filter_clauses=filter_clauses, 312 | select_related=self._select_related, 313 | limit_count=self.limit_count, 314 | offset=self.query_offset, 315 | order_by=self._order_by, 316 | ) 317 | 318 | def order_by(self, *order_by): 319 | return self.__class__( 320 | model_cls=self.model_cls, 321 | filter_clauses=self.filter_clauses, 322 | select_related=self._select_related, 323 | limit_count=self.limit_count, 324 | offset=self.query_offset, 325 | order_by=order_by, 326 | ) 327 | 328 | def select_related(self, related): 329 | if not isinstance(related, (list, tuple)): 330 | related = [related] 331 | 332 | related = list(self._select_related) + related 333 | return self.__class__( 334 | model_cls=self.model_cls, 335 | filter_clauses=self.filter_clauses, 336 | select_related=related, 337 | limit_count=self.limit_count, 338 | offset=self.query_offset, 339 | order_by=self._order_by, 340 | ) 341 | 342 | async def exists(self) -> bool: 343 | expr = self._build_select_expression() 344 | expr = sqlalchemy.exists(expr).select() 345 | return await self.database.fetch_val(expr) 346 | 347 | def limit(self, limit_count: int): 348 | return self.__class__( 349 | model_cls=self.model_cls, 350 | filter_clauses=self.filter_clauses, 351 | select_related=self._select_related, 352 | limit_count=limit_count, 353 | offset=self.query_offset, 354 | order_by=self._order_by, 355 | ) 356 | 357 | def offset(self, offset: int): 358 | return self.__class__( 359 | model_cls=self.model_cls, 360 | filter_clauses=self.filter_clauses, 361 | select_related=self._select_related, 362 | limit_count=self.limit_count, 363 | offset=offset, 364 | order_by=self._order_by, 365 | ) 366 | 367 | async def count(self) -> int: 368 | expr = self._build_select_expression().alias("subquery_for_count") 369 | expr = sqlalchemy.func.count().select().select_from(expr) 370 | return await self.database.fetch_val(expr) 371 | 372 | async def all(self, **kwargs): 373 | if kwargs: 374 | return await self.filter(**kwargs).all() 375 | 376 | expr = self._build_select_expression() 377 | rows = await self.database.fetch_all(expr) 378 | return [ 379 | self.model_cls._from_row(row, select_related=self._select_related) 380 | for row in rows 381 | ] 382 | 383 | async def get(self, **kwargs): 384 | if kwargs: 385 | return await self.filter(**kwargs).get() 386 | 387 | expr = self._build_select_expression().limit(2) 388 | rows = await self.database.fetch_all(expr) 389 | 390 | if not rows: 391 | raise NoMatch() 392 | if len(rows) > 1: 393 | raise MultipleMatches() 394 | return self.model_cls._from_row(rows[0], select_related=self._select_related) 395 | 396 | async def first(self, **kwargs): 397 | if kwargs: 398 | return await self.filter(**kwargs).first() 399 | 400 | rows = await self.limit(1).all() 401 | if rows: 402 | return rows[0] 403 | 404 | def _validate_kwargs(self, **kwargs): 405 | fields = self.model_cls.fields 406 | validator = typesystem.Schema( 407 | fields={key: value.validator for key, value in fields.items()} 408 | ) 409 | kwargs = validator.validate(kwargs) 410 | for key, value in fields.items(): 411 | if value.validator.read_only and value.validator.has_default(): 412 | kwargs[key] = value.validator.get_default_value() 413 | return kwargs 414 | 415 | async def create(self, **kwargs): 416 | kwargs = self._validate_kwargs(**kwargs) 417 | instance = self.model_cls(**kwargs) 418 | expr = self.table.insert().values(**kwargs) 419 | 420 | if self.pkname not in kwargs: 421 | instance.pk = await self.database.execute(expr) 422 | else: 423 | await self.database.execute(expr) 424 | 425 | return instance 426 | 427 | async def bulk_create(self, objs: typing.List[typing.Dict]) -> None: 428 | new_objs = [self._validate_kwargs(**obj) for obj in objs] 429 | 430 | expr = self.table.insert().values(new_objs) 431 | await self.database.execute(expr) 432 | 433 | async def delete(self) -> None: 434 | expr = self.table.delete() 435 | for filter_clause in self.filter_clauses: 436 | expr = expr.where(filter_clause) 437 | 438 | await self.database.execute(expr) 439 | 440 | async def update(self, **kwargs) -> None: 441 | fields = { 442 | key: field.validator 443 | for key, field in self.model_cls.fields.items() 444 | if key in kwargs 445 | } 446 | validator = typesystem.Schema(fields=fields) 447 | kwargs = _update_auto_now_fields( 448 | validator.validate(kwargs), self.model_cls.fields 449 | ) 450 | expr = self.table.update().values(**kwargs) 451 | 452 | for filter_clause in self.filter_clauses: 453 | expr = expr.where(filter_clause) 454 | 455 | await self.database.execute(expr) 456 | 457 | async def get_or_create( 458 | self, defaults: typing.Dict[str, typing.Any], **kwargs 459 | ) -> typing.Tuple[typing.Any, bool]: 460 | try: 461 | instance = await self.get(**kwargs) 462 | return instance, False 463 | except NoMatch: 464 | kwargs.update(defaults) 465 | instance = await self.create(**kwargs) 466 | return instance, True 467 | 468 | async def update_or_create( 469 | self, defaults: typing.Dict[str, typing.Any], **kwargs 470 | ) -> typing.Tuple[typing.Any, bool]: 471 | try: 472 | instance = await self.get(**kwargs) 473 | await instance.update(**defaults) 474 | return instance, False 475 | except NoMatch: 476 | kwargs.update(defaults) 477 | instance = await self.create(**kwargs) 478 | return instance, True 479 | 480 | def _prepare_order_by(self, order_by: str): 481 | reverse = order_by.startswith("-") 482 | order_by = order_by.lstrip("-") 483 | order_col = self.table.columns[order_by] 484 | return order_col.desc() if reverse else order_col 485 | 486 | 487 | class Model(metaclass=ModelMeta): 488 | objects = QuerySet() 489 | 490 | def __init__(self, **kwargs): 491 | if "pk" in kwargs: 492 | kwargs[self.pkname] = kwargs.pop("pk") 493 | for key, value in kwargs.items(): 494 | if key not in self.fields: 495 | raise ValueError( 496 | f"Invalid keyword {key} for class {self.__class__.__name__}" 497 | ) 498 | setattr(self, key, value) 499 | 500 | @property 501 | def pk(self): 502 | return getattr(self, self.pkname) 503 | 504 | @pk.setter 505 | def pk(self, value): 506 | setattr(self, self.pkname, value) 507 | 508 | def __repr__(self): 509 | return f"<{self.__class__.__name__}: {self}>" 510 | 511 | def __str__(self): 512 | return f"{self.__class__.__name__}({self.pkname}={self.pk})" 513 | 514 | @classmethod 515 | def build_table(cls): 516 | tablename = cls.tablename 517 | metadata = cls.registry._metadata 518 | columns = [] 519 | for name, field in cls.fields.items(): 520 | columns.append(field.get_column(name)) 521 | return sqlalchemy.Table(tablename, metadata, *columns, extend_existing=True) 522 | 523 | @property 524 | def table(self) -> sqlalchemy.Table: 525 | return self.__class__.table 526 | 527 | async def update(self, **kwargs): 528 | fields = { 529 | key: field.validator for key, field in self.fields.items() if key in kwargs 530 | } 531 | validator = typesystem.Schema(fields=fields) 532 | kwargs = _update_auto_now_fields(validator.validate(kwargs), self.fields) 533 | pk_column = getattr(self.table.c, self.pkname) 534 | expr = self.table.update().values(**kwargs).where(pk_column == self.pk) 535 | await self.database.execute(expr) 536 | 537 | # Update the model instance. 538 | for key, value in kwargs.items(): 539 | setattr(self, key, value) 540 | 541 | async def delete(self) -> None: 542 | pk_column = getattr(self.table.c, self.pkname) 543 | expr = self.table.delete().where(pk_column == self.pk) 544 | 545 | await self.database.execute(expr) 546 | 547 | async def load(self): 548 | # Build the select expression. 549 | pk_column = getattr(self.table.c, self.pkname) 550 | expr = self.table.select().where(pk_column == self.pk) 551 | 552 | # Perform the fetch. 553 | row = await self.database.fetch_one(expr) 554 | 555 | # Update the instance. 556 | for key, value in dict(row._mapping).items(): 557 | setattr(self, key, value) 558 | 559 | @classmethod 560 | def _from_row(cls, row, select_related=[]): 561 | """ 562 | Instantiate a model instance, given a database row. 563 | """ 564 | item = {} 565 | 566 | # Instantiate any child instances first. 567 | for related in select_related: 568 | if "__" in related: 569 | first_part, remainder = related.split("__", 1) 570 | model_cls = cls.fields[first_part].target 571 | item[first_part] = model_cls._from_row(row, select_related=[remainder]) 572 | else: 573 | model_cls = cls.fields[related].target 574 | item[related] = model_cls._from_row(row) 575 | 576 | # Pull out the regular column values. 577 | for column in cls.table.columns: 578 | if column.name not in item: 579 | item[column.name] = row[column] 580 | 581 | return cls(**item) 582 | 583 | def __setattr__(self, key, value): 584 | if key in self.fields: 585 | # Setting a relationship to a raw pk value should set a 586 | # fully-fledged relationship instance, with just the pk loaded. 587 | value = self.fields[key].expand_relationship(value) 588 | super().__setattr__(key, value) 589 | 590 | def __eq__(self, other): 591 | if self.__class__ != other.__class__: 592 | return False 593 | for key in self.fields.keys(): 594 | if getattr(self, key, None) != getattr(other, key, None): 595 | return False 596 | return True 597 | -------------------------------------------------------------------------------- /orm/sqlalchemy_fields.py: -------------------------------------------------------------------------------- 1 | import ipaddress 2 | import uuid 3 | 4 | import sqlalchemy 5 | 6 | 7 | class GUID(sqlalchemy.TypeDecorator): 8 | """ 9 | Platform-independent GUID type. 10 | 11 | Uses PostgreSQL's UUID type, otherwise uses 12 | CHAR(32), storing as stringified hex values. 13 | """ 14 | 15 | impl = sqlalchemy.CHAR 16 | cache_ok = True 17 | 18 | def load_dialect_impl(self, dialect): 19 | if dialect.name == "postgresql": 20 | return dialect.type_descriptor(sqlalchemy.dialects.postgresql.UUID()) 21 | else: 22 | return dialect.type_descriptor(sqlalchemy.CHAR(32)) 23 | 24 | def process_bind_param(self, value, dialect): 25 | if value is None: 26 | return value 27 | 28 | if dialect.name == "postgresql": 29 | return str(value) 30 | else: 31 | return value.hex 32 | 33 | def process_result_value(self, value, dialect): 34 | if value is None: 35 | return value 36 | 37 | if not isinstance(value, uuid.UUID): 38 | value = uuid.UUID(value) 39 | return value 40 | 41 | 42 | class GenericIP(sqlalchemy.TypeDecorator): 43 | """ 44 | Platform-independent IP Address type. 45 | 46 | Uses PostgreSQL's INET type, otherwise uses 47 | CHAR(45), storing as stringified values. 48 | """ 49 | 50 | impl = sqlalchemy.CHAR 51 | cache_ok = True 52 | 53 | def load_dialect_impl(self, dialect): 54 | if dialect.name == "postgresql": 55 | return dialect.type_descriptor(sqlalchemy.dialects.postgresql.INET()) 56 | else: 57 | return dialect.type_descriptor(sqlalchemy.CHAR(45)) 58 | 59 | def process_bind_param(self, value, dialect): 60 | if value is not None: 61 | return str(value) 62 | 63 | def process_result_value(self, value, dialect): 64 | if value is None: 65 | return value 66 | 67 | if not isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)): 68 | value = ipaddress.ip_address(value) 69 | return value 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | databases[postgresql, mysql, sqlite] 2 | typesystem 3 | 4 | # Packaging 5 | twine 6 | wheel 7 | 8 | # Testing 9 | anyio>=3.0.0,<4 10 | autoflake 11 | black 12 | codecov 13 | flake8 14 | isort 15 | mypy 16 | pytest 17 | pytest-cov 18 | 19 | # Documentation 20 | mkdocs 21 | mkdocs-material 22 | mkautodoc 23 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Development Scripts 2 | 3 | * `scripts/build` - Build package and documentation. 4 | * `scripts/check` - Check linting and formatting. 5 | * `scripts/clean` - Delete any build artifacts. 6 | * `scripts/coverage` - Check test coverage. 7 | * `scripts/docs` - Run documentation server locally. 8 | * `scripts/install` - Install dependencies in a virtual environment. 9 | * `scripts/lint` - Run the code linting. 10 | * `scripts/publish` - Publish the latest version to PyPI. 11 | `scripts/test` - Run the test suite. 12 | 13 | Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all). 14 | -------------------------------------------------------------------------------- /scripts/build: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | if [ -d 'venv' ] ; then 4 | PREFIX="venv/bin/" 5 | else 6 | PREFIX="" 7 | fi 8 | 9 | set -x 10 | 11 | ${PREFIX}python setup.py sdist bdist_wheel 12 | ${PREFIX}twine check dist/* 13 | ${PREFIX}mkdocs build 14 | -------------------------------------------------------------------------------- /scripts/check: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | export SOURCE_FILES="orm tests" 8 | 9 | set -x 10 | 11 | ${PREFIX}isort --check --diff --project=orm $SOURCE_FILES 12 | ${PREFIX}black --check --diff $SOURCE_FILES 13 | ${PREFIX}flake8 $SOURCE_FILES 14 | # ${PREFIX}mypy $SOURCE_FILES 15 | -------------------------------------------------------------------------------- /scripts/clean: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | PACKAGE="orm" 4 | 5 | if [ -d 'dist' ] ; then 6 | rm -r dist 7 | fi 8 | if [ -d 'site' ] ; then 9 | rm -r site 10 | fi 11 | if [ -d 'htmlcov' ] ; then 12 | rm -r htmlcov 13 | fi 14 | if [ -d "${PACKAGE}.egg-info" ] ; then 15 | rm -r "${PACKAGE}.egg-info" 16 | fi 17 | 18 | find ${PACKAGE} -type f -name "*.py[co]" -delete 19 | find ${PACKAGE} -type d -name __pycache__ -delete 20 | -------------------------------------------------------------------------------- /scripts/coverage: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | set -x 9 | 10 | ${PREFIX}coverage report --show-missing --skip-covered --fail-under=100 11 | -------------------------------------------------------------------------------- /scripts/docs: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | set -x 9 | 10 | ${PREFIX}mkdocs serve 11 | -------------------------------------------------------------------------------- /scripts/install: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | # Use the Python executable provided from the `-p` option, or a default. 4 | [ "$1" = "-p" ] && PYTHON=$2 || PYTHON="python3" 5 | 6 | REQUIREMENTS="requirements.txt" 7 | VENV="venv" 8 | 9 | set -x 10 | 11 | if [ -z "$GITHUB_ACTIONS" ]; then 12 | "$PYTHON" -m venv "$VENV" 13 | PIP="$VENV/bin/pip" 14 | else 15 | PIP="pip" 16 | fi 17 | 18 | "$PIP" install -r "$REQUIREMENTS" 19 | "$PIP" install -e . 20 | -------------------------------------------------------------------------------- /scripts/lint: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | export SOURCE_FILES="orm tests" 8 | 9 | set -x 10 | 11 | ${PREFIX}autoflake --in-place --recursive $SOURCE_FILES 12 | ${PREFIX}isort --project=orm $SOURCE_FILES 13 | ${PREFIX}black $SOURCE_FILES 14 | -------------------------------------------------------------------------------- /scripts/publish: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | 3 | VERSION_FILE="orm/__init__.py" 4 | 5 | if [ -d 'venv' ] ; then 6 | PREFIX="venv/bin/" 7 | else 8 | PREFIX="" 9 | fi 10 | 11 | if [ ! -z "$GITHUB_ACTIONS" ]; then 12 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" 13 | git config --local user.name "GitHub Action" 14 | 15 | VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` 16 | 17 | if [ "refs/tags/${VERSION}" != "${GITHUB_REF}" ] ; then 18 | echo "GitHub Ref '${GITHUB_REF}' did not match package version '${VERSION}'" 19 | exit 1 20 | fi 21 | fi 22 | 23 | set -x 24 | 25 | ${PREFIX}twine upload dist/* 26 | ${PREFIX}mkdocs gh-deploy --force 27 | -------------------------------------------------------------------------------- /scripts/test: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PREFIX="" 4 | if [ -d 'venv' ] ; then 5 | export PREFIX="venv/bin/" 6 | fi 7 | 8 | set -ex 9 | 10 | if [ -z $GITHUB_ACTIONS ]; then 11 | scripts/check 12 | fi 13 | 14 | ${PREFIX}coverage run -a -m pytest $@ 15 | 16 | if [ -z $GITHUB_ACTIONS ]; then 17 | scripts/coverage 18 | fi 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503, E203, B305 3 | max-line-length = 88 4 | 5 | [mypy] 6 | disallow_untyped_defs = True 7 | ignore_missing_imports = True 8 | 9 | [tool:isort] 10 | profile = black 11 | combine_as_imports = True 12 | 13 | [tool:pytest] 14 | addopts = 15 | -rxXs 16 | --strict-config 17 | --strict-markers 18 | xfail_strict=True 19 | filterwarnings= 20 | # Turn warnings that aren't filtered into exceptions 21 | error 22 | ignore::DeprecationWarning 23 | ignore::sqlalchemy.exc.SAWarning 24 | 25 | [coverage:run] 26 | source_pkgs = orm, tests 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import re 6 | 7 | from setuptools import setup 8 | 9 | 10 | PACKAGE = "orm" 11 | URL = "https://github.com/encode/orm" 12 | 13 | 14 | def get_version(package): 15 | """ 16 | Return package version as listed in `__version__` in `init.py`. 17 | """ 18 | with open(os.path.join(package, "__init__.py")) as f: 19 | return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) 20 | 21 | 22 | def get_long_description(): 23 | """ 24 | Return the README. 25 | """ 26 | with open("README.md", encoding="utf8") as f: 27 | return f.read() 28 | 29 | 30 | def get_packages(package): 31 | """ 32 | Return root package and all sub-packages. 33 | """ 34 | return [ 35 | dirpath 36 | for dirpath, dirnames, filenames in os.walk(package) 37 | if os.path.exists(os.path.join(dirpath, "__init__.py")) 38 | ] 39 | 40 | 41 | setup( 42 | name=PACKAGE, 43 | version=get_version(PACKAGE), 44 | url=URL, 45 | license="BSD", 46 | description="An async ORM.", 47 | long_description=get_long_description(), 48 | long_description_content_type="text/markdown", 49 | author="Tom Christie", 50 | author_email="tom@tomchristie.com", 51 | packages=get_packages(PACKAGE), 52 | package_data={PACKAGE: ["py.typed"]}, 53 | install_requires=["databases~=0.5", "typesystem==0.3.1"], 54 | extras_require={ 55 | "postgresql": ["asyncpg"], 56 | "mysql": ["aiomysql"], 57 | "sqlite": ["aiosqlite"], 58 | "postgresql+aiopg": ["aiopg"], 59 | }, 60 | classifiers=[ 61 | "Development Status :: 3 - Alpha", 62 | "Environment :: Web Environment", 63 | "Intended Audience :: Developers", 64 | "License :: OSI Approved :: BSD License", 65 | "Operating System :: OS Independent", 66 | "Topic :: Internet :: WWW/HTTP", 67 | "Programming Language :: Python :: 3", 68 | "Programming Language :: Python :: 3.7", 69 | "Programming Language :: Python :: 3.8", 70 | "Programming Language :: Python :: 3.9", 71 | "Programming Language :: Python :: 3.10", 72 | "Programming Language :: Python :: 3 :: Only", 73 | ], 74 | ) 75 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(scope="module") 5 | def anyio_backend(): 6 | return ("asyncio", {"debug": True}) 7 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | assert "TEST_DATABASE_URL" in os.environ, "TEST_DATABASE_URL is not set." 4 | 5 | DATABASE_URL = os.environ["TEST_DATABASE_URL"] 6 | -------------------------------------------------------------------------------- /tests/test_columns.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import decimal 3 | import ipaddress 4 | import uuid 5 | from enum import Enum 6 | 7 | import databases 8 | import pytest 9 | 10 | import orm 11 | from tests.settings import DATABASE_URL 12 | 13 | pytestmark = pytest.mark.anyio 14 | 15 | database = databases.Database(DATABASE_URL) 16 | models = orm.ModelRegistry(database=database) 17 | 18 | 19 | def time(): 20 | return datetime.datetime.now().time() 21 | 22 | 23 | class StatusEnum(Enum): 24 | DRAFT = "Draft" 25 | RELEASED = "Released" 26 | 27 | 28 | class Product(orm.Model): 29 | registry = models 30 | fields = { 31 | "id": orm.Integer(primary_key=True), 32 | "uuid": orm.UUID(allow_null=True), 33 | "created": orm.DateTime(default=datetime.datetime.now), 34 | "created_day": orm.Date(default=datetime.date.today), 35 | "created_time": orm.Time(default=time), 36 | "created_date": orm.Date(auto_now_add=True), 37 | "created_datetime": orm.DateTime(auto_now_add=True), 38 | "updated_datetime": orm.DateTime(auto_now=True), 39 | "updated_date": orm.Date(auto_now=True), 40 | "data": orm.JSON(default={}), 41 | "description": orm.Text(allow_blank=True), 42 | "huge_number": orm.BigInteger(default=0), 43 | "price": orm.Decimal(max_digits=5, decimal_places=2, allow_null=True), 44 | "status": orm.Enum(StatusEnum, default=StatusEnum.DRAFT), 45 | "value": orm.Float(allow_null=True), 46 | } 47 | 48 | 49 | class User(orm.Model): 50 | registry = models 51 | fields = { 52 | "id": orm.UUID(primary_key=True, default=uuid.uuid4), 53 | "name": orm.String(allow_null=True, max_length=16), 54 | "email": orm.Email(allow_null=True, max_length=256), 55 | "ipaddress": orm.IPAddress(allow_null=True), 56 | "url": orm.URL(allow_null=True, max_length=2048), 57 | } 58 | 59 | 60 | @pytest.fixture(autouse=True, scope="module") 61 | async def create_test_database(): 62 | await models.create_all() 63 | yield 64 | await models.drop_all() 65 | 66 | 67 | @pytest.fixture(autouse=True) 68 | async def rollback_transactions(): 69 | with database.force_rollback(): 70 | async with database: 71 | yield 72 | 73 | 74 | async def test_model_crud(): 75 | product = await Product.objects.create() 76 | product = await Product.objects.get(pk=product.pk) 77 | assert product.created.year == datetime.datetime.now().year 78 | assert product.created_day == datetime.date.today() 79 | assert product.created_date == datetime.date.today() 80 | assert product.created_datetime.date() == datetime.datetime.now().date() 81 | assert product.updated_date == datetime.date.today() 82 | assert product.updated_datetime.date() == datetime.datetime.now().date() 83 | assert product.data == {} 84 | assert product.description == "" 85 | assert product.huge_number == 0 86 | assert product.price is None 87 | assert product.status == StatusEnum.DRAFT 88 | assert product.value is None 89 | assert product.uuid is None 90 | 91 | await product.update( 92 | data={"foo": 123}, 93 | value=123.456, 94 | status=StatusEnum.RELEASED, 95 | price=decimal.Decimal("999.99"), 96 | uuid=uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b"), 97 | ) 98 | 99 | product = await Product.objects.get() 100 | assert product.value == 123.456 101 | assert product.data == {"foo": 123} 102 | assert product.status == StatusEnum.RELEASED 103 | assert product.price == decimal.Decimal("999.99") 104 | assert product.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") 105 | 106 | last_updated_datetime = product.updated_datetime 107 | last_updated_date = product.updated_date 108 | user = await User.objects.create() 109 | assert isinstance(user.pk, uuid.UUID) 110 | 111 | user = await User.objects.get() 112 | assert user.email is None 113 | assert user.ipaddress is None 114 | assert user.url is None 115 | 116 | await user.update( 117 | ipaddress="192.168.1.1", 118 | name="Chris", 119 | email="chirs@encode.io", 120 | url="https://encode.io", 121 | ) 122 | 123 | user = await User.objects.get() 124 | assert isinstance(user.ipaddress, (ipaddress.IPv4Address, ipaddress.IPv6Address)) 125 | assert user.url == "https://encode.io" 126 | # Test auto_now update 127 | await product.update( 128 | data={"foo": 1234}, 129 | ) 130 | assert product.updated_datetime != last_updated_datetime 131 | assert product.updated_date == last_updated_date 132 | 133 | 134 | async def test_both_auto_now_and_auto_now_add_raise_error(): 135 | with pytest.raises(ValueError): 136 | 137 | class Product(orm.Model): 138 | registry = models 139 | fields = { 140 | "id": orm.Integer(primary_key=True), 141 | "created_datetime": orm.DateTime(auto_now_add=True, auto_now=True), 142 | } 143 | 144 | await Product.objects.create() 145 | 146 | 147 | async def test_bulk_create(): 148 | await Product.objects.bulk_create( 149 | [ 150 | {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, 151 | {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, 152 | ] 153 | ) 154 | products = await Product.objects.all() 155 | assert len(products) == 2 156 | assert products[0].data == {"foo": 123} 157 | assert products[0].value == 123.456 158 | assert products[0].status == StatusEnum.RELEASED 159 | assert products[1].data == {"foo": 456} 160 | assert products[1].value == 456.789 161 | assert products[1].status == StatusEnum.DRAFT 162 | -------------------------------------------------------------------------------- /tests/test_foreignkey.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | 3 | import asyncpg 4 | import databases 5 | import pymysql 6 | import pytest 7 | 8 | import orm 9 | from tests.settings import DATABASE_URL 10 | 11 | pytestmark = pytest.mark.anyio 12 | 13 | database = databases.Database(DATABASE_URL) 14 | models = orm.ModelRegistry(database=database) 15 | 16 | 17 | class Album(orm.Model): 18 | registry = models 19 | fields = { 20 | "id": orm.Integer(primary_key=True), 21 | "name": orm.String(max_length=100), 22 | } 23 | 24 | 25 | class Track(orm.Model): 26 | registry = models 27 | fields = { 28 | "id": orm.Integer(primary_key=True), 29 | "album": orm.ForeignKey("Album", on_delete=orm.CASCADE), 30 | "title": orm.String(max_length=100), 31 | "position": orm.Integer(), 32 | } 33 | 34 | 35 | class Organisation(orm.Model): 36 | registry = models 37 | fields = { 38 | "id": orm.Integer(primary_key=True), 39 | "ident": orm.String(max_length=100), 40 | } 41 | 42 | 43 | class Team(orm.Model): 44 | registry = models 45 | fields = { 46 | "id": orm.Integer(primary_key=True), 47 | "org": orm.ForeignKey(Organisation, on_delete=orm.RESTRICT), 48 | "name": orm.String(max_length=100), 49 | } 50 | 51 | 52 | class Member(orm.Model): 53 | registry = models 54 | fields = { 55 | "id": orm.Integer(primary_key=True), 56 | "team": orm.ForeignKey(Team, on_delete=orm.SET_NULL, allow_null=True), 57 | "email": orm.String(max_length=100), 58 | } 59 | 60 | 61 | class Profile(orm.Model): 62 | registry = models 63 | fields = { 64 | "id": orm.Integer(primary_key=True), 65 | "website": orm.String(max_length=100), 66 | } 67 | 68 | 69 | class Person(orm.Model): 70 | registry = models 71 | fields = { 72 | "id": orm.Integer(primary_key=True), 73 | "email": orm.String(max_length=100), 74 | "profile": orm.OneToOne(Profile), 75 | } 76 | 77 | 78 | @pytest.fixture(autouse=True, scope="module") 79 | async def create_test_database(): 80 | await models.create_all() 81 | yield 82 | await models.drop_all() 83 | 84 | 85 | @pytest.fixture(autouse=True) 86 | async def rollback_connections(): 87 | with database.force_rollback(): 88 | async with database: 89 | yield 90 | 91 | 92 | async def test_model_crud(): 93 | album = await Album.objects.create(name="Malibu") 94 | await Track.objects.create(album=album, title="The Bird", position=1) 95 | await Track.objects.create( 96 | album=album, title="Heart don't stand a chance", position=2 97 | ) 98 | await Track.objects.create(album=album, title="The Waters", position=3) 99 | 100 | track = await Track.objects.get(title="The Bird") 101 | assert track.album.pk == album.pk 102 | assert not hasattr(track.album, "name") 103 | await track.album.load() 104 | assert track.album.name == "Malibu" 105 | 106 | 107 | async def test_select_related(): 108 | album = await Album.objects.create(name="Malibu") 109 | await Track.objects.create(album=album, title="The Bird", position=1) 110 | await Track.objects.create( 111 | album=album, title="Heart don't stand a chance", position=2 112 | ) 113 | await Track.objects.create(album=album, title="The Waters", position=3) 114 | 115 | fantasies = await Album.objects.create(name="Fantasies") 116 | await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) 117 | await Track.objects.create(album=fantasies, title="Sick Muse", position=2) 118 | await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) 119 | 120 | track = await Track.objects.select_related("album").get(title="The Bird") 121 | assert track.album.name == "Malibu" 122 | 123 | tracks = await Track.objects.select_related("album").all() 124 | assert len(tracks) == 6 125 | 126 | 127 | async def test_fk_filter(): 128 | malibu = await Album.objects.create(name="Malibu") 129 | await Track.objects.create(album=malibu, title="The Bird", position=1) 130 | await Track.objects.create( 131 | album=malibu, title="Heart don't stand a chance", position=2 132 | ) 133 | await Track.objects.create(album=malibu, title="The Waters", position=3) 134 | 135 | fantasies = await Album.objects.create(name="Fantasies") 136 | await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) 137 | await Track.objects.create(album=fantasies, title="Sick Muse", position=2) 138 | await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) 139 | 140 | tracks = ( 141 | await Track.objects.select_related("album") 142 | .filter(album__name="Fantasies") 143 | .all() 144 | ) 145 | assert len(tracks) == 3 146 | for track in tracks: 147 | assert track.album.name == "Fantasies" 148 | 149 | tracks = ( 150 | await Track.objects.select_related("album") 151 | .filter(album__name__icontains="fan") 152 | .all() 153 | ) 154 | assert len(tracks) == 3 155 | for track in tracks: 156 | assert track.album.name == "Fantasies" 157 | 158 | tracks = await Track.objects.filter(album__name__icontains="fan").all() 159 | assert len(tracks) == 3 160 | for track in tracks: 161 | assert track.album.name == "Fantasies" 162 | 163 | tracks = await Track.objects.filter(album=malibu).select_related("album").all() 164 | assert len(tracks) == 3 165 | for track in tracks: 166 | assert track.album.name == "Malibu" 167 | 168 | 169 | async def test_multiple_fk(): 170 | acme = await Organisation.objects.create(ident="ACME Ltd") 171 | red_team = await Team.objects.create(org=acme, name="Red Team") 172 | blue_team = await Team.objects.create(org=acme, name="Blue Team") 173 | await Member.objects.create(team=red_team, email="a@example.org") 174 | await Member.objects.create(team=red_team, email="b@example.org") 175 | await Member.objects.create(team=blue_team, email="c@example.org") 176 | await Member.objects.create(team=blue_team, email="d@example.org") 177 | 178 | other = await Organisation.objects.create(ident="Other ltd") 179 | team = await Team.objects.create(org=other, name="Green Team") 180 | await Member.objects.create(team=team, email="e@example.org") 181 | 182 | members = ( 183 | await Member.objects.select_related("team__org") 184 | .filter(team__org__ident="ACME Ltd") 185 | .all() 186 | ) 187 | assert len(members) == 4 188 | for member in members: 189 | assert member.team.org.ident == "ACME Ltd" 190 | 191 | 192 | async def test_queryset_delete_with_fk(): 193 | malibu = await Album.objects.create(name="Malibu") 194 | await Track.objects.create(album=malibu, title="The Bird", position=1) 195 | 196 | wall = await Album.objects.create(name="The Wall") 197 | await Track.objects.create(album=wall, title="The Wall", position=1) 198 | 199 | await Track.objects.filter(album=malibu).delete() 200 | assert await Track.objects.filter(album=malibu).count() == 0 201 | assert await Track.objects.filter(album=wall).count() == 1 202 | 203 | 204 | async def test_queryset_update_with_fk(): 205 | malibu = await Album.objects.create(name="Malibu") 206 | wall = await Album.objects.create(name="The Wall") 207 | await Track.objects.create(album=malibu, title="The Bird", position=1) 208 | 209 | await Track.objects.filter(album=malibu).update(album=wall) 210 | assert await Track.objects.filter(album=malibu).count() == 0 211 | assert await Track.objects.filter(album=wall).count() == 1 212 | 213 | 214 | @pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite") 215 | async def test_on_delete_cascade(): 216 | album = await Album.objects.create(name="The Wall") 217 | await Track.objects.create(album=album, title="Hey You", position=1) 218 | await Track.objects.create(album=album, title="Breathe", position=2) 219 | 220 | assert await Track.objects.count() == 2 221 | 222 | await album.delete() 223 | 224 | assert await Track.objects.count() == 0 225 | 226 | 227 | @pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite") 228 | async def test_on_delete_retstrict(): 229 | organisation = await Organisation.objects.create(ident="Encode") 230 | await Team.objects.create(org=organisation, name="Maintainers") 231 | 232 | exceptions = ( 233 | asyncpg.exceptions.ForeignKeyViolationError, 234 | pymysql.err.IntegrityError, 235 | ) 236 | 237 | with pytest.raises(exceptions): 238 | await organisation.delete() 239 | 240 | 241 | @pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite") 242 | async def test_on_delete_set_null(): 243 | organisation = await Organisation.objects.create(ident="Encode") 244 | team = await Team.objects.create(org=organisation, name="Maintainers") 245 | await Member.objects.create(email="member@encode.io", team=team) 246 | 247 | await team.delete() 248 | 249 | member = await Member.objects.first() 250 | assert member.team.pk is None 251 | 252 | 253 | async def test_one_to_one_crud(): 254 | profile = await Profile.objects.create(website="https://encode.io") 255 | await Person.objects.create(email="info@encode.io", profile=profile) 256 | 257 | person = await Person.objects.get(email="info@encode.io") 258 | assert person.profile.pk == profile.pk 259 | assert not hasattr(person.profile, "website") 260 | 261 | await person.profile.load() 262 | assert person.profile.website == "https://encode.io" 263 | 264 | exceptions = ( 265 | asyncpg.exceptions.UniqueViolationError, 266 | pymysql.err.IntegrityError, 267 | sqlite3.IntegrityError, 268 | ) 269 | 270 | with pytest.raises(exceptions): 271 | await Person.objects.create(email="contact@encode.io", profile=profile) 272 | 273 | 274 | async def test_nullable_foreign_key(): 275 | await Member.objects.create(email="dev@encode.io") 276 | 277 | member = await Member.objects.get() 278 | 279 | assert member.email == "dev@encode.io" 280 | assert member.team.pk is None 281 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import databases 2 | import pytest 3 | import typesystem 4 | 5 | import orm 6 | from tests.settings import DATABASE_URL 7 | 8 | pytestmark = pytest.mark.anyio 9 | 10 | database = databases.Database(DATABASE_URL) 11 | models = orm.ModelRegistry(database=database) 12 | 13 | 14 | class User(orm.Model): 15 | tablename = "users" 16 | registry = models 17 | fields = { 18 | "id": orm.Integer(primary_key=True), 19 | "name": orm.String(max_length=100), 20 | "language": orm.String(max_length=100, allow_null=True), 21 | } 22 | 23 | 24 | class Product(orm.Model): 25 | tablename = "products" 26 | registry = models 27 | fields = { 28 | "id": orm.Integer(primary_key=True), 29 | "name": orm.String(max_length=100), 30 | "rating": orm.Integer(minimum=1, maximum=5), 31 | "in_stock": orm.Boolean(default=False), 32 | } 33 | 34 | 35 | @pytest.fixture(autouse=True, scope="function") 36 | async def create_test_database(): 37 | await models.create_all() 38 | yield 39 | await models.drop_all() 40 | 41 | 42 | @pytest.fixture(autouse=True) 43 | async def rollback_connections(): 44 | with database.force_rollback(): 45 | async with database: 46 | yield 47 | 48 | 49 | def test_model_class(): 50 | assert list(User.fields.keys()) == ["id", "name", "language"] 51 | assert isinstance(User.fields["id"], orm.Integer) 52 | assert User.fields["id"].primary_key is True 53 | assert isinstance(User.fields["name"], orm.String) 54 | assert User.fields["name"].validator.max_length == 100 55 | 56 | with pytest.raises(ValueError): 57 | User(invalid="123") 58 | 59 | assert User(id=1) != Product(id=1) 60 | assert User(id=1) != User(id=2) 61 | assert User(id=1) == User(id=1) 62 | 63 | assert str(User(id=1)) == "User(id=1)" 64 | assert repr(User(id=1)) == "" 65 | 66 | assert isinstance(User.objects.schema.fields["id"], typesystem.Integer) 67 | assert isinstance(User.objects.schema.fields["name"], typesystem.String) 68 | 69 | 70 | def test_model_pk(): 71 | user = User(pk=1) 72 | assert user.pk == 1 73 | assert user.id == 1 74 | assert User.objects.pkname == "id" 75 | 76 | 77 | async def test_model_crud(): 78 | users = await User.objects.all() 79 | assert users == [] 80 | 81 | user = await User.objects.create(name="Tom") 82 | users = await User.objects.all() 83 | assert user.name == "Tom" 84 | assert user.pk is not None 85 | assert users == [user] 86 | 87 | lookup = await User.objects.get() 88 | assert lookup == user 89 | 90 | await user.update(name="Jane") 91 | users = await User.objects.all() 92 | assert user.name == "Jane" 93 | assert user.pk is not None 94 | assert users == [user] 95 | 96 | await user.delete() 97 | users = await User.objects.all() 98 | assert users == [] 99 | 100 | 101 | async def test_model_get(): 102 | with pytest.raises(orm.NoMatch): 103 | await User.objects.get() 104 | 105 | user = await User.objects.create(name="Tom") 106 | lookup = await User.objects.get() 107 | assert lookup == user 108 | 109 | user = await User.objects.create(name="Jane") 110 | with pytest.raises(orm.MultipleMatches): 111 | await User.objects.get() 112 | 113 | same_user = await User.objects.get(pk=user.id) 114 | assert same_user.id == user.id 115 | assert same_user.pk == user.pk 116 | 117 | 118 | async def test_model_filter(): 119 | await User.objects.create(name="Tom") 120 | await User.objects.create(name="Jane") 121 | await User.objects.create(name="Lucy") 122 | 123 | user = await User.objects.get(name="Lucy") 124 | assert user.name == "Lucy" 125 | 126 | with pytest.raises(orm.NoMatch): 127 | await User.objects.get(name="Jim") 128 | 129 | await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) 130 | await Product.objects.create(name="Dress", rating=4) 131 | await Product.objects.create(name="Coat", rating=3, in_stock=True) 132 | 133 | product = await Product.objects.get(name__iexact="t-shirt", rating=5) 134 | assert product.pk is not None 135 | assert product.name == "T-Shirt" 136 | assert product.rating == 5 137 | 138 | products = await Product.objects.all(rating__gte=2, in_stock=True) 139 | assert len(products) == 2 140 | 141 | products = await Product.objects.all(name__icontains="T") 142 | assert len(products) == 2 143 | 144 | # Test escaping % character from icontains, contains, and iexact 145 | await Product.objects.create(name="100%-Cotton", rating=3) 146 | await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) 147 | await Product.objects.create(name="Cotton-100%", rating=3) 148 | products = Product.objects.filter(name__iexact="100%-cotton") 149 | assert await products.count() == 1 150 | 151 | products = Product.objects.filter(name__contains="%") 152 | assert await products.count() == 3 153 | 154 | products = Product.objects.filter(name__icontains="%") 155 | assert await products.count() == 3 156 | 157 | products = Product.objects.exclude(name__iexact="100%-cotton") 158 | assert await products.count() == 5 159 | 160 | products = Product.objects.exclude(name__contains="%") 161 | assert await products.count() == 3 162 | 163 | products = Product.objects.exclude(name__icontains="%") 164 | assert await products.count() == 3 165 | 166 | 167 | async def test_model_order_by(): 168 | await User.objects.create(name="Bob") 169 | await User.objects.create(name="Allen") 170 | await User.objects.create(name="Bob") 171 | 172 | users = await User.objects.order_by("name").all() 173 | assert users[0].name == "Allen" 174 | assert users[1].name == "Bob" 175 | 176 | users = await User.objects.order_by("-name").all() 177 | assert users[1].name == "Bob" 178 | assert users[2].name == "Allen" 179 | 180 | users = await User.objects.order_by("name", "-id").all() 181 | assert users[0].name == "Allen" 182 | assert users[0].id == 2 183 | assert users[1].name == "Bob" 184 | assert users[1].id == 3 185 | 186 | users = await User.objects.filter(name="Bob").order_by("-id").all() 187 | assert users[0].name == "Bob" 188 | assert users[0].id == 3 189 | assert users[1].name == "Bob" 190 | assert users[1].id == 1 191 | 192 | users = await User.objects.order_by("id").limit(1).all() 193 | assert users[0].name == "Bob" 194 | assert users[0].id == 1 195 | 196 | users = await User.objects.order_by("id").limit(1).offset(1).all() 197 | assert users[0].name == "Allen" 198 | assert users[0].id == 2 199 | 200 | 201 | async def test_model_exists(): 202 | await User.objects.create(name="Tom") 203 | assert await User.objects.filter(name="Tom").exists() is True 204 | assert await User.objects.filter(name="Jane").exists() is False 205 | 206 | 207 | async def test_model_count(): 208 | await User.objects.create(name="Tom") 209 | await User.objects.create(name="Jane") 210 | await User.objects.create(name="Lucy") 211 | 212 | assert await User.objects.count() == 3 213 | assert await User.objects.filter(name__icontains="T").count() == 1 214 | 215 | 216 | async def test_model_limit(): 217 | await User.objects.create(name="Tom") 218 | await User.objects.create(name="Jane") 219 | await User.objects.create(name="Lucy") 220 | 221 | assert len(await User.objects.limit(2).all()) == 2 222 | 223 | 224 | async def test_model_limit_with_filter(): 225 | await User.objects.create(name="Tom") 226 | await User.objects.create(name="Tom") 227 | await User.objects.create(name="Tom") 228 | 229 | assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 230 | 231 | 232 | async def test_offset(): 233 | await User.objects.create(name="Tom") 234 | await User.objects.create(name="Jane") 235 | 236 | users = await User.objects.offset(1).limit(1).all() 237 | assert users[0].name == "Jane" 238 | 239 | 240 | async def test_model_first(): 241 | tom = await User.objects.create(name="Tom") 242 | jane = await User.objects.create(name="Jane") 243 | 244 | assert await User.objects.first() == tom 245 | assert await User.objects.first(name="Jane") == jane 246 | assert await User.objects.filter(name="Jane").first() == jane 247 | assert await User.objects.filter(name="Lucy").first() is None 248 | 249 | 250 | async def test_model_search(): 251 | tom = await User.objects.create(name="Tom", language="English") 252 | tshirt = await Product.objects.create(name="T-Shirt", rating=5) 253 | 254 | assert await User.objects.search(term="").first() == tom 255 | assert await User.objects.search(term="tom").first() == tom 256 | assert await Product.objects.search(term="shirt").first() == tshirt 257 | 258 | 259 | async def test_model_get_or_create(): 260 | user, created = await User.objects.get_or_create( 261 | name="Tom", defaults={"language": "Spanish"} 262 | ) 263 | assert created is True 264 | assert user.name == "Tom" 265 | assert user.language == "Spanish" 266 | 267 | user, created = await User.objects.get_or_create( 268 | name="Tom", defaults={"language": "English"} 269 | ) 270 | assert created is False 271 | assert user.name == "Tom" 272 | assert user.language == "Spanish" 273 | 274 | 275 | async def test_queryset_delete(): 276 | shirt = await Product.objects.create(name="Shirt", rating=5) 277 | await Product.objects.create(name="Belt", rating=5) 278 | await Product.objects.create(name="Tie", rating=5) 279 | 280 | await Product.objects.filter(pk=shirt.id).delete() 281 | assert await Product.objects.count() == 2 282 | 283 | await Product.objects.delete() 284 | assert await Product.objects.count() == 0 285 | 286 | 287 | async def test_queryset_update(): 288 | shirt = await Product.objects.create(name="Shirt", rating=5) 289 | tie = await Product.objects.create(name="Tie", rating=5) 290 | 291 | await Product.objects.filter(pk=shirt.id).update(rating=3) 292 | shirt = await Product.objects.get(pk=shirt.id) 293 | assert shirt.rating == 3 294 | assert await Product.objects.get(pk=tie.id) == tie 295 | 296 | await Product.objects.update(rating=3) 297 | tie = await Product.objects.get(pk=tie.id) 298 | assert tie.rating == 3 299 | 300 | 301 | async def test_model_update_or_create(): 302 | user, created = await User.objects.update_or_create( 303 | name="Tom", language="English", defaults={"name": "Jane"} 304 | ) 305 | assert created is True 306 | assert user.name == "Jane" 307 | assert user.language == "English" 308 | 309 | user, created = await User.objects.update_or_create( 310 | name="Jane", language="English", defaults={"name": "Tom"} 311 | ) 312 | assert created is False 313 | assert user.name == "Tom" 314 | assert user.language == "English" 315 | 316 | 317 | async def test_model_sqlalchemy_filter_operators(): 318 | user = await User.objects.create(name="George") 319 | 320 | assert user == await User.objects.filter(User.columns.name == "George").get() 321 | assert user == await User.objects.filter(User.columns.name.is_not(None)).get() 322 | assert ( 323 | user 324 | == await User.objects.filter(User.columns.name.startswith("G")) 325 | .filter(User.columns.name.endswith("e")) 326 | .get() 327 | ) 328 | 329 | assert user == await User.objects.exclude(User.columns.name != "Jack").get() 330 | 331 | shirt = await Product.objects.create(name="100%-Cotton", rating=3) 332 | assert ( 333 | shirt 334 | == await Product.objects.filter(Product.columns.name.contains("Cotton")).get() 335 | ) 336 | --------------------------------------------------------------------------------