├── .codecov.yml
├── .github
├── ISSUE_TEMPLATE
│ ├── 2-bug-report.md
│ ├── 3-feature-request.md
│ └── config.yml
└── workflows
│ ├── publish.yml
│ └── test-suite.yml
├── .gitignore
├── LICENSE.md
├── README.md
├── docs
├── declaring_models.md
├── index.md
├── making_queries.md
└── relationships.md
├── mkdocs.yml
├── orm
├── __init__.py
├── constants.py
├── exceptions.py
├── fields.py
├── models.py
└── sqlalchemy_fields.py
├── requirements.txt
├── scripts
├── README.md
├── build
├── check
├── clean
├── coverage
├── docs
├── install
├── lint
├── publish
└── test
├── setup.cfg
├── setup.py
└── tests
├── conftest.py
├── settings.py
├── test_columns.py
├── test_foreignkey.py
└── test_models.py
/.codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | precision: 2
3 | round: down
4 | range: "80...100"
5 |
6 | status:
7 | project: yes
8 | patch: yes
9 | changes: yes
10 |
11 | comment: off
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/2-bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Report a bug to help improve this project
4 | ---
5 |
6 | ### Checklist
7 |
8 |
9 |
10 | - [ ] The bug is reproducible against the latest release and/or `master`.
11 | - [ ] There are no similar issues or pull requests to fix it yet.
12 |
13 | ### Describe the bug
14 |
15 |
16 |
17 | ### To reproduce
18 |
19 |
25 |
26 | ### Expected behavior
27 |
28 |
29 |
30 | ### Actual behavior
31 |
32 |
33 |
34 | ### Debugging material
35 |
36 |
42 |
43 | ### Environment
44 |
45 | - OS:
46 | - Python version:
47 | - ORM version:
48 |
49 | ### Additional context
50 |
51 |
54 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/3-feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project.
4 | ---
5 |
6 | ### Checklist
7 |
8 |
9 |
10 | - [ ] There are no similar issues or pull requests for this yet.
11 | - [ ] I discussed this idea on the [community chat](https://gitter.im/encode/community) and feedback is positive.
12 |
13 | ### Is your feature related to a problem? Please describe.
14 |
15 |
18 |
19 | ## Describe the solution you would like.
20 |
21 |
25 |
26 | ## Describe alternatives you considered
27 |
28 |
30 |
31 | ## Additional context
32 |
33 |
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | # Ref: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository#configuring-the-template-chooser
2 | blank_issues_enabled: true
3 | contact_links:
4 | - name: Question
5 | url: https://gitter.im/encode/community
6 | about: >
7 | Ask a question
8 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Publish
3 |
4 | on:
5 | push:
6 | tags:
7 | - '*'
8 |
9 | jobs:
10 | publish:
11 | name: "Publish release"
12 | runs-on: "ubuntu-latest"
13 |
14 | steps:
15 | - uses: "actions/checkout@v2"
16 | - uses: "actions/setup-python@v2"
17 | with:
18 | python-version: 3.7
19 | - name: "Install dependencies"
20 | run: "scripts/install"
21 | - name: "Build package & docs"
22 | run: "scripts/build"
23 | - name: "Publish to PyPI & deploy docs"
24 | run: "scripts/publish"
25 | env:
26 | TWINE_USERNAME: __token__
27 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
28 |
--------------------------------------------------------------------------------
/.github/workflows/test-suite.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Test Suite
3 |
4 | on:
5 | push:
6 | branches: ["master"]
7 | pull_request:
8 | branches: ["master"]
9 |
10 | jobs:
11 | tests:
12 | name: "Python ${{ matrix.python-version }}"
13 | runs-on: "ubuntu-latest"
14 |
15 | strategy:
16 | matrix:
17 | python-version: ["3.7", "3.8", "3.9", "3.10"]
18 |
19 | services:
20 | mysql:
21 | image: mysql:5.7
22 | env:
23 | MYSQL_USER: username
24 | MYSQL_PASSWORD: password
25 | MYSQL_ROOT_PASSWORD: password
26 | MYSQL_DATABASE: testsuite
27 | ports:
28 | - 3306:3306
29 | options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3
30 |
31 | postgres:
32 | image: postgres:10.8
33 | env:
34 | POSTGRES_USER: username
35 | POSTGRES_PASSWORD: password
36 | POSTGRES_DB: testsuite
37 | ports:
38 | - 5432:5432
39 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
40 |
41 | steps:
42 | - uses: "actions/checkout@v2"
43 | - uses: "actions/setup-python@v2"
44 | with:
45 | python-version: "${{ matrix.python-version }}"
46 | - name: "Install dependencies"
47 | run: "scripts/install"
48 | - name: "Run linting checks"
49 | run: "scripts/check"
50 | - name: "Build package & docs"
51 | run: "scripts/build"
52 | - name: "Run tests with PostgreSQL"
53 | env:
54 | TEST_DATABASE_URL: "postgresql://username:password@localhost:5432/testsuite"
55 | run: "scripts/test"
56 | - name: "Run tests with MySQL"
57 | env:
58 | TEST_DATABASE_URL: "mysql://username:password@localhost:3306/testsuite"
59 | run: "scripts/test"
60 | - name: "Run tests with SQLite"
61 | env:
62 | TEST_DATABASE_URL: "sqlite:///testsuite"
63 | run: "scripts/test"
64 | - name: "Enforce coverage"
65 | run: "scripts/coverage"
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .coverage
3 | .pytest_cache/
4 | .mypy_cache/
5 | *.egg-info/
6 | htmlcov/
7 | venv/
8 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/).
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ORM
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | The `orm` package is an async ORM for Python, with support for Postgres,
16 | MySQL, and SQLite. ORM is built with:
17 |
18 | * [SQLAlchemy core][sqlalchemy-core] for query building.
19 | * [`databases`][databases] for cross-database async support.
20 | * [`typesystem`][typesystem] for data validation.
21 |
22 | Because ORM is built on SQLAlchemy core, you can use Alembic to provide
23 | database migrations.
24 |
25 | ---
26 |
27 | **Documentation**: [https://www.encode.io/orm](https://www.encode.io/orm)
28 |
29 | ---
30 |
31 | ## Installation
32 |
33 | ```shell
34 | $ pip install orm
35 | ```
36 |
37 | You can install the required database drivers with:
38 |
39 | ```shell
40 | $ pip install orm[postgresql]
41 | $ pip install orm[mysql]
42 | $ pip install orm[sqlite]
43 | ```
44 |
45 | Driver support is provided using one of [asyncpg][asyncpg], [aiomysql][aiomysql], or [aiosqlite][aiosqlite].
46 |
47 | ---
48 |
49 | ## Quickstart
50 |
51 | **Note**: Use `ipython` to try this from the console, since it supports `await`.
52 |
53 | ```python
54 | import databases
55 | import orm
56 |
57 | database = databases.Database("sqlite:///db.sqlite")
58 | models = orm.ModelRegistry(database=database)
59 |
60 |
61 | class Note(orm.Model):
62 | tablename = "notes"
63 | registry = models
64 | fields = {
65 | "id": orm.Integer(primary_key=True),
66 | "text": orm.String(max_length=100),
67 | "completed": orm.Boolean(default=False),
68 | }
69 |
70 | # Create the tables
71 | await models.create_all()
72 |
73 | await Note.objects.create(text="Buy the groceries.", completed=False)
74 |
75 | note = await Note.objects.get(id=1)
76 | print(note)
77 | # Note(id=1)
78 | ```
79 |
80 | — 🗃 —
81 | ORM is BSD licensed code. Designed & built in Brighton, England.
82 |
83 | [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
84 | [asyncpg]: https://github.com/MagicStack/asyncpg
85 | [aiomysql]: https://github.com/aio-libs/aiomysql
86 | [aiosqlite]: https://github.com/jreese/aiosqlite
87 |
88 | [databases]: https://github.com/encode/databases
89 | [typesystem]: https://github.com/encode/typesystem
90 | [typesystem-fields]: https://www.encode.io/typesystem/fields/
91 |
--------------------------------------------------------------------------------
/docs/declaring_models.md:
--------------------------------------------------------------------------------
1 | ## Declaring models
2 |
3 | You can define models by inheriting from `orm.Model` and
4 | defining model fields in the `fields` attribute.
5 | For each defined model you need to set two special variables:
6 |
7 | * `registry` an instance of `orm.ModelRegistry`
8 | * `fields` a `dict` of `orm` fields
9 |
10 | You can also specify the table name in database by setting `tablename` attribute.
11 |
12 | ```python
13 | import databases
14 | import orm
15 |
16 | database = databases.Database("sqlite:///db.sqlite")
17 | models = orm.ModelRegistry(database=database)
18 |
19 |
20 | class Note(orm.Model):
21 | tablename = "notes"
22 | registry = models
23 | fields = {
24 | "id": orm.Integer(primary_key=True),
25 | "text": orm.String(max_length=100),
26 | "completed": orm.Boolean(default=False),
27 | }
28 | ```
29 |
30 | ORM can create or drop database and tables from models using SQLAlchemy.
31 | You can use the following methods:
32 |
33 | ```python
34 | await models.create_all()
35 |
36 | await models.drop_all()
37 | ```
38 |
39 | ## Data types
40 |
41 | The following keyword arguments are supported on all field types.
42 |
43 | * `primary_key` - A boolean. Determine if column is primary key.
44 | * `allow_null` - A boolean. Determine if column is nullable.
45 | * `default` - A value or a callable (function).
46 | * `index` - A boolean. Determine if database indexes should be created.
47 | * `unique` - A boolean. Determine if unique constraint should be created.
48 |
49 | All fields are required unless one of the following is set:
50 |
51 | * `allow_null` - A boolean. Determine if column is nullable. Sets the default to `None`.
52 | * `allow_blank` - A boolean. Determine if empty strings are allowed. Sets the default to `""`.
53 | * `default` - A value or a callable (function).
54 |
55 | Special keyword arguments for `DateTime` and `Date` fields:
56 |
57 | * `auto_now` - Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
58 | * `auto_now_add` - Automatically set the field to now when the object is first created. Useful for creation of timestamps.
59 |
60 | Default=`datetime.date.today()` for `DateField` and `datetime.datetime.now()` for `DateTimeField`.
61 |
62 | !!! note
63 | Setting `auto_now` or `auto_now_add` to True will cause the field to be read_only.
64 |
65 | The following column types are supported.
66 | See `TypeSystem` for [type-specific validation keyword arguments][typesystem-fields].
67 |
68 | * `orm.BigInteger()`
69 | * `orm.Boolean()`
70 | * `orm.Date(auto_now, auto_now_add)`
71 | * `orm.DateTime(auto_now, auto_now_add)`
72 | * `orm.Decimal()`
73 | * `orm.Email(max_length)`
74 | * `orm.Enum()`
75 | * `orm.Float()`
76 | * `orm.Integer()`
77 | * `orm.IPAddress()`
78 | * `orm.String(max_length)`
79 | * `orm.Text()`
80 | * `orm.Time()`
81 | * `orm.URL(max_length)`
82 | * `orm.UUID()`
83 | * `orm.JSON()`
84 |
85 | [typesystem-fields]: https://www.encode.io/typesystem/fields/
86 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # ORM
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | The `orm` package is an async ORM for Python, with support for Postgres,
16 | MySQL, and SQLite. ORM is built with:
17 |
18 | * [SQLAlchemy core][sqlalchemy-core] for query building.
19 | * [`databases`][databases] for cross-database async support.
20 | * [`typesystem`][typesystem] for data validation.
21 |
22 | Because ORM is built on SQLAlchemy core, you can use Alembic to provide
23 | database migrations.
24 |
25 | **ORM is still under development: We recommend pinning any dependencies with `orm~=0.2`**
26 |
27 | ---
28 |
29 | ## Installation
30 |
31 | ```shell
32 | $ pip install orm
33 | ```
34 |
35 | You can install the required database drivers with:
36 |
37 | ```shell
38 | $ pip install orm[postgresql]
39 | $ pip install orm[mysql]
40 | $ pip install orm[sqlite]
41 | ```
42 |
43 | Driver support is provided using one of [asyncpg][asyncpg], [aiomysql][aiomysql], or [aiosqlite][aiosqlite].
44 |
45 | ---
46 |
47 | ## Quickstart
48 |
49 | **Note**: Use `ipython` to try this from the console, since it supports `await`.
50 |
51 | ```python
52 | import databases
53 | import orm
54 |
55 | database = databases.Database("sqlite:///db.sqlite")
56 | models = orm.ModelRegistry(database=database)
57 |
58 |
59 | class Note(orm.Model):
60 | tablename = "notes"
61 | registry = models
62 | fields = {
63 | "id": orm.Integer(primary_key=True),
64 | "text": orm.String(max_length=100),
65 | "completed": orm.Boolean(default=False),
66 | }
67 |
68 | # Create the database and tables
69 | await models.create_all()
70 |
71 | await Note.objects.create(text="Buy the groceries.", completed=False)
72 |
73 | note = await Note.objects.get(id=1)
74 | print(note)
75 | # Note(id=1)
76 | ```
77 |
78 | [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
79 | [asyncpg]: https://github.com/MagicStack/asyncpg
80 | [aiomysql]: https://github.com/aio-libs/aiomysql
81 | [aiosqlite]: https://github.com/jreese/aiosqlite
82 |
83 | [databases]: https://github.com/encode/databases
84 | [typesystem]: https://github.com/encode/typesystem
85 | [typesystem-fields]: https://www.encode.io/typesystem/fields/
86 |
--------------------------------------------------------------------------------
/docs/making_queries.md:
--------------------------------------------------------------------------------
1 | Let's say you have the following model defined:
2 |
3 | ```python
4 | import databases
5 | import orm
6 |
7 | database = databases.Database("sqlite:///db.sqlite")
8 | models = orm.ModelRegistry(database=database)
9 |
10 |
11 | class Note(orm.Model):
12 | tablename = "notes"
13 | registry = models
14 | fields = {
15 | "id": orm.Integer(primary_key=True),
16 | "text": orm.String(max_length=100),
17 | "completed": orm.Boolean(default=False),
18 | }
19 | ```
20 |
21 | ORM supports two types of queryset methods.
22 | Some queryset methods return another queryset and can be chained together like `.filter()` and `order_by`:
23 |
24 | ```python
25 | Note.objects.filter(completed=True).order_by("id")
26 | ```
27 |
28 | Other queryset methods return results and should be used as final method on the queryset like `.all()`:
29 |
30 | ```python
31 | Note.objects.filter(completed=True).all()
32 | ```
33 |
34 | ## Returning Querysets
35 |
36 | ### .exclude()
37 |
38 | To exclude instances:
39 |
40 | ```python
41 | notes = await Note.objects.exclude(completed=False).all()
42 | ```
43 |
44 | ### .filter()
45 |
46 | #### Django-style lookup
47 |
48 | To filter instances:
49 |
50 | ```python
51 | notes = await Note.objects.filter(completed=True).all()
52 | ```
53 |
54 | There are some special operators defined automatically on every column:
55 |
56 | * `in` - SQL `IN` operator.
57 | * `exact` - filter instances matching exact value.
58 | * `iexact` - same as `exact` but case-insensitive.
59 | * `contains` - filter instances containing value.
60 | * `icontains` - same as `contains` but case-insensitive.
61 | * `lt` - filter instances having value `Less Than`.
62 | * `lte` - filter instances having value `Less Than Equal`.
63 | * `gt` - filter instances having value `Greater Than`.
64 | * `gte` - filter instances having value `Greater Than Equal`.
65 |
66 | Example usage:
67 |
68 | ```python
69 | notes = await Note.objects.filter(text__icontains="mum").all()
70 |
71 | notes = await Note.objects.filter(id__in=[1, 2, 3]).all()
72 | ```
73 |
74 | #### SQLAlchemy filter operators
75 |
76 | The `filter` method also accepts SQLAlchemy filter operators:
77 |
78 | ```python
79 | notes = await Note.objects.filter(Note.columns.text.contains("mum")).all()
80 |
81 | notes = await Note.objects.filter(Note.columns.id.in_([1, 2, 3])).all()
82 | ```
83 |
84 | Here `Note.columns` refers to the columns of the underlying SQLAlchemy table.
85 |
86 | !!! note
87 | Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields.
88 |
89 | ### .limit()
90 |
91 | To limit number of results:
92 |
93 | ```python
94 | await Note.objects.limit(1).all()
95 | ```
96 |
97 | ### .offset()
98 |
99 | To apply offset to query results:
100 |
101 | ```python
102 | await Note.objects.offset(1).all()
103 | ```
104 |
105 | As mentioned before, you can chain multiple queryset methods together to form a query.
106 | As an exmaple:
107 |
108 | ```python
109 | await Note.objects.order_by("id").limit(1).offset(1).all()
110 | await Note.objects.filter(text__icontains="mum").limit(2).all()
111 | ```
112 |
113 | ### .order_by()
114 |
115 | To order query results:
116 |
117 | ```python
118 | notes = await Note.objects.order_by("text", "-id").all()
119 | ```
120 |
121 | !!! note
122 | This will sort by ascending `text` and descending `id`.
123 |
124 | ## Returning results
125 |
126 | ### .all()
127 |
128 | To retrieve all the instances:
129 |
130 | ```python
131 | notes = await Note.objects.all()
132 | ```
133 |
134 | ### .create()
135 |
136 | You need to pass the required model attributes and values to the `.create()` method:
137 |
138 | ```python
139 | await Note.objects.create(text="Buy the groceries.", completed=False)
140 | await Note.objects.create(text="Call Mum.", completed=True)
141 | await Note.objects.create(text="Send invoices.", completed=True)
142 | ```
143 |
144 | ### .bulk_create()
145 |
146 | You need to pass a list of dictionaries of required fields to create multiple objects:
147 |
148 | ```python
149 | await Product.objects.bulk_create(
150 | [
151 | {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
152 | {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
153 |
154 | ]
155 | )
156 | ```
157 |
158 | ### .delete()
159 |
160 | You can `delete` instances by calling `.delete()` on a queryset:
161 |
162 | ```python
163 | await Note.objects.filter(completed=True).delete()
164 | ```
165 |
166 | It's not very common, but to delete all rows in a table:
167 |
168 | ```python
169 | await Note.objects.delete()
170 | ```
171 |
172 | You can also call `.delete()` on a queried instance:
173 |
174 | ```python
175 | note = await Note.objects.first()
176 |
177 | await note.delete()
178 | ```
179 |
180 | ### .exists()
181 |
182 | To check if any instances matching the query exist. Returns `True` or `False`.
183 |
184 | ```python
185 | await Note.objects.filter(completed=True).exists()
186 | ```
187 |
188 | ### .first()
189 |
190 | This will return the first instance or `None`:
191 |
192 | ```python
193 | note = await Note.objects.filter(completed=True).first()
194 | ```
195 |
196 | `pk` always refers to the model's primary key field:
197 |
198 | ```python
199 | note = await Note.objects.get(pk=2)
200 | note.pk # 2
201 | ```
202 |
203 | ### .get()
204 |
205 | To get only one instance:
206 |
207 | ```python
208 | note = await Note.objects.get(id=1)
209 | ```
210 |
211 | !!! note
212 | `.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`.
213 |
214 | ### .update()
215 |
216 | You can update instances by calling `.update()` on a queryset:
217 |
218 | ```python
219 | await Note.objects.filter(completed=True).update(completed=False)
220 | ```
221 |
222 | It's not very common, but to update all rows in a table:
223 |
224 | ```python
225 | await Note.objects.update(completed=False)
226 | ```
227 |
228 | You can also call `.update()` on a queried instance:
229 |
230 | ```python
231 | note = await Note.objects.first()
232 |
233 | await note.update(completed=True)
234 | ```
235 |
236 | ## Convenience Methods
237 |
238 | ### .get_or_create()
239 |
240 | To get an existing instance matching the query, or create a new one.
241 | This will return a tuple of `instance` and `created`.
242 |
243 | ```python
244 | note, created = await Note.objects.get_or_create(
245 | text="Going to car wash", defaults={"completed": False}
246 | )
247 | ```
248 |
249 | This will query a `Note` with `text` as `"Going to car wash"`,
250 | if it doesn't exist, it will use `defaults` argument to create the new instance.
251 |
252 | !!! note
253 | Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
254 |
255 |
256 | ### .update_or_create()
257 |
258 | To update an existing instance matching the query, or create a new one.
259 | This will return a tuple of `instance` and `created`.
260 |
261 | ```python
262 | note, created = await Note.objects.update_or_create(
263 | text="Going to car wash", defaults={"completed": True}
264 | )
265 | ```
266 |
267 | This will query a `Note` with `text` as `"Going to car wash"`,
268 | if an instance is found, it will use the `defaults` argument to update the instance.
269 | If it matches no records, it will use the combination of arguments to create the new instance.
270 |
271 | !!! note
272 | Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
273 |
--------------------------------------------------------------------------------
/docs/relationships.md:
--------------------------------------------------------------------------------
1 | ## ForeignKey
2 |
3 | ### Defining and querying relationships
4 |
5 | ORM supports loading and filtering across foreign keys.
6 |
7 | Let's say you have the following models defined:
8 |
9 | ```python
10 | import databases
11 | import orm
12 |
13 | database = databases.Database("sqlite:///db.sqlite")
14 | models = orm.ModelRegistry(database=database)
15 |
16 |
17 | class Album(orm.Model):
18 | tablename = "albums"
19 | registry = models
20 | fields = {
21 | "id": orm.Integer(primary_key=True),
22 | "name": orm.String(max_length=100),
23 | }
24 |
25 |
26 | class Track(orm.Model):
27 | tablename = "tracks"
28 | registry = models
29 | fields = {
30 | "id": orm.Integer(primary_key=True),
31 | "album": orm.ForeignKey(Album),
32 | "title": orm.String(max_length=100),
33 | "position": orm.Integer(),
34 | }
35 | ```
36 |
37 | You can create some `Album` and `Track` instances:
38 |
39 | ```python
40 | malibu = await Album.objects.create(name="Malibu")
41 | await Track.objects.create(album=malibu, title="The Bird", position=1)
42 | await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2)
43 | await Track.objects.create(album=malibu, title="The Waters", position=3)
44 |
45 | fantasies = await Album.objects.create(name="Fantasies")
46 | await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
47 | await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
48 | ```
49 |
50 | To fetch an instance, without loading a foreign key relationship on it:
51 |
52 | ```python
53 | track = await Track.objects.get(title="The Bird")
54 |
55 | # We have an album instance, but it only has the primary key populated
56 | print(track.album) # Album(id=1) [sparse]
57 | print(track.album.pk) # 1
58 | print(track.album.name) # Raises AttributeError
59 | ```
60 |
61 | You can load the relationship from the database:
62 |
63 | ```python
64 | await track.album.load()
65 | assert track.album.name == "Malibu"
66 | ```
67 |
68 | You can also fetch an instance, loading the foreign key relationship with it:
69 |
70 | ```python
71 | track = await Track.objects.select_related("album").get(title="The Bird")
72 | assert track.album.name == "Malibu"
73 | ```
74 |
75 | To fetch an instance, filtering across a foreign key relationship:
76 |
77 | ```python
78 | tracks = Track.objects.filter(album__name="Fantasies")
79 | assert len(tracks) == 2
80 |
81 | tracks = Track.objects.filter(album__name__iexact="fantasies")
82 | assert len(tracks) == 2
83 | ```
84 |
85 | ### ForeignKey constraints
86 |
87 | `ForeigknKey` supports specifying a constraint through `on_delete` argument.
88 |
89 | This will result in a SQL `ON DELETE` query being generated when the referenced object is removed.
90 |
91 | With the following definition:
92 |
93 | ```python
94 | class Album(orm.Model):
95 | tablename = "albums"
96 | registry = models
97 | fields = {
98 | "id": orm.Integer(primary_key=True),
99 | "name": orm.String(max_length=100),
100 | }
101 |
102 |
103 | class Track(orm.Model):
104 | tablename = "tracks"
105 | registry = models
106 | fields = {
107 | "id": orm.Integer(primary_key=True),
108 | "album": orm.ForeignKey(Album, on_delete=orm.CASCADE),
109 | "title": orm.String(max_length=100),
110 | }
111 | ```
112 |
113 | `Track` model defines `orm.ForeignKey(Album, on_delete=orm.CASCADE)` so whenever an `Album` object is removed,
114 | all `Track` objects referencing that `Album` will also be removed.
115 |
116 | Available options for `on_delete` are:
117 |
118 | * `CASCADE`
119 |
120 | This will remove all referencing objects.
121 |
122 | * `RESTRICT`
123 |
124 | This will restrict removing referenced object, if there are objects referencing it.
125 | A database driver exception will be raised.
126 |
127 | * `SET NULL`
128 |
129 | This will set referencing objects `ForeignKey` column to `NULL`.
130 | The `ForeignKey` defined here should also have `allow_null=True`.
131 |
132 |
133 | ## OneToOne
134 |
135 | Creating a `OneToOne` relationship between models, this is basically
136 | the same as `ForeignKey` but it uses `unique=True` on the ForeignKey column:
137 |
138 | ```python
139 | class Profile(orm.Model):
140 | registry = models
141 | fields = {
142 | "id": orm.Integer(primary_key=True),
143 | "website": orm.String(max_length=100),
144 | }
145 |
146 |
147 | class Person(orm.Model):
148 | registry = models
149 | fields = {
150 | "id": orm.Integer(primary_key=True),
151 | "email": orm.String(max_length=100),
152 | "profile": orm.OneToOne(Profile),
153 | }
154 | ```
155 |
156 | You can create a `Profile` and `Person` instance:
157 |
158 | ```python
159 | profile = await Profile.objects.create(website="https://encode.io")
160 | await Person.objects.create(email="info@encode.io", profile=profile)
161 | ```
162 |
163 | Now creating another `Person` using the same `profile` will fail
164 | and will raise an exception:
165 |
166 | ```python
167 | await Person.objects.create(email="info@encode.io", profile=profile)
168 | ```
169 |
170 | `OneToOne` accepts the same `on_delete` parameters as `ForeignKey` which is
171 | described [here](#foreignkey-constraints).
172 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: ORM
2 | site_description: An async ORM.
3 |
4 | theme:
5 | name: 'material'
6 |
7 | repo_name: encode/orm
8 | repo_url: https://github.com/encode/orm
9 | # edit_uri: ""
10 |
11 | nav:
12 | - Introduction: 'index.md'
13 | - Declaring Models: 'declaring_models.md'
14 | - Making Queries: 'making_queries.md'
15 | - Relationships: 'relationships.md'
16 |
17 | markdown_extensions:
18 | - mkautodoc
19 | - admonition
20 | - pymdownx.highlight
21 | - pymdownx.superfences
22 |
--------------------------------------------------------------------------------
/orm/__init__.py:
--------------------------------------------------------------------------------
1 | from orm.constants import CASCADE, RESTRICT, SET_NULL
2 | from orm.exceptions import MultipleMatches, NoMatch
3 | from orm.fields import (
4 | JSON,
5 | URL,
6 | UUID,
7 | BigInteger,
8 | Boolean,
9 | Date,
10 | DateTime,
11 | Decimal,
12 | Email,
13 | Enum,
14 | Float,
15 | ForeignKey,
16 | Integer,
17 | IPAddress,
18 | OneToOne,
19 | String,
20 | Text,
21 | Time,
22 | )
23 | from orm.models import Model, ModelRegistry
24 |
25 | __version__ = "0.3.1"
26 | __all__ = [
27 | "CASCADE",
28 | "RESTRICT",
29 | "SET_NULL",
30 | "NoMatch",
31 | "MultipleMatches",
32 | "BigInteger",
33 | "Boolean",
34 | "Date",
35 | "DateTime",
36 | "Decimal",
37 | "Email",
38 | "Enum",
39 | "Float",
40 | "ForeignKey",
41 | "Integer",
42 | "IPAddress",
43 | "JSON",
44 | "OneToOne",
45 | "String",
46 | "Text",
47 | "Time",
48 | "URL",
49 | "UUID",
50 | "Model",
51 | "ModelRegistry",
52 | ]
53 |
--------------------------------------------------------------------------------
/orm/constants.py:
--------------------------------------------------------------------------------
1 | CASCADE = "CASCADE"
2 | RESTRICT = "RESTRICT"
3 | SET_NULL = "SET NULL"
4 |
--------------------------------------------------------------------------------
/orm/exceptions.py:
--------------------------------------------------------------------------------
1 | class NoMatch(Exception):
2 | pass
3 |
4 |
5 | class MultipleMatches(Exception):
6 | pass
7 |
--------------------------------------------------------------------------------
/orm/fields.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from datetime import date, datetime
3 |
4 | import sqlalchemy
5 | import typesystem
6 |
7 | from orm.sqlalchemy_fields import GUID, GenericIP
8 |
9 |
10 | class ModelField:
11 | def __init__(
12 | self,
13 | primary_key: bool = False,
14 | index: bool = False,
15 | unique: bool = False,
16 | **kwargs: typing.Any,
17 | ) -> None:
18 | if primary_key:
19 | kwargs["read_only"] = True
20 | self.allow_null = kwargs.get("allow_null", False)
21 | self.primary_key = primary_key
22 | self.index = index
23 | self.unique = unique
24 | self.validator = self.get_validator(**kwargs)
25 |
26 | def get_column(self, name: str) -> sqlalchemy.Column:
27 | column_type = self.get_column_type()
28 | constraints = self.get_constraints()
29 | return sqlalchemy.Column(
30 | name,
31 | column_type,
32 | *constraints,
33 | primary_key=self.primary_key,
34 | nullable=self.allow_null and not self.primary_key,
35 | index=self.index,
36 | unique=self.unique,
37 | )
38 |
39 | def get_validator(self, **kwargs) -> typesystem.Field:
40 | raise NotImplementedError() # pragma: no cover
41 |
42 | def get_column_type(self) -> sqlalchemy.types.TypeEngine:
43 | raise NotImplementedError() # pragma: no cover
44 |
45 | def get_constraints(self):
46 | return []
47 |
48 | def expand_relationship(self, value):
49 | return value
50 |
51 |
52 | class String(ModelField):
53 | def __init__(self, **kwargs):
54 | assert "max_length" in kwargs, "max_length is required"
55 | super().__init__(**kwargs)
56 |
57 | def get_validator(self, **kwargs) -> typesystem.Field:
58 | return typesystem.String(**kwargs)
59 |
60 | def get_column_type(self):
61 | return sqlalchemy.String(length=self.validator.max_length)
62 |
63 |
64 | class Text(ModelField):
65 | def get_validator(self, **kwargs) -> typesystem.Field:
66 | return typesystem.Text(**kwargs)
67 |
68 | def get_column_type(self):
69 | return sqlalchemy.Text()
70 |
71 |
72 | class Integer(ModelField):
73 | def get_validator(self, **kwargs) -> typesystem.Field:
74 | return typesystem.Integer(**kwargs)
75 |
76 | def get_column_type(self):
77 | return sqlalchemy.Integer()
78 |
79 |
80 | class Float(ModelField):
81 | def get_validator(self, **kwargs) -> typesystem.Field:
82 | return typesystem.Float(**kwargs)
83 |
84 | def get_column_type(self):
85 | return sqlalchemy.Float()
86 |
87 |
88 | class BigInteger(ModelField):
89 | def get_validator(self, **kwargs) -> typesystem.Field:
90 | return typesystem.Integer(**kwargs)
91 |
92 | def get_column_type(self):
93 | return sqlalchemy.BigInteger()
94 |
95 |
96 | class Boolean(ModelField):
97 | def get_validator(self, **kwargs) -> typesystem.Field:
98 | return typesystem.Boolean(**kwargs)
99 |
100 | def get_column_type(self):
101 | return sqlalchemy.Boolean()
102 |
103 |
104 | class AutoNowMixin(ModelField):
105 | def __init__(self, auto_now=False, auto_now_add=False, **kwargs):
106 | self.auto_now = auto_now
107 | self.auto_now_add = auto_now_add
108 | if auto_now_add and auto_now:
109 | raise ValueError("auto_now and auto_now_add cannot be both True")
110 | if auto_now_add or auto_now:
111 | kwargs["read_only"] = True
112 | super().__init__(**kwargs)
113 |
114 |
115 | class DateTime(AutoNowMixin):
116 | def get_validator(self, **kwargs) -> typesystem.Field:
117 | if self.auto_now_add or self.auto_now:
118 | kwargs["default"] = datetime.now
119 | return typesystem.DateTime(**kwargs)
120 |
121 | def get_column_type(self):
122 | return sqlalchemy.DateTime()
123 |
124 |
125 | class Date(AutoNowMixin):
126 | def get_validator(self, **kwargs) -> typesystem.Field:
127 | if self.auto_now_add or self.auto_now:
128 | kwargs["default"] = date.today
129 | return typesystem.Date(**kwargs)
130 |
131 | def get_column_type(self):
132 | return sqlalchemy.Date()
133 |
134 |
135 | class Time(ModelField):
136 | def get_validator(self, **kwargs) -> typesystem.Field:
137 | return typesystem.Time(**kwargs)
138 |
139 | def get_column_type(self):
140 | return sqlalchemy.Time()
141 |
142 |
143 | class JSON(ModelField):
144 | def get_validator(self, **kwargs) -> typesystem.Field:
145 | return typesystem.Any(**kwargs)
146 |
147 | def get_column_type(self):
148 | return sqlalchemy.JSON()
149 |
150 |
151 | class ForeignKey(ModelField):
152 | class ForeignKeyValidator(typesystem.Field):
153 | def validate(self, value):
154 | return value.pk
155 |
156 | def __init__(
157 | self, to, allow_null: bool = False, on_delete: typing.Optional[str] = None
158 | ):
159 | super().__init__(allow_null=allow_null)
160 | self.to = to
161 | self.on_delete = on_delete
162 |
163 | @property
164 | def target(self):
165 | if not hasattr(self, "_target"):
166 | if isinstance(self.to, str):
167 | self._target = self.registry.models[self.to]
168 | else:
169 | self._target = self.to
170 | return self._target
171 |
172 | def get_validator(self, **kwargs) -> typesystem.Field:
173 | return self.ForeignKeyValidator(**kwargs)
174 |
175 | def get_column(self, name: str) -> sqlalchemy.Column:
176 | target = self.target
177 | to_field = target.fields[target.pkname]
178 |
179 | column_type = to_field.get_column_type()
180 | constraints = [
181 | sqlalchemy.schema.ForeignKey(
182 | f"{target.tablename}.{target.pkname}", ondelete=self.on_delete
183 | )
184 | ]
185 | return sqlalchemy.Column(
186 | name,
187 | column_type,
188 | *constraints,
189 | nullable=self.allow_null,
190 | )
191 |
192 | def expand_relationship(self, value):
193 | target = self.target
194 | if isinstance(value, target):
195 | return value
196 | return target(pk=value)
197 |
198 |
199 | class OneToOne(ForeignKey):
200 | def get_column(self, name: str) -> sqlalchemy.Column:
201 | target = self.target
202 | to_field = target.fields[target.pkname]
203 |
204 | column_type = to_field.get_column_type()
205 | constraints = [
206 | sqlalchemy.schema.ForeignKey(
207 | f"{target.tablename}.{target.pkname}", ondelete=self.on_delete
208 | ),
209 | ]
210 |
211 | return sqlalchemy.Column(
212 | name,
213 | column_type,
214 | *constraints,
215 | nullable=self.allow_null,
216 | unique=True,
217 | )
218 |
219 |
220 | class Enum(ModelField):
221 | def __init__(self, enum, **kwargs):
222 | super().__init__(**kwargs)
223 | self.enum = enum
224 |
225 | def get_validator(self, **kwargs) -> typesystem.Field:
226 | return typesystem.Any(**kwargs)
227 |
228 | def get_column_type(self):
229 | return sqlalchemy.Enum(self.enum)
230 |
231 |
232 | class Decimal(ModelField):
233 | def __init__(self, max_digits: int, decimal_places: int, **kwargs):
234 | assert max_digits, "max_digits is required"
235 | assert decimal_places, "decimal_places is required"
236 | self.max_digits = max_digits
237 | self.decimal_places = decimal_places
238 | super().__init__(**kwargs)
239 |
240 | def get_validator(self, **kwargs) -> typesystem.Field:
241 | return typesystem.Decimal(**kwargs)
242 |
243 | def get_column_type(self):
244 | return sqlalchemy.Numeric(precision=self.max_digits, scale=self.decimal_places)
245 |
246 |
247 | class UUID(ModelField):
248 | def get_validator(self, **kwargs) -> typesystem.Field:
249 | return typesystem.UUID(**kwargs)
250 |
251 | def get_column_type(self):
252 | return GUID()
253 |
254 |
255 | class Email(String):
256 | def get_validator(self, **kwargs) -> typesystem.Field:
257 | return typesystem.Email(**kwargs)
258 |
259 | def get_column_type(self):
260 | return sqlalchemy.String(length=self.validator.max_length)
261 |
262 |
263 | class IPAddress(ModelField):
264 | def get_validator(self, **kwargs) -> typesystem.Field:
265 | return typesystem.IPAddress(**kwargs)
266 |
267 | def get_column_type(self):
268 | return GenericIP()
269 |
270 |
271 | class URL(String):
272 | def get_validator(self, **kwargs) -> typesystem.Field:
273 | return typesystem.URL(**kwargs)
274 |
275 | def get_column_type(self):
276 | return sqlalchemy.String(length=self.validator.max_length)
277 |
--------------------------------------------------------------------------------
/orm/models.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import databases
4 | import sqlalchemy
5 | import typesystem
6 | from sqlalchemy.ext.asyncio import create_async_engine
7 |
8 | from orm.exceptions import MultipleMatches, NoMatch
9 | from orm.fields import Date, DateTime, String, Text
10 |
11 | FILTER_OPERATORS = {
12 | "exact": "__eq__",
13 | "iexact": "ilike",
14 | "contains": "like",
15 | "icontains": "ilike",
16 | "in": "in_",
17 | "gt": "__gt__",
18 | "gte": "__ge__",
19 | "lt": "__lt__",
20 | "lte": "__le__",
21 | }
22 |
23 |
24 | def _update_auto_now_fields(values, fields):
25 | for key, value in fields.items():
26 | if isinstance(value, (DateTime, Date)) and value.auto_now:
27 | values[key] = value.validator.get_default_value()
28 | return values
29 |
30 |
31 | class ModelRegistry:
32 | def __init__(self, database: databases.Database) -> None:
33 | self.database = database
34 | self.models = {}
35 | self._metadata = sqlalchemy.MetaData()
36 |
37 | @property
38 | def metadata(self):
39 | for model_cls in self.models.values():
40 | model_cls.build_table()
41 | return self._metadata
42 |
43 | async def create_all(self):
44 | url = self._get_database_url()
45 | engine = create_async_engine(url)
46 |
47 | async with self.database:
48 | async with engine.begin() as conn:
49 | await conn.run_sync(self.metadata.create_all)
50 |
51 | await engine.dispose()
52 |
53 | async def drop_all(self):
54 | url = self._get_database_url()
55 | engine = create_async_engine(url)
56 |
57 | async with self.database:
58 | async with engine.begin() as conn:
59 | await conn.run_sync(self.metadata.drop_all)
60 |
61 | await engine.dispose()
62 |
63 | def _get_database_url(self) -> str:
64 | url = self.database.url
65 | if not url.driver:
66 | if url.dialect == "postgresql":
67 | url = url.replace(driver="asyncpg")
68 | elif url.dialect == "mysql":
69 | url = url.replace(driver="aiomysql")
70 | elif url.dialect == "sqlite":
71 | url = url.replace(driver="aiosqlite")
72 | return str(url)
73 |
74 |
75 | class ModelMeta(type):
76 | def __new__(cls, name, bases, attrs):
77 | model_class = super().__new__(cls, name, bases, attrs)
78 |
79 | if "registry" in attrs:
80 | model_class.database = attrs["registry"].database
81 | attrs["registry"].models[name] = model_class
82 |
83 | if "tablename" not in attrs:
84 | setattr(model_class, "tablename", name.lower())
85 |
86 | for name, field in attrs.get("fields", {}).items():
87 | setattr(field, "registry", attrs.get("registry"))
88 | if field.primary_key:
89 | model_class.pkname = name
90 |
91 | return model_class
92 |
93 | @property
94 | def table(cls):
95 | if not hasattr(cls, "_table"):
96 | cls._table = cls.build_table()
97 | return cls._table
98 |
99 | @property
100 | def columns(cls) -> sqlalchemy.sql.ColumnCollection:
101 | return cls._table.columns
102 |
103 |
104 | class QuerySet:
105 | ESCAPE_CHARACTERS = ["%", "_"]
106 |
107 | def __init__(
108 | self,
109 | model_cls=None,
110 | filter_clauses=None,
111 | select_related=None,
112 | limit_count=None,
113 | offset=None,
114 | order_by=None,
115 | ):
116 | self.model_cls = model_cls
117 | self.filter_clauses = [] if filter_clauses is None else filter_clauses
118 | self._select_related = [] if select_related is None else select_related
119 | self.limit_count = limit_count
120 | self.query_offset = offset
121 | self._order_by = [] if order_by is None else order_by
122 |
123 | def __get__(self, instance, owner):
124 | return self.__class__(model_cls=owner)
125 |
126 | @property
127 | def database(self):
128 | return self.model_cls.registry.database
129 |
130 | @property
131 | def table(self) -> sqlalchemy.Table:
132 | return self.model_cls.table
133 |
134 | @property
135 | def schema(self):
136 | fields = {key: field.validator for key, field in self.model_cls.fields.items()}
137 | return typesystem.Schema(fields=fields)
138 |
139 | @property
140 | def pkname(self):
141 | return self.model_cls.pkname
142 |
143 | def _build_select_expression(self):
144 | tables = [self.table]
145 | select_from = self.table
146 |
147 | for item in self._select_related:
148 | model_cls = self.model_cls
149 | select_from = self.table
150 | for part in item.split("__"):
151 | model_cls = model_cls.fields[part].target
152 | table = model_cls.table
153 | select_from = sqlalchemy.sql.join(select_from, table)
154 | tables.append(table)
155 |
156 | expr = sqlalchemy.sql.select(tables)
157 | expr = expr.select_from(select_from)
158 |
159 | if self.filter_clauses:
160 | if len(self.filter_clauses) == 1:
161 | clause = self.filter_clauses[0]
162 | else:
163 | clause = sqlalchemy.sql.and_(*self.filter_clauses)
164 | expr = expr.where(clause)
165 |
166 | if self._order_by:
167 | order_by = list(map(self._prepare_order_by, self._order_by))
168 | expr = expr.order_by(*order_by)
169 |
170 | if self.limit_count:
171 | expr = expr.limit(self.limit_count)
172 |
173 | if self.query_offset:
174 | expr = expr.offset(self.query_offset)
175 |
176 | return expr
177 |
178 | def filter(
179 | self,
180 | clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None,
181 | **kwargs: typing.Any,
182 | ):
183 | if clause is not None:
184 | self.filter_clauses.append(clause)
185 | return self
186 | else:
187 | return self._filter_query(**kwargs)
188 |
189 | def exclude(
190 | self,
191 | clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None,
192 | **kwargs: typing.Any,
193 | ):
194 | if clause is not None:
195 | self.filter_clauses.append(clause)
196 | return self
197 | else:
198 | return self._filter_query(_exclude=True, **kwargs)
199 |
200 | def _filter_query(self, _exclude: bool = False, **kwargs):
201 | clauses = []
202 | filter_clauses = self.filter_clauses
203 | select_related = list(self._select_related)
204 |
205 | if kwargs.get("pk"):
206 | pk_name = self.model_cls.pkname
207 | kwargs[pk_name] = kwargs.pop("pk")
208 |
209 | for key, value in kwargs.items():
210 | if "__" in key:
211 | parts = key.split("__")
212 |
213 | # Determine if we should treat the final part as a
214 | # filter operator or as a related field.
215 | if parts[-1] in FILTER_OPERATORS:
216 | op = parts[-1]
217 | field_name = parts[-2]
218 | related_parts = parts[:-2]
219 | else:
220 | op = "exact"
221 | field_name = parts[-1]
222 | related_parts = parts[:-1]
223 |
224 | model_cls = self.model_cls
225 | if related_parts:
226 | # Add any implied select_related
227 | related_str = "__".join(related_parts)
228 | if related_str not in select_related:
229 | select_related.append(related_str)
230 |
231 | # Walk the relationships to the actual model class
232 | # against which the comparison is being made.
233 | for part in related_parts:
234 | model_cls = model_cls.fields[part].target
235 |
236 | column = model_cls.table.columns[field_name]
237 |
238 | else:
239 | op = "exact"
240 | column = self.table.columns[key]
241 |
242 | # Map the operation code onto SQLAlchemy's ColumnElement
243 | # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
244 | op_attr = FILTER_OPERATORS[op]
245 | has_escaped_character = False
246 |
247 | if op in ["contains", "icontains"]:
248 | has_escaped_character = any(
249 | c for c in self.ESCAPE_CHARACTERS if c in value
250 | )
251 | if has_escaped_character:
252 | # enable escape modifier
253 | for char in self.ESCAPE_CHARACTERS:
254 | value = value.replace(char, f"\\{char}")
255 | value = f"%{value}%"
256 |
257 | if isinstance(value, Model):
258 | value = value.pk
259 |
260 | clause = getattr(column, op_attr)(value)
261 | clause.modifiers["escape"] = "\\" if has_escaped_character else None
262 |
263 | clauses.append(clause)
264 |
265 | if _exclude:
266 | filter_clauses.append(sqlalchemy.not_(sqlalchemy.sql.and_(*clauses)))
267 | else:
268 | filter_clauses += clauses
269 |
270 | return self.__class__(
271 | model_cls=self.model_cls,
272 | filter_clauses=filter_clauses,
273 | select_related=select_related,
274 | limit_count=self.limit_count,
275 | offset=self.query_offset,
276 | order_by=self._order_by,
277 | )
278 |
279 | def search(self, term: typing.Any):
280 | if not term:
281 | return self
282 |
283 | filter_clauses = list(self.filter_clauses)
284 | value = f"%{term}%"
285 |
286 | # has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in term)
287 | # if has_escaped_character:
288 | # # enable escape modifier
289 | # for char in self.ESCAPE_CHARACTERS:
290 | # term = term.replace(char, f'\\{char}')
291 | # term = f"%{value}%"
292 | #
293 | # clause.modifiers['escape'] = '\\' if has_escaped_character else None
294 |
295 | search_fields = [
296 | name
297 | for name, field in self.model_cls.fields.items()
298 | if isinstance(field, (String, Text))
299 | ]
300 | search_clauses = [
301 | self.table.columns[name].ilike(value) for name in search_fields
302 | ]
303 |
304 | if len(search_clauses) > 1:
305 | filter_clauses.append(sqlalchemy.sql.or_(*search_clauses))
306 | else:
307 | filter_clauses.extend(search_clauses)
308 |
309 | return self.__class__(
310 | model_cls=self.model_cls,
311 | filter_clauses=filter_clauses,
312 | select_related=self._select_related,
313 | limit_count=self.limit_count,
314 | offset=self.query_offset,
315 | order_by=self._order_by,
316 | )
317 |
318 | def order_by(self, *order_by):
319 | return self.__class__(
320 | model_cls=self.model_cls,
321 | filter_clauses=self.filter_clauses,
322 | select_related=self._select_related,
323 | limit_count=self.limit_count,
324 | offset=self.query_offset,
325 | order_by=order_by,
326 | )
327 |
328 | def select_related(self, related):
329 | if not isinstance(related, (list, tuple)):
330 | related = [related]
331 |
332 | related = list(self._select_related) + related
333 | return self.__class__(
334 | model_cls=self.model_cls,
335 | filter_clauses=self.filter_clauses,
336 | select_related=related,
337 | limit_count=self.limit_count,
338 | offset=self.query_offset,
339 | order_by=self._order_by,
340 | )
341 |
342 | async def exists(self) -> bool:
343 | expr = self._build_select_expression()
344 | expr = sqlalchemy.exists(expr).select()
345 | return await self.database.fetch_val(expr)
346 |
347 | def limit(self, limit_count: int):
348 | return self.__class__(
349 | model_cls=self.model_cls,
350 | filter_clauses=self.filter_clauses,
351 | select_related=self._select_related,
352 | limit_count=limit_count,
353 | offset=self.query_offset,
354 | order_by=self._order_by,
355 | )
356 |
357 | def offset(self, offset: int):
358 | return self.__class__(
359 | model_cls=self.model_cls,
360 | filter_clauses=self.filter_clauses,
361 | select_related=self._select_related,
362 | limit_count=self.limit_count,
363 | offset=offset,
364 | order_by=self._order_by,
365 | )
366 |
367 | async def count(self) -> int:
368 | expr = self._build_select_expression().alias("subquery_for_count")
369 | expr = sqlalchemy.func.count().select().select_from(expr)
370 | return await self.database.fetch_val(expr)
371 |
372 | async def all(self, **kwargs):
373 | if kwargs:
374 | return await self.filter(**kwargs).all()
375 |
376 | expr = self._build_select_expression()
377 | rows = await self.database.fetch_all(expr)
378 | return [
379 | self.model_cls._from_row(row, select_related=self._select_related)
380 | for row in rows
381 | ]
382 |
383 | async def get(self, **kwargs):
384 | if kwargs:
385 | return await self.filter(**kwargs).get()
386 |
387 | expr = self._build_select_expression().limit(2)
388 | rows = await self.database.fetch_all(expr)
389 |
390 | if not rows:
391 | raise NoMatch()
392 | if len(rows) > 1:
393 | raise MultipleMatches()
394 | return self.model_cls._from_row(rows[0], select_related=self._select_related)
395 |
396 | async def first(self, **kwargs):
397 | if kwargs:
398 | return await self.filter(**kwargs).first()
399 |
400 | rows = await self.limit(1).all()
401 | if rows:
402 | return rows[0]
403 |
404 | def _validate_kwargs(self, **kwargs):
405 | fields = self.model_cls.fields
406 | validator = typesystem.Schema(
407 | fields={key: value.validator for key, value in fields.items()}
408 | )
409 | kwargs = validator.validate(kwargs)
410 | for key, value in fields.items():
411 | if value.validator.read_only and value.validator.has_default():
412 | kwargs[key] = value.validator.get_default_value()
413 | return kwargs
414 |
415 | async def create(self, **kwargs):
416 | kwargs = self._validate_kwargs(**kwargs)
417 | instance = self.model_cls(**kwargs)
418 | expr = self.table.insert().values(**kwargs)
419 |
420 | if self.pkname not in kwargs:
421 | instance.pk = await self.database.execute(expr)
422 | else:
423 | await self.database.execute(expr)
424 |
425 | return instance
426 |
427 | async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
428 | new_objs = [self._validate_kwargs(**obj) for obj in objs]
429 |
430 | expr = self.table.insert().values(new_objs)
431 | await self.database.execute(expr)
432 |
433 | async def delete(self) -> None:
434 | expr = self.table.delete()
435 | for filter_clause in self.filter_clauses:
436 | expr = expr.where(filter_clause)
437 |
438 | await self.database.execute(expr)
439 |
440 | async def update(self, **kwargs) -> None:
441 | fields = {
442 | key: field.validator
443 | for key, field in self.model_cls.fields.items()
444 | if key in kwargs
445 | }
446 | validator = typesystem.Schema(fields=fields)
447 | kwargs = _update_auto_now_fields(
448 | validator.validate(kwargs), self.model_cls.fields
449 | )
450 | expr = self.table.update().values(**kwargs)
451 |
452 | for filter_clause in self.filter_clauses:
453 | expr = expr.where(filter_clause)
454 |
455 | await self.database.execute(expr)
456 |
457 | async def get_or_create(
458 | self, defaults: typing.Dict[str, typing.Any], **kwargs
459 | ) -> typing.Tuple[typing.Any, bool]:
460 | try:
461 | instance = await self.get(**kwargs)
462 | return instance, False
463 | except NoMatch:
464 | kwargs.update(defaults)
465 | instance = await self.create(**kwargs)
466 | return instance, True
467 |
468 | async def update_or_create(
469 | self, defaults: typing.Dict[str, typing.Any], **kwargs
470 | ) -> typing.Tuple[typing.Any, bool]:
471 | try:
472 | instance = await self.get(**kwargs)
473 | await instance.update(**defaults)
474 | return instance, False
475 | except NoMatch:
476 | kwargs.update(defaults)
477 | instance = await self.create(**kwargs)
478 | return instance, True
479 |
480 | def _prepare_order_by(self, order_by: str):
481 | reverse = order_by.startswith("-")
482 | order_by = order_by.lstrip("-")
483 | order_col = self.table.columns[order_by]
484 | return order_col.desc() if reverse else order_col
485 |
486 |
487 | class Model(metaclass=ModelMeta):
488 | objects = QuerySet()
489 |
490 | def __init__(self, **kwargs):
491 | if "pk" in kwargs:
492 | kwargs[self.pkname] = kwargs.pop("pk")
493 | for key, value in kwargs.items():
494 | if key not in self.fields:
495 | raise ValueError(
496 | f"Invalid keyword {key} for class {self.__class__.__name__}"
497 | )
498 | setattr(self, key, value)
499 |
500 | @property
501 | def pk(self):
502 | return getattr(self, self.pkname)
503 |
504 | @pk.setter
505 | def pk(self, value):
506 | setattr(self, self.pkname, value)
507 |
508 | def __repr__(self):
509 | return f"<{self.__class__.__name__}: {self}>"
510 |
511 | def __str__(self):
512 | return f"{self.__class__.__name__}({self.pkname}={self.pk})"
513 |
514 | @classmethod
515 | def build_table(cls):
516 | tablename = cls.tablename
517 | metadata = cls.registry._metadata
518 | columns = []
519 | for name, field in cls.fields.items():
520 | columns.append(field.get_column(name))
521 | return sqlalchemy.Table(tablename, metadata, *columns, extend_existing=True)
522 |
523 | @property
524 | def table(self) -> sqlalchemy.Table:
525 | return self.__class__.table
526 |
527 | async def update(self, **kwargs):
528 | fields = {
529 | key: field.validator for key, field in self.fields.items() if key in kwargs
530 | }
531 | validator = typesystem.Schema(fields=fields)
532 | kwargs = _update_auto_now_fields(validator.validate(kwargs), self.fields)
533 | pk_column = getattr(self.table.c, self.pkname)
534 | expr = self.table.update().values(**kwargs).where(pk_column == self.pk)
535 | await self.database.execute(expr)
536 |
537 | # Update the model instance.
538 | for key, value in kwargs.items():
539 | setattr(self, key, value)
540 |
541 | async def delete(self) -> None:
542 | pk_column = getattr(self.table.c, self.pkname)
543 | expr = self.table.delete().where(pk_column == self.pk)
544 |
545 | await self.database.execute(expr)
546 |
547 | async def load(self):
548 | # Build the select expression.
549 | pk_column = getattr(self.table.c, self.pkname)
550 | expr = self.table.select().where(pk_column == self.pk)
551 |
552 | # Perform the fetch.
553 | row = await self.database.fetch_one(expr)
554 |
555 | # Update the instance.
556 | for key, value in dict(row._mapping).items():
557 | setattr(self, key, value)
558 |
559 | @classmethod
560 | def _from_row(cls, row, select_related=[]):
561 | """
562 | Instantiate a model instance, given a database row.
563 | """
564 | item = {}
565 |
566 | # Instantiate any child instances first.
567 | for related in select_related:
568 | if "__" in related:
569 | first_part, remainder = related.split("__", 1)
570 | model_cls = cls.fields[first_part].target
571 | item[first_part] = model_cls._from_row(row, select_related=[remainder])
572 | else:
573 | model_cls = cls.fields[related].target
574 | item[related] = model_cls._from_row(row)
575 |
576 | # Pull out the regular column values.
577 | for column in cls.table.columns:
578 | if column.name not in item:
579 | item[column.name] = row[column]
580 |
581 | return cls(**item)
582 |
583 | def __setattr__(self, key, value):
584 | if key in self.fields:
585 | # Setting a relationship to a raw pk value should set a
586 | # fully-fledged relationship instance, with just the pk loaded.
587 | value = self.fields[key].expand_relationship(value)
588 | super().__setattr__(key, value)
589 |
590 | def __eq__(self, other):
591 | if self.__class__ != other.__class__:
592 | return False
593 | for key in self.fields.keys():
594 | if getattr(self, key, None) != getattr(other, key, None):
595 | return False
596 | return True
597 |
--------------------------------------------------------------------------------
/orm/sqlalchemy_fields.py:
--------------------------------------------------------------------------------
1 | import ipaddress
2 | import uuid
3 |
4 | import sqlalchemy
5 |
6 |
7 | class GUID(sqlalchemy.TypeDecorator):
8 | """
9 | Platform-independent GUID type.
10 |
11 | Uses PostgreSQL's UUID type, otherwise uses
12 | CHAR(32), storing as stringified hex values.
13 | """
14 |
15 | impl = sqlalchemy.CHAR
16 | cache_ok = True
17 |
18 | def load_dialect_impl(self, dialect):
19 | if dialect.name == "postgresql":
20 | return dialect.type_descriptor(sqlalchemy.dialects.postgresql.UUID())
21 | else:
22 | return dialect.type_descriptor(sqlalchemy.CHAR(32))
23 |
24 | def process_bind_param(self, value, dialect):
25 | if value is None:
26 | return value
27 |
28 | if dialect.name == "postgresql":
29 | return str(value)
30 | else:
31 | return value.hex
32 |
33 | def process_result_value(self, value, dialect):
34 | if value is None:
35 | return value
36 |
37 | if not isinstance(value, uuid.UUID):
38 | value = uuid.UUID(value)
39 | return value
40 |
41 |
42 | class GenericIP(sqlalchemy.TypeDecorator):
43 | """
44 | Platform-independent IP Address type.
45 |
46 | Uses PostgreSQL's INET type, otherwise uses
47 | CHAR(45), storing as stringified values.
48 | """
49 |
50 | impl = sqlalchemy.CHAR
51 | cache_ok = True
52 |
53 | def load_dialect_impl(self, dialect):
54 | if dialect.name == "postgresql":
55 | return dialect.type_descriptor(sqlalchemy.dialects.postgresql.INET())
56 | else:
57 | return dialect.type_descriptor(sqlalchemy.CHAR(45))
58 |
59 | def process_bind_param(self, value, dialect):
60 | if value is not None:
61 | return str(value)
62 |
63 | def process_result_value(self, value, dialect):
64 | if value is None:
65 | return value
66 |
67 | if not isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
68 | value = ipaddress.ip_address(value)
69 | return value
70 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | databases[postgresql, mysql, sqlite]
2 | typesystem
3 |
4 | # Packaging
5 | twine
6 | wheel
7 |
8 | # Testing
9 | anyio>=3.0.0,<4
10 | autoflake
11 | black
12 | codecov
13 | flake8
14 | isort
15 | mypy
16 | pytest
17 | pytest-cov
18 |
19 | # Documentation
20 | mkdocs
21 | mkdocs-material
22 | mkautodoc
23 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Development Scripts
2 |
3 | * `scripts/build` - Build package and documentation.
4 | * `scripts/check` - Check linting and formatting.
5 | * `scripts/clean` - Delete any build artifacts.
6 | * `scripts/coverage` - Check test coverage.
7 | * `scripts/docs` - Run documentation server locally.
8 | * `scripts/install` - Install dependencies in a virtual environment.
9 | * `scripts/lint` - Run the code linting.
10 | * `scripts/publish` - Publish the latest version to PyPI.
11 | `scripts/test` - Run the test suite.
12 |
13 | Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all).
14 |
--------------------------------------------------------------------------------
/scripts/build:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | if [ -d 'venv' ] ; then
4 | PREFIX="venv/bin/"
5 | else
6 | PREFIX=""
7 | fi
8 |
9 | set -x
10 |
11 | ${PREFIX}python setup.py sdist bdist_wheel
12 | ${PREFIX}twine check dist/*
13 | ${PREFIX}mkdocs build
14 |
--------------------------------------------------------------------------------
/scripts/check:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | export PREFIX=""
4 | if [ -d 'venv' ] ; then
5 | export PREFIX="venv/bin/"
6 | fi
7 | export SOURCE_FILES="orm tests"
8 |
9 | set -x
10 |
11 | ${PREFIX}isort --check --diff --project=orm $SOURCE_FILES
12 | ${PREFIX}black --check --diff $SOURCE_FILES
13 | ${PREFIX}flake8 $SOURCE_FILES
14 | # ${PREFIX}mypy $SOURCE_FILES
15 |
--------------------------------------------------------------------------------
/scripts/clean:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | PACKAGE="orm"
4 |
5 | if [ -d 'dist' ] ; then
6 | rm -r dist
7 | fi
8 | if [ -d 'site' ] ; then
9 | rm -r site
10 | fi
11 | if [ -d 'htmlcov' ] ; then
12 | rm -r htmlcov
13 | fi
14 | if [ -d "${PACKAGE}.egg-info" ] ; then
15 | rm -r "${PACKAGE}.egg-info"
16 | fi
17 |
18 | find ${PACKAGE} -type f -name "*.py[co]" -delete
19 | find ${PACKAGE} -type d -name __pycache__ -delete
20 |
--------------------------------------------------------------------------------
/scripts/coverage:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | export PREFIX=""
4 | if [ -d 'venv' ] ; then
5 | export PREFIX="venv/bin/"
6 | fi
7 |
8 | set -x
9 |
10 | ${PREFIX}coverage report --show-missing --skip-covered --fail-under=100
11 |
--------------------------------------------------------------------------------
/scripts/docs:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | export PREFIX=""
4 | if [ -d 'venv' ] ; then
5 | export PREFIX="venv/bin/"
6 | fi
7 |
8 | set -x
9 |
10 | ${PREFIX}mkdocs serve
11 |
--------------------------------------------------------------------------------
/scripts/install:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | # Use the Python executable provided from the `-p` option, or a default.
4 | [ "$1" = "-p" ] && PYTHON=$2 || PYTHON="python3"
5 |
6 | REQUIREMENTS="requirements.txt"
7 | VENV="venv"
8 |
9 | set -x
10 |
11 | if [ -z "$GITHUB_ACTIONS" ]; then
12 | "$PYTHON" -m venv "$VENV"
13 | PIP="$VENV/bin/pip"
14 | else
15 | PIP="pip"
16 | fi
17 |
18 | "$PIP" install -r "$REQUIREMENTS"
19 | "$PIP" install -e .
20 |
--------------------------------------------------------------------------------
/scripts/lint:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | export PREFIX=""
4 | if [ -d 'venv' ] ; then
5 | export PREFIX="venv/bin/"
6 | fi
7 | export SOURCE_FILES="orm tests"
8 |
9 | set -x
10 |
11 | ${PREFIX}autoflake --in-place --recursive $SOURCE_FILES
12 | ${PREFIX}isort --project=orm $SOURCE_FILES
13 | ${PREFIX}black $SOURCE_FILES
14 |
--------------------------------------------------------------------------------
/scripts/publish:
--------------------------------------------------------------------------------
1 | #!/bin/sh -e
2 |
3 | VERSION_FILE="orm/__init__.py"
4 |
5 | if [ -d 'venv' ] ; then
6 | PREFIX="venv/bin/"
7 | else
8 | PREFIX=""
9 | fi
10 |
11 | if [ ! -z "$GITHUB_ACTIONS" ]; then
12 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
13 | git config --local user.name "GitHub Action"
14 |
15 | VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'`
16 |
17 | if [ "refs/tags/${VERSION}" != "${GITHUB_REF}" ] ; then
18 | echo "GitHub Ref '${GITHUB_REF}' did not match package version '${VERSION}'"
19 | exit 1
20 | fi
21 | fi
22 |
23 | set -x
24 |
25 | ${PREFIX}twine upload dist/*
26 | ${PREFIX}mkdocs gh-deploy --force
27 |
--------------------------------------------------------------------------------
/scripts/test:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | export PREFIX=""
4 | if [ -d 'venv' ] ; then
5 | export PREFIX="venv/bin/"
6 | fi
7 |
8 | set -ex
9 |
10 | if [ -z $GITHUB_ACTIONS ]; then
11 | scripts/check
12 | fi
13 |
14 | ${PREFIX}coverage run -a -m pytest $@
15 |
16 | if [ -z $GITHUB_ACTIONS ]; then
17 | scripts/coverage
18 | fi
19 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = W503, E203, B305
3 | max-line-length = 88
4 |
5 | [mypy]
6 | disallow_untyped_defs = True
7 | ignore_missing_imports = True
8 |
9 | [tool:isort]
10 | profile = black
11 | combine_as_imports = True
12 |
13 | [tool:pytest]
14 | addopts =
15 | -rxXs
16 | --strict-config
17 | --strict-markers
18 | xfail_strict=True
19 | filterwarnings=
20 | # Turn warnings that aren't filtered into exceptions
21 | error
22 | ignore::DeprecationWarning
23 | ignore::sqlalchemy.exc.SAWarning
24 |
25 | [coverage:run]
26 | source_pkgs = orm, tests
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import re
6 |
7 | from setuptools import setup
8 |
9 |
10 | PACKAGE = "orm"
11 | URL = "https://github.com/encode/orm"
12 |
13 |
14 | def get_version(package):
15 | """
16 | Return package version as listed in `__version__` in `init.py`.
17 | """
18 | with open(os.path.join(package, "__init__.py")) as f:
19 | return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1)
20 |
21 |
22 | def get_long_description():
23 | """
24 | Return the README.
25 | """
26 | with open("README.md", encoding="utf8") as f:
27 | return f.read()
28 |
29 |
30 | def get_packages(package):
31 | """
32 | Return root package and all sub-packages.
33 | """
34 | return [
35 | dirpath
36 | for dirpath, dirnames, filenames in os.walk(package)
37 | if os.path.exists(os.path.join(dirpath, "__init__.py"))
38 | ]
39 |
40 |
41 | setup(
42 | name=PACKAGE,
43 | version=get_version(PACKAGE),
44 | url=URL,
45 | license="BSD",
46 | description="An async ORM.",
47 | long_description=get_long_description(),
48 | long_description_content_type="text/markdown",
49 | author="Tom Christie",
50 | author_email="tom@tomchristie.com",
51 | packages=get_packages(PACKAGE),
52 | package_data={PACKAGE: ["py.typed"]},
53 | install_requires=["databases~=0.5", "typesystem==0.3.1"],
54 | extras_require={
55 | "postgresql": ["asyncpg"],
56 | "mysql": ["aiomysql"],
57 | "sqlite": ["aiosqlite"],
58 | "postgresql+aiopg": ["aiopg"],
59 | },
60 | classifiers=[
61 | "Development Status :: 3 - Alpha",
62 | "Environment :: Web Environment",
63 | "Intended Audience :: Developers",
64 | "License :: OSI Approved :: BSD License",
65 | "Operating System :: OS Independent",
66 | "Topic :: Internet :: WWW/HTTP",
67 | "Programming Language :: Python :: 3",
68 | "Programming Language :: Python :: 3.7",
69 | "Programming Language :: Python :: 3.8",
70 | "Programming Language :: Python :: 3.9",
71 | "Programming Language :: Python :: 3.10",
72 | "Programming Language :: Python :: 3 :: Only",
73 | ],
74 | )
75 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.fixture(scope="module")
5 | def anyio_backend():
6 | return ("asyncio", {"debug": True})
7 |
--------------------------------------------------------------------------------
/tests/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | assert "TEST_DATABASE_URL" in os.environ, "TEST_DATABASE_URL is not set."
4 |
5 | DATABASE_URL = os.environ["TEST_DATABASE_URL"]
6 |
--------------------------------------------------------------------------------
/tests/test_columns.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import decimal
3 | import ipaddress
4 | import uuid
5 | from enum import Enum
6 |
7 | import databases
8 | import pytest
9 |
10 | import orm
11 | from tests.settings import DATABASE_URL
12 |
13 | pytestmark = pytest.mark.anyio
14 |
15 | database = databases.Database(DATABASE_URL)
16 | models = orm.ModelRegistry(database=database)
17 |
18 |
19 | def time():
20 | return datetime.datetime.now().time()
21 |
22 |
23 | class StatusEnum(Enum):
24 | DRAFT = "Draft"
25 | RELEASED = "Released"
26 |
27 |
28 | class Product(orm.Model):
29 | registry = models
30 | fields = {
31 | "id": orm.Integer(primary_key=True),
32 | "uuid": orm.UUID(allow_null=True),
33 | "created": orm.DateTime(default=datetime.datetime.now),
34 | "created_day": orm.Date(default=datetime.date.today),
35 | "created_time": orm.Time(default=time),
36 | "created_date": orm.Date(auto_now_add=True),
37 | "created_datetime": orm.DateTime(auto_now_add=True),
38 | "updated_datetime": orm.DateTime(auto_now=True),
39 | "updated_date": orm.Date(auto_now=True),
40 | "data": orm.JSON(default={}),
41 | "description": orm.Text(allow_blank=True),
42 | "huge_number": orm.BigInteger(default=0),
43 | "price": orm.Decimal(max_digits=5, decimal_places=2, allow_null=True),
44 | "status": orm.Enum(StatusEnum, default=StatusEnum.DRAFT),
45 | "value": orm.Float(allow_null=True),
46 | }
47 |
48 |
49 | class User(orm.Model):
50 | registry = models
51 | fields = {
52 | "id": orm.UUID(primary_key=True, default=uuid.uuid4),
53 | "name": orm.String(allow_null=True, max_length=16),
54 | "email": orm.Email(allow_null=True, max_length=256),
55 | "ipaddress": orm.IPAddress(allow_null=True),
56 | "url": orm.URL(allow_null=True, max_length=2048),
57 | }
58 |
59 |
60 | @pytest.fixture(autouse=True, scope="module")
61 | async def create_test_database():
62 | await models.create_all()
63 | yield
64 | await models.drop_all()
65 |
66 |
67 | @pytest.fixture(autouse=True)
68 | async def rollback_transactions():
69 | with database.force_rollback():
70 | async with database:
71 | yield
72 |
73 |
74 | async def test_model_crud():
75 | product = await Product.objects.create()
76 | product = await Product.objects.get(pk=product.pk)
77 | assert product.created.year == datetime.datetime.now().year
78 | assert product.created_day == datetime.date.today()
79 | assert product.created_date == datetime.date.today()
80 | assert product.created_datetime.date() == datetime.datetime.now().date()
81 | assert product.updated_date == datetime.date.today()
82 | assert product.updated_datetime.date() == datetime.datetime.now().date()
83 | assert product.data == {}
84 | assert product.description == ""
85 | assert product.huge_number == 0
86 | assert product.price is None
87 | assert product.status == StatusEnum.DRAFT
88 | assert product.value is None
89 | assert product.uuid is None
90 |
91 | await product.update(
92 | data={"foo": 123},
93 | value=123.456,
94 | status=StatusEnum.RELEASED,
95 | price=decimal.Decimal("999.99"),
96 | uuid=uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b"),
97 | )
98 |
99 | product = await Product.objects.get()
100 | assert product.value == 123.456
101 | assert product.data == {"foo": 123}
102 | assert product.status == StatusEnum.RELEASED
103 | assert product.price == decimal.Decimal("999.99")
104 | assert product.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b")
105 |
106 | last_updated_datetime = product.updated_datetime
107 | last_updated_date = product.updated_date
108 | user = await User.objects.create()
109 | assert isinstance(user.pk, uuid.UUID)
110 |
111 | user = await User.objects.get()
112 | assert user.email is None
113 | assert user.ipaddress is None
114 | assert user.url is None
115 |
116 | await user.update(
117 | ipaddress="192.168.1.1",
118 | name="Chris",
119 | email="chirs@encode.io",
120 | url="https://encode.io",
121 | )
122 |
123 | user = await User.objects.get()
124 | assert isinstance(user.ipaddress, (ipaddress.IPv4Address, ipaddress.IPv6Address))
125 | assert user.url == "https://encode.io"
126 | # Test auto_now update
127 | await product.update(
128 | data={"foo": 1234},
129 | )
130 | assert product.updated_datetime != last_updated_datetime
131 | assert product.updated_date == last_updated_date
132 |
133 |
134 | async def test_both_auto_now_and_auto_now_add_raise_error():
135 | with pytest.raises(ValueError):
136 |
137 | class Product(orm.Model):
138 | registry = models
139 | fields = {
140 | "id": orm.Integer(primary_key=True),
141 | "created_datetime": orm.DateTime(auto_now_add=True, auto_now=True),
142 | }
143 |
144 | await Product.objects.create()
145 |
146 |
147 | async def test_bulk_create():
148 | await Product.objects.bulk_create(
149 | [
150 | {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
151 | {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
152 | ]
153 | )
154 | products = await Product.objects.all()
155 | assert len(products) == 2
156 | assert products[0].data == {"foo": 123}
157 | assert products[0].value == 123.456
158 | assert products[0].status == StatusEnum.RELEASED
159 | assert products[1].data == {"foo": 456}
160 | assert products[1].value == 456.789
161 | assert products[1].status == StatusEnum.DRAFT
162 |
--------------------------------------------------------------------------------
/tests/test_foreignkey.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 |
3 | import asyncpg
4 | import databases
5 | import pymysql
6 | import pytest
7 |
8 | import orm
9 | from tests.settings import DATABASE_URL
10 |
11 | pytestmark = pytest.mark.anyio
12 |
13 | database = databases.Database(DATABASE_URL)
14 | models = orm.ModelRegistry(database=database)
15 |
16 |
17 | class Album(orm.Model):
18 | registry = models
19 | fields = {
20 | "id": orm.Integer(primary_key=True),
21 | "name": orm.String(max_length=100),
22 | }
23 |
24 |
25 | class Track(orm.Model):
26 | registry = models
27 | fields = {
28 | "id": orm.Integer(primary_key=True),
29 | "album": orm.ForeignKey("Album", on_delete=orm.CASCADE),
30 | "title": orm.String(max_length=100),
31 | "position": orm.Integer(),
32 | }
33 |
34 |
35 | class Organisation(orm.Model):
36 | registry = models
37 | fields = {
38 | "id": orm.Integer(primary_key=True),
39 | "ident": orm.String(max_length=100),
40 | }
41 |
42 |
43 | class Team(orm.Model):
44 | registry = models
45 | fields = {
46 | "id": orm.Integer(primary_key=True),
47 | "org": orm.ForeignKey(Organisation, on_delete=orm.RESTRICT),
48 | "name": orm.String(max_length=100),
49 | }
50 |
51 |
52 | class Member(orm.Model):
53 | registry = models
54 | fields = {
55 | "id": orm.Integer(primary_key=True),
56 | "team": orm.ForeignKey(Team, on_delete=orm.SET_NULL, allow_null=True),
57 | "email": orm.String(max_length=100),
58 | }
59 |
60 |
61 | class Profile(orm.Model):
62 | registry = models
63 | fields = {
64 | "id": orm.Integer(primary_key=True),
65 | "website": orm.String(max_length=100),
66 | }
67 |
68 |
69 | class Person(orm.Model):
70 | registry = models
71 | fields = {
72 | "id": orm.Integer(primary_key=True),
73 | "email": orm.String(max_length=100),
74 | "profile": orm.OneToOne(Profile),
75 | }
76 |
77 |
78 | @pytest.fixture(autouse=True, scope="module")
79 | async def create_test_database():
80 | await models.create_all()
81 | yield
82 | await models.drop_all()
83 |
84 |
85 | @pytest.fixture(autouse=True)
86 | async def rollback_connections():
87 | with database.force_rollback():
88 | async with database:
89 | yield
90 |
91 |
92 | async def test_model_crud():
93 | album = await Album.objects.create(name="Malibu")
94 | await Track.objects.create(album=album, title="The Bird", position=1)
95 | await Track.objects.create(
96 | album=album, title="Heart don't stand a chance", position=2
97 | )
98 | await Track.objects.create(album=album, title="The Waters", position=3)
99 |
100 | track = await Track.objects.get(title="The Bird")
101 | assert track.album.pk == album.pk
102 | assert not hasattr(track.album, "name")
103 | await track.album.load()
104 | assert track.album.name == "Malibu"
105 |
106 |
107 | async def test_select_related():
108 | album = await Album.objects.create(name="Malibu")
109 | await Track.objects.create(album=album, title="The Bird", position=1)
110 | await Track.objects.create(
111 | album=album, title="Heart don't stand a chance", position=2
112 | )
113 | await Track.objects.create(album=album, title="The Waters", position=3)
114 |
115 | fantasies = await Album.objects.create(name="Fantasies")
116 | await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
117 | await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
118 | await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
119 |
120 | track = await Track.objects.select_related("album").get(title="The Bird")
121 | assert track.album.name == "Malibu"
122 |
123 | tracks = await Track.objects.select_related("album").all()
124 | assert len(tracks) == 6
125 |
126 |
127 | async def test_fk_filter():
128 | malibu = await Album.objects.create(name="Malibu")
129 | await Track.objects.create(album=malibu, title="The Bird", position=1)
130 | await Track.objects.create(
131 | album=malibu, title="Heart don't stand a chance", position=2
132 | )
133 | await Track.objects.create(album=malibu, title="The Waters", position=3)
134 |
135 | fantasies = await Album.objects.create(name="Fantasies")
136 | await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
137 | await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
138 | await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
139 |
140 | tracks = (
141 | await Track.objects.select_related("album")
142 | .filter(album__name="Fantasies")
143 | .all()
144 | )
145 | assert len(tracks) == 3
146 | for track in tracks:
147 | assert track.album.name == "Fantasies"
148 |
149 | tracks = (
150 | await Track.objects.select_related("album")
151 | .filter(album__name__icontains="fan")
152 | .all()
153 | )
154 | assert len(tracks) == 3
155 | for track in tracks:
156 | assert track.album.name == "Fantasies"
157 |
158 | tracks = await Track.objects.filter(album__name__icontains="fan").all()
159 | assert len(tracks) == 3
160 | for track in tracks:
161 | assert track.album.name == "Fantasies"
162 |
163 | tracks = await Track.objects.filter(album=malibu).select_related("album").all()
164 | assert len(tracks) == 3
165 | for track in tracks:
166 | assert track.album.name == "Malibu"
167 |
168 |
169 | async def test_multiple_fk():
170 | acme = await Organisation.objects.create(ident="ACME Ltd")
171 | red_team = await Team.objects.create(org=acme, name="Red Team")
172 | blue_team = await Team.objects.create(org=acme, name="Blue Team")
173 | await Member.objects.create(team=red_team, email="a@example.org")
174 | await Member.objects.create(team=red_team, email="b@example.org")
175 | await Member.objects.create(team=blue_team, email="c@example.org")
176 | await Member.objects.create(team=blue_team, email="d@example.org")
177 |
178 | other = await Organisation.objects.create(ident="Other ltd")
179 | team = await Team.objects.create(org=other, name="Green Team")
180 | await Member.objects.create(team=team, email="e@example.org")
181 |
182 | members = (
183 | await Member.objects.select_related("team__org")
184 | .filter(team__org__ident="ACME Ltd")
185 | .all()
186 | )
187 | assert len(members) == 4
188 | for member in members:
189 | assert member.team.org.ident == "ACME Ltd"
190 |
191 |
192 | async def test_queryset_delete_with_fk():
193 | malibu = await Album.objects.create(name="Malibu")
194 | await Track.objects.create(album=malibu, title="The Bird", position=1)
195 |
196 | wall = await Album.objects.create(name="The Wall")
197 | await Track.objects.create(album=wall, title="The Wall", position=1)
198 |
199 | await Track.objects.filter(album=malibu).delete()
200 | assert await Track.objects.filter(album=malibu).count() == 0
201 | assert await Track.objects.filter(album=wall).count() == 1
202 |
203 |
204 | async def test_queryset_update_with_fk():
205 | malibu = await Album.objects.create(name="Malibu")
206 | wall = await Album.objects.create(name="The Wall")
207 | await Track.objects.create(album=malibu, title="The Bird", position=1)
208 |
209 | await Track.objects.filter(album=malibu).update(album=wall)
210 | assert await Track.objects.filter(album=malibu).count() == 0
211 | assert await Track.objects.filter(album=wall).count() == 1
212 |
213 |
214 | @pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite")
215 | async def test_on_delete_cascade():
216 | album = await Album.objects.create(name="The Wall")
217 | await Track.objects.create(album=album, title="Hey You", position=1)
218 | await Track.objects.create(album=album, title="Breathe", position=2)
219 |
220 | assert await Track.objects.count() == 2
221 |
222 | await album.delete()
223 |
224 | assert await Track.objects.count() == 0
225 |
226 |
227 | @pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite")
228 | async def test_on_delete_retstrict():
229 | organisation = await Organisation.objects.create(ident="Encode")
230 | await Team.objects.create(org=organisation, name="Maintainers")
231 |
232 | exceptions = (
233 | asyncpg.exceptions.ForeignKeyViolationError,
234 | pymysql.err.IntegrityError,
235 | )
236 |
237 | with pytest.raises(exceptions):
238 | await organisation.delete()
239 |
240 |
241 | @pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite")
242 | async def test_on_delete_set_null():
243 | organisation = await Organisation.objects.create(ident="Encode")
244 | team = await Team.objects.create(org=organisation, name="Maintainers")
245 | await Member.objects.create(email="member@encode.io", team=team)
246 |
247 | await team.delete()
248 |
249 | member = await Member.objects.first()
250 | assert member.team.pk is None
251 |
252 |
253 | async def test_one_to_one_crud():
254 | profile = await Profile.objects.create(website="https://encode.io")
255 | await Person.objects.create(email="info@encode.io", profile=profile)
256 |
257 | person = await Person.objects.get(email="info@encode.io")
258 | assert person.profile.pk == profile.pk
259 | assert not hasattr(person.profile, "website")
260 |
261 | await person.profile.load()
262 | assert person.profile.website == "https://encode.io"
263 |
264 | exceptions = (
265 | asyncpg.exceptions.UniqueViolationError,
266 | pymysql.err.IntegrityError,
267 | sqlite3.IntegrityError,
268 | )
269 |
270 | with pytest.raises(exceptions):
271 | await Person.objects.create(email="contact@encode.io", profile=profile)
272 |
273 |
274 | async def test_nullable_foreign_key():
275 | await Member.objects.create(email="dev@encode.io")
276 |
277 | member = await Member.objects.get()
278 |
279 | assert member.email == "dev@encode.io"
280 | assert member.team.pk is None
281 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import databases
2 | import pytest
3 | import typesystem
4 |
5 | import orm
6 | from tests.settings import DATABASE_URL
7 |
8 | pytestmark = pytest.mark.anyio
9 |
10 | database = databases.Database(DATABASE_URL)
11 | models = orm.ModelRegistry(database=database)
12 |
13 |
14 | class User(orm.Model):
15 | tablename = "users"
16 | registry = models
17 | fields = {
18 | "id": orm.Integer(primary_key=True),
19 | "name": orm.String(max_length=100),
20 | "language": orm.String(max_length=100, allow_null=True),
21 | }
22 |
23 |
24 | class Product(orm.Model):
25 | tablename = "products"
26 | registry = models
27 | fields = {
28 | "id": orm.Integer(primary_key=True),
29 | "name": orm.String(max_length=100),
30 | "rating": orm.Integer(minimum=1, maximum=5),
31 | "in_stock": orm.Boolean(default=False),
32 | }
33 |
34 |
35 | @pytest.fixture(autouse=True, scope="function")
36 | async def create_test_database():
37 | await models.create_all()
38 | yield
39 | await models.drop_all()
40 |
41 |
42 | @pytest.fixture(autouse=True)
43 | async def rollback_connections():
44 | with database.force_rollback():
45 | async with database:
46 | yield
47 |
48 |
49 | def test_model_class():
50 | assert list(User.fields.keys()) == ["id", "name", "language"]
51 | assert isinstance(User.fields["id"], orm.Integer)
52 | assert User.fields["id"].primary_key is True
53 | assert isinstance(User.fields["name"], orm.String)
54 | assert User.fields["name"].validator.max_length == 100
55 |
56 | with pytest.raises(ValueError):
57 | User(invalid="123")
58 |
59 | assert User(id=1) != Product(id=1)
60 | assert User(id=1) != User(id=2)
61 | assert User(id=1) == User(id=1)
62 |
63 | assert str(User(id=1)) == "User(id=1)"
64 | assert repr(User(id=1)) == ""
65 |
66 | assert isinstance(User.objects.schema.fields["id"], typesystem.Integer)
67 | assert isinstance(User.objects.schema.fields["name"], typesystem.String)
68 |
69 |
70 | def test_model_pk():
71 | user = User(pk=1)
72 | assert user.pk == 1
73 | assert user.id == 1
74 | assert User.objects.pkname == "id"
75 |
76 |
77 | async def test_model_crud():
78 | users = await User.objects.all()
79 | assert users == []
80 |
81 | user = await User.objects.create(name="Tom")
82 | users = await User.objects.all()
83 | assert user.name == "Tom"
84 | assert user.pk is not None
85 | assert users == [user]
86 |
87 | lookup = await User.objects.get()
88 | assert lookup == user
89 |
90 | await user.update(name="Jane")
91 | users = await User.objects.all()
92 | assert user.name == "Jane"
93 | assert user.pk is not None
94 | assert users == [user]
95 |
96 | await user.delete()
97 | users = await User.objects.all()
98 | assert users == []
99 |
100 |
101 | async def test_model_get():
102 | with pytest.raises(orm.NoMatch):
103 | await User.objects.get()
104 |
105 | user = await User.objects.create(name="Tom")
106 | lookup = await User.objects.get()
107 | assert lookup == user
108 |
109 | user = await User.objects.create(name="Jane")
110 | with pytest.raises(orm.MultipleMatches):
111 | await User.objects.get()
112 |
113 | same_user = await User.objects.get(pk=user.id)
114 | assert same_user.id == user.id
115 | assert same_user.pk == user.pk
116 |
117 |
118 | async def test_model_filter():
119 | await User.objects.create(name="Tom")
120 | await User.objects.create(name="Jane")
121 | await User.objects.create(name="Lucy")
122 |
123 | user = await User.objects.get(name="Lucy")
124 | assert user.name == "Lucy"
125 |
126 | with pytest.raises(orm.NoMatch):
127 | await User.objects.get(name="Jim")
128 |
129 | await Product.objects.create(name="T-Shirt", rating=5, in_stock=True)
130 | await Product.objects.create(name="Dress", rating=4)
131 | await Product.objects.create(name="Coat", rating=3, in_stock=True)
132 |
133 | product = await Product.objects.get(name__iexact="t-shirt", rating=5)
134 | assert product.pk is not None
135 | assert product.name == "T-Shirt"
136 | assert product.rating == 5
137 |
138 | products = await Product.objects.all(rating__gte=2, in_stock=True)
139 | assert len(products) == 2
140 |
141 | products = await Product.objects.all(name__icontains="T")
142 | assert len(products) == 2
143 |
144 | # Test escaping % character from icontains, contains, and iexact
145 | await Product.objects.create(name="100%-Cotton", rating=3)
146 | await Product.objects.create(name="Cotton-100%-Egyptian", rating=3)
147 | await Product.objects.create(name="Cotton-100%", rating=3)
148 | products = Product.objects.filter(name__iexact="100%-cotton")
149 | assert await products.count() == 1
150 |
151 | products = Product.objects.filter(name__contains="%")
152 | assert await products.count() == 3
153 |
154 | products = Product.objects.filter(name__icontains="%")
155 | assert await products.count() == 3
156 |
157 | products = Product.objects.exclude(name__iexact="100%-cotton")
158 | assert await products.count() == 5
159 |
160 | products = Product.objects.exclude(name__contains="%")
161 | assert await products.count() == 3
162 |
163 | products = Product.objects.exclude(name__icontains="%")
164 | assert await products.count() == 3
165 |
166 |
167 | async def test_model_order_by():
168 | await User.objects.create(name="Bob")
169 | await User.objects.create(name="Allen")
170 | await User.objects.create(name="Bob")
171 |
172 | users = await User.objects.order_by("name").all()
173 | assert users[0].name == "Allen"
174 | assert users[1].name == "Bob"
175 |
176 | users = await User.objects.order_by("-name").all()
177 | assert users[1].name == "Bob"
178 | assert users[2].name == "Allen"
179 |
180 | users = await User.objects.order_by("name", "-id").all()
181 | assert users[0].name == "Allen"
182 | assert users[0].id == 2
183 | assert users[1].name == "Bob"
184 | assert users[1].id == 3
185 |
186 | users = await User.objects.filter(name="Bob").order_by("-id").all()
187 | assert users[0].name == "Bob"
188 | assert users[0].id == 3
189 | assert users[1].name == "Bob"
190 | assert users[1].id == 1
191 |
192 | users = await User.objects.order_by("id").limit(1).all()
193 | assert users[0].name == "Bob"
194 | assert users[0].id == 1
195 |
196 | users = await User.objects.order_by("id").limit(1).offset(1).all()
197 | assert users[0].name == "Allen"
198 | assert users[0].id == 2
199 |
200 |
201 | async def test_model_exists():
202 | await User.objects.create(name="Tom")
203 | assert await User.objects.filter(name="Tom").exists() is True
204 | assert await User.objects.filter(name="Jane").exists() is False
205 |
206 |
207 | async def test_model_count():
208 | await User.objects.create(name="Tom")
209 | await User.objects.create(name="Jane")
210 | await User.objects.create(name="Lucy")
211 |
212 | assert await User.objects.count() == 3
213 | assert await User.objects.filter(name__icontains="T").count() == 1
214 |
215 |
216 | async def test_model_limit():
217 | await User.objects.create(name="Tom")
218 | await User.objects.create(name="Jane")
219 | await User.objects.create(name="Lucy")
220 |
221 | assert len(await User.objects.limit(2).all()) == 2
222 |
223 |
224 | async def test_model_limit_with_filter():
225 | await User.objects.create(name="Tom")
226 | await User.objects.create(name="Tom")
227 | await User.objects.create(name="Tom")
228 |
229 | assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2
230 |
231 |
232 | async def test_offset():
233 | await User.objects.create(name="Tom")
234 | await User.objects.create(name="Jane")
235 |
236 | users = await User.objects.offset(1).limit(1).all()
237 | assert users[0].name == "Jane"
238 |
239 |
240 | async def test_model_first():
241 | tom = await User.objects.create(name="Tom")
242 | jane = await User.objects.create(name="Jane")
243 |
244 | assert await User.objects.first() == tom
245 | assert await User.objects.first(name="Jane") == jane
246 | assert await User.objects.filter(name="Jane").first() == jane
247 | assert await User.objects.filter(name="Lucy").first() is None
248 |
249 |
250 | async def test_model_search():
251 | tom = await User.objects.create(name="Tom", language="English")
252 | tshirt = await Product.objects.create(name="T-Shirt", rating=5)
253 |
254 | assert await User.objects.search(term="").first() == tom
255 | assert await User.objects.search(term="tom").first() == tom
256 | assert await Product.objects.search(term="shirt").first() == tshirt
257 |
258 |
259 | async def test_model_get_or_create():
260 | user, created = await User.objects.get_or_create(
261 | name="Tom", defaults={"language": "Spanish"}
262 | )
263 | assert created is True
264 | assert user.name == "Tom"
265 | assert user.language == "Spanish"
266 |
267 | user, created = await User.objects.get_or_create(
268 | name="Tom", defaults={"language": "English"}
269 | )
270 | assert created is False
271 | assert user.name == "Tom"
272 | assert user.language == "Spanish"
273 |
274 |
275 | async def test_queryset_delete():
276 | shirt = await Product.objects.create(name="Shirt", rating=5)
277 | await Product.objects.create(name="Belt", rating=5)
278 | await Product.objects.create(name="Tie", rating=5)
279 |
280 | await Product.objects.filter(pk=shirt.id).delete()
281 | assert await Product.objects.count() == 2
282 |
283 | await Product.objects.delete()
284 | assert await Product.objects.count() == 0
285 |
286 |
287 | async def test_queryset_update():
288 | shirt = await Product.objects.create(name="Shirt", rating=5)
289 | tie = await Product.objects.create(name="Tie", rating=5)
290 |
291 | await Product.objects.filter(pk=shirt.id).update(rating=3)
292 | shirt = await Product.objects.get(pk=shirt.id)
293 | assert shirt.rating == 3
294 | assert await Product.objects.get(pk=tie.id) == tie
295 |
296 | await Product.objects.update(rating=3)
297 | tie = await Product.objects.get(pk=tie.id)
298 | assert tie.rating == 3
299 |
300 |
301 | async def test_model_update_or_create():
302 | user, created = await User.objects.update_or_create(
303 | name="Tom", language="English", defaults={"name": "Jane"}
304 | )
305 | assert created is True
306 | assert user.name == "Jane"
307 | assert user.language == "English"
308 |
309 | user, created = await User.objects.update_or_create(
310 | name="Jane", language="English", defaults={"name": "Tom"}
311 | )
312 | assert created is False
313 | assert user.name == "Tom"
314 | assert user.language == "English"
315 |
316 |
317 | async def test_model_sqlalchemy_filter_operators():
318 | user = await User.objects.create(name="George")
319 |
320 | assert user == await User.objects.filter(User.columns.name == "George").get()
321 | assert user == await User.objects.filter(User.columns.name.is_not(None)).get()
322 | assert (
323 | user
324 | == await User.objects.filter(User.columns.name.startswith("G"))
325 | .filter(User.columns.name.endswith("e"))
326 | .get()
327 | )
328 |
329 | assert user == await User.objects.exclude(User.columns.name != "Jack").get()
330 |
331 | shirt = await Product.objects.create(name="100%-Cotton", rating=3)
332 | assert (
333 | shirt
334 | == await Product.objects.filter(Product.columns.name.contains("Cotton")).get()
335 | )
336 |
--------------------------------------------------------------------------------