├── .envrc ├── .github └── workflows │ ├── pypi.yml │ └── tests.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docker-compose.yml ├── misc └── _doc │ ├── README.md.j2 │ └── README.py ├── mongosql ├── __init__.py ├── bag.py ├── crud │ ├── __init__.py │ ├── crudhelper.py │ └── crudview.py ├── exc.py ├── handlers │ ├── __init__.py │ ├── aggregate.py │ ├── base.py │ ├── count.py │ ├── filter.py │ ├── group.py │ ├── join.py │ ├── joinf.py │ ├── limit.py │ ├── project.py │ └── sort.py ├── query.py ├── sa.py └── util │ ├── __init__.py │ ├── bulk.py │ ├── counting_query_wrapper.py │ ├── history_proxy.py │ ├── inspect.py │ ├── marker.py │ ├── method_decorator.py │ ├── mongoquery_settings_handler.py │ ├── reusable.py │ ├── selectinquery.py │ └── settings_dict.py ├── myproject └── __init__.py ├── noxfile.py ├── pyproject.toml └── tests ├── __init__.py ├── benchmarks ├── .gitignore ├── __init__.py ├── benchmark_CountingQuery.py ├── benchmark_compare_orm_overhead_with_pure_jsonb_output.py ├── benchmark_one_query.py ├── benchmark_selectinquery.py ├── benchmark_utils.py ├── benchmark_v2_vs_v1.py └── mongosql_v1_checkout.sh ├── conftest.py ├── crud_view.py ├── models.py ├── saversion.py ├── t1_bags_test.py ├── t2_handlers_test.py ├── t3_statements_test.py ├── t4_query_test.py ├── t5_crud_test.py ├── t_method_decorator_test.py ├── t_modelhistoryproxy_test.py ├── t_raiseload_col_test.py ├── t_selectinquery_test.py └── util.py /.envrc: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # direnv 4 | # Will automatically load .envrc if you enter the directory and load your dev tools 5 | 6 | # Automatically activate poetry virtualenv 7 | if [[ -f "pyproject.toml" ]]; then 8 | # create venv if it doesn't exist; then print the path to this virtualenv 9 | export VIRTUAL_ENV=$(poetry run true && poetry env info --path) 10 | if [[ "$VIRTUAL_ENV" != "" ]] ; then 11 | export POETRY_ACTIVE=1 12 | PATH_add "$VIRTUAL_ENV/bin" 13 | echo "Activated Poetry virtualenv: $(basename "$VIRTUAL_ENV")" 14 | fi 15 | fi 16 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | on: 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | release: 5 | types: [ published ] 6 | jobs: 7 | publish: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/setup-python@v2 11 | - uses: actions/checkout@v2 12 | - run: pip install poetry 13 | - run: poetry publish --build -u kolypto -p ${{ secrets.PYPI_PASSWORD }} 14 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | pull_request: 5 | types: [ opened, edited, reopened ] 6 | release: 7 | types: [ prereleased, released ] 8 | push: 9 | branches: [ master, development ] 10 | # Trigger this workflow manually from the `actions` page 11 | workflow_dispatch: 12 | inputs: 13 | git-ref: 14 | description: Git Commit or Branch (Optional) 15 | required: false 16 | jobs: 17 | tests: 18 | runs-on: ubuntu-latest 19 | #container: 20 | 21 | strategy: 22 | matrix: 23 | python-version: ['3.x'] 24 | 25 | name: ${{ github.ref }}, Python ${{ matrix.python-version }} 26 | steps: 27 | - uses: actions/setup-python@v2 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - uses: actions/checkout@v2 31 | - uses: actions/cache@v2 32 | with: 33 | path: ./.nox/ # what we cache: nox virtualenv (they're expensive) 34 | key: ${{ runner.os }}-nox-2-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml', '**/noxfile.py') }} # cache key 35 | restore-keys: | 36 | ${{ runner.os }}-nox-2-${{ matrix.python-version }}- 37 | - run: pip install nox 38 | - run: nox --report nox-report.json 39 | 40 | services: 41 | postgres: 42 | image: postgres 43 | ports: 44 | - 5432:5432 45 | env: 46 | POSTGRES_USER: postgres 47 | POSTGRES_PASSWORD: postgres 48 | POSTGRES_DB: test_mongosql 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ===[ APP ]=== # 2 | # it's a library; no lockfile should be committed 3 | /poetry.lock 4 | 5 | # ===[ PYTHON ]=== # 6 | 7 | # Env 8 | /.python-version 9 | 10 | # Package 11 | /build/ 12 | /dist/ 13 | /*.egg-info/ 14 | /*.egg/ 15 | 16 | # Tests 17 | /.tox/ 18 | /.nox/ 19 | /.pytest_cache 20 | /.coverage 21 | /.noseids 22 | /profile.* 23 | 24 | # Generated 25 | __pycache__ 26 | *.py[cod] 27 | 28 | 29 | 30 | # ===[ COMMON ]=== # 31 | 32 | # IDE Projects 33 | /.idea 34 | /.*project 35 | /.settings 36 | /.vscode 37 | 38 | # Temps 39 | *~ 40 | *.tmp 41 | *.bak 42 | *.swp 43 | *.kate-swp 44 | *.DS_Store 45 | Thumbs.db 46 | 47 | # Generated 48 | *.pot 49 | *.mo 50 | 51 | # Runtime 52 | *.log 53 | 54 | # But ... 55 | !.gitkeep 56 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 2.0.15 (2021-04-23) 2 | * Added support for `column_property()` 3 | * `nplus1loader` is not an optional dependency. Install it if you want `raiseload_col()` 4 | * Tested with SqlAlchemy 1.3.21-23 and Python 3.9 5 | * Migrated to Poetry 6 | 7 | 8 | ## 2.0.14 (2020-12-03) 9 | * Proper handling of deferred columns: projections overwrite them 10 | 11 | ## 2.0.12 (2020-11-11) 12 | * Fixed a bug with joining self-referential relationships 13 | * Tested with SqlAlchemy 1.3.20 14 | 15 | ## 2.0.11 (2020-04-15) 16 | * Removed the `raiseload_col()` and imported the implementation from `nplus1loader`. 17 | That's a new dependency. 18 | * Tested with SqlAlchemy 1.3.16 19 | 20 | ## 2.0.10 (2020-01-29) 21 | * Tested with SqlAlchemy 1.3.12-15 22 | * Fixed an issue with MongoSQL not detecting `TypeDecorator()`-wrapped JSON columns as JSON 23 | 24 | ## 2.0.9 (2019-11-27) 25 | * Support string input for: `project`, `sort`, `group`, `join`, `joinf`. 26 | Previously, they only accepted arrays or objects. 27 | Now, they also accept strings: whitespace-separated as a list of columns. 28 | * Tested with SqlAlchemy 1.3.11 29 | 30 | ## 2.0.8 (2019-10-27) 31 | * `CrudView._method_create_or_update_many()`: saving many objects at once 32 | * `MongoQuery.options()`: `no_limit_offset=True` lets you disable limits & offsets for a specific query 33 | * Tested with SqlAlchemy 1.3.9-10 34 | 35 | ## 2.0.7 (2019-10-07) 36 | * `project`: now allow using a string of field names, separated by whitespace. Example: `{project: "name age weight"}` 37 | * `@saves_relations`: now possible to differentiate a value not provided from a provided `None` 38 | * Better sqlalchemy error messages: a new `mongosql.exc.RuntimeQueryError` provides more details about internal errors. 39 | 40 | Example: 41 | > mongosql.exc.RuntimeQueryError: Error processing MongoQuery(Assignment -> Category).join: (cryptic sqlalchemy message) 42 | * Bugfix: compilation of dialect-specific clauses used to fail with some JOINs 43 | * Bugfix: `max_items` + `force_filter` used to cause trouble because MongoSQL could not make the correct decision 44 | that a nested query is necessary 45 | 46 | ## 2.0.6 (2019-09-17) 47 | * `bundled_project` are now loaded quietly (meaning, they are loaded, but not included into the projection) 48 | * `MongoQuerySettingsDict.pluck_from()`: now skips the `max_items` key because it does not make sense when inherited 49 | * Bugfix: in some cases, the `projection` property returned invalid results 50 | * Bugfix: when the `join` operation includes two relationships that are LEFT JOINed, the query is not broken anymore. 51 | 52 | ## 2.0.5 (2019-08-27) 53 | * `ensure_loaded` setting for projections 54 | 55 | ## 2.0.4 (2019-07-26) 56 | * `filter`: new `$prefix` operator 57 | * `AssociationProxy` support for `project` and `filter` 58 | * `MongoQuery.get_final_query_object()` method for debugging 59 | * Project Handler: the new `default_projection` behavior lets you build APIs that return no fields by default: 60 | the API user will have to require every field explicitly. 61 | 62 | ## 2.0.3 (2019-07-06) 63 | * CrudHelper is not able to save `@property` values: welcome `writable_properties`! 64 | * Fix: `bundled_project` now takes care of `force_include`d fields as well 65 | * Recommendation when a legacy column is removed: use `legacy_fields` together with `force_include` 66 | on a `@property` that fakes the missing column or relationship. 67 | * CrudHelper now removes `legacy_fields` from the input dict 68 | * `legacy_fields` are now included into all projections generated by `project` and `join` handlers 69 | 70 | ## 2.0.2 (2019-06-29) 71 | * `legacy_fields` setting will make handlers ignore certain fields that are not available anymore. 72 | Works with: `filter`, `sort`, `group`, `join`, `joinf`, `aggregate`. 73 | * `method_decorator` has had a few improvements that no one would notice 74 | 75 | ## 2.0.0 (2019-06-18) 76 | * Version 2.0 is released! 77 | * Complete redesign 78 | * Query Object format is the same: backwards-compatible 79 | * `outerjoin` is renamed to `join`, old buggy `join` is now `joinf` 80 | * `join` is not handled by a tweaked `selectin` loader, which is a lot easier and faster! 81 | * Overall 1.5x-2.5x performance improvement 82 | * `MongoQuery` settings lets you configure everything 83 | * `StrictCrudHelper` is much more powerful 84 | * `@saves_relations` helps with saving related entities 85 | * `MongoQuery.end_count()` counts and selects at the same time 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Mark Vartanyan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | py-mongosql-postgres-test: 5 | image: postgres 6 | restart: always 7 | ports: 8 | - '127.0.0.1:5432:5432' 9 | environment: 10 | POSTGRES_USER: postgres 11 | POSTGRES_PASSWORD: postgres 12 | POSTGRES_DB: test_mongosql 13 | -------------------------------------------------------------------------------- /misc/_doc/README.md.j2: -------------------------------------------------------------------------------- 1 | [![Build Status](https://api.travis-ci.org/kolypto/py-mongosql.png?branch=master)](https://travis-ci.org/kolypto/py-mongosql) 2 | [![Pythons](https://img.shields.io/badge/python-3.6%E2%80%933.8-blue.svg)](.travis.yml) 3 | 4 | 5 | MongoSQL 6 | ======== 7 | 8 | {{ mongosql['doc'] }} 9 | 10 | 11 | 12 | Table of Contents 13 | ================= 14 | 15 | * Querying 16 | * Query Object Syntax 17 | * Operations 18 | * Project Operation 19 | * Sort Operation 20 | * Filter Operation 21 | * Join Operation 22 | * Filtering Join Operation 23 | * Aggregate Operation 24 | * Group Operation 25 | * Slice Operation 26 | * Count Operation 27 | * JSON Column Support 28 | * MongoSQL Programming Interface 29 | * MongoQuery 30 | * Creating a MongoQuery 31 | * Reusable 32 | * Querying: MongoQuery.query() 33 | * Getting Results: MongoQuery.end() 34 | * Getting All Sorts of Results 35 | * MongoQuery Configuration 36 | * MongoQuery API 37 | * MongoQuery(model, handler_settings=None) 38 | * MongoQuery.from_query(query) -> MongoQuery 39 | * MongoQuery.with_session(ssn) -> MongoQuery 40 | * MongoQuery.query(**query_object) -> MongoQuery 41 | * MongoQuery.end() -> Query 42 | * MongoQuery.end_count() -> CountingQuery 43 | * MongoQuery.result_contains_entities() -> bool 44 | * MongoQuery.result_is_scalar() -> bool 45 | * MongoQuery.result_is_tuples() -> bool 46 | * MongoQuery.get_final_query_object() -> dict 47 | * MongoQuery.ensure_loaded(*cols) -> MongoQuery 48 | * MongoQuery.get_projection_tree() -> dict 49 | * MongoQuery.get_full_projection_tree() -> dict 50 | * MongoQuery.pluck_instance(instance) -> dict 51 | * Handlers 52 | * CRUD Helpers 53 | * CrudHelper(model, **handler_settings) 54 | * StrictCrudHelper 55 | * CrudViewMixin() 56 | * @saves_relations(*field_names) 57 | * Other Useful Tools 58 | * ModelPropertyBags(model) 59 | * CombinedBag(**bags) 60 | * CountingQuery(query)" 61 | 62 | Querying 63 | ======== 64 | 65 | {{ handlers['doc'] }} 66 | 67 | Operations 68 | ---------- 69 | 70 | {{ operations['project']['doc'] }} 71 | {{ operations['sort']['doc'] }} 72 | {{ operations['filter']['doc'] }} 73 | {{ operations['join']['doc'] }} 74 | {{ operations['joinf']['doc'] }} 75 | {{ operations['aggregate']['doc'] }} 76 | {{ operations['group']['doc'] }} 77 | {{ operations['limit']['doc'] }} 78 | {{ operations['count']['doc'] }} 79 | 80 | 81 | JSON Column Support 82 | ------------------- 83 | 84 | A `JSON` (or `JSONB`) field is a column that contains an embedded object, 85 | which itself has fields too. You can access these fields using a dot. 86 | 87 | Given a model fields: 88 | 89 | ```javascript 90 | model.data = { rating: 5.5, list: [1, 2, 3], obj: {a: 1} } 91 | ``` 92 | 93 | You can reference JSON field's internals: 94 | 95 | ```javascript 96 | 'data.rating' 97 | 'data.list.0' 98 | 'data.obj.a' 99 | 'data.obj.z' // gives NULL when a field does not exist 100 | ``` 101 | 102 | Operations that support it: 103 | 104 | * [Sort](#sort-operation) and [Group](#group-operation) operations: 105 | 106 | ```javascript 107 | $.get('/api/user?query=' + JSON.stringify({ 108 | sort: ['data.rating'] // JSON field sorting 109 | })) 110 | ``` 111 | 112 | * [Filter](#filter-operation) operation: 113 | 114 | ```javascript 115 | $.get('/api/user?query=' + JSON.stringify({ 116 | filter: { 117 | 'data.rating': { $gte: 5.5 }, // JSON field condition 118 | } 119 | })) 120 | ``` 121 | 122 | or this is how you test that a property is missing: 123 | 124 | ```javascript 125 | { 'data.rating': null } // Test for missing property 126 | ``` 127 | 128 | *CAVEAT*: PostgreSQL is a bit capricious about data types, so MongoSql tries to guess it *using the operand you provide*. 129 | Hence, when filtering with a property known to contain a `float`-typed field, please provide a `float` value!. 130 | 131 | * [Aggregate](#aggregate-operation): 132 | 133 | ```javascript 134 | $.get('/api/user?query=' + JSON.stringify({ 135 | aggregate: { 136 | avg_rating: { $avg: 'data.rating' } 137 | } 138 | })) 139 | ``` 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | {% macro doc_class_method(method) -%} 164 | 165 | ### `{{ method['qrtsignature'] }}` 166 | {{ method['doc'] }} 167 | 168 | {% if method['args'] %} 169 | Arguments: 170 | 171 | {% for arg in method['args'] %} 172 | * `{{ arg['name'] }}{% if arg['type'] %}: {{ arg['type'] }}{% endif %}{% if 'default' in arg %} = {{ arg['default'] }}{% 173 | endif %}`: {{ 174 | arg['doc']|indent(4) }} 175 | {% endfor %} 176 | {% endif %} 177 | 178 | {% if method['ret'] %} 179 | Returns{% if method['ret']['type'] %} `{{ method['ret']['type'] }}`{% endif %}{% if method['ret']['doc'] %}: {{ method['ret']['doc'] }}{% endif %} 180 | {% endif %} 181 | 182 | {% if method['exc'] %} 183 | Exceptions: 184 | 185 | {% for exc in method['exc'] %} 186 | * `{{ exc['name'] }}`: {{ exc['doc']|indent(4) }} 187 | {% endfor %} 188 | 189 | {% endif %} 190 | 191 | {% if method['example'] %} 192 | Example: 193 | 194 | {{ method['example'] }} 195 | 196 | {% endif %} 197 | 198 | {%- endmacro %} 199 | 200 | 201 | 202 | MongoSQL Programming Interface 203 | ============================== 204 | 205 | MongoQuery 206 | ---------- 207 | {{ mongosql_query['doc'] }} 208 | 209 | MongoQuery Configuration 210 | ------------------------ 211 | 212 | {{ MongoQuerySettingsDict_init['doc'] }} 213 | 214 | Example: 215 | 216 | {{ MongoQuerySettingsDict_init['example'] }} 217 | 218 | The available settings are: 219 | 220 | {% for arg in MongoQuerySettingsDict_init['args'] %} 221 | * `{{ arg['name'] }}`: {{ arg['doc']|indent(4) }} 222 | {% endfor %} 223 | 224 | 225 | 226 | {{ StrictCrudHelperSettingsDict_init['doc'] }} 227 | 228 | {% for arg in StrictCrudHelperSettingsDict_init['args'] %} 229 | * `{{ arg['name'] }}`: {{ arg['doc']|indent(4) }} 230 | {% endfor %} 231 | 232 | 233 | MongoQuery API 234 | -------------- 235 | 236 | ### `{{ MongoQuery['cls']['signature'] }}` 237 | {{ MongoQuery['cls']['clsdoc'] }} 238 | 239 | {{ doc_class_method(MongoQuery['attrs']['from_query']) }} 240 | {{ doc_class_method(MongoQuery['attrs']['with_session']) }} 241 | {{ doc_class_method(MongoQuery['attrs']['query']) }} 242 | {{ doc_class_method(MongoQuery['attrs']['end']) }} 243 | {{ doc_class_method(MongoQuery['attrs']['end_count']) }} 244 | {{ doc_class_method(MongoQuery['attrs']['result_contains_entities']) }} 245 | {{ doc_class_method(MongoQuery['attrs']['result_is_scalar']) }} 246 | {{ doc_class_method(MongoQuery['attrs']['result_is_tuples']) }} 247 | {{ doc_class_method(MongoQuery['attrs']['ensure_loaded']) }} 248 | {{ doc_class_method(MongoQuery['attrs']['get_final_query_object']) }} 249 | {{ doc_class_method(MongoQuery['attrs']['get_projection_tree']) }} 250 | {{ doc_class_method(MongoQuery['attrs']['get_full_projection_tree']) }} 251 | {{ doc_class_method(MongoQuery['attrs']['pluck_instance']) }} 252 | 253 | ### Handlers 254 | In addition to this, `MongoQuery` lets you inspect the internals of the MongoQuery. 255 | Every handler is available as a property of the `MongoQuery`: 256 | 257 | * `MongoQuery.handler_project`: [handlers.MongoProject](mongosql/handlers/project.py) 258 | * `MongoQuery.handler_sort`: [handlers.MongoSort](mongosql/handlers/sort.py) 259 | * `MongoQuery.handler_group`: [handlers.MongoGroup](mongosql/handlers/group.py) 260 | * `MongoQuery.handler_join`: [handlers.MongoJoin](mongosql/handlers/join.py) 261 | * `MongoQuery.handler_joinf`: [handlers.MongoFilteringJoin](mongosql/handlers/joinf.py) 262 | * `MongoQuery.handler_filter`: [handlers.MongoFilter](mongosql/handlers/filter.py) 263 | * `MongoQuery.handler_aggregate`: [handlers.MongoAggregate](mongosql/handlers/aggregate.py) 264 | * `MongoQuery.handler_limit`: [handlers.MongoLimit](mongosql/handlers/limit.py) 265 | * `MongoQuery.handler_count`: [handlers.MongoCount](mongosql/handlers/count.py) 266 | 267 | Some of them have methods which may be useful for the application you're building, 268 | especially if you need to get some information out of `MongoQuery`. 269 | 270 | 271 | 272 | 273 | 274 | CRUD Helpers 275 | ============ 276 | 277 | {{ crudhelper['doc'] }} 278 | 279 | ## `{{ CrudHelper['cls']['signature'] }}` 280 | {{ CrudHelper['cls']['clsdoc'] }} 281 | 282 | {{ doc_class_method(CrudHelper['attrs']['query_model']) }} 283 | {{ doc_class_method(CrudHelper['attrs']['create_model']) }} 284 | {{ doc_class_method(CrudHelper['attrs']['update_model']) }} 285 | 286 | 287 | ## `{{ StrictCrudHelper['cls']['name'] }}` 288 | {{ StrictCrudHelper['cls']['clsdoc'] }} 289 | 290 | {{ doc_class_method(StrictCrudHelper['cls']) }} 291 | 292 | 293 | ## `{{ CrudViewMixin['cls']['signature'] }}` 294 | {{ CrudViewMixin['cls']['clsdoc'] }} 295 | 296 | {{ doc_class_method(CrudViewMixin['attrs']['_get_db_session']) }} 297 | {{ doc_class_method(CrudViewMixin['attrs']['_get_query_object']) }} 298 | 299 | {{ doc_class_method(CrudViewMixin['attrs']['_method_get']) }} 300 | {{ doc_class_method(CrudViewMixin['attrs']['_method_list']) }} 301 | {{ doc_class_method(CrudViewMixin['attrs']['_method_create']) }} 302 | {{ doc_class_method(CrudViewMixin['attrs']['_method_update']) }} 303 | {{ doc_class_method(CrudViewMixin['attrs']['_method_delete']) }} 304 | 305 | {{ doc_class_method(CrudViewMixin['attrs']['_mongoquery_hook']) }} 306 | {{ doc_class_method(CrudViewMixin['attrs']['_save_hook']) }} 307 | 308 | {{ doc_class_method(CrudViewMixin['attrs']['_method_create_or_update_many']) }} 309 | 310 | 311 | ## `@{{ saves_relations['cls']['signature'] }}` 312 | {{ saves_relations['cls']['clsdoc'] }} 313 | 314 | 315 | 316 | 317 | 318 | Other Useful Tools 319 | ================== 320 | 321 | ## `{{ ModelPropertyBags['cls']['signature'] }}` 322 | {{ ModelPropertyBags['cls']['clsdoc'] }} 323 | 324 | ## `{{ CombinedBag['cls']['signature'] }}` 325 | {{ CombinedBag['cls']['clsdoc'] }} 326 | 327 | ## `{{ CountingQuery['cls']['signature'] }}` 328 | {{ CountingQuery['cls']['clsdoc'] }} 329 | 330 | 331 | -------------------------------------------------------------------------------- /misc/_doc/README.py: -------------------------------------------------------------------------------- 1 | import json 2 | from exdoc import doc, getmembers, subclasses 3 | 4 | # Methods 5 | doccls = lambda cls, *allowed_keys: { 6 | 'cls': doc(cls), 7 | 'attrs': {name: doc(m, cls) 8 | for name, m in getmembers(cls, None, 9 | lambda key, value: key in allowed_keys or not key.startswith('_'))} 10 | } 11 | 12 | docmodule = lambda mod: { 13 | 'module': doc(mod), 14 | 'members': [ doc(getattr(mod, name)) for name in mod.__all__] 15 | } 16 | 17 | # Data 18 | import mongosql 19 | from mongosql.handlers import project, sort, group, filter, join, joinf, aggregate, limit, count 20 | from mongosql import query, MongoQuerySettingsDict, StrictCrudHelperSettingsDict 21 | from mongosql import ModelPropertyBags, CombinedBag, CountingQuery 22 | from mongosql.crud import crudhelper, CrudHelper, StrictCrudHelper, CrudViewMixin, saves_relations 23 | 24 | data = dict( 25 | mongosql=doc(mongosql), 26 | handlers=doc(mongosql.handlers), 27 | operations={ 28 | m.__name__.rsplit('.', 1)[1]: doc(m) 29 | for m in (project, sort, group, filter, join, joinf, aggregate, limit, count)}, 30 | mongosql_query=doc(mongosql.query), 31 | 32 | MongoQuery=doccls(mongosql.query.MongoQuery), 33 | MongoQuerySettingsDict_init=doc(MongoQuerySettingsDict.__init__, MongoQuerySettingsDict), 34 | StrictCrudHelperSettingsDict_init=doc(StrictCrudHelperSettingsDict.__init__, StrictCrudHelperSettingsDict), 35 | 36 | crudhelper=doc(crudhelper), 37 | CrudHelper=doccls(CrudHelper), 38 | StrictCrudHelper=doccls(StrictCrudHelper), 39 | CrudViewMixin=doccls(CrudViewMixin, *dir(CrudViewMixin)), 40 | saves_relations=doccls(saves_relations), 41 | 42 | ModelPropertyBags=doccls(ModelPropertyBags), 43 | CombinedBag=doccls(CombinedBag), 44 | CountingQuery=doccls(CountingQuery), 45 | ) 46 | 47 | # Patches 48 | 49 | class MyJsonEncoder(json.JSONEncoder): 50 | def default(self, o): 51 | # Classes 52 | if isinstance(o, type): 53 | return o.__name__ 54 | return super(MyJsonEncoder, self).default(o) 55 | 56 | # Document 57 | print(json.dumps(data, indent=2, cls=MyJsonEncoder)) 58 | -------------------------------------------------------------------------------- /mongosql/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MongoSQL is a JSON query engine that lets you query [SqlAlchemy](http://www.sqlalchemy.org/) 3 | like a MongoDB database. 4 | 5 | The main use case is the interation with the UI: 6 | every time the UI needs some *sorting*, *filtering*, *pagination*, or to load some 7 | *related objects*, you won't have to write a single line of repetitive code! 8 | 9 | It will let the API user send a JSON Query Object along with the REST request, 10 | which will control the way the result set is generated: 11 | 12 | ```javascript 13 | $.get('/api/user?query=' + JSON.stringify({ 14 | sort: ['first_name-'], // sort by `first_name` DESC 15 | filter: { age: { $gte: 18 } }, // filter: age >= 18 16 | join: ['user_profile'], // load related `user_profile` 17 | limit: 10, // limit to 10 rows 18 | })) 19 | ``` 20 | 21 | Tired of adding query parameters for pagination, filtering, sorting? 22 | Here is the ultimate solution. 23 | 24 | NOTE: currently, only tested with PostgreSQL. 25 | """ 26 | 27 | # SqlAlchemy versions 28 | from sqlalchemy import __version__ as SA_VERSION 29 | SA_12 = SA_VERSION.startswith('1.2') 30 | SA_13 = SA_VERSION.startswith('1.3') 31 | 32 | # Exceptions that are used here and there 33 | from .exc import * 34 | 35 | # MongoSQL needs a lot of information about the properties of your models. 36 | # All this is handled by the following class: 37 | from .bag import ModelPropertyBags, CombinedBag 38 | 39 | # The heart of MongoSql are the handlers: 40 | # that's where your JSON objects are converted to actual SqlAlchemy queries! 41 | from . import handlers 42 | 43 | # MongoQuery is the man that parses your QueryObject and applies the methods from MongoModel that 44 | # implement individual fields. 45 | from .query import MongoQuery 46 | 47 | # SqlAlchemy declarative base that defines .mongomodel() and .mongoquery() on it 48 | # That's just for your convenience. 49 | from .sa import MongoSqlBase 50 | 51 | # CrudHelper is something that enabled you to use JSON for: 52 | # - Creation (i.e. save a record into DB using JSON) 53 | # - Replacement (i.e. completely replace a record in the DB) 54 | # - Modification (i.e. update some specific fields of a record) 55 | # CrudHelper is something that you'll need when building JSON API that implements CRUD: 56 | # Create/Read/Update/Delete 57 | from .crud import CrudHelper, StrictCrudHelper, CrudViewMixin 58 | from .crud import saves_relations, ABSENT 59 | 60 | # Helpers 61 | # Reusable query objects (so that you don't have to initialize them over and over again) 62 | from mongosql.util import Reusable 63 | # selectinquery() relationship loader that supports custom queries 64 | from mongosql.util import selectinquery 65 | # `Query` object wrapper that is able to query and count() at the same time 66 | from mongosql.util import CountingQuery 67 | # Settings objects for MongoQuery and StrictCrudHelper 68 | from mongosql.util import MongoQuerySettingsDict, StrictCrudHelperSettingsDict 69 | 70 | # raiseload_col() that can be applied to columns, not only relationships 71 | try: 72 | from mongosql.util import raiseload_col, raiseload_rel, raiseload_all 73 | except ImportError: 74 | pass 75 | -------------------------------------------------------------------------------- /mongosql/crud/__init__.py: -------------------------------------------------------------------------------- 1 | from .crudhelper import CrudHelper, StrictCrudHelper 2 | from .crudview import CrudViewMixin, CRUD_METHOD, saves_relations, ABSENT 3 | from ..util import EntityDictWrapper, load_many_instance_dicts, model_primary_key_columns_and_names 4 | -------------------------------------------------------------------------------- /mongosql/crud/crudhelper.py: -------------------------------------------------------------------------------- 1 | """ 2 | MongoSql is designed to help with data selection for the APIs. 3 | To ease the pain of implementing CRUD for all of your models, 4 | MongoSQL comes with a CRUD helper that exposes MongoSQL capabilities for querying to the API user. 5 | Together with [RestfulView](https://github.com/kolypto/py-flask-jsontools#restfulview) 6 | from [flask-jsontools](https://github.com/kolypto/py-flask-jsontools), 7 | CRUD controllers are extremely easy to build. 8 | """ 9 | 10 | from sqlalchemy.orm import Query 11 | from sqlalchemy.orm.attributes import flag_modified 12 | 13 | from mongosql import exc 14 | from mongosql import MongoQuery, ModelPropertyBags 15 | from mongosql.util import Reusable 16 | 17 | from typing import Union, Mapping, Iterable, Set, Callable, MutableMapping 18 | from sqlalchemy.ext.declarative import DeclarativeMeta 19 | 20 | 21 | class CrudHelper: 22 | """ Crud helper: an object that helps implement CRUD operations for an API endpoint: 23 | 24 | * Create: construct SqlAlchemy instances from the submitted entity dict 25 | * Read: use MongoQuery for querying 26 | * Update: update SqlAlchemy instances from the submitted entity using a dict 27 | * Delete: use MongoQuery for deletion 28 | 29 | Source: [mongosql/crud/crudhelper.py](mongosql/crud/crudhelper.py) 30 | 31 | This object is supposed to be initialized only once; 32 | don't do it for every query, keep it at the class level! 33 | 34 | Most likely, you'll want to keep it at the class level of your view: 35 | 36 | ```python 37 | from .models import User 38 | from mongosql import CrudHelper 39 | 40 | class UserView: 41 | crudhelper = CrudHelper( 42 | # The model to work with 43 | User, 44 | # Settings for MongoQuery 45 | **MongoQuerySettingsDict( 46 | allowed_relations=('user_profile',), 47 | ) 48 | ) 49 | # ... 50 | ``` 51 | 52 | Note that during "create" and "update" operations, this class lets you write values 53 | to column attributes, and also to @property that are writable (have a setter). 54 | If this behavior (with writable properties) is undesirable, 55 | set `writable_properties=False` 56 | 57 | The following methods are available: 58 | """ 59 | 60 | # The class to use for getting structural data from a model 61 | _MODEL_PROPERTY_BAGS_CLS = ModelPropertyBags 62 | # The class to use for MongoQuery 63 | _MONGOQUERY_CLS = MongoQuery 64 | 65 | def __init__(self, model: DeclarativeMeta, writable_properties=True, **handler_settings): 66 | """ Init CRUD helper 67 | 68 | :param model: The model to work with 69 | :param handler_settings: Settings for the MongoQuery used to make queries 70 | """ 71 | self.model = model 72 | self.handler_settings = handler_settings 73 | self.bags = self._MODEL_PROPERTY_BAGS_CLS.for_model(model) 74 | self.reusable_mongoquery = Reusable(self._MONGOQUERY_CLS(self.model, handler_settings)) # type: MongoQuery 75 | 76 | # Settings 77 | self.writable_properties = writable_properties 78 | 79 | # We also need `legacy_fields` 80 | # we're going to ignore them in the input 81 | self.legacy_fields = self.reusable_mongoquery.handler_project.legacy_fields 82 | 83 | def query_model(self, query_obj: Union[Mapping, None] = None, from_query: Union[Query, None] = None) -> MongoQuery: 84 | """ Make a MongoQuery using the provided Query Object 85 | 86 | Note that you have to provide the MongoQuery yourself. 87 | This is because it has to be properly configured with handler_settings. 88 | 89 | :param query_obj: The Query Object to use 90 | :param from_query: An optional Query to initialize MongoQuery with 91 | :raises exc.InvalidColumnError: Invalid column name specified in the Query Object by the user 92 | :raises exc.InvalidRelationError: Invalid relationship name specified in the Query Object by the user 93 | :raises exc.InvalidQueryError: There is an error in the Query Object that the user has made 94 | :raises exc.DisabledError: A feature is disabled; likely, due to a configuration issue. See handler_settings. 95 | """ 96 | # Validate 97 | if not isinstance(query_obj, (Mapping, NoneType)): 98 | raise exc.InvalidQueryError('Query Object must be either an object, or null') 99 | 100 | # Query 101 | return self._query_model(query_obj or {}, from_query) # ensure dict 102 | 103 | def _query_model(self, query_obj: Mapping, from_query: Union[Query, None] = None) -> MongoQuery: 104 | """ Make a MongoQuery """ 105 | return self.reusable_mongoquery.from_query(from_query).query(**query_obj) 106 | 107 | def _validate_columns(self, column_names: Iterable[str], where: str) -> Set[str]: 108 | """ Validate column names 109 | 110 | :raises exc.InvalidColumnError: Invalid column name 111 | """ 112 | unk_cols = self.bags.columns.get_invalid_names(column_names) 113 | if unk_cols: 114 | raise exc.InvalidColumnError(self.bags.model_name, unk_cols.pop(), where) 115 | return set(column_names) 116 | 117 | def _validate_attributes(self, column_names: Iterable[str], where: str) -> Set[str]: 118 | """ Validate attribute names (any, inc. properties) 119 | 120 | :raises exc.InvalidColumnError: Invalid column name 121 | """ 122 | column_names = set(column_names) 123 | unk_cols = column_names - self.bags.all_names 124 | if unk_cols: 125 | raise exc.InvalidColumnError(self.bags.model_name, unk_cols.pop(), where) 126 | return column_names 127 | 128 | def _validate_writable_attributes(self, attr_names: Iterable[str], where: str) -> Set[str]: 129 | """ Validate attribute names (columns, properties, hybrid properties) that are writable 130 | 131 | This list does not include attributes like relationships and read-only properties 132 | 133 | :raises exc.InvalidColumnError: Column name was not writable 134 | :rtype: set[set] 135 | """ 136 | attr_names = set(attr_names) 137 | unk_cols = attr_names - self.bags.writable.names 138 | if unk_cols: 139 | raise exc.InvalidColumnError(self.bags.model_name, unk_cols.pop(), where) 140 | return attr_names 141 | 142 | def validate_incoming_entity_dict_fields(self, entity_dict: dict, action: str) -> dict: 143 | """ Validate the incoming JSON data """ 144 | # Validate 145 | if not isinstance(entity_dict, Mapping): 146 | raise exc.InvalidQueryError(f'Model "{action}": the value has to be an object, ' 147 | f'not {type(entity_dict)}') 148 | 149 | # Remove certain fields from the entity dict 150 | if action == 'create': 151 | self._remove_entity_dict_fields(entity_dict, self._fields_to_remove_on_create) 152 | elif action == 'update': 153 | self._remove_entity_dict_fields(entity_dict, self._fields_to_remove_on_update) 154 | else: 155 | raise ValueError(action) 156 | 157 | # Check fields 158 | if self.writable_properties: 159 | # let both columns and @properties 160 | self._validate_writable_attributes(entity_dict.keys(), action) 161 | else: 162 | # let only columns 163 | self._validate_columns(entity_dict.keys(), action) 164 | 165 | # Done 166 | return entity_dict 167 | 168 | @property 169 | def _fields_to_remove_on_create(self): 170 | """ The list of fields to remove when creating an instance from an entity dict """ 171 | return self.legacy_fields 172 | 173 | @property 174 | def _fields_to_remove_on_update(self): 175 | """ The list of fields to remove when updating an instance from an entity dict """ 176 | return self.legacy_fields 177 | 178 | def _remove_entity_dict_fields(self, entity_dict: MutableMapping, rm_fields: Set[str]): 179 | """ Remove certain fields from the incoming entity dict """ 180 | for k in set(entity_dict.keys()) & rm_fields: 181 | entity_dict.pop(k) 182 | 183 | def create_model(self, entity_dict: Mapping) -> object: 184 | """ Create an instance from entity dict. 185 | 186 | This method lets you set the value of columns and writable properties, 187 | but not relations. Use @saves_relations to handle additional fields. 188 | 189 | :param entity_dict: Entity dict 190 | :return: Created instance 191 | :raises InvalidQueryError: validation errors 192 | :raises InvalidColumnError: invalid column 193 | """ 194 | # Validate and prepare it 195 | entity_dict = self.validate_incoming_entity_dict_fields(entity_dict, 'create') 196 | 197 | # Create 198 | return self._create_model(entity_dict) 199 | 200 | def _create_model(self, entity_dict: Mapping) -> object: 201 | """ Create an instance from a dict 202 | 203 | This method does not validate `entity_dict` 204 | """ 205 | return self.model(**entity_dict) 206 | 207 | def update_model(self, entity_dict: Mapping, instance: object) -> object: 208 | """ Update an instance from an entity dict by merging the fields 209 | 210 | - Attributes are copied over 211 | - JSON dicts are shallowly merged 212 | 213 | Note that because properties are *copied over*, 214 | this operation does not replace the entity; it merely updates the entity. 215 | 216 | In other words, this method does a *partial update*: 217 | only updates the fields that were provided by the client, leaving all the rest intact. 218 | 219 | :param entity_dict: Entity dict 220 | :param instance: The instance to update 221 | :return: New instance, updated 222 | :raises InvalidQueryError: validation errors 223 | :raises InvalidColumnError: invalid column 224 | """ 225 | # Validate and prepare it 226 | entity_dict = self.validate_incoming_entity_dict_fields(entity_dict, 'update') 227 | 228 | # Update 229 | return self._update_model(entity_dict, instance) 230 | 231 | def _update_model(self, entity_dict: Mapping, instance: object) -> object: 232 | """ Update an instance from an entity dict 233 | 234 | This method does not validate `entity_dict` 235 | """ 236 | # Update 237 | for name, val in entity_dict.items(): 238 | if isinstance(val, Mapping) and self.bags.columns.is_column_json(name): 239 | # JSON column with a dict: do a shallow merge 240 | getattr(instance, name).update(val) 241 | # Tell SqlAlchemy that a mutable collection was updated 242 | flag_modified(instance, name) 243 | else: 244 | # Other columns: just assign 245 | setattr(instance, name, val) 246 | 247 | # Finish 248 | return instance 249 | 250 | 251 | class StrictCrudHelper(CrudHelper): 252 | """ A Strict Crud Helper imposes defaults and limitations on the API user: 253 | 254 | Source: [mongosql/crud/crudhelper.py](mongosql/crud/crudhelper.py) 255 | 256 | - Read-only fields can not be set: not with create, nor with update 257 | - Constant fields can be set initially, but never be updated 258 | - Defaults for Query Object provide the default values for every query, unless overridden 259 | 260 | The following behavior is implemented: 261 | 262 | * By default, all fields are writable 263 | * If ro_fields is provided, these fields become read-only, all other fields are writable 264 | * If rw_fields is provided, ony these fields are writable, all other fields are read-only 265 | * If const_fields, it is seen as a further limitation on rw_fields: those fields would be writable, 266 | but only once. 267 | 268 | Attributes: 269 | writable_properties (bool): Enable saving values from incoming JSON into @property attrs? 270 | ro_fields (set[str]): The list of read-only field names 271 | rw_fields (set[str]): The list of writable field names 272 | const_fields (set[str]): The list of constant field names 273 | query_defaults (dict): Default values for every field of the Query Object 274 | """ 275 | 276 | def __init__(self, model: DeclarativeMeta, 277 | writable_properties: bool = True, 278 | ro_fields: Union[Iterable[str], Callable, None] = None, 279 | rw_fields: Union[Iterable[str], Callable, None] = None, 280 | const_fields: Union[Iterable[str], Callable, None] = None, 281 | query_defaults: Union[Iterable[str], Callable, None] = None, 282 | **handler_settings): 283 | """ Initializes a strict CRUD helper 284 | 285 | Note: use a `**StrictCrudHelperSettingsDict()` to help you with the argument names and their docs! 286 | 287 | Args: 288 | model: The model to work with 289 | writable_properties: enable writing to @property attributes? 290 | ro_fields: List of read-only property names, or a callable which gives the list 291 | rw_fields: List of writable property names, or a callable which gives the list 292 | const_fields: List of property names that are constant once set, or a callable which gives the list 293 | query_defaults: Defaults for every Query Object: Query Object will be merged into it. 294 | handler_settings: Settings for the `MongoQuery` used to make queries 295 | writable_properties: 296 | 297 | Example: 298 | 299 | ```python 300 | from .models import User 301 | from mongosql import StrictCrudHelper, StrictCrudHelperSettingsDict 302 | 303 | class UserView: 304 | crudhelper = StrictCrudHelper( 305 | # The model to work with 306 | User, 307 | # Settings for MongoQuery and StrictCrudHelper 308 | **StrictCrudHelperSettingsDict( 309 | # Can never be set of modified 310 | ro_fields=('id',), 311 | # Can only be set once 312 | const_fields=('login',), 313 | # Relations that can be `join`ed 314 | allowed_relations=('user_profile',), 315 | ) 316 | ) 317 | # ... 318 | ``` 319 | """ 320 | super().__init__(model, writable_properties=writable_properties, **handler_settings) 321 | 322 | # ro, rw, const fields 323 | ro, rw, cn = self._init_ro_rw_cn_fields(ro_fields, rw_fields, const_fields) 324 | self.ro_fields = ro 325 | self.rw_fields = rw 326 | self.const_fields = cn 327 | 328 | # Defaults for the Query Object 329 | self.query_defaults = query_defaults or {} # type: dict 330 | 331 | # Validate the Default Query Object 332 | MongoQuery(self.model).query(**self.query_defaults) 333 | 334 | def _init_ro_rw_cn_fields(self, ro_fields, rw_fields, cn_fields): 335 | """ Initialize ro_fields and rw_fields and const_fields 336 | 337 | :rtype: (set[str], set[str], set[str]) 338 | """ 339 | # Usage 340 | ro_provided = ro_fields is not None # provided, even if empty 341 | rw_provided = rw_fields is not None 342 | if ro_provided and rw_provided: 343 | raise ValueError('Use either `ro_fields` or `rw_fields`, but not both') 344 | 345 | # Read-only and Read-Write fields 346 | ro_fields = set(call_if_callable(ro_fields)) if ro_fields is not None else set() 347 | rw_fields = set(call_if_callable(rw_fields)) if rw_fields is not None else set() 348 | cn_fields = set(call_if_callable(cn_fields)) if cn_fields is not None else set() 349 | 350 | # Validate 351 | self._validate_attributes(ro_fields, 'ro_fields') 352 | self._validate_writable_attributes(rw_fields, 'rw_fields') 353 | self._validate_writable_attributes(cn_fields, 'const_fields') 354 | 355 | # ro_fields 356 | if rw_provided: 357 | ro_fields = set(self.bags.all_names - rw_fields - cn_fields) 358 | 359 | # rw_fields 360 | rw_fields = self.bags.writable.names - ro_fields - cn_fields 361 | 362 | # Done 363 | return frozenset(ro_fields), frozenset(rw_fields), frozenset(cn_fields) 364 | 365 | @property 366 | def _fields_to_remove_on_create(self): 367 | """ The list of fields to remove when creating an instance from an entity dict """ 368 | return super()._fields_to_remove_on_create | self.ro_fields 369 | 370 | @property 371 | def _fields_to_remove_on_update(self): 372 | """ The list of fields to remove when updating an instance from an entity dict """ 373 | return super()._fields_to_remove_on_update | self.ro_fields | self.const_fields 374 | 375 | def _query_model(self, query_obj: Mapping, from_query: Union[Query, None] = None) -> MongoQuery: 376 | # Default Query Object 377 | if self.query_defaults: 378 | query_obj = {**self.query_defaults, **(query_obj or {})} 379 | 380 | # Super 381 | return super()._query_model(query_obj, from_query=from_query) 382 | 383 | 384 | NoneType = type(None) 385 | 386 | 387 | def call_if_callable(v): 388 | """ Preprocess a value: return it ; but call it, if it's a lambda (for late binding) """ 389 | return v() if callable(v) else v 390 | -------------------------------------------------------------------------------- /mongosql/exc.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseMongoSqlException(AssertionError): # `AssertionError` for backwards-compatibility 3 | pass 4 | 5 | 6 | class InvalidQueryError(BaseMongoSqlException): 7 | """ Invalid input provided by the User """ 8 | 9 | def __init__(self, err: str): 10 | super(InvalidQueryError, self).__init__('Query object error: {err}'.format(err=err)) 11 | 12 | 13 | class DisabledError(InvalidQueryError): 14 | """ The feature is disabled """ 15 | 16 | 17 | class InvalidColumnError(BaseMongoSqlException): 18 | """ Query mentioned an invalid column name """ 19 | 20 | def __init__(self, model: str, column_name: str, where: str): 21 | self.model = model 22 | self.column_name = column_name 23 | self.where = where 24 | 25 | super(InvalidColumnError, self).__init__( 26 | 'Invalid column "{column_name}" for "{model}" specified in {where}'.format( 27 | column_name=column_name, 28 | model=model, 29 | where=where) 30 | ) 31 | 32 | 33 | class InvalidRelationError(InvalidColumnError, BaseMongoSqlException): 34 | """ Query mentioned an invalid relationship name """ 35 | def __init__(self, model: str, column_name: str, where: str): 36 | self.model = model 37 | self.column_name = column_name 38 | self.where = where 39 | 40 | super(InvalidColumnError, self).__init__( 41 | 'Invalid relation "{column_name}" for "{model}" specified in {where}'.format( 42 | column_name=column_name, 43 | model=model, 44 | where=where) 45 | ) 46 | 47 | 48 | class RuntimeQueryError(BaseMongoSqlException): 49 | """ Uncaught error while processing a MongoQuery 50 | 51 | This class is used to augment other errors 52 | """ 53 | -------------------------------------------------------------------------------- /mongosql/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | If you know how to query documents in MongoDB, you can query your database with the same language. 4 | MongoSQL uses the familiar [MongoDB Query Operators](https://docs.mongodb.com/manual/reference/operator/query/) 5 | language with a few custom additions. 6 | 7 | The Query Object, in JSON format, will let you sort, filter, paginate, and do other things. 8 | You would typically send this object in the URL query string, like this: 9 | 10 | ``` 11 | GET /api/user?query={"filter":{"age":{"$gte":18}}} 12 | ``` 13 | 14 | The name of the `query` argument, however, may differ from project to project. 15 | 16 | 17 | 18 | Query Object Syntax 19 | ------------------- 20 | 21 | A Query Object is a JSON object that the API user can submit to the server to change the way the results are generated. 22 | It is an object with the following properties: 23 | 24 | * `project`: [Project Operation](#project-operation) selects the fields to be loaded 25 | * `sort`: [Sort Operation](#sort-operation) determines the sorting of the results 26 | * `filter`: [Filter Operation](#filter-operation) filters the results, using your criteria 27 | * `join`: [Join Operation](#join-operation) loads related models 28 | * `joinf`: [Filtering Join Operation](#filtering-join-operation) loads related models with filtering 29 | * `aggregate`: [Aggregate Operation](#aggregate-operation) lets you calculate statistics 30 | * `group`: [Group Operation](#group-operation) determines how to group rows while doing aggregation 31 | * `skip`, `limit`: [Rows slicing](#slice-operation): paginates the results 32 | * `count`: [Counting rows](#count-operation) counts the number of rows without producing results 33 | 34 | An example Query Object is: 35 | 36 | ```javascript 37 | { 38 | project: ['id', 'name'], # Only fetch these columns 39 | sort: ['age+'], # Sort by age, ascending 40 | filter: { 41 | # Filter condition 42 | sex: 'female', # Girls 43 | age: { $gte: 18 }, # Age >= 18 44 | }, 45 | join: ['user_profile'], # Load the 'user_profile' relationship 46 | limit: 100, # Display 100 per page 47 | skip: 10, # Skip first 10 rows 48 | } 49 | ``` 50 | 51 | Detailed syntax for every operation is provided in the relevant sections. 52 | 53 | Please keep in mind that while MongoSQL provides a query language that is rich enough for most typical tasks, 54 | there would still be cases when an implementation of a custom API would be better, or even the only option available. 55 | 56 | MongoSQL was not designed to be a complete replacement for the SQL; it was designed only to keep you from doing 57 | repetitive work :) So it's absolutely fine that some queries that you may have in mind won't be possible with MongoSQL. 58 | """ 59 | 60 | from .project import MongoProject 61 | from .sort import MongoSort 62 | from .group import MongoGroup 63 | from .join import MongoJoin, \ 64 | MongoJoinParams 65 | from .joinf import MongoFilteringJoin 66 | from .filter import MongoFilter, \ 67 | FilterExpressionBase, FilterBooleanExpression, FilterColumnExpression, FilterRelatedColumnExpression 68 | from .aggregate import MongoAggregate, \ 69 | AggregateExpressionBase, AggregateLabelledColumn, AggregateColumnOperator, AggregateBooleanCount 70 | from .aggregate import MongoAggregateInsecure 71 | from .limit import MongoLimit 72 | from .count import MongoCount 73 | 74 | # TODO: implement update operations on a model in MongoDB-style 75 | # TODO: document MongoHandler classes 76 | -------------------------------------------------------------------------------- /mongosql/handlers/base.py: -------------------------------------------------------------------------------- 1 | from ..bag import ModelPropertyBags 2 | from ..exc import InvalidQueryError, InvalidColumnError, InvalidRelationError 3 | 4 | 5 | class MongoQueryHandlerBase: 6 | """ An implementation of a handler from MongoQuery 7 | 8 | Every subclass will handle a single field from the Query object 9 | """ 10 | 11 | #: Name of the QueryObject section that this object is capable of handling 12 | query_object_section_name = None 13 | 14 | def __init__(self, model, bags): 15 | """ Initialize the Query Object section handler with a model. 16 | 17 | This method does *not* receive any input data just yet, with the purpose of having an 18 | object that can be extended with some interesting defaults right at init time. 19 | 20 | :param model: The sqlalchemy model it's being applied to 21 | :type model: sqlalchemy.ext.declarative.DeclarativeMeta 22 | :param bags: Model bags. 23 | We have to have `bags` provided to us, because then someone may subclass MongoQuery, 24 | use a different MongoPropertyBags, and customize the way a model is analyzed. 25 | :type bags: ModelPropertyBags 26 | 27 | NOTE: Any arguments that have default values will be treated as handler settings!! 28 | """ 29 | #: The model to handle the Query Object for 30 | self.model = model # the model, ot its alias (when used with self.aliased()) 31 | #: Model property bags: because we need access to the lists of its properties 32 | self.bags = bags 33 | #: A CombinedBag() that allows to handle properties of multiple types (e.g. columns + hybrid properties) 34 | self.supported_bags = self._get_supported_bags() 35 | 36 | # Has the input() method been called already? 37 | # This may be important for handlers that depend on other handlers 38 | self.input_received = False 39 | 40 | # Has the aliased() method been called already? 41 | # This is important because it can't be done again, or undone. 42 | self.is_aliased = False 43 | 44 | # Should this handler's alter_query() be skipped by MongoQuery? 45 | # This is used by MongoJoin when it removes a filtering condition into the ON-clause, 46 | # and does not want the original filter to be executed. 47 | self.skip_this_handler = False 48 | 49 | #: MongoQuery bound to this object. It may remain uninitialized. 50 | self.mongoquery = None 51 | 52 | def with_mongoquery(self, mongoquery): 53 | """ Bind this object with a MongoQuery 54 | 55 | :type mongoquery: mongosql.query.MongoQuery 56 | """ 57 | self.mongoquery = mongoquery 58 | return self 59 | 60 | def __copy__(self): 61 | """ Some objects may be reused: i.e. their state before input() is called. 62 | 63 | Reusable handlers are implemented using the Reusable() wrapper which performs the 64 | automatic copying on input() call 65 | """ 66 | cls = self.__class__ 67 | result = cls.__new__(cls) 68 | result.__dict__.update(self.__dict__) 69 | return result 70 | 71 | def aliased(self, model): 72 | """ Use an aliased model to build queries 73 | 74 | This is used by MongoQuery.aliased(), which is ultimately useful to MongoJoin handler. 75 | Note that the method modifies the current object and does not make a copy! 76 | """ 77 | # Only once 78 | assert not self.is_aliased, 'You cannot call {}.aliased() ' \ 79 | 'on a handler that has already been aliased()' \ 80 | .format(self.__class__.__name__) 81 | self.is_aliased = True 82 | 83 | # aliased() everything 84 | self.model = model 85 | self.bags = self.bags.aliased(model) 86 | self.supported_bags = self._get_supported_bags() # re-initialize 87 | return self 88 | 89 | def _get_supported_bags(self): 90 | """ Get the _PropertiesBag interface supported by this handler 91 | 92 | :rtype: mongosql.bag._PropertiesBagBase 93 | """ 94 | raise NotImplementedError() 95 | 96 | def validate_properties(self, prop_names, bag=None, where=None): 97 | """ Validate the given list of property names against `self.supported_bags` 98 | 99 | :param prop_names: List of property names 100 | :param bag: A specific bag to use 101 | :raises InvalidColumnError 102 | """ 103 | # Bag to check against 104 | if bag is None: 105 | bag = self.supported_bags 106 | 107 | # Validate 108 | invalid = bag.get_invalid_names(prop_names) 109 | if invalid: 110 | raise InvalidColumnError(self.bags.model_name, 111 | invalid.pop(), 112 | where or self.query_object_section_name) 113 | 114 | def input_prepare_query_object(self, query_object): 115 | """ Modify the Query Object before it is processed. 116 | 117 | Sometimes a handler would need to alter it. 118 | Here's its chance. 119 | 120 | This method is called before any input(), or validation, or anything. 121 | 122 | :param query_object: dict 123 | """ 124 | return query_object 125 | 126 | def input(self, qo_value): 127 | """ Get a section of the Query object. 128 | 129 | The purpose of this method is to receive the input, validate it, and store as a public 130 | property so that external tools may export its value. 131 | Note that validation does not *have* to happen here: it may in fact be implemented in one 132 | of the compile_*() methods. 133 | 134 | :param qo_value: the value of the Query object field it's handling 135 | :param qo_value: Any 136 | 137 | :rtype: MongoQueryHandlerBase 138 | :raises InvalidRelationError 139 | :raises InvalidColumnError 140 | :raises InvalidQueryError 141 | """ 142 | self.input_value = qo_value # no copying. Try not to modify it. 143 | 144 | # Set the flag 145 | self.input_received = True 146 | 147 | # Make sure that input() can only be used once 148 | self.input = self.__raise_input_not_reusable 149 | 150 | return self 151 | 152 | def is_input_empty(self): 153 | """ Test whether the input value was empty """ 154 | return not self.input_value 155 | 156 | def __raise_input_not_reusable(self, *args, **kwargs): 157 | raise RuntimeError("You can't use the {}.input() method twice. " 158 | "Wrap the class into Reusable(), or copy() it!" 159 | .format(self.__class__.__name__)) 160 | 161 | # These methods implement the logic of individual handlers 162 | # Note that not all methods are going to be implemented by subclasses! 163 | 164 | def compile_columns(self): 165 | """ Compile a list of columns. 166 | 167 | Purpose: argument for Query(*) 168 | 169 | :rtype: list[sqlalchemy.sql.schema.Column] 170 | """ 171 | raise NotImplementedError() 172 | 173 | def compile_options(self, as_relation): 174 | """ Compile a list of options for a Query 175 | 176 | Purpose: argument for Query.options(*) 177 | 178 | :param as_relation: Load interface to chain the loader options from 179 | :type as_relation: sqlalchemy.orm.Load 180 | :return: list 181 | """ 182 | raise NotImplementedError() 183 | 184 | def compile_statement(self): 185 | """ Compile a statement 186 | 187 | :return: SQL statement 188 | """ 189 | raise NotImplementedError() 190 | 191 | def compile_statements(self): 192 | """ Compile a list of statements 193 | 194 | :return: list of SQL statements 195 | """ 196 | raise NotImplementedError() 197 | 198 | def alter_query(self, query, as_relation): 199 | """ Alter the given query and apply the Query Object section this handler is handling 200 | 201 | :param query: The query to apply this MongoSQL Query Object to 202 | :type query: Query 203 | :param as_relation: Load interface to work with nested relations. 204 | Note that some classed need it, others don't 205 | :type as_relation: Load 206 | :rtype: Query 207 | """ 208 | raise NotImplementedError() 209 | 210 | def get_final_input_value(self): 211 | """ Get the final input of the handler """ 212 | return self.input_value 213 | -------------------------------------------------------------------------------- /mongosql/handlers/count.py: -------------------------------------------------------------------------------- 1 | """ 2 | ### Count Operation 3 | Slicing corresponds to the `SELECT COUNT(*)` part of an SQL query. 4 | 5 | Simply, return the number of items, without returning the items themselves. Just a number. That's it. 6 | 7 | Example: 8 | 9 | ```javascript 10 | $.get('/api/user?query=' + JSON.stringify({ 11 | count: 1, 12 | })) 13 | ``` 14 | 15 | The `1` is the *on* switch. Replace it with `0` to stop counting. 16 | 17 | NOTE: In MongoSQL 2.0, there is a way to get both the list of items, *and* their count *simultaneously*. 18 | This would have way better performance than two separate queries. 19 | Please have a look: [CountingQuery](#countingqueryquery) and [MongoQuery.end_count()](#mongoqueryend_count---countingquery). 20 | """ 21 | 22 | from sqlalchemy import func 23 | from sqlalchemy import exc as sa_exc 24 | 25 | from .base import MongoQueryHandlerBase 26 | from ..exc import InvalidQueryError, InvalidColumnError, InvalidRelationError 27 | 28 | 29 | class MongoCount(MongoQueryHandlerBase): 30 | """ MongoDB count query 31 | 32 | Just give it: 33 | * count=True 34 | """ 35 | 36 | query_object_section_name = 'count' 37 | 38 | def __init__(self, model, bags): 39 | """ Init a count 40 | 41 | :param model: Sqlalchemy model to work with 42 | :param bags: Model bags 43 | """ 44 | super(MongoCount, self).__init__(model, bags) 45 | 46 | # On input 47 | self.count = None 48 | 49 | def input_prepare_query_object(self, query_object): 50 | # When we count, we don't care about certain things 51 | if query_object.get('count', False): 52 | # Performance: do not sort when counting 53 | query_object.pop('sort', None) 54 | # We don't care about projections either 55 | query_object.pop('project', None) 56 | # Also, remove all skips & limits 57 | query_object.pop('skip', None) 58 | query_object.pop('limit', None) 59 | # Remove all join, but not joinf (as it may filter) 60 | query_object.pop('join', None) 61 | # Finally, when we count, we have to remove `max_items` setting from MongoLimit. 62 | # Only MongoLimit can do it, and it will do it for us. 63 | # See: MongoLimit.input_prepare_query_object 64 | 65 | return query_object 66 | 67 | def input(self, count=None): 68 | super(MongoCount, self).input(count) 69 | if not isinstance(count, (int, bool, NoneType)): 70 | raise InvalidQueryError('Count must be either true or false. Or at least a 1, or a 0') 71 | 72 | # Done 73 | self.count = count 74 | return self 75 | 76 | def _get_supported_bags(self): 77 | return None # not used by this class 78 | 79 | # Not Implemented for this Query Object handler 80 | compile_columns = NotImplemented 81 | compile_options = NotImplemented 82 | compile_statement = NotImplemented 83 | compile_statements = NotImplemented 84 | 85 | def alter_query(self, query, as_relation=None): 86 | """ Apply offset() and limit() to the query """ 87 | if self.count: 88 | # Previously, we used to do counts like this: 89 | # >>> query = query.with_entities(func.count()) 90 | # However, when there's no WHERE clause set on a Query, it's left without any reference to the target table. 91 | # In this case, SqlAlchemy will actually generate a query without a FROM clause, which gives a wrong count! 92 | # Therefore, we have to make sure that there will always be a FROM clause. 93 | # 94 | # Normally, we just do the following: 95 | # >>> query = query.select_from(self.model) 96 | # This is supposed to indicate which table to select from. 97 | # However, it can only be applied when there's no FROM nor ORDER BY clauses present. 98 | # 99 | # But wait a second... didn't we just assume that there would be no FROM clause? 100 | # Have a look at this ugly duckling: 101 | # >>> Query(User).filter_by().select_from(User) 102 | # This filter_by() would actually create an EMPTY condition, which will break select_from()'s assertions! 103 | # This is reported to SqlAlchemy: 104 | # https://github.com/sqlalchemy/sqlalchemy/issues/4606 105 | # And (is fixed in version x.x.x | is not going to be fixed) 106 | # 107 | # Therefore, we'll try to do it the nice way ; and if it fails, we'll have to do something else. 108 | try: 109 | query = query.with_entities(func.count()).select_from(self.model) 110 | except sa_exc.InvalidRequestError: 111 | query = query.from_self(func.count()) 112 | 113 | return query 114 | 115 | 116 | NoneType = type(None) 117 | -------------------------------------------------------------------------------- /mongosql/handlers/group.py: -------------------------------------------------------------------------------- 1 | """ 2 | ### Group Operation 3 | Grouping corresponds to the `GROUP BY` part of an SQL query. 4 | 5 | By default, the [Aggregate Operation](#aggregate-operation) gives statistical results over all rows. 6 | 7 | For instance, if you've asked for `{ avg_age: { $avg: 'age' } }`, you'll get the average age of all users. 8 | 9 | Oftentimes this is not enough, and you'll want statistics calculated over groups of items. 10 | This is what the Group Operation does: specifies which field to use as the "group" indicator. 11 | 12 | Better start with a few examples. 13 | 14 | #### Example #1: calculate the number of users of every specific age. 15 | We use the `age` field as the group discriminator, and the total number of users is therefore calculated per group. 16 | The result would be: something like: 17 | 18 | age 18: 25 users 19 | age 19: 20 users 20 | age 21: 35 users 21 | ... 22 | 23 | The code: 24 | 25 | ```javascript 26 | $.get('/api/user?query=' + JSON.stringify({ 27 | // The statistics 28 | aggregate: { 29 | age: 'age', // Get the unadulterated column value 30 | count: { $sum: 1 }, // The count 31 | }, 32 | // The discriminator 33 | group: ['age'], // we do not discriminate by sex this time... :) 34 | })) 35 | ``` 36 | 37 | #### Example #2: calculate teh average salary per profession 38 | 39 | ```javascript 40 | $.get('/api/user?query=' + JSON.stringify({ 41 | prof: 'profession', 42 | salary: { '$avg': 'salary' } 43 | }, 44 | group: ['profession_id'], 45 | })) 46 | ``` 47 | 48 | #### Syntax 49 | The Group Operator, as you have seen, receives an array of column names. 50 | 51 | * Array syntax. 52 | 53 | List of column names, optionally suffixed by the sort direction: `-` for `DESC`, `+` for `ASC`. 54 | The default is `+`. 55 | 56 | Example: 57 | 58 | ```javascript 59 | { group: [ 'a+', 'b-', 'c' ] } // -> a ASC, b DESC, c DESC 60 | ``` 61 | 62 | * String syntax 63 | 64 | List of columns, with optional `+` / `-`, separated by whitespace. 65 | 66 | Example: 67 | 68 | ```javascript 69 | { group: 'a+ b- c' } 70 | ``` 71 | 72 | """ 73 | 74 | from .sort import MongoSort 75 | 76 | 77 | class MongoGroup(MongoSort): 78 | """ MongoDB-style grouping 79 | 80 | It has the same syntax as MongoSort, so we just reuse the code. 81 | 82 | See :cls:MongoSort 83 | """ 84 | 85 | query_object_section_name = 'group' 86 | 87 | def __init__(self, model, bags, legacy_fields=None): 88 | # Legacy fields 89 | self.legacy_fields = frozenset(legacy_fields or ()) 90 | 91 | # Parent 92 | super(MongoSort, self).__init__(model, bags) # yes, call the base; not the parent 93 | 94 | # On input 95 | #: OderedDict() of a group spec: {key: +1|-1} 96 | self.group_spec = None 97 | 98 | def input(self, group_spec): 99 | super(MongoSort, self).input(group_spec) # call base; not the parent 100 | self.group_spec = self._input(group_spec) 101 | return self 102 | 103 | def compile_columns(self): 104 | return [ 105 | self.supported_bags.get(name).desc() if d == -1 else self.supported_bags.get(name) 106 | for name, d in self.group_spec.items() 107 | ] 108 | 109 | # Not Implemented for this Query Object handler 110 | compile_options = NotImplemented 111 | compile_statement = NotImplemented 112 | compile_statements = NotImplemented 113 | 114 | def alter_query(self, query, as_relation=None): 115 | if not self.group_spec: 116 | return query # short-circuit 117 | 118 | return query.group_by(*self.compile_columns()) 119 | 120 | def get_final_input_value(self): 121 | return [f'{name}{"-" if d == -1 else ""}' 122 | for name, d in self.group_spec.items()] 123 | -------------------------------------------------------------------------------- /mongosql/handlers/joinf.py: -------------------------------------------------------------------------------- 1 | """ 2 | ### Filtering Join Operation 3 | The [Join Operation](#join-operation) has the following behavior: 4 | when you requested the loading of a relation, and there were no items found, an empty value is returned 5 | (a `null`, or an empty array). 6 | 7 | ```javascript 8 | // This one will return all users 9 | // (even those that have no articles) 10 | $.get('/api/user?query=' + JSON.stringify({ 11 | join: ["articles"] // Regular Join: `join` 12 | })) 13 | ``` 14 | 15 | This `joinf` Filtering Join operation does just the same thing that `join` does; 16 | however, if there were no related items, the primary one is also removed. 17 | 18 | ```javascript 19 | // This one will return *only those users that have articles* 20 | // (users with no articles will be excluded) 21 | $.get('/api/user?query=' + JSON.stringify({ 22 | joinf: ["articles"] // Filtering Join: `joinf` 23 | })) 24 | ``` 25 | 26 | This feature is, quite honestly, weird, and is only available for backward-compatibility with a bug that existed 27 | in some early MongoSQL versions. It has proven to be useful in some cases, so the bug has been given a name and a 28 | place within the MongoSQL library :) 29 | 30 | Note that `joinf`` does not support `skip` and `limit` 31 | on nested entities because of the way it's implemented with Postgres. 32 | """ 33 | 34 | from .join import MongoJoin 35 | 36 | class MongoFilteringJoin(MongoJoin): 37 | """ Joining relations: perform a real SQL JOIN to the related model, applying a filter to the 38 | whole result set (!) 39 | 40 | Note that this will distort the results of the original query: 41 | essentially, it will only return entities *having* at least one related entity with 42 | the given condition. 43 | 44 | This means that if you take an `Article`, make a 'joinf' to `Article.author`, 45 | and specify a filter with `age > 20`, 46 | you will get articles and their authors, 47 | but the articles *will be limited to only teenage authors*. 48 | """ 49 | 50 | query_object_section_name = 'joinf' 51 | 52 | def _choose_relationship_loading_strategy(self, mjp): 53 | if mjp.has_nested_query: 54 | # Quite intentionally, we will use a regular JOIN here. 55 | # It will remove rows that 1) have no related rows, and 2) do not match our filter conditions. 56 | # This is what the user wants when they use 'joinf' handler. 57 | return self.RELSTRATEGY_JOINF 58 | else: 59 | return self.RELSTRATEGY_EAGERLOAD 60 | 61 | # merge() is not implemented for joinf, because the results wouldn't be compatible 62 | merge = NotImplemented 63 | -------------------------------------------------------------------------------- /mongosql/handlers/limit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ### Slice Operation 3 | Slicing corresponds to the `LIMIT .. OFFSET ..` part of an SQL query. 4 | 5 | The Slice operation consists of two optional parts: 6 | 7 | * `limit` would limit the number of items returned by the API 8 | * `skip` would shift the "window" a number of items 9 | 10 | Together, these two elements implement pagination. 11 | 12 | Example: 13 | 14 | ```javascript 15 | $.get('/api/user?query=' + JSON.stringify({ 16 | limit: 100, // 100 items per page 17 | skip: 200, // skip 200 items, meaning, we're on the third page 18 | })) 19 | ``` 20 | 21 | Values: can be a number, or a `null`. 22 | """ 23 | 24 | from sqlalchemy import inspect 25 | from sqlalchemy.sql import func, literal_column 26 | 27 | from .base import MongoQueryHandlerBase 28 | from ..exc import InvalidQueryError, InvalidColumnError, InvalidRelationError 29 | 30 | 31 | class MongoLimit(MongoQueryHandlerBase): 32 | """ MongoDB limits and offsets 33 | 34 | Handles two keys: 35 | * 'limit': None, or int: LIMIT for the query 36 | * 'offset': None, or int: OFFSET for the query 37 | """ 38 | 39 | query_object_section_name = 'limit' 40 | 41 | def __init__(self, model, bags, max_items=None): 42 | """ Init a limit 43 | 44 | :param model: Sqlalchemy model to work with 45 | :param bags: Model bags 46 | :param max_items: The maximum number of items that can be loaded with this query. 47 | The user can never go any higher than that, and this value is forced onto every query. 48 | """ 49 | super(MongoLimit, self).__init__(model, bags) 50 | 51 | # Config 52 | self.max_items = max_items 53 | assert self.max_items is None or self.max_items > 0 54 | 55 | # On input 56 | self.skip = None 57 | self.limit = None 58 | 59 | # Internal 60 | # List of columns to group results with (in order to import a limit per group) 61 | self._window_over_columns = None 62 | 63 | def input_prepare_query_object(self, query_object): 64 | """ Alter Query Object 65 | 66 | Unlike other handlers, this one receives 2 values: 'skip' and 'limit'. 67 | MongoQuery only supports one key per handler. 68 | Solution: pack them as a tuple 69 | """ 70 | # (skip, limit) hack 71 | # LimitHandler is the only one that receives two arguments instead of one. 72 | # Collect them, and rename 73 | if 'skip' in query_object or 'limit' in query_object: 74 | query_object['limit'] = (query_object.pop('skip', None), 75 | query_object.pop('limit', None)) 76 | if query_object['limit'] == (None, None): 77 | query_object.pop('limit') # remove it if it's actually empty 78 | 79 | # When there is a 'count', we have to disable self.max_items 80 | # We can safely just alter ourselves, because we're a copy anyway 81 | if query_object.get('count', False): 82 | self.max_items = None 83 | 84 | return query_object 85 | 86 | def input(self, skip=None, limit=None): 87 | # MongoQuery actually gives us a tuple (skip, limit) 88 | # Adapt. 89 | if isinstance(skip, tuple): 90 | skip, limit = skip 91 | 92 | # Super 93 | super(MongoLimit, self).input((skip, limit)) 94 | 95 | # Validate 96 | if not isinstance(skip, (int, NoneType)): 97 | raise InvalidQueryError('Skip must be either an integer, or null') 98 | if not isinstance(limit, (int, NoneType)): 99 | raise InvalidQueryError('Limit must be either an integer, or null') 100 | 101 | # Clamp 102 | skip = None if skip is None or skip <= 0 else skip 103 | limit = None if limit is None or limit <= 0 else limit 104 | 105 | # Max limit 106 | if self.max_items: 107 | limit = min(self.max_items, limit or self.max_items) 108 | 109 | # Done 110 | self.skip = skip 111 | self.limit = limit 112 | return self 113 | 114 | def _get_supported_bags(self): 115 | return None # not used by this class 116 | 117 | # Not Implemented for this Query Object handler 118 | compile_columns = NotImplemented 119 | compile_options = NotImplemented 120 | compile_statement = NotImplemented 121 | compile_statements = NotImplemented 122 | 123 | @property 124 | def has_limit(self): 125 | """ Check thether there's a limit on this handler """ 126 | return self.limit is not None or self.skip is not None 127 | 128 | def limit_groups_over_columns(self, fk_columns): 129 | """ Instead of the usual limit, use a window function over the given columns. 130 | 131 | This method is used by MongoJoin when doing a custom selectinquery() to load a limited number of related 132 | items per every primary entity. 133 | 134 | Instead of using LIMIT, LimitHandler will group rows over `fk_columns`, and impose a limit per group. 135 | This is used to load related models with selectinquery(), where you can now put a limit per group: 136 | that is, a limit on the number of related entities per primary entity. 137 | 138 | This is achieved using a Window Function: 139 | 140 | SELECT *, row_number() OVER(PARTITION BY author_id) AS group_row_n 141 | FROM articles 142 | WHERE group_row_name < 10 143 | 144 | This will result in the following table: 145 | 146 | id | author_id | group_row_n 147 | ------------------------------------ 148 | 1 1 1 149 | 2 1 2 150 | 3 2 1 151 | 4 2 2 152 | 5 2 3 153 | 6 3 1 154 | 7 3 2 155 | 156 | That's what window functions do: they work like aggregate functions, but they don't group rows. 157 | 158 | :param fk_columns: List of foreign key columns to group with 159 | """ 160 | # Adaptation not needed, because this method is never used with aliases 161 | # pa_insp = inspect(self.model) 162 | # fk_columns = [col.adapt_to_entity(pa_insp) for col in fk_columns] 163 | assert not inspect(self.model).is_aliased_class, "Cannot be used with aliases; not implemented yet (because nobody needs it anyway!)" 164 | 165 | self._window_over_columns = fk_columns 166 | 167 | def alter_query(self, query, as_relation=None): 168 | """ Apply offset() and limit() to the query """ 169 | if not self._window_over_columns: 170 | # Use the regular skip/limit 171 | if self.skip: 172 | query = query.offset(self.skip) 173 | if self.limit: 174 | query = query.limit(self.limit) 175 | return query 176 | else: 177 | # Use a window function 178 | return self._limit_using_window_function(query) 179 | 180 | def _limit_using_window_function(self, query): 181 | """ Apply a limit using a window function 182 | 183 | This approach enables us to limit the number of eagerly loaded related entities 184 | """ 185 | # Only do it when there is a limit 186 | if self.skip or self.limit: 187 | # First, add a row counter: 188 | query = query.add_columns( 189 | # for every group, count the rows with row_number(). 190 | func.row_number().over( 191 | # Groups are partitioned by self._window_over_columns, 192 | partition_by=self._window_over_columns, 193 | # We have to apply the same ordering from the outside query; 194 | # otherwise, the numbering will be undetermined 195 | order_by=self.mongoquery.handler_sort.compile_columns() 196 | ).label('group_row_n') # give it a name that we can use later 197 | ) 198 | 199 | # Now, make ourselves into a subquery 200 | query = query.from_self() 201 | 202 | # Well, it turns out that subsequent joins somehow work. 203 | # I have no idea how, but they do. 204 | # Otherwise, we would have had to ban using 'joins' after 'limit' in nested queries. 205 | 206 | # And apply the LIMIT condition using row numbers 207 | # These two statements simulate skip/limit using window functions 208 | if self.skip: 209 | query = query.filter(literal_column('group_row_n') > self.skip) 210 | if self.limit: 211 | query = query.filter(literal_column('group_row_n') <= ((self.skip or 0) + self.limit)) 212 | 213 | # Done 214 | return query 215 | 216 | def get_final_input_value(self): 217 | return dict(skip=self.skip, limit=self.limit) 218 | 219 | NoneType = type(None) 220 | -------------------------------------------------------------------------------- /mongosql/handlers/sort.py: -------------------------------------------------------------------------------- 1 | """ 2 | ### Sort Operation 3 | 4 | Sorting corresponds to the `ORDER BY` part of an SQL query. 5 | 6 | The UI would normally require the records to be sorted by some field, or fields. 7 | 8 | The sort operation lets the API user specify the sorting of the results, 9 | which makes sense for API endpoints that return a list of items. 10 | 11 | An example of a sort operation would look like this: 12 | 13 | ```javascript 14 | $.get('/api/user?query=' + JSON.stringify({ 15 | // sort by age, descending; 16 | // then sort by first name, alphabetically 17 | sort: ['age-', 'first_name+'], 18 | })) 19 | ``` 20 | 21 | #### Syntax 22 | 23 | * Array syntax. 24 | 25 | List of column names, optionally suffixed by the sort direction: `-` for `DESC`, `+` for `ASC`. 26 | The default is `+`. 27 | 28 | Example: 29 | 30 | ```javascript 31 | { sort: [ 'a+', 'b-', 'c' ] } // -> a ASC, b DESC, c DESC 32 | ``` 33 | 34 | * String syntax 35 | 36 | List of columns, with optional `+` / `-`, separated by whitespace. 37 | 38 | Example: 39 | 40 | ```javascript 41 | { sort: 'a+ b- c' } 42 | ``` 43 | 44 | Object syntax is not supported because it does not preserve the ordering of keys. 45 | """ 46 | 47 | from collections import OrderedDict 48 | 49 | from .base import MongoQueryHandlerBase 50 | from ..bag import CombinedBag, FakeBag 51 | from ..exc import InvalidQueryError, InvalidColumnError, InvalidRelationError 52 | 53 | 54 | class MongoSort(MongoQueryHandlerBase): 55 | """ MongoDB sorting 56 | 57 | * None: no sorting 58 | * OrderedDict({ a: +1, b: -1 }) 59 | * [ 'a+', 'b-', 'c' ] - array of strings '[<+|->]'. default direction = +1 60 | * dict({a: +1}) -- you can only use a dict with ONE COLUMN (because of its unstable order) 61 | 62 | Supports: Columns, hybrid properties 63 | """ 64 | 65 | query_object_section_name = 'sort' 66 | 67 | def __init__(self, model, bags, legacy_fields=None): 68 | # Legacy fields 69 | self.legacy_fields = frozenset(legacy_fields or ()) 70 | 71 | # Parent 72 | super(MongoSort, self).__init__(model, bags) 73 | 74 | # On input 75 | #: OderedDict() of a sort spec: {key: +1|-1} 76 | self.sort_spec = None 77 | 78 | def _get_supported_bags(self): 79 | return CombinedBag( 80 | col=self.bags.columns, 81 | colp=self.bags.column_properties, 82 | hybrid=self.bags.hybrid_properties, 83 | assocproxy=self.bags.association_proxies, 84 | legacy=FakeBag({n: None for n in self.legacy_fields}), 85 | ) 86 | 87 | def _input(self, spec): 88 | """ Reusable method: fits both MongoSort and MongoGroup """ 89 | 90 | # Empty 91 | if not spec: 92 | spec = [] 93 | 94 | # String syntax 95 | if isinstance(spec, str): 96 | # Split by whitespace and convert to a list 97 | spec = spec.split() 98 | 99 | # List 100 | if isinstance(spec, (list, tuple)): 101 | # Strings: convert "column[+-]" into an ordered dict 102 | if all(isinstance(v, str) for v in spec): 103 | spec = OrderedDict([ 104 | [v[:-1], -1 if v[-1] == '-' else +1] 105 | if v[-1] in {'+', '-'} 106 | else [v, +1] 107 | for v in spec 108 | ]) 109 | 110 | # Dict 111 | if isinstance(spec, OrderedDict): 112 | pass # nothing to do here 113 | elif isinstance(spec, dict): 114 | if len(spec) > 1: 115 | raise InvalidQueryError('{} is a plain object; can only have 1 column ' 116 | 'because of unstable ordering of object keys; ' 117 | 'use list syntax instead' 118 | .format(self.query_object_section_name)) 119 | spec = OrderedDict(spec) 120 | else: 121 | raise InvalidQueryError('{name} must be either a list, a string, or an object; {type} provided.' 122 | .format(name=self.query_object_section_name, type=type(spec))) 123 | 124 | # Validate directions: +1 or -1 125 | if not all(dir in {-1, +1} for field, dir in spec.items()): 126 | raise InvalidQueryError('{} direction can be either +1 or -1'.format(self.query_object_section_name)) 127 | 128 | # Validate columns 129 | self.validate_properties(spec.keys()) 130 | return spec 131 | 132 | def input(self, sort_spec): 133 | super(MongoSort, self).input(sort_spec) 134 | self.sort_spec = self._input(sort_spec) 135 | return self 136 | 137 | def merge(self, sort_spec): 138 | self.sort_spec.update(self._input(sort_spec)) 139 | return self 140 | 141 | def compile_columns(self): 142 | return [ 143 | self.supported_bags.get(name).desc() if d == -1 else self.supported_bags.get(name) 144 | for name, d in self.sort_spec.items() 145 | if name not in self.supported_bags.bag('legacy') # remove fake items 146 | ] 147 | 148 | # Not Implemented for this Query Object handler 149 | compile_options = NotImplemented 150 | compile_statement = NotImplemented 151 | compile_statements = NotImplemented 152 | 153 | def alter_query(self, query, as_relation=None): 154 | if not self.sort_spec: 155 | return query # short-circuit 156 | return query.order_by(*self.compile_columns()) 157 | 158 | def get_final_input_value(self): 159 | return [f'{name}{"-" if d == -1 else ""}' 160 | for name, d in self.sort_spec.items()] 161 | 162 | # Extra stuff 163 | 164 | def undefer_columns_involved_in_sorting(self, as_relation): 165 | """ undefer() columns required for this sort """ 166 | # Get the names of the columns 167 | order_by_column_names = [c.key or c.element.key 168 | for c in self.compile_columns()] 169 | 170 | # Return options: undefer() every column 171 | return (as_relation.undefer(column_name) 172 | for column_name in order_by_column_names) 173 | -------------------------------------------------------------------------------- /mongosql/sa.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from typing import Union 3 | 4 | from sqlalchemy.orm import Session, Query 5 | 6 | from .query import MongoQuery 7 | 8 | 9 | class MongoSqlBase: 10 | """ Mixin for SqlAlchemy models that provides the .mongoquery() method for convenience """ 11 | 12 | # Override this method in your subclass in order to be able to configure MongoSql on a per-model basis! 13 | @classmethod 14 | def _init_mongoquery(cls, handler_settings: dict = None) -> MongoQuery: 15 | """ Get a reusable MongoQuery object. Is only invoked once. 16 | 17 | Override this method in order to initialize MongoQuery they way you need. 18 | For example, you might want to pass `handler_settings` dict to it. 19 | 20 | But for now, you can only use `mongoquery_configure()` on it. 21 | 22 | :rtype: MongoQuery 23 | """ 24 | 25 | # Idea: have an `_api` field in your models that will feed MongoQuery with the settings 26 | # Example: return MongoQuery(cls, handler_settings=cls._api) 27 | 28 | return MongoQuery(cls, handler_settings=handler_settings) 29 | 30 | __mongoquery_per_class_cache = {} 31 | 32 | @classmethod 33 | def _get_mongoquery(cls) -> MongoQuery: 34 | """ Get a Reusable MongoQuery for this model ; initialize it only once 35 | 36 | :rtype: MongoQuery 37 | """ 38 | try: 39 | # We want ever model class to have its own MongoQuery, 40 | # and we want no one to inherit it. 41 | # We could use model.__dict__ for this, but classes in Python 3 use an immutable `mappingproxy` instead. 42 | # Thus, we have to keep our own cache of ModelPropertyBags. 43 | mq = cls.__mongoquery_per_class_cache[cls] 44 | except KeyError: 45 | cls.__mongoquery_per_class_cache[cls] = mq = cls._init_mongoquery() 46 | 47 | # Return a copy 48 | return copy(mq) 49 | 50 | @classmethod 51 | def mongoquery_configure(cls, handler_settings: dict) -> MongoQuery: 52 | """ Initialize this models' MongoQuery settings and make it permanent. 53 | 54 | This method is just a shortcut to do configuration the lazy way. 55 | A better way would be to subclass MongoSqlBase and override the _init_mongoquery() method. 56 | See _init_mongoquery() for a suggestion on how to do this. 57 | 58 | :param handler_settings: a dict of settings. See MongoQuery 59 | """ 60 | # Initialize a configured MongoQuery 61 | mq = cls._init_mongoquery(handler_settings) 62 | 63 | # Put it in cache 64 | cls.__cached_mongoquery = mq 65 | 66 | # Done 67 | return mq 68 | 69 | @classmethod 70 | def mongoquery(cls, query_or_session: Union[Query, Session] = None) -> MongoQuery: 71 | """ Build a MongoQuery 72 | 73 | Note that when `None` is given, the resulting Query is not bound to any session! 74 | You'll have to bind it manually, after calling .end() 75 | 76 | :param query_or_session: Query to start with, or a session object to initiate the query with 77 | :type query_or_session: sqlalchemy.orm.Query | sqlalchemy.orm.Session | None 78 | :rtype: mongosql.MongoQuery 79 | """ 80 | if query_or_session is None: 81 | query = Query([cls]) 82 | elif isinstance(query_or_session, Session): 83 | query = query_or_session.query(cls) 84 | elif isinstance(query_or_session, Query): 85 | query = query_or_session 86 | else: 87 | raise ValueError('Argument must be Query or Session') 88 | 89 | return cls._get_mongoquery().from_query(query) 90 | -------------------------------------------------------------------------------- /mongosql/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .selectinquery import selectinquery 2 | from .counting_query_wrapper import CountingQuery 3 | from .reusable import Reusable 4 | from .mongoquery_settings_handler import MongoQuerySettingsHandler 5 | from .marker import Marker 6 | from .settings_dict import MongoQuerySettingsDict, StrictCrudHelperSettingsDict 7 | from .method_decorator import method_decorator, method_decorator_meta 8 | from .bulk import \ 9 | EntityDictWrapper, load_many_instance_dicts, \ 10 | model_primary_key_columns_and_names, entity_dict_has_primary_key 11 | 12 | try: 13 | import nplus1loader 14 | except ImportError: 15 | pass 16 | else: 17 | from nplus1loader import raiseload_col, raiseload_rel, raiseload_all 18 | -------------------------------------------------------------------------------- /mongosql/util/bulk.py: -------------------------------------------------------------------------------- 1 | """ Tools for working with bulk operations """ 2 | from typing import Iterable, List, Tuple, Union, Mapping, Sequence 3 | from collections import UserDict, UserList 4 | 5 | from sqlalchemy import inspect, Column, tuple_ as sql_tuple 6 | from sqlalchemy.ext.declarative import DeclarativeMeta 7 | from sqlalchemy.orm import Query 8 | from sqlalchemy.sql.elements import BinaryExpression 9 | 10 | 11 | NoneType = type(None) 12 | 13 | 14 | class EntityDictWrapper(UserDict, dict): # `dict` to allow isinstance() checks 15 | """ Entity dict wrapper with metadata 16 | 17 | When the user submits N objects to be saved, we need to handle extra information along with every dict. 18 | This is what this object is for: for caching primary keys and associating results with the input 19 | 20 | Attributes: 21 | Model: the model class 22 | has_primary_key: Whether the entity dict contains the complete primary key 23 | primary_key_tuple: The primary key tuple, if present 24 | ordinal_number: The ordinal number of this entity dict in the submitted data 25 | skip: Ignore this entity dict 26 | loaded_instance: The instance that was (possibly) loaded from the database (if the primary key was given) 27 | instance: The instance that was (possibly) saved as a result of this operation (unless an error has occurred) 28 | error: The exception that was (possibly) raised while processing this entity dict (if an error has occurred) 29 | """ 30 | Model: DeclarativeMeta 31 | has_primary_key: bool 32 | primary_key_tuple: Union[NoneType, Tuple] 33 | 34 | ordinal_number: int = None 35 | skip: bool = False 36 | 37 | loaded_instance: object = None 38 | instance: object = None 39 | error: BaseException = None 40 | 41 | def __init__(self, 42 | Model: DeclarativeMeta, 43 | entity_dict: dict, 44 | *, 45 | ordinal_number: int = None, 46 | pk_names: Sequence[str] = None): 47 | super().__init__(entity_dict) 48 | self.Model = Model 49 | self.ordinal_number = ordinal_number 50 | 51 | # Primary key names: use the provided list; get it ourselves if not provided 52 | if not pk_names: 53 | _, pk_names = model_primary_key_columns_and_names(Model) 54 | 55 | # The primary key tuple 56 | try: 57 | self.primary_key_tuple = tuple(entity_dict[pk_field] 58 | for pk_field in pk_names) 59 | self.has_primary_key = True 60 | # If any of the primary key fields has raised a KeyError, assume that no PK is defined 61 | except KeyError: 62 | self.has_primary_key = False 63 | self.primary_key_tuple = None 64 | 65 | @classmethod 66 | def from_entity_dicts(cls, 67 | Model: DeclarativeMeta, 68 | entity_dicts: Sequence[dict], 69 | *, 70 | preprocess: callable = None, 71 | pk_names: Sequence[str] = None) -> Sequence['EntityDictWrapper']: 72 | """ Given a list of entity dicts, create a list of EntityDictWrappers with ordinal numbers 73 | 74 | If any dicts are already wrapped with EntityDictWrapper, it's not re-wrapped; 75 | but be careful to maintain their ordinal numbers, or the client will have difficulties! 76 | 77 | Example: 78 | 79 | _, pk_names = model_primary_key_columns_and_names(Model) 80 | entity_dicts = EntityDictWrapper.from_entity_dicts(models.User, [ 81 | {'id': 1, 'login': 'kolypto'}, 82 | { 'login': 'vdmit11'}, 83 | ], pk_names=pk_names) 84 | """ 85 | # Prepare the list of primary key columns 86 | if not pk_names: 87 | _, pk_names = model_primary_key_columns_and_names(Model) 88 | 89 | # Generator: EntityDictWrappers with ordinal numbers 90 | return [entity_dict 91 | if isinstance(entity_dict, EntityDictWrapper) else 92 | cls(Model, entity_dict, ordinal_number=i, pk_names=pk_names) 93 | for i, entity_dict in enumerate(entity_dicts)] 94 | 95 | # Object states 96 | 97 | @property 98 | def is_new(self): 99 | """ The submitted object has no primary key and will therefore be created """ 100 | return not self.has_primary_key 101 | 102 | @property 103 | def is_found(self): 104 | """ The submitted object has a primary key and it was successfully loaded from the database """ 105 | return self.has_primary_key and self.loaded_instance is not None 106 | 107 | @property 108 | def is_not_found(self): 109 | """ The submitted object has a primary key but it was not found in the database """ 110 | return self.has_primary_key and self.loaded_instance is None 111 | 112 | 113 | # This function isn't really used by MongoSQL, but it's here because it's beautiful 114 | # MongoSQL uses load_many_instance_dicts() instead 115 | def filter_many_objects_by_list_of_primary_keys(Model: DeclarativeMeta, entity_dicts: Sequence[dict]) -> BinaryExpression: 116 | """ Build an expression to load many objects from the database by their primary keys 117 | 118 | This function uses SQL tuples to build an expression which looks like this: 119 | 120 | SELECT * FROM users WHERE (uid, login) IN ((1, 'vdmit11'), (2, 'kolypto')); 121 | 122 | Example: 123 | 124 | entity_dicts = [ 125 | {'id': 1, ...}, 126 | {'id': 2, ...}, 127 | ... 128 | ] 129 | ssn.query(models.User).filter( 130 | filter_many_objects_by_list_of_primary_keys(models.User, entity_dicts) 131 | ) 132 | 133 | Args: 134 | Model: the model to query 135 | entity_dicts: list of entity dicts to pluck the PK values from 136 | 137 | Returns: 138 | The condition for filter() 139 | 140 | Raises: 141 | KeyError: one of `entity_dicts` did not contain a full primary key set of fields 142 | """ 143 | pk_columns, pk_names = model_primary_key_columns_and_names(Model) 144 | 145 | # Build the condition: (primary-key-tuple) IN (....) 146 | # It uses sql tuples and the IN operator: (pk_col_a, pk_col_b, ...) IN ((val1, val2, ...), (val3, val4, ...), ...) 147 | # Thanks @vdmit11 for this beautiful approach! 148 | return sql_tuple(*pk_columns).in_( 149 | # Every object is represented by its primary key tuple 150 | tuple(entity_dict[pk_field] for pk_field in pk_names) 151 | for entity_dict in entity_dicts 152 | ) 153 | 154 | 155 | def load_many_instance_dicts(query: Query, pk_columns: Sequence[Column], entity_dicts: Sequence[EntityDictWrapper]) -> Sequence[EntityDictWrapper]: 156 | """ Given a list of wrapped entity dicts submitted by the client, load some of them from the database 157 | 158 | As the client submits a list of entity dicts, some of them may contain the primary key. 159 | This function loads them from the database with one query and returns a list of EntityDictWrapper objects. 160 | 161 | Note that there will be three kinds of EntityDictWrapper objects: is_new, is_found, is_not_found: 162 | 163 | 1. New: entity dicts without a primary key 164 | 2. Found: entity dicts with a primary key that were also found in the database 165 | 3. Not found: entity dicts with a primary key that were not found in the database 166 | 167 | NOTE: no errors are raised for instances that were not found by their primary key! 168 | 169 | Args: 170 | query: The query to load the instances with 171 | pk_columns: The list of primary key columns for the target model. 172 | Use model_primary_key_columns_and_names() 173 | entity_dicts: The list of entity dicts submitted by the user 174 | """ 175 | # Load all instances by their primary keys at once 176 | # It uses sql tuples and the IN operator: (pk_col_a, pk_col_b, ...) IN ((val1, val2, ...), (val3, val4, ...), ...) 177 | # Thanks @vdmit11 for this beautiful approach! 178 | instances = query.filter(sql_tuple(*pk_columns).in_( 179 | # Search by PK tuples 180 | entity_dict.primary_key_tuple 181 | for entity_dict in entity_dicts 182 | if entity_dict.has_primary_key 183 | )) 184 | 185 | # Prepare a PK lookup object: we want to look up entity dicts by primary key tuples 186 | entity_dict_lookup_by_pk: Mapping[Tuple, EntityDictWrapper] = { 187 | entity_dict.primary_key_tuple: entity_dict 188 | for entity_dict in entity_dicts 189 | if entity_dict.has_primary_key 190 | } 191 | 192 | # Match instances with entity dicts 193 | for instance in instances: 194 | # Lookup an entity dict by its primary key tuple 195 | # We safely expect it to be there because objects were loaded by those primary keys in the first place :) 196 | entity_dict = entity_dict_lookup_by_pk[inspect(instance).identity] 197 | # Associate the instance with it 198 | entity_dict.loaded_instance = instance 199 | 200 | # Done 201 | return entity_dicts 202 | 203 | 204 | def model_primary_key_columns_and_names(Model: DeclarativeMeta) -> (Sequence[Column], List[str]): 205 | """ Get the list of primary columns and their names as two separate tuples 206 | 207 | Example: 208 | 209 | pk_columns, pk_names = model_primary_key_columns_and_names(models.User) 210 | pk_columns # -> (models.User.id, ) 211 | pk_names # -> ('id', ) 212 | """ 213 | pk_columns: Sequence[Column] = inspect(Model).primary_key 214 | pk_names: List[str] = [col.key for col in pk_columns] 215 | return pk_columns, pk_names 216 | 217 | def entity_dict_has_primary_key(pk_names: Sequence[str], entity_dict: dict) -> bool: 218 | """ Check whether the given dict contains all primary key fields """ 219 | return set(pk_names) <= set(entity_dict) 220 | -------------------------------------------------------------------------------- /mongosql/util/counting_query_wrapper.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from sqlalchemy import func 4 | from sqlalchemy.orm import Query, Session 5 | 6 | 7 | class CountingQuery: 8 | """ `Query` object wrapper that can count the rows while returning results 9 | 10 | This is achieved by SELECTing like this: 11 | 12 | SELECT *, count(*) OVER() AS full_count 13 | 14 | In order to be transparent, this class eliminates all those tuples in results and still returns objects 15 | like a normal query would. The total count is available through a property. 16 | 17 | Example: 18 | 19 | ```python 20 | qc = CountingQuery(ssn.query(...)) 21 | 22 | # Get the count 23 | qc.count # -> 127 24 | 25 | # Get the results 26 | list(qc) 27 | 28 | # (!) only one SQL query was made 29 | ``` 30 | """ 31 | __slots__ = ('_query', '_original_query', 32 | '_count', '_query_iterator', 33 | '_single_entity', '_row_fixer') 34 | 35 | def __init__(self, query: Query): 36 | # The original query. We store it just in case. 37 | self._original_query = query 38 | 39 | # The current query 40 | # It differs from the originla query in that it is modified with a window function counting the rows 41 | self._query = query 42 | 43 | # The iterator for query results ; `None` if the query has not yet been executed 44 | # If the query has been executed, there is always an iterator available, even if there were no results 45 | self._query_iterator = None 46 | 47 | # The total count ; `None` if the query has not yet been executed 48 | self._count = None 49 | 50 | # Whether the query is going to return single entities 51 | self._single_entity = ( # copied from sqlalchemy.orm.loading.instances 52 | not getattr(query, '_only_return_tuples', False) # accessing protected properties 53 | and len(query._entities) == 1 54 | and query._entities[0].supports_single_entity 55 | ) 56 | 57 | # The method that will fix result rows 58 | self._row_fixer = self._fix_result_tuple__single_entity if self._single_entity else self._fix_result_tuple__tuple 59 | 60 | def with_session(self, ssn: Session): 61 | """ Return a `Query` that will use the given `Session`. """ 62 | self._query = self._query.with_session(ssn) 63 | return self 64 | 65 | @property 66 | def count(self): 67 | """ Get the total count 68 | 69 | If the query has not been executed yet, it will be at this point. 70 | If there are no rows, it will make an additional query to make sure the result is available. 71 | """ 72 | # Execute the query and get the count 73 | if self._count is None: 74 | self._get_query_count() 75 | 76 | # Done 77 | return self._count 78 | 79 | def __iter__(self): 80 | """ Get Query results """ 81 | # Make sure the Query is executed 82 | if self._query_iterator is None: 83 | self._query_execute() 84 | 85 | # Iterate 86 | return self._query_iterator 87 | 88 | # region Counting logic 89 | 90 | def _get_query_count(self): 91 | """ Retrieve the first row and get the count. 92 | If that fails due to an OFFSET being present in the query, make an additional, COUNT query. 93 | """ 94 | # Make a new query 95 | self._query = self._query.add_columns( 96 | func.count().over() # this window function will count all rows 97 | ) 98 | 99 | # Execute it 100 | qi = iter(self._query) 101 | 102 | # Attempt to retrieve the first row 103 | try: 104 | first_row = next(qi) 105 | except StopIteration: 106 | # No rows in the result. 107 | 108 | # Prepare the iterator anyway 109 | self._query_iterator = iter(()) # empty iterator 110 | 111 | # If there was an OFFSET in the query, we may have failed because of it. 112 | if not self._query_has_offset(): 113 | # If there was no offset, then the count is simply zero 114 | self._count = 0 115 | else: 116 | # A separate COUNT() query will do better than us 117 | self._count = self._get_query_count__make_another_query() 118 | 119 | # Done here 120 | return 121 | 122 | # Alright, there are some results 123 | 124 | # Get the count 125 | self._count = self._get_count_from_result_tuple(first_row) 126 | 127 | # Build an iterator that will yield normal result rows 128 | self._query_iterator = map( 129 | # The callback that will drop the extra count column 130 | self._row_fixer, 131 | itertools.chain( 132 | # Prepend the first row we're taken off 133 | [first_row], 134 | # Add the rest of the results 135 | qi 136 | ) 137 | ) 138 | 139 | _query_execute = _get_query_count # makes more sense when called this way in the context of __iter__ method 140 | 141 | def _get_query_count__make_another_query(self) -> int: 142 | """ Make an additional query to count the number of rows """ 143 | # Build the query 144 | q = self._original_query 145 | 146 | # Remove eager loads 147 | q = q.enable_eagerloads(False) 148 | 149 | # Remove LIMIT and OFFSET 150 | q = q.limit(None).offset(None) 151 | 152 | # Exec 153 | return q.count() 154 | 155 | def _query_has_offset(self) -> bool: 156 | """ Tell if the query has an OFFSET clause 157 | 158 | The issue is that with an OFFSET large enough, our window function won't have any rows to return its 159 | result with. Therefore, we'd be forced to make an additional query. 160 | """ 161 | return self._query._offset is not None # accessing protected property 162 | 163 | # endregion 164 | 165 | # region Result tuple processing 166 | 167 | @staticmethod 168 | def _get_count_from_result_tuple(row): 169 | """ Get the count from the result row """ 170 | return row[-1] 171 | 172 | @staticmethod 173 | def _fix_result_tuple__single_entity(row): 174 | """ Fix the result tuple: get the first Entity only """ 175 | return row[0] 176 | 177 | @staticmethod 178 | def _fix_result_tuple__tuple(row): 179 | """ Fix the result tuple: drop the last item """ 180 | return row[:-1] 181 | 182 | # endregion 183 | 184 | -------------------------------------------------------------------------------- /mongosql/util/history_proxy.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from sqlalchemy import inspect 3 | from sqlalchemy.orm.base import DEFAULT_STATE_ATTR 4 | from sqlalchemy.orm.state import InstanceState 5 | 6 | from mongosql.bag import ModelPropertyBags 7 | 8 | 9 | class ModelHistoryProxy: 10 | """ Proxy object to gain access to historical model attributes. 11 | 12 | This leverages SqlAlchemy attribute history to provide access to the previous value of an 13 | attribute. The only reason why this object exists is because keeping two instances in memory may 14 | be expensive. But because normally you'll only need a field or two, the decision was to use 15 | this magic proxy object that will load model history on demand. 16 | 17 | Why would you need to access model history at all? 18 | Because CrudHelper's update method (i.e., changing model fields) gives you two objects: the 19 | current instance, and the old instance, so that your custom code in the update handler can 20 | compare those fields. 21 | For instance, when a certain object is being moved from one User to another, you might want 22 | to notify both of them. In that case, you'll need access to the historical user. 23 | 24 | The initial solution was to *copy* the instance, apply the modifications from JSON to a copy, 25 | and then feed both of them to the save handler... but copying was expensive. 26 | That's why we have this proxy: it does not load all the fields of the historical model, 27 | but acts as a proxy object (__getattr__()) that will get those properties on demand. 28 | """ 29 | 30 | def __init__(self, instance): 31 | # Save the information that we'll definitely need 32 | self.__instance = instance 33 | self.__model = self.__instance.__class__ 34 | self.__bags = ModelPropertyBags.for_model(self.__model) # type: ModelPropertyBags 35 | self.__inspect = inspect(instance) # type: InstanceState 36 | 37 | # Copy every field onto ourselves 38 | self.__copy_from_instance(self.__instance) 39 | 40 | # Enable accessing relationships through our proxy 41 | self.__install_instance_state(instance) 42 | 43 | def __copy_from_instance(self, instance): 44 | """ Copy all attributes of `instance` to `self` 45 | 46 | Alright, this code renders the whole point of having ModelHistoryProxy void. 47 | There is an issue with model history: 48 | 49 | "Each time the Session is flushed, the history of each attribute is reset to empty. 50 | The Session by default autoflushes each time a Query is invoked" 51 | https://docs.sqlalchemy.org/en/latest/orm/internals.html#sqlalchemy.orm.state.AttributeState.history 52 | 53 | This means that as soon as you load a relationship, model history is reset. 54 | To solve this, we have to make a copy of this model. 55 | All attributes are set on `self`, so accessing `self.attr` will not trigger `__getattr__()` 56 | """ 57 | """ Copy the given list of columns from the instance onto self """ 58 | insp = self.__inspect # type: InstanceState 59 | 60 | # Copy all values onto `self` 61 | for column_name in self.__bags.columns.names: 62 | # Skip unloaded columns (because that would emit sql queries) 63 | # Also skip the columns that were already copied (perhaps, mutable columns?) 64 | if column_name not in insp.unloaded: 65 | # The state 66 | attr_state = insp.attrs[column_name] # type: AttributeState 67 | 68 | # Get the historical value 69 | # deepcopy() ensures JSON and ARRAY values are copied in full 70 | hist_val = deepcopy(_get_historical_value(attr_state)) 71 | 72 | # Remove the value onto `self`: we're bearing the value now 73 | setattr(self, column_name, hist_val) 74 | 75 | def __install_instance_state(self, instance): 76 | """ Install an InstanceState, so that relationship descriptors can work properly """ 77 | # These lines install the internal SqlAlchemy's property on our proxy 78 | # This property mimics the original object. 79 | # This ensures that we can access relationship attributes through a ModelHistoryProxy object 80 | # Example: 81 | # hist = ModelHistoryProxy(comment) 82 | # hist.user.id # wow! 83 | instance_state = getattr(instance, DEFAULT_STATE_ATTR) 84 | my_state = InstanceState(self, instance_state.manager) 85 | my_state.key = instance_state.key 86 | my_state.session_id = instance_state.session_id 87 | setattr(self, DEFAULT_STATE_ATTR, my_state) 88 | 89 | def __getattr__(self, key): 90 | # Get a relationship: 91 | if key in self.__bags.relations: 92 | relationship = getattr(self.__model, key) 93 | return relationship.__get__(self, self.__model) 94 | 95 | # Get a property (@property) 96 | if key in self.__bags.properties: 97 | # Because properties may use other columns, 98 | # we have to run it against our`self`, because only then it'll be able to get the original values. 99 | return getattr(self.__model, key).fget(self) 100 | 101 | # Every column attribute is accessed through history 102 | attr = self.__inspect.attrs[key] 103 | return _get_historical_value(attr) 104 | 105 | 106 | def _get_historical_value(attr): 107 | """ Get the previous value of an attribute 108 | 109 | This is where the magic happens: this method goes into the SqlAlchemy instance and 110 | obtains the historical value of an attribute called `key` 111 | """ 112 | # Examine attribute history 113 | # If a value was deleted (e.g. replaced) -- we return it as the previous version. 114 | history = attr.history 115 | if not history.deleted: 116 | # No previous value, return the current value instead 117 | return attr.value 118 | else: 119 | # Return the previous value 120 | # It's a tuple, since History supports collections, but we do not support these, 121 | # so just get the first element 122 | return history.deleted[0] 123 | -------------------------------------------------------------------------------- /mongosql/util/inspect.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from functools import lru_cache 3 | from typing import Callable, Mapping, Tuple 4 | 5 | 6 | @lru_cache(100) 7 | def get_function_defaults(for_func: Callable) -> dict: 8 | """ Get a dict of function's arguments that have default values """ 9 | # Analyze the method 10 | argspec = inspect.getfullargspec(for_func) # TODO: use signature(): Python 3.3 11 | 12 | # Get the names of the kwargs 13 | # Only process those that have defaults 14 | n_args = len(argspec.args) - len(argspec.defaults or ()) # Args without defaults 15 | kwargs_names = argspec.args[n_args:] 16 | 17 | # Get defaults for kwargs: put together argument names + default values 18 | defaults = dict(zip(kwargs_names, argspec.defaults or ())) 19 | 20 | # Done 21 | return defaults 22 | 23 | 24 | def pluck_kwargs_from(dct: Mapping, for_func: Callable, skip: Tuple[str] = ()) -> dict: 25 | """ Analyze a function, pluck the arguments it needs from a dict """ 26 | defaults = get_function_defaults(for_func) 27 | 28 | # Get the values for these kwargs 29 | kwargs = {k: dct.get(k, defaults[k]) 30 | for k in defaults.keys() 31 | if k not in skip} 32 | 33 | # Done 34 | return kwargs 35 | -------------------------------------------------------------------------------- /mongosql/util/marker.py: -------------------------------------------------------------------------------- 1 | class Marker: 2 | """ An object that can transparently wrap a dict key 3 | 4 | Example: 5 | 6 | d = { Marker('key'): value } 7 | 8 | # You can still use the original key: 9 | d['key'] # -> value 10 | 'key' in d # -> True 11 | 12 | # At the same time, your marker key will pass an explicit isinstance() check: 13 | key, value = d.popitem() 14 | key == 'key' # -> True 15 | isinstance(key, Marker) # -> True 16 | 17 | This enables you to easily define custom markers, and, for instance, 18 | keep track of where do dictionary keys originate from! 19 | """ 20 | 21 | __slots__ = ('key',) 22 | 23 | @classmethod 24 | def unwrap(cls, value): 25 | """ Unwrap the value if it's wrapped with a marker """ 26 | return value.key if isinstance(value, cls) else value 27 | 28 | def __init__(self, key): 29 | # Store the original key 30 | self.key = key 31 | 32 | def __str__(self): 33 | return str(self.key) 34 | 35 | def __repr__(self): 36 | return '{}({})'.format(self.__class__.__name__, repr(self.key)) 37 | 38 | # region Marker is a Proxy 39 | 40 | # All these methods have to be implemented in order to mimic the behavior of a dict key 41 | 42 | def __hash__(self): 43 | # Dict keys rely on hashes 44 | # We're ought to have the same hash with the underlying value 45 | return hash(self.key) 46 | 47 | def __eq__(self, other): 48 | # Marker equality comparison: 49 | # key == key | key == Marker.key 50 | return self.key == (other.key if isinstance(other, Marker) else other) 51 | 52 | def __bool__(self): 53 | # Marker truth check: 54 | # `if include:` would always be true otherwise 55 | return bool(self.key) 56 | 57 | def __instancecheck__(self, instance): 58 | # isinstance() will react on both the Marker's type and the value's type 59 | return isinstance(instance, type(self)) or isinstance(instance, type(self.key)) 60 | 61 | # endregion 62 | -------------------------------------------------------------------------------- /mongosql/util/method_decorator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from functools import partial, lru_cache, update_wrapper 3 | 4 | 5 | class method_decorator_meta(type): 6 | def __instancecheck__(self, method): 7 | """ Metaclass magic enables isinstance() checks even for decorators wrapped with other decorators """ 8 | # Recursion stopper 9 | if method is None: 10 | return False 11 | # Check the type: isinstance() on self, or on the wrapped object 12 | # Have to implement isinstance() manually. 13 | # We use issubclass() to enable it detecting a generic method_decorator: isinstance(something, method_decorator) 14 | return issubclass(type(method), self) \ 15 | or isinstance(getattr(method, '__wrapped__', None), self) 16 | 17 | 18 | class method_decorator(metaclass=method_decorator_meta): 19 | """ A decorator that marks a method, receives arguments, adds metadata, and provides custom group behavior 20 | 21 | Sometimes in Python there's a need to mark some methods of a class and then use them for some sort 22 | of special processing. 23 | 24 | The important goals here are: 25 | 1) to be able to mark methods, 26 | 2) to be able to execute them transparently, 27 | 3) to be able to collect them and get their names, 28 | 4) to be able to store metadata on them (by receiving arguments) 29 | 30 | I've found out that a good solution would be to implement a class decorator, 31 | which is also a descriptor. This way, we'll have an attribute that lets you transparently use the method, 32 | but also knows some metadata about it. 33 | 34 | This wrapper can also contain some business-logic related to this decorator, 35 | which lets us keep all the relevant logic in one place. 36 | """ 37 | 38 | # The name of the property to install onto every wrapped method 39 | # Please override, or set `None` if this behavior is undesired 40 | METHOD_PROPERTY_NAME = 'method_decorator' 41 | 42 | def __init__(self): # override me to receive arguments 43 | # Handler method 44 | self.method = None 45 | # Handler method function name 46 | self.method_name = None 47 | 48 | def __call__(self, handler_method): 49 | # Make sure the object itself is callable only once 50 | if self.method is not None: 51 | raise RuntimeError("@{decorator}, when used, is not itself callable".format(decorator=self.__class__.__name__)) 52 | 53 | # The handler method to use for saving the field's data 54 | self.method = handler_method 55 | self.method_name = handler_method.__name__ 56 | 57 | # Store ourselves as a property of the wrapped function :) 58 | if self.METHOD_PROPERTY_NAME: 59 | setattr(self.method, self.METHOD_PROPERTY_NAME, self) 60 | 61 | # Use the proper update_wrapper() for we are a decorator 62 | update_wrapper(self, self.method) 63 | 64 | # Done 65 | return self # This is what is saved on the class' __dict__ 66 | 67 | def __get__(self, instance, owner): 68 | """ Magic descriptor: return the wrapped method when accessed """ 69 | # This descriptor magic makes the decorated method accessible directly, even though it's wrapped. 70 | # This is how it works: 71 | # whenever a method is wrapped with @saves_relations, there is this decorator class standing in the object's 72 | # dict instead of the method. The decorator is not callable anymore. 73 | # However, because it's a descriptor (has the __get__ method), when you access this method 74 | # (by using class.method or object.method), it will hide itself and give you the wrapped method instead. 75 | 76 | # We, however, will have to pass the `self` argument manually, because this descriptor magic 77 | # breaks python's passing of `self` to the method 78 | if instance is None: 79 | # Accessing a class attribute directly 80 | # We return the decorator object. It's callable. 81 | return self 82 | 83 | # Old behavior: 84 | # # Accessing a class attribute directly 85 | # # We return the method function, so that subclasses can actually call invoke it unwrapped. 86 | # return self.method # got from the class 87 | else: 88 | # Accessing an object's attribute 89 | # We prepare for calling the method. 90 | return partial(self.method, instance) # pass the `self` 91 | 92 | def __repr__(self): 93 | return '@{decorator}({func})'.format(decorator=self.__class__.__name__, func=self.method_name) 94 | 95 | # region: Usage API 96 | 97 | @classmethod 98 | def is_decorated(cls, method) -> bool: 99 | """ Check whether the given method is decorated with @cls() 100 | 101 | It also supports detecting methods wrapped with multiple decorators, one of them being @cls. 102 | Note that it works only when update_wrapper() was properly used. 103 | """ 104 | return isinstance(method, cls) 105 | 106 | @classmethod 107 | def get_method_decorator(cls, Klass: type, name: str) -> 'method_decorator': 108 | """ Get the decorator object, stored as `METHOD_PROPERTY_NAME` on the wrapped method """ 109 | return getattr(getattr(Klass, name), cls.METHOD_PROPERTY_NAME) 110 | 111 | @classmethod 112 | @lru_cache(256) # can't be too many views out there.. :) 113 | def all_decorators_from(cls, Klass: type) -> Iterable['method_decorator']: 114 | """ Get all decorator objects from a class (cached) 115 | 116 | Note that it won't collect any inherited handler methods: 117 | only those declared directly on this class. 118 | """ 119 | if not isinstance(Klass, type): 120 | raise ValueError('Can only collect decorators from a class, not from an object {}'.format(Klass)) 121 | 122 | return tuple( 123 | attr 124 | for attr in Klass.__dict__.values() 125 | if cls.is_decorated(attr)) 126 | 127 | # endregion 128 | -------------------------------------------------------------------------------- /mongosql/util/mongoquery_settings_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from sqlalchemy.ext.declarative import DeclarativeMeta 4 | 5 | import mongosql 6 | from mongosql.util.inspect import pluck_kwargs_from 7 | from ..exc import DisabledError 8 | 9 | 10 | class MongoQuerySettingsHandler: 11 | """ Settings keeper for MongoQuery 12 | 13 | This is essentially a helper which will feed the correct kwargs to every class. 14 | 15 | MongoSql handlers receive settings as kwargs to their __init__() methods, 16 | and those kwargs have unique names. 17 | 18 | This class will collect all settings as a single, flat array, 19 | and give each handler only the settings it wants. 20 | 21 | This approach will let us use a flat configuration dict. 22 | In addition, because some handlers have matching settings (e.g. join and joinf), 23 | both of those will receive them! 24 | """ 25 | 26 | def __init__(self, settings: dict): 27 | """ Store the settings for every handler 28 | 29 | :param settings: dict of handler kwargs 30 | """ 31 | assert isinstance(settings, dict) 32 | 33 | #: Settings dict 34 | self._settings = settings # we don't make a copy, because we don't modify it 35 | 36 | #: Handler names 37 | self._handler_names = set() 38 | 39 | #: kwarg names for every handler: dict[handler] = set() 40 | self._handler_kwargs_names = {} 41 | 42 | #: all kwargs names (to identify invalid ones) 43 | self._all_known_kwargs_names = set() 44 | 45 | #: disabled handler names 46 | self._disabled_handlers = set() 47 | 48 | #: Nested MongoQuery settings (for relations) 49 | self._nested_relation_settings = call_if_callable(self._settings.get('related', None)) or {} 50 | 51 | #: Nested MongoQuery settings (for related models) 52 | self._nested_model_settings = call_if_callable(self._settings.get('related_models', None))or {} 53 | 54 | def validate_related_settings(self, bags: mongosql.ModelPropertyBags): 55 | """ Validate the settings for related entities. 56 | 57 | This method only validates the keys for "related" and "related_models". 58 | 59 | :raises KeyError: Invalid keys 60 | """ 61 | # Validate "related": all keys must be relationship names 62 | invalid_keys = set(self._nested_relation_settings.keys()) - bags.relations.names - {'*'} 63 | if invalid_keys: 64 | raise KeyError('Invalid relationship name provided to "related": {!r}' 65 | .format(list(invalid_keys))) 66 | 67 | # Validated "related_models": all keys must be models, not names 68 | invalid_keys = set(v 69 | for v in self._nested_model_settings.keys() 70 | if not isinstance(v, DeclarativeMeta)) 71 | invalid_keys -= {'*'} 72 | if invalid_keys: 73 | raise KeyError('Invalid related model object provided to "related_models": {!r}' 74 | .format(list(invalid_keys))) 75 | 76 | def get_settings(self, handler_name: str, handler_cls: type) -> dict: 77 | """ Get settings for the given handler 78 | 79 | Because we do not know in advance how many handlers we will have, what their names will be, 80 | and what classes implement them, we have to handle them one by one. 81 | 82 | Every time a class is given us, we analyze its __init__() method in order to know its kwargs and its default values. 83 | Then, we take the matching keys from the settings dict, we take defaults from the argument defaults, 84 | and make it all into `kwargs` that will be given to the class. 85 | 86 | In addition to that, if the settings contain `handler_name=False`, then it means it's disabled. 87 | is_handler_enabled() method will later tell that to MongoQuery. 88 | """ 89 | # Now we know the handler name 90 | # See if it's actually disabled 91 | if not self._settings.get('{}_enabled'.format(handler_name), True): 92 | self._disabled_handlers.add(handler_name) 93 | 94 | # Analyze a function, pluck the arguments that it needs 95 | kwargs = pluck_kwargs_from(self._settings, for_func=handler_cls.__init__) 96 | kwargs_names = kwargs.keys() # always all of them 97 | 98 | # Store the data that we'll need 99 | self._handler_kwargs_names[handler_name] = set(kwargs_names) 100 | self._handler_names.add(handler_name) 101 | self._all_known_kwargs_names.update(kwargs_names) 102 | 103 | # Done 104 | return kwargs # for the handler's __init__() 105 | 106 | def is_handler_enabled(self, handler_name: str) -> bool: 107 | """ Test if the handler is enabled in the configuration """ 108 | return handler_name not in self._disabled_handlers 109 | 110 | def raise_if_not_handler_enabled(self, model_name: str, handler_name: str): 111 | """ Raise an error if the handler is not enabled """ 112 | if not self.is_handler_enabled(handler_name): 113 | raise DisabledError('Query handler "{}" is disabled for "{}"' 114 | .format(handler_name, model_name)) 115 | 116 | def raise_if_invalid_handler_settings(self, mongoquery: 'mongosql.MongoQuery'): 117 | """ Check whether there were any typos in setting names 118 | 119 | After all handlers were initialized, we've had a chance to analyze all their keyword arguments. 120 | Now, we have the information about them, and we can check whether every kwarg was actually used. 121 | If not, there must be a typo. 122 | 123 | :raises: KeyError: Invalid settings provided 124 | """ 125 | # Known keys 126 | handler_names = set('{}_enabled'.format(handler_name) 127 | for handler_name in self._handler_names) 128 | valid_kwargs = set(self._all_known_kwargs_names) 129 | other_known_keys = {'related', 'related_models'} 130 | 131 | # Merge all known keys into one 132 | all_known_keys = handler_names | valid_kwargs | other_known_keys 133 | 134 | # Provided keys 135 | provided_keys = set(self._settings.keys()) 136 | 137 | # Result: unknown keys 138 | invalid_keys = provided_keys - all_known_keys 139 | 140 | # Raise? 141 | if invalid_keys: 142 | raise KeyError('Invalid settings were provided for MongoQuery {!r}: {}' 143 | .format(mongoquery, ','.join(invalid_keys))) 144 | 145 | def _get_nested_settings_from_store_attr(self, store: dict, key: str, star_lambda_args) -> Union[dict, None]: 146 | """ Get settings from `store`, which is "related" or "related_models" 147 | 148 | handler_settings may be stored in two dict keys: 149 | * `related` is keyed by relation_name 150 | * `related_models` is keyed by target_model 151 | * Both map the key either a dict, or a lambda: dict | None, 152 | * Both have the default catch-all '*' 153 | * Both keep looking when a `None` is discovered 154 | 155 | Because of these similarities, this method handles them both. 156 | 157 | :param store: `self._nested_relation_settings` or `self._nested_model_settings` 158 | :param key: `relation_name`, or `target_model` 159 | :param args: Arguments passed to '*' lambda-handler 160 | :return: dict | None 161 | """ 162 | # Try to get it by key 163 | sets = store.get(key, None) 164 | 165 | # callable? 166 | if callable(sets): 167 | sets = sets() if key != '*' else sets(*star_lambda_args) 168 | 169 | # Found? 170 | if sets is not None: 171 | return sets 172 | 173 | # Fallback: '*' 174 | if key != '*': 175 | return self._get_nested_settings_from_store_attr(store, '*', star_lambda_args) 176 | else: 177 | # Not found 178 | return None 179 | 180 | def settings_for_nested_mongoquery(self, relation_name: str, target_model: DeclarativeMeta) -> dict: 181 | """ Get settings for a nested MongoQuery 182 | 183 | Tries in turn: 184 | related[relation-name] 185 | related[*] 186 | related_models[target-model] 187 | related_models[*] 188 | """ 189 | # Try "related" 190 | sets = self._get_nested_settings_from_store_attr(self._nested_relation_settings, relation_name, (relation_name, target_model)) 191 | 192 | # Try "related_models" 193 | if sets is None: 194 | sets = self._get_nested_settings_from_store_attr(self._nested_model_settings, target_model, (relation_name, target_model)) 195 | 196 | # Done 197 | return sets 198 | 199 | def __repr__(self): 200 | return repr('{}({})'.format(self.__class__.__name__, self._settings)) 201 | 202 | 203 | 204 | call_if_callable = lambda v: v() if callable(v) else v 205 | -------------------------------------------------------------------------------- /mongosql/util/reusable.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | 4 | class Reusable: 5 | """ Make a reusable handler or query 6 | 7 | When a handler object is initialized, it's a pity to waste it! 8 | This class wrapper makes a copy every time .input() is called on its wrapped object. 9 | 10 | Example: 11 | 12 | project = Reusable(MongoProject(User, force_exclude=('password',)) 13 | 14 | It also works for MongoQuery: 15 | 16 | query = Reusable(MongoQuery(User)) 17 | """ 18 | __slots__ = ('__obj',) 19 | 20 | def __init__(self, obj): 21 | # Just store the object inside 22 | self.__obj = obj 23 | 24 | # Whenever any attribute (property or method) is accessed, the whole thing is copied. 25 | # This is copy-on-access 26 | 27 | def __getattr__(self, attr): 28 | return getattr(copy(self.__obj), attr) 29 | 30 | def __repr__(self): 31 | return repr(self.__obj) 32 | -------------------------------------------------------------------------------- /mongosql/util/selectinquery.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm.strategy_options import loader_option, _UnboundLoad 2 | from sqlalchemy.orm.strategies import SelectInLoader 3 | from sqlalchemy.orm import properties 4 | from sqlalchemy import log, util 5 | 6 | 7 | @log.class_logger 8 | @properties.RelationshipProperty.strategy_for(lazy="selectin_query") 9 | class SelectInQueryLoader(SelectInLoader, util.MemoizedSlots): 10 | """ A custom loader that acts like selectinload(), but supports using a custom query for related models. 11 | 12 | This enables us to use selectinload() with relationships which are loaded with a query that we 13 | can alter with a callable. 14 | 15 | Example usage: 16 | 17 | selectin_query( 18 | User.articles, 19 | lambda q, **kw: \ 20 | q.filter(User.articles.rating > 0.5) 21 | ) 22 | """ 23 | 24 | __slots__ = ('_alter_query', '_cache_key', '_bakery') 25 | 26 | def create_row_processor(self, context, path, loadopt, mapper, result, adapter, populators): 27 | # Pluck the custom callable that alters the query out of the `loadopt` 28 | self._alter_query = loadopt.local_opts['alter_query'] 29 | self._cache_key = loadopt.local_opts['cache_key'] 30 | 31 | # Call super 32 | return super(SelectInQueryLoader, self) \ 33 | .create_row_processor(context, path, loadopt, mapper, result, adapter, populators) 34 | 35 | # The easiest way would be to just copy `SelectInLoader` and make adjustments to the code, 36 | # but that would require us supporting it, porting every change from SqlAlchemy. 37 | # We don't want that! 38 | # Therefore, this class is hacky, and tries to reuse as much code as possible. 39 | 40 | # The main method that performs all the magic is SelectInLoader._load_for_path() 41 | # I don't want to copy it. 42 | # Solution? Let's hack into it. 43 | 44 | # The first step is to investigate the way the query is generated. 45 | # 1. q = self._bakery(lambda session: session.query(...)) makes the Query 46 | # 2. q.add_criteria(lambda q: ...) is used to alter the query, and add all criteria 47 | # 3. q.add_criteria(lambda q: q.filter(in_expr.in_....)) builds the IN clause 48 | # 4. q._add_lazyload_options() copies some options from the original query (`context.query`) with 49 | # 5. (sometimes) q.add_criteria(lambda q: q.order_by()) orders by a foreign key 50 | # 6. In a loop, q(context.session).params(..) is invoked, for every chunk 51 | 52 | # Looks like we can wrap self._bakery with a wrapper which will make the query, 53 | # and inject our alter_query() just before it is used, in step 6. 54 | # How do we do it? 55 | # See SmartInjectorBakedQuery, or AlteringBakedQuery, or UnBakedQuery 56 | 57 | def _memoized_attr__bakery(self): 58 | # Here we override the `self.bakery` attribute 59 | # We feed it with a callable that can fetch the information about the current query 60 | return SmartInjectorBakedQuery.bakery( 61 | lambda: (self._alter_query, self._cache_key), 62 | size=300 # we can expect a lot of different queries 63 | ) 64 | 65 | 66 | # region Bakery Wrapper that will apply alter_query() in the end 67 | 68 | from sqlalchemy.ext.baked import Bakery, BakedQuery 69 | 70 | 71 | class SmartInjectorBakedQuery(BakedQuery): 72 | """ A BakedQuery that is able to inject another callable at the very last step, and still use the cache 73 | 74 | The whole point of the trick is that the function that we want to hack into uses a series of 75 | q.add_criteria(), and ultimately, does q.__call__() in a loop. 76 | Our goal is to inject another q.add_criteria() right before the final __call__(). 77 | 78 | To achieve that, we subclass BakedQuery, and do our injection in the overridden __call__(), once. 79 | """ 80 | __slots__ = ('_alter_query', '_done_once', '_can_be_cached') 81 | 82 | @classmethod 83 | def bakery(cls, alter_query_getter, size=200, _size_alert=None): 84 | bakery = SmartInjectorBakery(cls, util.LRUCache(size, size_alert=_size_alert)) # Copied from sqlalchemy 85 | bakery.alter_query_getter(alter_query_getter) 86 | return bakery 87 | 88 | def __init__(self, bakery, initial_fn, args=(), alter_query=None, cache_key=None): 89 | """ Initialize the baked query wrapper that will apply `alter_query` at the last moment 90 | 91 | Here, we just pass everything down the chain, 92 | but add another item to `args`, which is our cache key 93 | """ 94 | super(SmartInjectorBakedQuery, self).__init__(bakery, initial_fn, args + (cache_key,)) 95 | self._alter_query = alter_query 96 | self._can_be_cached = cache_key is not None 97 | self._done_once = False 98 | 99 | def __call__(self, session): 100 | # This method will be called many times in a loop, so we have to inject only once. 101 | 102 | # Dot it just once 103 | if not self._done_once: 104 | # If no external cache key was provided, we can't cache 105 | if not self._can_be_cached: 106 | self.spoil() 107 | 108 | # Inject our custom query 109 | self.add_criteria(self._alter_query) 110 | self._done_once = True # never again 111 | 112 | # Execute the query 113 | return super(SmartInjectorBakedQuery, self).__call__(session) 114 | 115 | 116 | class SmartInjectorBakery(Bakery): 117 | """ A bakery that remembers its parent class and is able to load additional data from it. 118 | 119 | In our case, it remembers a getter function that asks the parent class to provice an 120 | `alter_query` function. It is then passed to a BakedQuery, and is injected at the very 121 | last stage. 122 | """ 123 | __slots__ = ('_alter_query_getter', ) 124 | 125 | def alter_query_getter(self, alter_query_getter): 126 | self._alter_query_getter = alter_query_getter 127 | 128 | def __call__(self, initial_fn, *args): 129 | # Copy-paste from Bakery.__call__() 130 | return self.cls(self.cache, initial_fn, args, *self._alter_query_getter()) 131 | 132 | # endregion 133 | 134 | 135 | # Register the loader option 136 | 137 | @loader_option() 138 | def selectinquery(loadopt, relationship, alter_query, cache_key=None): 139 | """Indicate that the given attribute should be loaded using SELECT IN eager loading, 140 | with a custom `alter_query(q)` callable that returns a modified query. 141 | 142 | Args 143 | ---- 144 | 145 | alter_query: Callable 146 | A callable(query) that alters the query produced by selectinloader 147 | cache_key: Hashable 148 | A value to use for caching the query (if possible) 149 | """ 150 | # The loader option just declares which class to use 151 | loadopt = loadopt.set_relationship_strategy(relationship, {"lazy": "selectin_query"}) 152 | 153 | # Loader options don't let us pass any other data to the class, but we need our custom query in. 154 | # The only way is to use the loader option itself. 155 | # create_row_processor() method will pluck it out. 156 | assert 'alter_query' not in loadopt.local_opts # I'm not too sure that there won't be a clash. If there is, we'll have to use a unique key per relationship. 157 | loadopt.local_opts['alter_query'] = alter_query 158 | loadopt.local_opts['cache_key'] = cache_key 159 | 160 | # Done 161 | return loadopt 162 | 163 | 164 | @selectinquery._add_unbound_fn 165 | def selectinquery(relationship, alter_query, cache_key=None): 166 | return _UnboundLoad.selectinquery(_UnboundLoad(), relationship, alter_query, cache_key) 167 | 168 | 169 | # The exported loader option 170 | selectinquery = selectinquery._unbound_fn 171 | -------------------------------------------------------------------------------- /myproject/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = __import__('pkg_resources').get_distribution('mongosql').version 2 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox.sessions 2 | 3 | PYTHON_VERSIONS = ['3.7', '3.8', '3.9'] 4 | SQLALCHEMY_VERSIONS = [ 5 | *(f'1.2.{x}' for x in range(0, 1 + 19)), 6 | *(f'1.3.{x}' for x in range(0, 1 + 24)), 7 | # '1.4.0b3', # not yet 8 | ] 9 | SQLALCHEMY_VERSIONS.remove('1.2.9') # bug 10 | 11 | 12 | nox.options.reuse_existing_virtualenvs = True 13 | nox.options.sessions = [ 14 | 'tests', 15 | 'tests_sqlalchemy', 16 | ] 17 | 18 | 19 | @nox.session(python=PYTHON_VERSIONS) 20 | def tests(session: nox.sessions.Session, sqlalchemy=None): 21 | """ Run all tests """ 22 | session.install('poetry') 23 | session.run('poetry', 'install') 24 | 25 | # Specific package versions 26 | if sqlalchemy: 27 | session.install(f'sqlalchemy=={sqlalchemy}') 28 | 29 | # Test 30 | session.run('pytest', 'tests/', '--cov=mongosql') 31 | 32 | 33 | @nox.session(python=PYTHON_VERSIONS[-1]) 34 | @nox.parametrize('sqlalchemy', SQLALCHEMY_VERSIONS) 35 | def tests_sqlalchemy(session: nox.sessions.Session, sqlalchemy): 36 | """ Test against a specific SqlAlchemy version """ 37 | tests(session, sqlalchemy) 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mongosql" 3 | version = "2.0.15-1" 4 | description = "A JSON query engine with SqlAlchemy as a back-end" 5 | authors = ["Mark Vartanyan "] 6 | repository = 'https://github.com/kolypto/py-mongosql' 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.6" 10 | sqlalchemy = '^1.2, !=1.2.9, < 1.4' 11 | nplus1loader = { version = '^1.0', optional = true } 12 | 13 | [tool.poetry.dev-dependencies] 14 | nox = "^2020.8.22" 15 | pytest = "^6.0.1" 16 | pytest-cov = "^2.10.1" 17 | nplus1loader = '^1.0' 18 | j2cli = '^0.3.10' 19 | psycopg2-binary = '^2.8' 20 | exdoc = '^0.1.3' 21 | flask_jsontools = '^0.1.7' 22 | 23 | [tool.pytest.ini_options] 24 | testpaths = [ 25 | "tests/", 26 | ] 27 | 28 | [build-system] 29 | requires = ["poetry-core>=1.0.0"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # import warnings 2 | # 3 | # from sqlalchemy.exc import SAWarning 4 | # warnings.filterwarnings('error', category=SAWarning) 5 | # 6 | -------------------------------------------------------------------------------- /tests/benchmarks/.gitignore: -------------------------------------------------------------------------------- 1 | /mongosql_v1 2 | -------------------------------------------------------------------------------- /tests/benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kolypto/py-mongosql/27d2d125e862106077addec0376b07b13894439d/tests/benchmarks/__init__.py -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_CountingQuery.py: -------------------------------------------------------------------------------- 1 | """ 2 | This benchmark compares the performance of: 3 | * selectinload() 4 | * selectinquery() with query caching 5 | * selectinquery() with no query caching 6 | """ 7 | from tests.benchmarks.benchmark_utils import benchmark_parallel_funcs 8 | from tests.models import get_big_db_for_benchmarks, User 9 | 10 | from mongosql import CountingQuery 11 | 12 | # Init DB 13 | engine, Session = get_big_db_for_benchmarks(50, 0, 0) 14 | 15 | # Prepare 16 | N_REPEATS = 1000 17 | ssn = Session() 18 | 19 | 20 | # Tests 21 | def test_two_queries(n): 22 | """ Test making an additional query to get the count """ 23 | for i in range(n): 24 | users = list(ssn.query(User)) 25 | count = ssn.query(User).count() 26 | 27 | def test_counting_query(n): 28 | """ Test CountingQuery """ 29 | for i in range(n): 30 | qc = CountingQuery(ssn.query(User)) 31 | users = list(qc) 32 | count = qc.count 33 | 34 | 35 | # Run 36 | res = benchmark_parallel_funcs( 37 | N_REPEATS, 10, 38 | test_two_queries, 39 | test_counting_query, 40 | ) 41 | 42 | # Done 43 | print(res) 44 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_compare_orm_overhead_with_pure_jsonb_output.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test: mapping DB objects directly to dicts. 3 | 4 | I've run the following benchmark: we load 100 Users, each having 5 5 | articles, each having 3 comments. 6 | We use several different ways to do it, and compare the results. 7 | 8 | The results are quite interesting: 9 | 10 | test_joinedload: 15.40s (395.12%) 11 | test_selectinload: 18.73s (480.50%) 12 | test_core__left_join_with_python_nesting: 3.90s (100.00%) 13 | test_core__3_queries__tuples: 4.57s (117.38%) 14 | test_core__3_queries__json: 10.98s (281.63%) 15 | test_subquery_jsonb_tuples: 12.35s (316.98%) 16 | test_subquery_jsonb_objects: 20.82s (534.27%) 17 | test_subquery_json_tuples: 6.05s (155.31%) 18 | test_subquery_json_objects: 9.44s (242.32%) 19 | test_single_line_agg__json: 7.31s (187.58%) 20 | test_single_line_agg__jsonb: 14.42s (369.89%) 21 | test_semisingle_line_agg__json: 7.34s (188.41%) 22 | test_semisingle_line_agg__jsonb: 15.40s (395.09%) 23 | 24 | 25 | What that means is: 26 | 27 | Postgres was able to load all data in 7.20s and send it to us. 28 | With SqlAlchemy ORM, the whole process has taken 32.22 seconds: the ORM has spent an additional 25s making its Python objects!! 29 | That's 500% overhead! 30 | 31 | That's alright if you load just one object. But when all we need is load a bunch of objects and immediately convert 32 | them to JSON, that's a huge, huge overhead for no added benefit. We don't need no ORM features for this task. 33 | 34 | Sidenote: joinedload() is somehow 20% faster than selectinload(). Surprise! 35 | But that is probably because we didn't have many fields. 36 | 37 | Now, @vdmit11 has suggested a crazy idea: what if we make JSON object in 38 | Postgres? `jsonb_agg()` is what he has suggested. I've tested different ways to do it, and discovered that it really 39 | is faster. 40 | 41 | Using `json_agg()` is somehow 2x faster than `jsonb_agg()`, both in Postgres 9.6 and 11.5. 42 | We can also win an additional 2x by not using `to_json()` on rows, but return tuples. 43 | Both this techniques let us fetch the results 3.5x faster than `selectinload()`, 2.7x faster than `joinedload()`. 44 | 45 | But this will give us tuples. 46 | If we try to use `to_json()` and fetch keyed objects, it's more convenient, but reduces the performance improvement 47 | to just 1.5x, which brings it close to what SqlAlchemy does. 48 | 49 | Conclusion: forming JSON directly in Postgres can potentially speed up some queries 3x. But this is only applicable 50 | to those queries that feed the data to JSON immediately. It's worth doing, but is rather complicated. 51 | 52 | The problems with this approach: we'll have to make sure that `@property` fields are included into the results if 53 | they were specified in a projection. 54 | """ 55 | 56 | from sqlalchemy.orm import selectinload, joinedload 57 | 58 | from tests.benchmarks.benchmark_utils import benchmark_parallel_funcs 59 | from tests.models import get_big_db_for_benchmarks, User, Article 60 | 61 | # Init DB 62 | engine, Session = get_big_db_for_benchmarks(n_users=100, 63 | n_articles_per_user=5, 64 | n_comments_per_article=3) 65 | 66 | # Prepare 67 | N_REPEATS = 500 68 | ssn = Session() 69 | 70 | 71 | # Tests 72 | def test_selectinload(n): 73 | """ Load Users+Articles+Comments, with selectinload() """ 74 | for i in range(n): 75 | users = ssn.query(User).options( 76 | selectinload(User.articles).selectinload(Article.comments) 77 | ).all() 78 | 79 | 80 | def test_joinedload(n): 81 | """ Load Users+Articles+Comments, with joinedload() """ 82 | for i in range(n): 83 | users = ssn.query(User).options( 84 | joinedload(User.articles).joinedload(Article.comments) 85 | ).all() 86 | 87 | 88 | 89 | def test_core__left_join_with_no_post_processing(n): 90 | """ Use a plain SQL query with LEFT JOIN + generate JSON objects in Python """ 91 | for i in range(n): 92 | rows = ssn.execute(""" 93 | SELECT u.*, a.*, c.* 94 | FROM u 95 | LEFT JOIN a ON u.id = a.uid 96 | LEFT JOIN c ON c.aid = a.id 97 | """).fetchall() 98 | 99 | 100 | def test_core__left_join_with_python_nesting(n): 101 | """ Use a plain SQL query with LEFT JOIN + generate JSON objects in Python """ 102 | for i in range(n): 103 | rows = ssn.execute(""" 104 | SELECT u.*, a.*, c.* 105 | FROM u 106 | LEFT JOIN a ON u.id = a.uid 107 | LEFT JOIN c ON c.aid = a.id 108 | """).fetchall() 109 | 110 | # Okay, that was fast; what if we collect all the data in Python? 111 | # Will Python loops kill all performance? 112 | 113 | users = [] 114 | user_id_to_user = {} 115 | article_id_to_article = {} 116 | comment_id_to_comment = {} 117 | 118 | for row in rows: 119 | # user dict() 120 | user_id = row[0] 121 | if user_id in user_id_to_user: 122 | user = user_id_to_user[user_id] 123 | else: 124 | user = {'id': row[0], 'name': row[1], 'tags': row[2], 'age': row[3], 'articles': []} 125 | user_id_to_user[user_id] = user 126 | users.append(user) 127 | 128 | # article dict() 129 | article_id = row[4] 130 | if article_id: 131 | if article_id in article_id_to_article: 132 | article = article_id_to_article[article_id] 133 | else: 134 | article = {'id': row[4], 'uid': row[5], 'title': row[6], 'theme': row[7], 'data': row[8], 'comments': []} 135 | article_id_to_article[article_id] = article 136 | if user: 137 | user['articles'].append(article) 138 | else: 139 | article = None 140 | 141 | # comment dict() 142 | comment_id = row[9] 143 | if comment_id: 144 | if comment_id in comment_id_to_comment: 145 | comment = comment_id_to_comment[comment_id] 146 | else: 147 | comment ={'id': row[9], 'aid': row[10], 'uid': row[11], 'text': row[12]} 148 | if article: 149 | article['comments'].append(comment) 150 | else: 151 | comment = None 152 | 153 | 154 | def test_core__3_queries__tuples(n): 155 | """ Make 3 queries, load tuples """ 156 | for i in range(n): 157 | users = ssn.execute('SELECT u.* FROM u;').fetchall() 158 | 159 | user_ids = set(str(u.id) for u in users) 160 | articles = ssn.execute('SELECT a.* FROM a WHERE uid IN (' + (','.join(user_ids)) + ')').fetchall() 161 | 162 | article_ids = set(str(a.id) for a in articles) 163 | comments = ssn.execute('SELECT c.* FROM c WHERE aid IN (' + (','.join(article_ids)) + ')').fetchall() 164 | 165 | 166 | def test_core__3_queries__json(n): 167 | """ Make 3 queries, load json rows """ 168 | for i in range(n): 169 | users = ssn.execute('SELECT to_json(u) FROM u;').fetchall() 170 | 171 | user_ids = set(str(u[0]['id']) for u in users) 172 | articles = ssn.execute('SELECT to_json(a) FROM a WHERE uid IN (' + (','.join(user_ids)) + ')').fetchall() 173 | 174 | article_ids = set(str(a[0]['id']) for a in articles) 175 | comments = ssn.execute('SELECT to_json(c) FROM c WHERE aid IN (' + (','.join(article_ids)) + ')').fetchall() 176 | 177 | # Now do the same with with nested JSON 178 | # We query Users, their Articles, and their Comments, all as nested objects 179 | # We use `AGG_FUNCTION`, which is either json_agg(), or jsonb_agg() 180 | # We use different `*_SELECT` expressions, which are either a list of columns, or to_json(row), or to_jsonb(row) 181 | # All those combinations we test to see which one is faster 182 | NESTED_AGG_QUERY_TEMPLATE = """ 183 | SELECT 184 | {USERS_SELECT}, 185 | {AGG_FUNCTION}(articles_q) AS articles 186 | FROM u AS users 187 | LEFT JOIN ( 188 | SELECT 189 | {ARTICLES_SELECT}, 190 | {AGG_FUNCTION}(comments_q) AS comments 191 | FROM a AS articles 192 | LEFT JOIN ( 193 | SELECT 194 | {COMMENTS_SELECT} 195 | FROM c AS comments 196 | ) AS comments_q ON articles.id = comments_q.aid 197 | GROUP BY articles.id 198 | ) AS articles_q ON users.id = articles_q.uid 199 | GROUP BY users.id; 200 | """ 201 | 202 | def test_subquery_jsonb_tuples(n): 203 | """ Test making JSONB on the server. Return tuples. """ 204 | query = NESTED_AGG_QUERY_TEMPLATE.format( 205 | # Use JSONB for nested objects 206 | AGG_FUNCTION='jsonb_agg', 207 | # Select rows as tuples 208 | USERS_SELECT='users.*', 209 | ARTICLES_SELECT='articles.*', 210 | COMMENTS_SELECT='comments.*', 211 | ) 212 | for i in range(n): 213 | users = ssn.execute(query).fetchall() 214 | 215 | 216 | def test_subquery_jsonb_objects(n): 217 | """ Test making JSONB on the server. Make objects with row_to_json() """ 218 | query = NESTED_AGG_QUERY_TEMPLATE.format( 219 | # Use JSONB for nested objects 220 | AGG_FUNCTION='jsonb_agg', 221 | # Select rows as JSONB objects 222 | # Select ids needed for joining as well 223 | USERS_SELECT='users.id, to_jsonb(users) AS user', 224 | ARTICLES_SELECT='articles.id, articles.uid, to_jsonb(articles) AS article', 225 | COMMENTS_SELECT='comments.id, comments.aid, to_jsonb(comments) AS comment', 226 | ) 227 | 228 | for i in range(n): 229 | users = list(ssn.execute(query)) 230 | 231 | def test_subquery_json_tuples(n): 232 | """ Test making JSON on the server. Return tuples. """ 233 | query = NESTED_AGG_QUERY_TEMPLATE.format( 234 | # Use JSON for nested objects 235 | AGG_FUNCTION='json_agg', 236 | # Select rows as tuples 237 | USERS_SELECT='users.*', 238 | ARTICLES_SELECT='articles.*', 239 | COMMENTS_SELECT='comments.*', 240 | ) 241 | for i in range(n): 242 | users = list(ssn.execute(query)) 243 | 244 | def test_subquery_json_objects(n): 245 | """ Test making JSONB on the server. Make objects with row_to_json() """ 246 | query = NESTED_AGG_QUERY_TEMPLATE.format( 247 | # Use JSON for nested objects 248 | AGG_FUNCTION='json_agg', 249 | USERS_SELECT='users.id, to_json(users) AS user', 250 | # Select rows as JSON objects 251 | # Select ids needed for joining as well 252 | ARTICLES_SELECT='articles.id, articles.uid, to_json(articles) AS article', 253 | COMMENTS_SELECT='comments.id, comments.aid, to_json(comments) AS comment', 254 | ) 255 | 256 | for i in range(n): 257 | users = list(ssn.execute(query)) 258 | 259 | 260 | LINE_AGG_QUERY_TEMPLATE = """ 261 | SELECT {TO_JSON}(u), {JSON_AGG}(a), {JSON_AGG}(c) 262 | FROM u 263 | LEFT JOIN a ON(u.id=a.uid) 264 | LEFT JOIN c ON (a.id=c.aid) 265 | GROUP BY u.id; 266 | """ 267 | 268 | def test_single_line_agg__json(n): 269 | """ Linear aggregation: no nesting; everything's aggregated into separate JSON lists """ 270 | query = LINE_AGG_QUERY_TEMPLATE.format( 271 | TO_JSON='to_json', 272 | JSON_AGG='json_agg', 273 | ) 274 | for i in range(n): 275 | rows = ssn.execute(query).fetchall() 276 | 277 | def test_single_line_agg__jsonb(n): 278 | """ Linear aggregation: no nesting; everything's aggregated into separate JSONB lists """ 279 | query = LINE_AGG_QUERY_TEMPLATE.format( 280 | TO_JSON='to_jsonb', 281 | JSON_AGG='jsonb_agg', 282 | ) 283 | for i in range(n): 284 | rows = ssn.execute(query).fetchall() 285 | 286 | 287 | SEMILINE_AGG_QUERY_TEMPLATE = """ 288 | SELECT 289 | {JSON_BUILD_OBJECT}( 290 | 'id', u.id, 'name', u.name, 'tags', u.tags, 'age', u.age, 291 | 'articles', {JSON_AGG}(a)), 292 | {JSON_AGG}(c) 293 | FROM u 294 | LEFT JOIN a ON(u.id=a.uid) 295 | LEFT JOIN c ON (a.id=c.aid) 296 | GROUP BY u.id; 297 | """ 298 | 299 | def test_semisingle_line_agg__json(n): 300 | """ Aggregate only 1st level objects; things that are nested deeper are expelled to the outskirts """ 301 | query = SEMILINE_AGG_QUERY_TEMPLATE.format( 302 | JSON_BUILD_OBJECT='json_build_object', 303 | JSON_AGG='json_agg', 304 | ) 305 | for i in range(n): 306 | rows = ssn.execute(query).fetchall() 307 | 308 | 309 | def test_semisingle_line_agg__jsonb(n): 310 | """ Aggregate only 1st level objects; things that are nested deeper are expelled to the outskirts """ 311 | query = SEMILINE_AGG_QUERY_TEMPLATE.format( 312 | JSON_BUILD_OBJECT='jsonb_build_object', 313 | JSON_AGG='jsonb_agg', 314 | ) 315 | for i in range(n): 316 | rows = ssn.execute(query).fetchall() 317 | 318 | 319 | 320 | # Run 321 | res = benchmark_parallel_funcs( 322 | N_REPEATS, 10, 323 | test_joinedload, 324 | test_selectinload, 325 | test_core__left_join_with_no_post_processing, 326 | test_core__left_join_with_python_nesting, 327 | test_core__3_queries__tuples, 328 | test_core__3_queries__json, 329 | test_subquery_jsonb_tuples, 330 | test_subquery_jsonb_objects, 331 | test_subquery_json_tuples, 332 | test_subquery_json_objects, 333 | test_single_line_agg__json, 334 | test_single_line_agg__jsonb, 335 | test_semisingle_line_agg__json, 336 | test_semisingle_line_agg__jsonb, 337 | ) 338 | 339 | # Done 340 | print(res) 341 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_one_query.py: -------------------------------------------------------------------------------- 1 | from tests.benchmarks.benchmark_utils import benchmark_parallel_funcs 2 | 3 | from mongosql.handlers import MongoJoin 4 | from tests.models import * 5 | 6 | # Run me: 7 | # $ python -m cProfile -o profile.out tests/benchmark_one_query.py 8 | 9 | # Init DB: choose one 10 | # engine, Session = get_big_db_for_benchmarks(100, 10, 3) 11 | engine, Session = get_working_db_for_tests() 12 | # engine, Session = get_empty_db() 13 | 14 | # Prepare 15 | N_REPEATS = 1000 16 | ssn = Session() 17 | 18 | # Tests 19 | def run_query(n): 20 | for i in range(n): 21 | q = User.mongoquery(ssn).query( 22 | project=['name'], 23 | filter={'age': {'$ne': 100}}, 24 | join={'articles': dict(project=['title'], 25 | filter={'theme': {'$ne': 'sci-fi'}}, 26 | join={'comments': dict(project=['aid'], 27 | filter={'text': {'$exists': True}})})} 28 | ).end() 29 | list(q.all()) 30 | 31 | def test_selectinquery(n): 32 | """ Test with selectinquery() """ 33 | MongoJoin.ENABLED_EXPERIMENTAL_SELECTINQUERY = True 34 | run_query(n) 35 | 36 | def test_joinedload(n): 37 | """ Test with joinedload() """ 38 | MongoJoin.ENABLED_EXPERIMENTAL_SELECTINQUERY = False 39 | run_query(n) 40 | 41 | # Run 42 | print('Running tests...') 43 | res = benchmark_parallel_funcs( 44 | N_REPEATS, 10, 45 | test_joinedload, 46 | test_selectinquery 47 | ) 48 | 49 | # Done 50 | print(res) 51 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_selectinquery.py: -------------------------------------------------------------------------------- 1 | """ 2 | This benchmark compares the performance of: 3 | * selectinload() 4 | * selectinquery() with query caching 5 | * selectinquery() with no query caching 6 | """ 7 | 8 | 9 | from tests.benchmarks.benchmark_utils import benchmark_parallel_funcs 10 | 11 | from sqlalchemy.orm import selectinload, joinedload 12 | 13 | from mongosql import selectinquery 14 | from tests.models import get_working_db_for_tests, User, Article 15 | 16 | # Run me: PyCharm Profiler 17 | # Run me: python -m cProfile -o profile.out tests/benchmark_selectinquery.py 18 | 19 | # Init DB 20 | engine, Session = get_working_db_for_tests() 21 | 22 | # Prepare 23 | N_REPEATS = 1000 24 | ssn = Session() 25 | 26 | # Tests 27 | def test_selectinload(n): 28 | """ Test SqlAlchemy's selectinload(): using it as a baseline """ 29 | for i in range(n): 30 | q = ssn.query(User).options( 31 | selectinload(User.articles).selectinload(Article.comments) 32 | ) 33 | list(q.all()) 34 | 35 | def test_selectinquery__cache(n): 36 | """ Test our custom selectinquery(), with query caching """ 37 | for i in range(n): 38 | q = ssn.query(User).options( 39 | selectinquery(User.articles, lambda q: q, 'a').selectinquery(Article.comments, lambda q: q, 'b') 40 | ) 41 | list(q.all()) 42 | 43 | def test_selectinquery__no_cache(n): 44 | """ Test our custom selectinquery(), without query caching """ 45 | for i in range(n): 46 | q = ssn.query(User).options( 47 | selectinquery(User.articles, lambda q: q).selectinquery(Article.comments, lambda q: q) 48 | ) 49 | list(q.all()) 50 | 51 | 52 | # Run 53 | res = benchmark_parallel_funcs( 54 | N_REPEATS, 10, 55 | test_selectinload, 56 | test_selectinquery__cache, 57 | test_selectinquery__no_cache, 58 | ) 59 | 60 | # Done 61 | print(res) 62 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | from time import time_ns 2 | from collections import defaultdict 3 | 4 | 5 | class Nanotimers: 6 | """ A timer with nanosecond precision 7 | 8 | This timer lets you do start() and stop() many times, measuring small intervals. 9 | It supports measuring many things at once, each having its distinct `name`. 10 | 11 | Because it's designed for benchmarking functions, where a function call itself has overhead in Python, 12 | it can also compensate for this overhead with overhead_shift(). 13 | 14 | Typical usage: 15 | 16 | # Init 17 | timers = Nanotimers() 18 | 19 | # Measure a `test-name` 20 | timers.start('test-name') 21 | for i in range(1000): 22 | call_your_function() 23 | timers.stop('test-name') 24 | 25 | timers.overhead_shift(100) # shift all results by 100ns (as measured by N iterations of an empty function) 26 | """ 27 | 28 | __slots__ = ('_timers', '_results') 29 | 30 | def __init__(self): 31 | self._timers = {} 32 | self._results = defaultdict(int) 33 | 34 | def start(self, name): 35 | """ Start measuring time for `name` """ 36 | self._timers[name] = time_ns() 37 | 38 | def stop(self, name): 39 | """ Stop measuring time for `name`. 40 | 41 | You can add more time by calling start()/stop() again. 42 | """ 43 | total = time_ns() - self._timers[name] 44 | self._results[name] += total 45 | 46 | def overhead_shift(self, value): 47 | """ Reduce all results by `value` nanoseconds to compensate for some overhead """ 48 | for name in self._results: 49 | self._results[name] -= value 50 | 51 | def __getitem__(self, name): 52 | return self._results[name] / 10**9 53 | 54 | def dict(self): 55 | return {name: ns / 10**9 for name, ns in self._results.items()} 56 | 57 | def results(self): 58 | """ Get results. 59 | 60 | This method does not only return the raw measured times, but also calculates the relative percentages. 61 | Example return value: 62 | 63 | { 64 | 'your-name': dict( 65 | time=1.2, # seconds 66 | perc=120, # percent of the best time 67 | ), 68 | ... 69 | } 70 | """ 71 | min_time = min(self._results.values()) 72 | return { 73 | name: { 74 | 'time': ns / 10**9, 75 | 'perc': 100 * ns / min_time, 76 | } 77 | for name, ns in self._results.items() 78 | } 79 | 80 | def __str__(self): 81 | """ Format the measured values to make it look great! """ 82 | return '\n'.join( 83 | f'{name}: {res["time"]:.02f}s ({res["perc"]:.02f}%)' 84 | for name, res in self.results().items() 85 | ) 86 | 87 | 88 | def benchmark_parallel(n_iterations, n_parts, **tests): 89 | """ Run the given tests in parallel. 90 | 91 | It will switch between all the given tests back and forth, making sure that some local 92 | performance fluctuations won't hurt running the tests. 93 | 94 | Args 95 | ---- 96 | 97 | n_iterations: int 98 | The total number of iterations 99 | n_parts: int 100 | The number of parts to break those iterations into 101 | tests: 102 | Named tests to run. 103 | Each is a callable that receives the `n` argument: number of repetitions 104 | """ 105 | timers = Nanotimers() 106 | iters_per_run = n_iterations // n_parts 107 | 108 | # Run 109 | for run in range(n_parts): 110 | for name, test in tests.items(): 111 | timers.start(name) 112 | test(iters_per_run) 113 | timers.stop(name) 114 | 115 | # Fix overhead 116 | # Measure overhead: call an empty function the same number of times 117 | def f(): pass 118 | t1 = time_ns() 119 | for i in range(n_iterations): f() 120 | t2 = time_ns() 121 | overhead_ns = t2 - t1 122 | # Fix it: shift all results by the measured number of nanoseconds 123 | timers.overhead_shift(overhead_ns) 124 | 125 | # Done 126 | return timers 127 | 128 | 129 | def benchmark_parallel_funcs(n_iterations, n_parts, *funcs): 130 | """ Run the given `funcs` test functions `n_iterations` times. 131 | 132 | Every function receives the `n` argument and is supposed to do its job `n` times in a loop. 133 | This is to reduce the impact of a repeated function call, and to let your tests initialize before they run. 134 | 135 | Names of those functions are used in displaying results. 136 | """ 137 | return benchmark_parallel( 138 | n_iterations, 139 | n_parts, 140 | **{f.__name__: f for f in funcs}) 141 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_v2_vs_v1.py: -------------------------------------------------------------------------------- 1 | from tests.benchmarks.benchmark_utils import benchmark_parallel_funcs 2 | from mongosql.handlers import MongoJoin 3 | from tests.models import * 4 | 5 | # Import both MongoSQL packages 6 | try: 7 | from tests.benchmarks.mongosql_v1.mongosql import MongoQuery as MongoQuery_v1 8 | except ImportError: 9 | print('Please install MongoSQL 1.5: ') 10 | print('$ bash tests/benchmarks/mongosql_v1_checkout.sh') 11 | exit(1) 12 | 13 | from mongosql import MongoQuery as MongoQuery_v2 14 | 15 | # Check SqlAlchemy version 16 | from sqlalchemy import __version__ as SA_VERSION 17 | assert SA_VERSION.startswith('1.2.'), 'Only works with SqlAlchemy 1.2.x' 18 | 19 | 20 | 21 | # Init DB: choose one 22 | engine, Session = get_working_db_for_tests() 23 | 24 | # Prepare 25 | N_REPEATS = 1000 26 | MongoJoin.ENABLED_EXPERIMENTAL_SELECTINQUERY = False 27 | ssn = Session() 28 | 29 | # Tests 30 | def test_v1(n): 31 | """ Test MongoSQL v1 """ 32 | for i in range(n): 33 | q = MongoQuery_v1(User, query=ssn.query(User)).query( 34 | project=['name'], 35 | filter={'age': {'$ne': 100}}, 36 | join={'articles': dict(project=['title'], 37 | filter={'theme': {'$ne': 'sci-fi'}}, 38 | join={'comments': dict(project=['aid'], 39 | filter={'text': {'$exists': True}})})} 40 | ).end() 41 | list(q.all()) 42 | 43 | def test_v2(n): 44 | """ Test MongoSQL v2 """ 45 | for i in range(n): 46 | q = MongoQuery_v2(User).with_session(ssn).query( 47 | project=['name'], 48 | filter={'age': {'$ne': 100}}, 49 | join={'articles': dict(project=['title'], 50 | filter={'theme': {'$ne': 'sci-fi'}}, 51 | join={'comments': dict(project=['aid'], 52 | filter={'text': {'$exists': True}})})} 53 | ).end() 54 | list(q.all()) 55 | 56 | # Run 57 | print('Running tests...') 58 | res = benchmark_parallel_funcs( 59 | N_REPEATS, 10, 60 | test_v1, 61 | test_v2 62 | ) 63 | 64 | # Done 65 | print(res) 66 | -------------------------------------------------------------------------------- /tests/benchmarks/mongosql_v1_checkout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | 4 | cd $(dirname "$0") 5 | git clone --depth 1 --branch v1.5 git@github.com:kolypto/py-mongosql.git mongosql_v1 6 | touch mongosql_v1/__init__.py 7 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | -------------------------------------------------------------------------------- /tests/crud_view.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from logging import getLogger 3 | 4 | from mongosql import CrudViewMixin, StrictCrudHelper, StrictCrudHelperSettingsDict, saves_relations 5 | 6 | from . import models 7 | from flask import request, g, jsonify 8 | from flask_jsontools import jsonapi, RestfulView 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | def passthrough_decorator(f): 14 | """ A no-op decorator. 15 | It's only purpose is to see whether @saves_relations() works even when decorated with something else. 16 | """ 17 | @wraps(f) 18 | def wrapper(*args, **kwargs): 19 | return f(*args, **kwargs) 20 | return wrapper 21 | 22 | 23 | class RestfulModelView(RestfulView, CrudViewMixin): 24 | """ Base view class for all other views """ 25 | crudhelper = None 26 | 27 | # RestfulView needs that for routing 28 | primary_key = None 29 | decorators = (jsonapi,) 30 | 31 | # Every response will have either { article: ... } or { articles: [...] } 32 | # Stick to the DRY principle: store the key name once 33 | entity_name = None 34 | entity_names = None 35 | 36 | # Implement the method that fetches the Query Object for this request 37 | def _get_query_object(self): 38 | """ Get Query Object from request 39 | 40 | :rtype: dict | None 41 | """ 42 | return (request.get_json() or {}).get('query', None) 43 | 44 | # CrudViewMixin demands: needs to be able to get a session so that it can run a query 45 | def _get_db_session(self): 46 | """ Get database Session 47 | 48 | :rtype: sqlalchemy.orm.Session 49 | """ 50 | return g.db 51 | 52 | # This is our method: it plucks an instance using the current projection 53 | # This is just convenience: if the user has requested 54 | def _return_instance(self, instance): 55 | """ Modify a returned instance """ 56 | return self._mongoquery.pluck_instance(instance) 57 | 58 | def _save_hook(self, new: models.Article, prev: models.Article = None): 59 | # There's one special case for a failure: title='z'. 60 | # This is how unit-tests can test exceptions 61 | if new.title == 'z': 62 | # Simulate a bug 63 | raise RuntimeError( 64 | 'This method inexplicably fails when title="z"' 65 | ) 66 | super()._save_hook(new, prev) 67 | 68 | # region CRUD methods 69 | 70 | def list(self): 71 | """ List method: GET /article/ """ 72 | # List results 73 | results = self._method_list() 74 | 75 | # Format response 76 | # NOTE: can't return map(), because it's not JSON serializable 77 | return {self.entity_names: results} 78 | 79 | def _method_list_result__groups(self, dicts): 80 | """ Format the result from GET /article/ when the result is a list of dicts (GROUP BY) """ 81 | return list(dicts) # our JSON serializer does not like generators. Have to make it into a list 82 | 83 | def _method_list_result__entities(self, entities): 84 | """ Format the result from GET /article/ when the result is a list of sqlalchemy entities """ 85 | # Pluck results: apply projection to the result set 86 | # This is just our good manners: if the client has requested certain fields, we return only those they requested. 87 | # Even if our code loads some more columns (and it does!), the client will always get what they requested. 88 | return list(map(self._return_instance, entities)) 89 | 90 | def get(self, id): 91 | item = self._method_get(id=id) 92 | return {self.entity_name: self._return_instance(item)} 93 | 94 | def create(self): 95 | # Trying to save many objects at once? 96 | if self.entity_names in request.get_json(): 97 | return self.save_many() 98 | 99 | # Saving only one object 100 | input_entity_dict = request.get_json()[self.entity_name] 101 | instance = self._method_create(input_entity_dict) 102 | 103 | ssn = self._get_db_session() 104 | ssn.add(instance) 105 | ssn.commit() 106 | 107 | return {self.entity_name: self._return_instance(instance)} 108 | 109 | def save_many(self): 110 | # Get the input 111 | input_json = request.get_json() 112 | entity_dicts = input_json[self.entity_names] 113 | 114 | # Process 115 | results = self._method_create_or_update_many(entity_dicts) 116 | 117 | # Save 118 | ssn = self._get_db_session() 119 | ssn.add_all(res.instance for res in results if res.instance is not None) 120 | ssn.commit() 121 | 122 | # Log every error 123 | for res in results: 124 | if res.error: 125 | logger.exception(str(res.error), exc_info=res.error) 126 | 127 | # Results 128 | return { 129 | # Entities 130 | self.entity_names: [ 131 | # Each one goes through self._return_instance() 132 | self._return_instance(res.instance) if res.instance else None 133 | for res in results 134 | ], 135 | # Errors 136 | 'errors': { 137 | res.ordinal_number: str(res.error) 138 | for res in results 139 | if res.error 140 | }, 141 | } 142 | 143 | def update(self, id): 144 | input_entity_dict = request.get_json()[self.entity_name] 145 | instance = self._method_update(input_entity_dict, id=id) 146 | 147 | ssn = self._get_db_session() 148 | ssn.add(instance) 149 | ssn.commit() 150 | 151 | return {self.entity_name: self._return_instance(instance)} 152 | 153 | def delete(self, id): 154 | instance = self._method_delete(id=id) 155 | 156 | ssn = self._get_db_session() 157 | ssn.delete(instance) 158 | ssn.commit() 159 | 160 | return {self.entity_name: self._return_instance(instance)} 161 | 162 | # endregion 163 | 164 | 165 | class ArticleView(RestfulModelView): 166 | """ Full-featured CRUD view """ 167 | 168 | # First, configure a CrudHelper 169 | crudhelper = StrictCrudHelper( 170 | # The model to work with 171 | models.Article, 172 | **StrictCrudHelperSettingsDict( 173 | # Read-only fields, as a callable (just because) 174 | ro_fields=lambda: ('id', 'uid',), 175 | legacy_fields=('removed_column',), 176 | # MongoQuery settings 177 | aggregate_columns=('id', 'data',), # have to explicitly enable aggregation for columns 178 | query_defaults=dict( 179 | sort=('id-',), 180 | ), 181 | writable_properties=True, 182 | max_items=2, 183 | # Related entities configuration 184 | allowed_relations=('user', 'comments'), 185 | related={ 186 | 'user': dict( 187 | # Exclude @property by default 188 | default_exclude=('user_calculated',), 189 | allowed_relations=('comments',), 190 | related={ 191 | 'comments': dict( 192 | # Exclude @property by default 193 | default_exclude=('comment_calc',), 194 | # No further joins 195 | join_enabled=False, 196 | ) 197 | } 198 | ), 199 | 'comments': dict( 200 | # Exclude @property by default 201 | default_exclude=('comment_calc',), 202 | # No further joins 203 | join_enabled=False, 204 | ), 205 | }, 206 | ) 207 | ) 208 | 209 | # ensure_loaded: always load these columns and relationships 210 | # This is necessary in case some custom code relies on it 211 | ensure_loaded = ('data', 'comments') # that's a weird requirement, but since the user is supposed to use projections, it will be excluded 212 | 213 | primary_key = ('id',) 214 | decorators = (jsonapi,) 215 | 216 | entity_name = 'article' 217 | entity_names = 'articles' 218 | 219 | def _method_create(self, entity_dict: dict) -> object: 220 | instance = super()._method_create(entity_dict) 221 | instance.uid = 3 # Manually set ro field value, because the client can't 222 | return instance 223 | 224 | # Our completely custom stuff 225 | 226 | @passthrough_decorator # no-op to demonstrate that it still works 227 | @saves_relations('comments') 228 | def save_comments(self, new, prev=None, comments=None): 229 | # Just store it in the class for unit-test to find it 230 | self.__class__._save_comments__args = dict(new=new, prev=prev, comments=comments) 231 | 232 | @passthrough_decorator # no-op to demonstrate that it still works 233 | @saves_relations('user', 'comments') 234 | def save_relations(self, new, prev=None, user=None, comments=None): 235 | # Just store it in the class for unit-test to find it 236 | self.__class__._save_relations__args = dict(new=new, prev=prev, user=user, comments=comments) 237 | 238 | @saves_relations('removed_column') 239 | def save_removed_column(self, new, prev=None, removed_column=None): 240 | # Store 241 | self.__class__._save_removed_column = dict(removed_column=removed_column) 242 | 243 | _save_comments__args = None 244 | _save_relations__args = None 245 | _save_removed_column = None 246 | 247 | 248 | class GirlWatcherView(RestfulModelView): 249 | crudhelper = StrictCrudHelper( 250 | models.GirlWatcher, 251 | **StrictCrudHelperSettingsDict( 252 | # Read-only fields, as a callable (just because) 253 | ro_fields=('id', 'favorite_id',), 254 | allowed_relations=('good', 'best') 255 | ) 256 | ) 257 | 258 | primary_key = ('id',) 259 | decorators = (jsonapi,) 260 | 261 | entity_name = 'girlwatcher' 262 | entity_names = 'girlwatchers' 263 | 264 | def _return_instance(self, instance): 265 | instance = super()._return_instance(instance) 266 | 267 | # TypeError: Object of type _AssociationList is not JSON serializable 268 | for k in ('good_names', 'best_names'): 269 | if k in instance: 270 | # Convert this _AssociationList() object into a real list 271 | instance[k] = list(instance[k]) 272 | 273 | return instance 274 | 275 | -------------------------------------------------------------------------------- /tests/saversion.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | 3 | from mongosql import SA_VERSION, SA_12, SA_13 4 | 5 | 6 | def SA_VERSION_IN(min_version, max_version): 7 | """ Check that SqlAlchemy version lies within a range 8 | 9 | This is slow; only use in unit-tests! 10 | """ 11 | return LooseVersion(min_version) <= LooseVersion(SA_VERSION) <= LooseVersion(max_version) 12 | 13 | 14 | def SA_SINCE(version): 15 | """ Check SqlAlchemy >= version """ 16 | return LooseVersion(SA_VERSION) >= LooseVersion(version) 17 | 18 | 19 | def SA_UNTIL(version): 20 | """ Check SqlAlchemy <= version """ 21 | return LooseVersion(SA_VERSION) <= LooseVersion(version) 22 | -------------------------------------------------------------------------------- /tests/t_method_decorator_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from functools import wraps 3 | 4 | from mongosql.util.method_decorator import method_decorator 5 | 6 | 7 | class method_decorator_test(unittest.TestCase): 8 | def test_method_decorators(self): 9 | # === Test: method decorator 10 | # Create a class 11 | class A: 12 | @method_decorator_1(1) 13 | def a(self): pass 14 | 15 | @method_decorator_1(2) 16 | def b(self): pass 17 | 18 | @method_decorator_2(3) 19 | def c(self): pass 20 | 21 | # isinstance() checks 22 | self.assertTrue(isinstance(method_decorator.get_method_decorator(A, 'a'), method_decorator_1)) 23 | self.assertTrue(isinstance(method_decorator.get_method_decorator(A, 'b'), method_decorator_1)) 24 | self.assertTrue(isinstance(method_decorator.get_method_decorator(A, 'c'), method_decorator_2)) 25 | 26 | self.assertFalse(isinstance(method_decorator.get_method_decorator(A, 'a'), method_decorator_2)) 27 | self.assertFalse(isinstance(method_decorator.get_method_decorator(A, 'b'), method_decorator_2)) 28 | self.assertFalse(isinstance(method_decorator.get_method_decorator(A, 'c'), method_decorator_1)) 29 | 30 | # Collect: decorator 1 31 | m1s = method_decorator_1.all_decorators_from(A) 32 | self.assertEqual(len(m1s), 2) 33 | self.assertEqual(m1s[0].method_name, 'a') 34 | self.assertEqual(m1s[0].arg1, 1) 35 | self.assertEqual(m1s[1].method_name, 'b') 36 | self.assertEqual(m1s[1].arg1, 2) 37 | 38 | # Collect: decorator 2 39 | m2s = method_decorator_2.all_decorators_from(A) 40 | self.assertEqual(len(m2s), 1) 41 | self.assertEqual(m2s[0].method_name, 'c') 42 | self.assertEqual(m2s[0].arg2, 3) 43 | 44 | # === Test: now try to mix it with other decorators 45 | class B: 46 | @nop_decorator # won't hide it 47 | @method_decorator_1(0) 48 | def a(self): pass 49 | 50 | @method_decorator_1(0) 51 | @nop_decorator 52 | def b(self): pass 53 | 54 | # isinstance() checks 55 | # They work even through the second decorator 56 | self.assertTrue(isinstance(method_decorator.get_method_decorator(B, 'a'), method_decorator_1)) 57 | self.assertTrue(isinstance(method_decorator.get_method_decorator(B, 'b'), method_decorator_1)) 58 | 59 | # Collect 60 | m1s = method_decorator_1.all_decorators_from(B) 61 | self.assertEqual(len(m1s), 2) 62 | self.assertEqual(m1s[0].method_name, 'a') 63 | self.assertEqual(m1s[1].method_name, 'b') 64 | 65 | # === Test: apply two method_decorators at the same time! 66 | class C: 67 | @nop_decorator 68 | @method_decorator_1(1) 69 | @method_decorator_2(2) 70 | def a(self): pass 71 | 72 | @nop_decorator 73 | @method_decorator_2(3) 74 | @method_decorator_1(4) 75 | def b(self): pass 76 | 77 | # isinstance() checks 78 | self.assertTrue(isinstance(method_decorator.get_method_decorator(C, 'a'), method_decorator_1)) 79 | self.assertTrue(isinstance(method_decorator.get_method_decorator(C, 'a'), method_decorator_2)) 80 | self.assertTrue(isinstance(method_decorator.get_method_decorator(C, 'b'), method_decorator_1)) 81 | self.assertTrue(isinstance(method_decorator.get_method_decorator(C, 'b'), method_decorator_2)) 82 | 83 | # Collect: decorator 1 84 | m1s = method_decorator_1.all_decorators_from(C) 85 | self.assertEqual(len(m1s), 2) 86 | self.assertEqual(m1s[0].method_name, 'a') 87 | self.assertEqual(m1s[0].arg1, 1) 88 | self.assertEqual(m1s[1].method_name, 'b') 89 | self.assertEqual(m1s[1].arg1, 4) 90 | 91 | # Collect: decorator 2 92 | m2s = method_decorator_2.all_decorators_from(C) 93 | self.assertEqual(len(m2s), 2) 94 | self.assertEqual(m2s[0].method_name, 'a') 95 | self.assertEqual(m2s[0].arg2, 2) 96 | self.assertEqual(m2s[1].method_name, 'b') 97 | self.assertEqual(m2s[1].arg2, 3) 98 | 99 | 100 | 101 | # Example wrappers 102 | 103 | class method_decorator_1(method_decorator): 104 | def __init__(self, arg): 105 | super().__init__() 106 | self.arg1 = arg 107 | 108 | 109 | class method_decorator_2(method_decorator): 110 | def __init__(self, arg): 111 | super().__init__() 112 | self.arg2 = arg 113 | 114 | 115 | def nop_decorator(f): 116 | @wraps(f) 117 | def wrapper(*a, **k): 118 | return f(*a, **k) 119 | return wrapper 120 | -------------------------------------------------------------------------------- /tests/t_modelhistoryproxy_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from sqlalchemy.orm import load_only, lazyload 3 | 4 | from . import models 5 | from .util import ExpectedQueryCounter 6 | from mongosql.util.history_proxy import ModelHistoryProxy 7 | 8 | 9 | class HistoryTest(unittest.TestCase): 10 | """ Test MongoQuery """ 11 | 12 | @classmethod 13 | def setUpClass(cls): 14 | # Init db 15 | cls.engine, cls.Session = models.get_working_db_for_tests(autoflush=False) 16 | 17 | def test_model_history__loses_history_on_flush(self): 18 | # Session is reused 19 | ssn = self.Session() 20 | assert ssn.autoflush is False, 'these tests rely on Session not having an autoflush' 21 | 22 | # === Test 1: ModelHistoryProxy does not lose history when lazyloading a column 23 | user = ssn.query(models.User).options( 24 | load_only('name'), 25 | ).get(1) 26 | 27 | with ExpectedQueryCounter(self.engine, 0, 'Expected no queries here'): 28 | # Prepare a ModelHistoryProxy 29 | old_user_hist = ModelHistoryProxy(user) 30 | 31 | # Modify 32 | user.name = 'CHANGED' 33 | 34 | # History works 35 | self.assertEqual(old_user_hist.name, 'a') 36 | 37 | # Load a column 38 | with ExpectedQueryCounter(self.engine, 1, 'Expected 1 lazyload query'): 39 | user.age 40 | 41 | # History is NOT broken! 42 | self.assertEqual(old_user_hist.name, 'a') 43 | 44 | # Change another column; history is NOT broken! 45 | user.age = 1800 46 | self.assertEqual(old_user_hist.age, 18) 47 | 48 | 49 | 50 | # === Test 2: ModelHistoryProxy does not lose history when lazyloading a one-to-many relationship 51 | user = ssn.query(models.User).get(1) 52 | 53 | with ExpectedQueryCounter(self.engine, 0, 'Expected no queries here'): 54 | # Prepare a ModelHistoryProxy 55 | old_user_hist = ModelHistoryProxy(user) 56 | 57 | # Modify 58 | user.name = 'CHANGED' 59 | 60 | # History works 61 | self.assertEqual(old_user_hist.name, 'a') 62 | 63 | # Load a relationship 64 | with ExpectedQueryCounter(self.engine, 1, 'Expected 1 lazyload query'): 65 | list(user.articles) 66 | 67 | # History is NOT broken! 68 | self.assertEqual(old_user_hist.name, 'a') 69 | 70 | 71 | 72 | # === Test 3: ModelHistoryProxy does not lose history when lazyloading a one-to-one 73 | # We intentionally choose an article by another user (uid=2), 74 | # because User(uid=1) is cached in the session, and accessing `article.user` would just reuse it. 75 | # We want a new query, however 76 | article = ssn.query(models.Article).get(20) 77 | 78 | with ExpectedQueryCounter(self.engine, 0, 'Expected no queries here'): 79 | # Prepare a ModelHistoryProxy 80 | old_article_hist = ModelHistoryProxy(article) 81 | 82 | # Modify 83 | article.title = 'CHANGED' 84 | 85 | # History works 86 | self.assertEqual(old_article_hist.title, '20') 87 | 88 | # Load a relationship 89 | with ExpectedQueryCounter(self.engine, 1, 'Expected 1 lazyload query'): 90 | article.user 91 | 92 | # History is NOT broken! 93 | self.assertEqual(old_article_hist.title, '20') 94 | 95 | 96 | 97 | # === Test 4: ModelHistoryProxy does not lose history when flushing a session 98 | user = ssn.query(models.User).options( 99 | load_only('name'), 100 | ).get(1) 101 | with ExpectedQueryCounter(self.engine, 0, 'Expected no queries here'): 102 | # Prepare a ModelHistoryProxy 103 | old_user_hist = ModelHistoryProxy(user) 104 | 105 | # Modify 106 | user.name = 'CHANGED' 107 | 108 | # History works 109 | self.assertEqual(old_user_hist.name, 'a') 110 | 111 | # Flush session 112 | ssn.flush() 113 | 114 | # History is NOT broken 115 | self.assertEqual(old_user_hist.name, 'a') 116 | 117 | # Undo 118 | ssn.rollback() # undo our changes 119 | ssn.close() # have to close(), or other queries might hang 120 | 121 | 122 | def test_model_history__both_classes(self): 123 | ssn = self.Session() 124 | # Get a user from the DB 125 | user = ssn.query(models.User).options( 126 | lazyload('*') 127 | ).get(1) 128 | 129 | # Prepare two history objects 130 | old_user = ModelHistoryProxy(user) 131 | 132 | # Check `user` properties 133 | self.assertEqual(user.id, 1) 134 | self.assertEqual(user.name, 'a') 135 | self.assertEqual(user.age, 18) 136 | self.assertEqual(user.tags, ['1', 'a']) 137 | 138 | # === Test: columns 139 | # Check `old_user` properties 140 | # self.assertEqual(old_user.id, 1) 141 | self.assertEqual(old_user.name, 'a') 142 | self.assertEqual(old_user.age, 18) 143 | self.assertEqual(old_user.tags, ['1', 'a']) 144 | 145 | # Change `user` 146 | user.id = 1000 147 | user.name = 'aaaa' 148 | user.age = 1800 149 | user.tags = [1000,] 150 | 151 | # Check `old_user` retains properties 152 | self.assertEqual(old_user.id, 1) 153 | self.assertEqual(old_user.name, 'a') 154 | self.assertEqual(old_user.age, 18) 155 | self.assertEqual(old_user.tags, ['1', 'a']) 156 | 157 | # Undo 158 | ssn.close() 159 | 160 | # Older tests 161 | 162 | def test_change_field(self): 163 | ssn = self.Session() 164 | comment = ssn.query(models.Comment).get(100) 165 | old_text = comment.text 166 | comment.text = 'Changed two' 167 | hist = ModelHistoryProxy(comment) 168 | # When you load a relationship, model history is dropped 169 | # This happens because History is reset on flush(), which happens with a query 170 | user = comment.user 171 | self.assertEqual(hist.text, old_text) 172 | ssn.close() 173 | 174 | # Test for json fields 175 | ssn = self.Session() 176 | article = ssn.query(models.Article).get(10) 177 | old_rating = article.data['rating'] 178 | hist = ModelHistoryProxy(article) 179 | article.data['rating'] = 11111 180 | 181 | self.assertEqual(hist.data['rating'], old_rating) 182 | article.data = {'one': {'two': 2}} 183 | ssn.add(article) 184 | ssn.flush() 185 | article = ssn.query(models.Article).get(10) 186 | hist = ModelHistoryProxy(article) 187 | article.data['one']['two'] = 10 188 | self.assertEqual(hist.data['one']['two'], 2) 189 | 190 | # Undo 191 | ssn.rollback() 192 | ssn.close() 193 | 194 | def test_model_property(self): 195 | ssn = self.Session() 196 | 197 | # Get one comment 198 | comment = ssn.query(models.Comment).get(100) 199 | 200 | # Check the original value 201 | old_value = '0-a' 202 | self.assertEqual(comment.comment_calc, old_value) 203 | 204 | # Change the value of another attribute: the one @property depends on 205 | comment.text = 'Changed one' 206 | 207 | # Try to build history after the fact 208 | hist = ModelHistoryProxy(comment) 209 | 210 | # Historical value for a @property 211 | self.assertEqual(hist.comment_calc, old_value) 212 | 213 | # Current value 214 | self.assertEqual(comment.comment_calc, 'one') 215 | 216 | # Undo 217 | ssn.close() 218 | 219 | def test_relation(self): 220 | ssn = self.Session() 221 | comment = ssn.query(models.Comment).get(100) 222 | 223 | old_id = comment.user.id 224 | hist = ModelHistoryProxy(comment) 225 | new_user = ssn.query(models.User).filter(models.User.id != old_id).first() 226 | comment.user = new_user 227 | article = comment.article # load a relationship; see that history is not reset 228 | self.assertEqual(hist.user.id, old_id) # look how we can access relationship's attrs through history! 229 | 230 | article = ssn.query(models.Article).get(10) 231 | 232 | old_commensts = set([c.id for c in article.comments]) 233 | article.comments = article.comments[:1] 234 | hist = ModelHistoryProxy(article) 235 | u = article.user # load a relationship; see that history is not reset 236 | self.assertEqual(old_commensts, set([c.id for c in hist.comments])) 237 | 238 | # Undo 239 | ssn.close() 240 | -------------------------------------------------------------------------------- /tests/t_raiseload_col_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | from sqlalchemy.exc import NoSuchColumnError 5 | from sqlalchemy.orm import Load 6 | from sqlalchemy import exc as sa_exc 7 | 8 | from .util import ExpectedQueryCounter 9 | 10 | try: 11 | from mongosql import raiseload_col 12 | except ImportError: 13 | raiseload_col = None 14 | 15 | from . import models 16 | 17 | 18 | class RaiseloadTesterMixin: 19 | def assertRaiseloadWorked(self, entity, loaded, raiseloaded, unloaded): 20 | """ Test columns and their load state 21 | 22 | :param entity: the entity 23 | :param loaded: column names that are loaded and may be accessed without emitting any sql queries 24 | :param raiseloaded: column names that will raise an InvalidRequestError when accessed 25 | :param unloaded: column names that will emit 1 query when accessed 26 | """ 27 | # loaded 28 | for name in loaded: 29 | with ExpectedQueryCounter(self.engine, 0, 30 | 'Unexpected query while accessing column {}'.format(name)): 31 | getattr(entity, name) 32 | 33 | # raiseloaded 34 | for name in raiseloaded: 35 | with self.assertRaises(sa_exc.InvalidRequestError, msg='Exception was not raised when accessing attr `{}`'.format(name)): 36 | getattr(entity, name) 37 | 38 | # unloaded 39 | for name in unloaded: 40 | with ExpectedQueryCounter(self.engine, 1, 41 | 'Expected one query while accessing column {}'.format(name)): 42 | getattr(entity, name) 43 | 44 | 45 | class RaiseloadColTest(RaiseloadTesterMixin, unittest.TestCase): 46 | @classmethod 47 | def setUpClass(cls): 48 | cls.engine, cls.Session = models.get_working_db_for_tests() 49 | 50 | def test_defer_pk(self): 51 | """ Test: we can't defer a PK """ 52 | # Test: try to defer a PK with load_only() 53 | ssn = self.Session() 54 | u = ssn.query(models.User).options( 55 | Load(models.User).load_only(models.User.name), 56 | ).first() 57 | 58 | self.assertRaiseloadWorked(u, 59 | loaded={'id', 'name'}, # PK is still here 60 | raiseloaded={}, 61 | unloaded={'age', 'tags'}) 62 | 63 | # defer()'s default behavior has no failsafe mechanism: it will defer a PK 64 | ssn = self.Session() 65 | with self.assertRaises(NoSuchColumnError): 66 | u = ssn.query(models.User).options( 67 | Load(models.User).undefer(models.User.name), 68 | Load(models.User).defer('*'), # as it happens, sqlalchemy will actually let us defer a PK! be careful! 69 | ).first() 70 | 71 | @pytest.mark.skipif(raiseload_col is None, reason='nplus1loader is not available') 72 | def test_raiseload_col(self): 73 | """ raiseload_col() on a single column """ 74 | # raiseload_rel() some columns 75 | ssn = self.Session() 76 | u = ssn.query(models.User).options( 77 | Load(models.User).load_only(models.User.name), 78 | Load(models.User).raiseload_col(models.User.tags, models.User.age), 79 | ).first() 80 | 81 | self.assertRaiseloadWorked(u, 82 | loaded={'id', 'name'}, 83 | raiseloaded={'age', 'tags'}, 84 | unloaded={}) 85 | 86 | # raiseload_col() on a PK destroys entity loading, and sqlalchemy gives an error 87 | with self.assertRaises(NoSuchColumnError): 88 | ssn = self.Session() 89 | u = ssn.query(models.User).options( 90 | Load(models.User).raiseload_col(models.User.id), 91 | ).first() 92 | 93 | @pytest.mark.skipif(raiseload_col is None, reason='nplus1loader is not available') 94 | def test_raiseload_star(self): 95 | """ raiseload_col('*') """ 96 | ssn = self.Session() 97 | 98 | # raiseload_col() will defer our PKs! 99 | with self.assertRaises(NoSuchColumnError): 100 | u = ssn.query(models.User).options( 101 | Load(models.User).load_only(models.User.name), 102 | Load(models.User).raiseload_col('*') 103 | ).first() 104 | 105 | # Have to undefer() PKs manually 106 | u = ssn.query(models.User).options( 107 | Load(models.User).load_only(models.User.name), 108 | Load(models.User).undefer(models.User.id), # undefer PK manually 109 | Load(models.User).raiseload_col('*') 110 | ).first() 111 | 112 | self.assertRaiseloadWorked(u, 113 | loaded={'id', 'name'}, 114 | raiseloaded={'age', 'tags'}, 115 | unloaded={}) 116 | 117 | @pytest.mark.skipif(raiseload_col is None, reason='nplus1loader is not available') 118 | def test_interaction_with_other_options(self): 119 | # === Test: just load_only() 120 | # NOTE: we have to restart ssn = self.Session() every time because otherwise SqlAlchemy is too clever and caches entities in the session!! 121 | ssn = self.Session() 122 | user = ssn.query(models.User).options( 123 | Load(models.User).load_only('name', 'age'), # only these two 124 | ).first() 125 | 126 | self.assertRaiseloadWorked( 127 | user, 128 | loaded={'id', 'name', 'age'}, 129 | raiseloaded={}, 130 | unloaded={'tags', 131 | 'articles', 'comments'} 132 | ) 133 | 134 | # === Test: raiseload_rel() 135 | ssn = self.Session() 136 | user = ssn.query(models.User).options( 137 | Load(models.User).load_only('name', 'age'), # only these two 138 | Load(models.User).undefer(models.User.id), # undefer PK manually 139 | Load(models.User).raiseload('*'), 140 | ).first() 141 | 142 | self.assertRaiseloadWorked( 143 | user, 144 | loaded={'id', 'name', 'age'}, 145 | raiseloaded={'articles', 'comments'}, 146 | unloaded={'tags'} 147 | ) 148 | 149 | # === Test: raiseload_rel() + raiseload_col() 150 | ssn = self.Session() 151 | user = ssn.query(models.User).options( 152 | Load(models.User).load_only('name', 'age'), # only these two 153 | Load(models.User).undefer(models.User.id), # undefer PK manually 154 | Load(models.User).raiseload_col('*'), 155 | Load(models.User).raiseload('*'), 156 | ).first() 157 | 158 | self.assertRaiseloadWorked( 159 | user, 160 | loaded={'id', 'name', 'age'}, 161 | raiseloaded={'tags', 162 | 'articles', 'comments'}, 163 | unloaded={} 164 | ) 165 | 166 | # More tests in: 167 | # tests.t4_query_test.QueryTest#test_raise 168 | -------------------------------------------------------------------------------- /tests/t_selectinquery_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from random import shuffle 3 | from sqlalchemy.orm import defaultload, selectinload 4 | 5 | from . import models 6 | from .util import QueryLogger, TestQueryStringsMixin 7 | from .saversion import SA_SINCE, SA_UNTIL 8 | from mongosql import selectinquery 9 | 10 | 11 | # Detect SqlAlchemy version 12 | # We need to differentiate, because: 13 | # in 1.2.x, selectinload() builds a JOIN query from the left entity to the right entity 14 | # in 1.3.x, selectinload() queries just the right entity, and filters by the foreign key field directly 15 | from mongosql import SA_12, SA_13 16 | 17 | 18 | class SelectInQueryLoadTest(unittest.TestCase, TestQueryStringsMixin): 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.engine, cls.Session = models.get_working_db_for_tests() 22 | cls.ssn = cls.Session() # let every test reuse the same session; expect some interference issues 23 | 24 | def test_filter(self): 25 | """ selectinquery() + filter """ 26 | engine, ssn = self.engine, self.ssn 27 | 28 | # Test: load a relationship, filtered 29 | with QueryLogger(engine) as ql: 30 | q = ssn.query(models.User).options(selectinquery( 31 | models.User.articles, 32 | lambda q, **kw: q.filter(models.Article.id.between(11,21)) 33 | )) 34 | res = q.all() 35 | 36 | # Test query 37 | if SA_12: 38 | # SqlAlchemy 1.2.x used to make a JOIN 39 | self.assertQuery(ql[1], 40 | 'FROM u AS u_1 JOIN a ON u_1.id = a.uid', 41 | 'WHERE u_1.id IN (1, 2, 3) AND ' 42 | 'a.id BETWEEN 11 AND 21 ' 43 | 'ORDER BY u_1.id', 44 | ) 45 | else: 46 | # SqlAlchemy 1.3.x uses foreign keys directly, no joins 47 | self.assertNotIn(ql[1], 'JOIN') 48 | self.assertQuery(ql[1], 49 | 'WHERE a.uid IN (1, 2, 3) AND ', 50 | 'a.id BETWEEN 11 AND 21 ', 51 | # v1.3.16: no ordering by PK anymore 52 | 'ORDER BY a.uid' if SA_UNTIL('1.3.15') else '', 53 | ) 54 | 55 | 56 | # Test results 57 | self.assert_users_articles_comments(res, 3, 4, None) # 3 users, 4 articles in total 58 | 59 | def test_plain_old_selectinload(self): 60 | """ Test plain selectinload() """ 61 | engine, ssn = self.engine, self.ssn 62 | 63 | with QueryLogger(self.engine) as ql: 64 | q = ssn.query(models.User).options(selectinload(models.User.articles)) 65 | res = q.all() 66 | 67 | # Test query 68 | if SA_12: 69 | self.assertQuery(ql[1], 70 | 'WHERE u_1.id IN (1, 2, 3)', 71 | # v1.3.16: no ordering by PK anymore 72 | 'ORDER BY u_1.id' if SA_UNTIL('1.3.15') else '', 73 | ) 74 | else: 75 | self.assertQuery(ql[1], 76 | 'WHERE a.uid IN (1, 2, 3)', 77 | # v1.3.16: no ordering by PK anymore 78 | 'ORDER BY a.uid' if SA_UNTIL('1.3.15') else '', 79 | ) 80 | 81 | # Test results 82 | self.assert_users_articles_comments(res, 3, 6, None) # 3 users, 6 articles in total 83 | 84 | 85 | def test_options(self): 86 | """ selectinquery() + options(load_only()) + limit """ 87 | engine, ssn = self.engine, self.ssn 88 | 89 | with QueryLogger(engine) as ql: 90 | q = ssn.query(models.User).options(selectinquery( 91 | models.User.articles, 92 | # Notice how we still have to apply the options using the relationship! 93 | lambda q, **kw: q.options(defaultload(models.User.articles) 94 | .load_only(models.Article.title)).limit(1) 95 | )) 96 | 97 | res = q.all() 98 | 99 | # Test query 100 | self.assertQuery(ql[1], 'LIMIT 1') 101 | if SA_12: 102 | self.assertSelectedColumns(ql[1], 'a.id', 'u_1.id', 'a.title') # PK, FK, load_only() 103 | else: 104 | self.assertSelectedColumns(ql[1], 'a.id', 'a.uid', 'a.title') # PK, FK, load_only() 105 | 106 | # Test results 107 | self.assert_users_articles_comments(res, 3, 1, None) # 3 users, 1 article in total ; just one, because of the limit 108 | 109 | def test_options_joinedload(self): 110 | """ selectinquery() + options(joinedload()) """ 111 | engine, ssn = self.engine, self.ssn 112 | 113 | with QueryLogger(engine) as ql: 114 | q = ssn.query(models.User).options(selectinquery( 115 | models.User.articles, 116 | lambda q, **kw: q.options(defaultload(models.User.articles) 117 | .joinedload(models.Article.comments)) 118 | )) 119 | 120 | res = q.all() 121 | 122 | # Test query 123 | self.assertQuery(ql[1], 'LEFT OUTER JOIN c AS c_1 ON a.id = c_1.aid') 124 | 125 | # Test results 126 | self.assert_users_articles_comments(res, 3, 6, 9) # 3 users, 6 articles, 9 comments 127 | 128 | def test_options_selectinload(self): 129 | """ selectinquery() + options(selectinload()) """ 130 | engine, ssn = self.engine, self.ssn 131 | 132 | with QueryLogger(engine) as ql: 133 | q = ssn.query(models.User).options(selectinquery( 134 | models.User.articles, 135 | lambda q, **kw: q.options(defaultload(models.User.articles) 136 | .selectinload(models.Article.comments)) 137 | )) 138 | 139 | res = q.all() 140 | 141 | # Test second query 142 | if SA_12: 143 | self.assertQuery(ql[2], 'JOIN c') 144 | else: 145 | self.assertQuery(ql[2], 'FROM c') 146 | 147 | # Test results 148 | self.assert_users_articles_comments(res, 3, 6, 9) # 3 users, 6 articles, 9 comments 149 | 150 | def test_options_selectinquery(self): 151 | """ selectinquery() + load_only() + options(selectinquery() + load_only()) """ 152 | engine, ssn = self.engine, self.ssn 153 | 154 | with QueryLogger(engine) as ql: 155 | q = ssn.query(models.User).options(selectinquery( 156 | models.User.articles, 157 | lambda q, **kw: q 158 | .filter(models.Article.id > 10) # first level filter() 159 | .options(defaultload(models.User.articles) 160 | .load_only(models.Article.title) # first level options() 161 | .selectinquery(models.Article.comments, 162 | lambda q, **kw: 163 | q 164 | .filter(models.Comment.uid > 1) # second level filter() 165 | .options( 166 | defaultload(models.User.articles) 167 | .defaultload(models.Article.comments) 168 | .load_only(models.Comment.text) # second level options() 169 | ))) 170 | )) 171 | 172 | res = q.all() 173 | 174 | # Test query 175 | self.assertQuery(ql[1], 'AND a.id > 10') 176 | 177 | if SA_12: 178 | self.assertSelectedColumns(ql[1], 'a.id', 'u_1.id', 'a.title') # PK, FK, load_only() 179 | else: 180 | self.assertSelectedColumns(ql[1], 'a.id', 'a.uid', 'a.title') # PK, FK, load_only() 181 | 182 | # Test second query 183 | self.assertQuery(ql[2], 'AND c.uid > 1') 184 | 185 | if SA_12: 186 | self.assertSelectedColumns(ql[2], 'c.id', 'a_1.id', 'c.text') # PK, FK, load_only() 187 | else: 188 | self.assertSelectedColumns(ql[2], 'c.id', 'c.aid', 'c.text') # PK, FK, load_only() 189 | 190 | # Test results 191 | self.assert_users_articles_comments(res, 3, 5, 1) # 3 users, 5 articles, 1 comment 192 | 193 | # Re-run all tests in wild combinations 194 | def test_all_tests_interference(self): 195 | """ Repeat all tests by randomly mixing them and running them in different order 196 | to make sure that they do not interfere with each other """ 197 | all_tests = (getattr(self, name) 198 | for name in dir(self) 199 | if name.startswith('test_') 200 | and name != 'test_all_tests_interference') 201 | 202 | for i in range(20): 203 | # Make a randomized mix of all tests 204 | tests = list(all_tests) 205 | shuffle(tests) 206 | 207 | # Run them all 208 | print('='*20 + ' Random run #{}'.format(i)) 209 | for t in tests: 210 | try: 211 | # Repeat every test several times 212 | for n in range(3): 213 | t() 214 | except unittest.SkipTest: pass # proceed 215 | 216 | def assert_users_articles_comments(self, users, n_users, n_articles=None, n_comments=None): 217 | self.assertEqual(len(users), n_users) 218 | if n_articles is not None: 219 | self.assertEqual(sum(len(u.articles) for u in users), n_articles) 220 | if n_comments is not None: 221 | self.assertEqual(sum(sum(len(a.comments) for a in u.articles) for u in users), n_comments) 222 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | from sqlalchemy import event 5 | from sqlalchemy.orm import Query 6 | from sqlalchemy.dialects import postgresql as pg 7 | 8 | PY2 = sys.version_info[0] == 2 9 | 10 | 11 | def _insert_query_params(statement_str, parameters, dialect): 12 | """ Compile a statement by inserting *unquoted* parameters into the query """ 13 | return statement_str % parameters 14 | 15 | 16 | def stmt2sql(stmt): 17 | """ Convert an SqlAlchemy statement into a string """ 18 | # See: http://stackoverflow.com/a/4617623/134904 19 | # This intentionally does not escape values! 20 | dialect = pg.dialect() 21 | query = stmt.compile(dialect=dialect) 22 | return _insert_query_params(query.string, query.params, pg.dialect()) 23 | 24 | 25 | def q2sql(q): 26 | """ Convert an SqlAlchemy query to string """ 27 | return stmt2sql(q.statement) 28 | 29 | 30 | class TestQueryStringsMixin: 31 | """ unittest mixin that will help testing query strings """ 32 | 33 | def assertQuery(self, qs, *expected_lines): 34 | """ Compare a query line by line 35 | 36 | Problem: because of dict disorder, you can't just compare a query string: columns and expressions may be present, 37 | but be in a completely different order. 38 | Solution: compare a query piece by piece. 39 | To achieve this, you've got to feed the query as a string where every logical piece 40 | is separated by \n, and we compare the pieces. 41 | It also removes trailing commas. 42 | 43 | :param expected_lines: the query, separated into pieces 44 | """ 45 | try: 46 | # Query? 47 | if isinstance(qs, Query): 48 | qs = q2sql(qs) 49 | 50 | # tuple 51 | expected_lines = '\n'.join(expected_lines) 52 | 53 | # Test 54 | for line in expected_lines.splitlines(): 55 | self.assertIn(line.strip().rstrip(','), qs) 56 | 57 | # Done 58 | return qs 59 | except: 60 | print(qs) 61 | raise 62 | 63 | @staticmethod 64 | def _qs_selected_columns(qs): 65 | """ Get the set of column names from the SELECT clause 66 | 67 | Example: 68 | SELECT a, u.b, c AS c_1, u.d AS u_d 69 | -> {'a', 'u.b', 'c', 'u.d'} 70 | """ 71 | rex = re.compile(r'^SELECT (.*?)\s+FROM') 72 | # Match 73 | m = rex.match(qs) 74 | # Results 75 | if not m: 76 | return set() 77 | selected_columns_str = m.group(1) 78 | # Match results 79 | rex = re.compile(r'(\S+?)(?: AS \w+)?(?:,|$)') # column names, no 'as' 80 | return set(rex.findall(selected_columns_str)) 81 | 82 | def assertSelectedColumns(self, qs, *expected): 83 | """ Test that the query has certain columns in the SELECT clause 84 | 85 | :param qs: Query | query string 86 | :param expected: list of expected column names 87 | :returns: query string 88 | """ 89 | # Query? 90 | if isinstance(qs, Query): 91 | qs = q2sql(qs) 92 | 93 | try: 94 | self.assertEqual( 95 | self._qs_selected_columns(qs), 96 | set(expected) 97 | ) 98 | return qs 99 | except: 100 | print(qs) 101 | raise 102 | 103 | 104 | class QueryCounter: 105 | """ Counts the number of queries """ 106 | 107 | def __init__(self, engine): 108 | super(QueryCounter, self).__init__() 109 | self.engine = engine 110 | self.n = 0 111 | 112 | def start_logging(self): 113 | event.listen(self.engine, "after_cursor_execute", self._after_cursor_execute_event_handler, named=True) 114 | 115 | def stop_logging(self): 116 | event.remove(self.engine, "after_cursor_execute", self._after_cursor_execute_event_handler) 117 | self._done() 118 | 119 | def _done(self): 120 | """ Handler executed when logging is stopped """ 121 | 122 | def _after_cursor_execute_event_handler(self, **kw): 123 | self.n += 1 124 | 125 | def print_log(self): 126 | pass # nothing to do 127 | 128 | # Context manager 129 | 130 | def __enter__(self): 131 | self.start_logging() 132 | return self 133 | 134 | def __exit__(self, *exc): 135 | self.stop_logging() 136 | if exc != (None, None, None): 137 | self.print_log() 138 | return False 139 | 140 | 141 | class QueryLogger(QueryCounter, list): 142 | """ Log raw SQL queries on the given engine """ 143 | 144 | def _after_cursor_execute_event_handler(self, **kw): 145 | super(QueryLogger, self)._after_cursor_execute_event_handler() 146 | # Compile, append 147 | self.append(_insert_query_params(kw['statement'], kw['parameters'], kw['context'])) 148 | 149 | def print_log(self): 150 | for i, q in enumerate(self): 151 | print('=' * 5, ' Query #{}'.format(i)) 152 | print(q) 153 | 154 | 155 | class ExpectedQueryCounter(QueryLogger): 156 | """ A QueryLogger that expects a certain number of queries, raises an error otherwise """ 157 | 158 | def __init__(self, engine, expected_queries, comment): 159 | super(ExpectedQueryCounter, self).__init__(engine) 160 | self.expected_queries = expected_queries 161 | self.comment = comment 162 | 163 | def _done(self): 164 | if self.n != self.expected_queries: 165 | self.print_log() 166 | raise AssertionError('{} (expected {} queries, actually had {})' 167 | .format(self.comment, self.expected_queries, self.n)) 168 | 169 | --------------------------------------------------------------------------------