├── .editorconfig ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .readthedocs.yaml ├── CHANGELOG.rst ├── CONTRIBUTING.rst ├── LICENSE ├── README.rst ├── docs ├── Makefile ├── make.bat └── source │ ├── api │ ├── index.rst │ ├── modules.rst │ ├── odata_query.ast.rst │ ├── odata_query.django.django_q.rst │ ├── odata_query.django.django_q_ext.rst │ ├── odata_query.django.rst │ ├── odata_query.django.shorthand.rst │ ├── odata_query.django.utils.rst │ ├── odata_query.exceptions.rst │ ├── odata_query.grammar.rst │ ├── odata_query.rewrite.rst │ ├── odata_query.roundtrip.rst │ ├── odata_query.rst │ ├── odata_query.sql.athena.rst │ ├── odata_query.sql.base.rst │ ├── odata_query.sql.rst │ ├── odata_query.sql.sqlite.rst │ ├── odata_query.sqlalchemy.common.rst │ ├── odata_query.sqlalchemy.core.rst │ ├── odata_query.sqlalchemy.functions_ext.rst │ ├── odata_query.sqlalchemy.orm.rst │ ├── odata_query.sqlalchemy.rst │ ├── odata_query.sqlalchemy.shorthand.rst │ ├── odata_query.typing.rst │ ├── odata_query.utils.rst │ └── odata_query.visitor.rst │ ├── changelog.rst │ ├── conf.py │ ├── contributing.rst │ ├── deviations-and-extensions.rst │ ├── django.rst │ ├── glossary.rst │ ├── index.rst │ ├── parsing-odata.rst │ ├── snippets │ ├── modifying.rst │ └── parsing.rst │ ├── sql.rst │ ├── sqlalchemy.rst │ └── working-with-ast.rst ├── odata_query ├── __init__.py ├── ast.py ├── django │ ├── __init__.py │ ├── django_q.py │ ├── django_q_ext.py │ ├── shorthand.py │ └── utils.py ├── exceptions.py ├── grammar.py ├── py.typed ├── rewrite.py ├── roundtrip.py ├── sql │ ├── __init__.py │ ├── athena.py │ ├── base.py │ └── sqlite.py ├── sqlalchemy │ ├── __init__.py │ ├── common.py │ ├── core.py │ ├── functions_ext.py │ ├── orm.py │ └── shorthand.py ├── typing.py ├── utils.py └── visitor.py ├── poetry.lock ├── pyproject.toml ├── setup.cfg ├── sonar-project.properties ├── tests ├── __init__.py ├── conftest.py ├── data │ └── world_borders.zip ├── integration │ ├── __init__.py │ ├── django │ │ ├── __init__.py │ │ ├── apps.py │ │ ├── conftest.py │ │ ├── models.py │ │ ├── settings.py │ │ ├── test_odata_to_django_q.py │ │ ├── test_querying.py │ │ └── test_utils.py │ ├── django_geo │ │ ├── __init__.py │ │ ├── apps.py │ │ ├── conftest.py │ │ ├── models.py │ │ └── test_querying.py │ ├── sql │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_odata_to_athena_sql.py │ │ ├── test_odata_to_sql.py │ │ ├── test_odata_to_sqlite.py │ │ └── test_querying.py │ ├── sqlalchemy │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── models.py │ │ ├── test_odata_to_sqlalchemy_core.py │ │ ├── test_odata_to_sqlalchemy_orm.py │ │ └── test_querying.py │ └── test_roundtrip.py └── unit │ ├── __init__.py │ ├── django │ └── __init__.py │ ├── sql │ ├── __init__.py │ ├── test_ast_to_sql.py │ └── test_ast_to_sqlite.py │ ├── test_odata_parser.py │ ├── test_rewrite.py │ ├── test_typing.py │ ├── test_utils.py │ └── test_visitor.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [Makefile] 14 | indent_style = tab 15 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Run tests & linting 2 | 3 | on: [push, pull_request, workflow_dispatch] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Install GIS libraries 15 | run: sudo apt-get install -y binutils libproj-dev gdal-bin libsqlite3-mod-spatialite 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install tox tox-gh-actions 24 | - name: Run tests 25 | run: tox 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | *.cover 46 | .hypothesis/ 47 | .pytest_cache/ 48 | coverage.xml 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | .envrc 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | # Serverless 106 | .serverless/ 107 | node_modules/ 108 | 109 | tests/integration/django/db/* 110 | tests/data/world_borders/* 111 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | 17 | # Optionally build your docs in additional formats such as PDF 18 | formats: [] 19 | 20 | # Optionally set the version of Python and requirements required to build your docs 21 | python: 22 | install: 23 | - method: pip 24 | path: . 25 | extra_requirements: 26 | - docs 27 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | 2 | Changelog 3 | ========= 4 | 5 | All notable changes to this project will be documented in this file. 6 | 7 | The format is based on `Keep a Changelog `_\ , 8 | and this project adheres to `Semantic Versioning `_. 9 | 10 | [Unreleased] 11 | --------------------- 12 | 13 | Changed 14 | ^^^^^^^ 15 | 16 | * Log ``debug`` instead of ``warning`` when type inference fails. 17 | 18 | 19 | [0.10.0] - 2024-01-21 20 | --------------------- 21 | 22 | Added 23 | ^^^^^ 24 | 25 | * Parser,Django: Added support for geo.{intersects, distance, length} functions. 26 | * Python 3.11 coverage in test matrix. 27 | * Parser: Added support for durations specified in years and months. 28 | 29 | 30 | [0.9.1] - 2023-11-15 31 | -------------------- 32 | 33 | Fixed 34 | ^^^^^ 35 | 36 | * SQL: support function calls on right-hand side of `contains` node. 37 | 38 | 39 | [0.9.0] - 2023-08-17 40 | -------------------- 41 | 42 | Fixed 43 | ^^^^^ 44 | 45 | * SQLAlchemy functions defined by ``odata_query`` 46 | 47 | - no longer clash with other functions defined in ``sqlalchemy.func`` with 48 | the same name. 49 | - inherit cache to prevent SQLAlchemy performance warnings. 50 | 51 | 52 | [0.8.1] - 2023-02-17 53 | -------------------- 54 | 55 | Fixed 56 | ^^^^^ 57 | 58 | * SQLAlchemy: Fix detection of pre-existing joins for legacy Query objects. 59 | 60 | 61 | [0.8.0] - 2023-01-20 62 | -------------------- 63 | 64 | Added 65 | ^^^^^ 66 | 67 | * SQL: Support for the SQLite dialect. 68 | 69 | Changed 70 | ^^^^^^^ 71 | 72 | * SQL: Refactored SQL support 73 | 74 | - The ``AstToSqlVisitor`` now implements basic, standard SQL and functions 75 | as a base class for dialects. 76 | - The Athena dialect now lives in ``AstToAthenaSqlVisitor``. 77 | 78 | 79 | [0.7.2] - 2022-12-19 80 | -------------------- 81 | 82 | Fixed 83 | ^^^^^ 84 | 85 | * SQLAlchemy; Return a 1.x ``Query`` if a 1.x ``Query`` is passed in the shorthand. 86 | 87 | 88 | [0.7.1] - 2022-12-16 89 | -------------------- 90 | 91 | Fixed 92 | ^^^^^ 93 | 94 | * SQLAlchemy; Fix accidentally dropped support for 1.x style queries. 95 | 96 | 97 | [0.7.0] - 2022-10-28 98 | -------------------- 99 | 100 | Added 101 | ^^^^^ 102 | 103 | * A new visitor that roundtrips OData AST back to an OData query string. 104 | * SQLAlchemy: Support for SQLAlchemy Core through a new visitor. 105 | 106 | 107 | Changed 108 | ^^^^^^^ 109 | 110 | * Some lower-level SQLAlchemy related modules have been reorganized to 111 | facilitate SQLAlchemy Core support. 112 | 113 | 114 | [0.6.0] - 2022-10-28 115 | -------------------- 116 | 117 | Changed 118 | ^^^^^^^ 119 | 120 | * Django; Rework Django query transformer to use Django query nodes more. 121 | With the release of Django 4, we can now use ``Lookup``, ``Function``, and other 122 | Django query nodes in queries directly instead of relying on the keyworded 123 | syntax. For older Django versions, we can still transform those nodes to the 124 | keyworded syntax. 125 | * Deps; Upgraded Sphinx and the ReadTheDocs theme. 126 | 127 | 128 | Fixed 129 | ^^^^^ 130 | 131 | * Django; Support comparisons of boolean funcs to booleans 132 | (e.g. ``contains(a, 'b') eq true``) 133 | 134 | 135 | [0.5.2] - 2022-03-14 136 | -------------------- 137 | 138 | Changed 139 | ^^^^^^^ 140 | 141 | * Deps; Upgraded pytest. 142 | 143 | Fixed 144 | ^^^^^ 145 | 146 | * SQLAlchemy; Fixed datetime extract functions. 147 | 148 | 149 | [0.5.1] - 2022-02-28 150 | -------------------- 151 | 152 | Fixed 153 | ^^^^^ 154 | 155 | * QA; Remove ``type:ignore`` from ``grammar.py`` and fix resulting type issues. 156 | 157 | 158 | [0.5.0] - 2022-02-28 159 | -------------------- 160 | 161 | Added 162 | ^^^^^ 163 | 164 | * Parser: Rudimentary OData namespace support. 165 | * AST: Literal nodes now have a `py_val` getter that returns the closest Python 166 | approximation to the OData value. 167 | * QA: Added full typing support. 168 | 169 | Changed 170 | ^^^^^^^ 171 | 172 | * QA: Upgraded linting libraries. 173 | 174 | 175 | [0.4.2] - 2021-12-19 176 | -------------------- 177 | 178 | Added 179 | ^^^^^ 180 | 181 | * Docs: Include contribution guidelines and changelog in the main documentation. 182 | 183 | Changed 184 | ^^^^^^^ 185 | 186 | * Docs: Use ReStructuredText instead of markdown where possible, for easier 187 | interaction with Sphinx. 188 | 189 | Removed 190 | ^^^^^^^ 191 | 192 | * Docs: Removed the ``Myst`` dependency as we're no longer mixing markdown into 193 | our docs. 194 | * Dev: Removed the ``moto`` and ``Faker`` dependencies as they weren't used. 195 | 196 | [0.4.1] - 2021-07-16 197 | -------------------- 198 | 199 | Added 200 | ^^^^^ 201 | 202 | * Added shorthands for the most common use cases: Applying an OData filter 203 | straight to a Django QuerySet or SQLAlchemy query. 204 | 205 | Fixed 206 | ^^^^^ 207 | 208 | * Cleared warnings produced in SLY by wrong regex flag placement. 209 | 210 | [0.4.0] - 2021-05-28 211 | -------------------- 212 | 213 | Changed 214 | ^^^^^^^ 215 | 216 | * Raise a new ``InvalidFieldException`` if a field in a query doesn't exist. 217 | 218 | Fixed 219 | ^^^^^ 220 | 221 | * Allow ``AliasRewriter`` to recurse into ``Attribute`` nodes, in order to replace 222 | nodes in the ``Attribute``\ 's ownership chain. 223 | 224 | [0.3.0] - 2021-05-17 225 | -------------------- 226 | 227 | Added 228 | ^^^^^ 229 | 230 | * Added ``NodeTransformers``\ , which are like ``NodeVisitors`` but replace visited 231 | nodes with the returned value. 232 | * Initial API documentation. 233 | 234 | Changed 235 | ^^^^^^^ 236 | 237 | * The AstTo{ORMQuery} visitors for SQLAlchemy and Django now have the same 238 | interface. 239 | * AstToDjangoQVisitor now builds subqueries for ``any()/all()`` itself, instead 240 | of relying on ``SubQueryToken``\ s and a seperate visitor. 241 | * Made all AST Nodes ``frozen`` (read-only), so they can be hashed. 242 | * Replaced ``field_mapping`` on the ORM visitors with a more general 243 | ``AliasRewriter`` based on the new ``NodeTransformers``. 244 | * Refactored ``IdentifierStripper`` to use the new ``NodeTransformers``. 245 | 246 | [0.2.0] - 2021-05-05 247 | -------------------- 248 | 249 | Added 250 | ^^^^^ 251 | 252 | * Transform OData queries to SQLAlchemy expressions with the new 253 | AstToSqlAlchemyClauseVisitor. 254 | 255 | Changed 256 | ^^^^^^^ 257 | 258 | * Don't write a debugfile for the parser by default. 259 | 260 | [0.1.0] - 2021-03-12 261 | -------------------- 262 | 263 | Added 264 | ^^^^^ 265 | 266 | * Initial split to seperate package. 267 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | How to contribute 2 | ================= 3 | 4 | Getting started 5 | --------------- 6 | 7 | So you're interested in contributing to ``odata-query``? That's great! We're 8 | excited to hear your ideas and experiences. 9 | 10 | This file describes all the different ways in which you can contribute. 11 | 12 | 13 | Reporting a bug 14 | ------------------ 15 | 16 | Have you encountered a bug? Please let us know by reporting it. 17 | 18 | Before doing so, take a look at the existing `Issues`_ to make sure the bug you 19 | encountered hasn't already been reported by someone else. If so, we ask you to 20 | reply to the existing Issue rather than creating a new one. Bugs with many 21 | replies will obviously have a higher priority. 22 | 23 | If the bug you encountered has not been reported yet, create a new Issue for it 24 | and make sure to label it as a 'bug'. To allow us to help you as efficiently as 25 | possible, always try to include the following: 26 | 27 | - Which version of ``odata-query`` are you using? 28 | - What went wrong? 29 | - What did you expect to happen? 30 | - Detailed steps to reproduce. 31 | 32 | The maintainer of this repository monitors issues on a regular basis and will 33 | respond to your bug report as soon as possible. 34 | 35 | 36 | Requesting an enhancement 37 | ------------------------- 38 | 39 | Do you have a great idea that could make ``odata-query`` even better? 40 | Feel free to request an enhancement. 41 | 42 | Before doing so, take a look at the existing `Issues`_ to make sure your idea 43 | hasn't already been requested by someone else. If so, we ask you to reply or 44 | give a thumbs-up to the existing Issue rather than creating a new one. Requests 45 | with many replies will obviously have a higher priority. 46 | 47 | If your idea has not been requested yet, create a new Issue for it and make sure 48 | to label it as an 'enhancement'. Explain what your idea is in detail and how it 49 | could improve ``odata-query``. 50 | 51 | The maintainer of this repository monitors issues on a regular basis and will 52 | respond to your request as soon as possible. 53 | 54 | 55 | Contributing code 56 | ----------------- 57 | 58 | Would you prefer to contribute directly by writing some code yourself? That's 59 | great. 60 | 61 | **If your contribution is minor**, such as fixing a bug or typo, we encourage 62 | you to open a pull request right away. 63 | 64 | **If your contribution is major**, such as a new feature or a breaking change, 65 | start by opening an issue first. That way, other people can weigh in on the 66 | discussion before you do any work. 67 | 68 | The workflow for creating a pull request: 69 | 70 | 1. Fork the repository. 71 | 2. ``clone`` your forked repository. 72 | 3. Create a new feature branch from the ``master`` branch. 73 | 4. Make your contributions to the project's code. Please run the tests before 74 | committing, and keep the code style conform to the project's style guide 75 | (listed below). 76 | 5. Add documentation where required. Please keep the style conform. 77 | 6. Add your changes to the :ref:`changelog` in the "Unreleased" 78 | section. Include a link to your GitHub profile for some internet fame. 79 | 7. ``commit`` your changes in logical chunks. 80 | 8. ``push`` your branch to your fork on GitHub. 81 | 9. Create a pull request from your feature branch to the original repository's 82 | ``master`` branch. 83 | 84 | The maintainer of this repository monitors pull requests on a regular basis and 85 | will respond as soon as possible. The smaller individual pull requests are, the 86 | faster the maintainer will be able to respond. 87 | 88 | 89 | Local development 90 | ^^^^^^^^^^^^^^^^^ 91 | 92 | ``odata-query`` uses `poetry`_ to manage the package and its dependencies, and 93 | `tox`_ to run tests and linting in dedicated environments. 94 | 95 | To install the project for development, run: 96 | 97 | .. code-block:: bash 98 | 99 | poetry install -E testing -E linting -E docs 100 | 101 | 102 | To run all tests and linting, run: 103 | 104 | .. code-block:: bash 105 | 106 | tox 107 | 108 | To ensure your files are configured correctly (linting will bug you if they're not), 109 | you can configure your IDE to use the included linting tools OR run them manually 110 | as follows: 111 | 112 | 113 | .. code-block:: bash 114 | 115 | poetry run black . # Format files to adhere to the Black code style 116 | poetry run isort . # Ensure the imports are organised correctly 117 | 118 | 119 | Contact 120 | ------- 121 | 122 | Discussions about ``odata-query`` take place on this repository's `Issues`_ and 123 | `Pull Requests`_ sections. Anybody is welcome to join the conversation. Wherever 124 | possible, do not take these conversations to private channels, including 125 | contacting the maintainers directly. Keeping communication public means 126 | everybody can benefit and learn. 127 | 128 | 129 | .. _Issues: https://github.com/gorilla-co/odata-query/issues 130 | .. _Pull Requests: https://github.com/gorilla-co/odata-query/pulls 131 | .. _poetry: https://python-poetry.org/ 132 | .. _tox: https://tox.readthedocs.io/en/latest/index.html 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Gorillini NV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | OData-Query 2 | =========== 3 | 4 | .. image:: https://readthedocs.org/projects/odata-query/badge/?version=latest 5 | :alt: Documentation Status 6 | :target: https://odata-query.readthedocs.io/en/latest/?badge=latest 7 | 8 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg 9 | :alt: Code style: black 10 | :target: https://github.com/psf/black 11 | 12 | 13 | ``odata-query`` is a library that parses `OData v4`_ filter strings, and can 14 | convert them to other forms such as `Django Queries`_, `SQLAlchemy Queries`_, 15 | or just plain SQL. 16 | 17 | 18 | Installation 19 | ------------ 20 | 21 | ``odata-query`` is available on pypi, so can be installed with the package manager 22 | of your choice: 23 | 24 | .. code-block:: bash 25 | 26 | pip install odata-query 27 | # OR 28 | poetry add odata-query 29 | # OR 30 | pipenv install odata-query 31 | 32 | 33 | The package defines the following optional ``extra``'s: 34 | 35 | * ``django``: If you want to pin a compatible Django version. 36 | * ``sqlalchemy``: If you want to pin a compatible SQLAlchemy version. 37 | 38 | 39 | The following ``extra``'s relate to the development of this library: 40 | 41 | - ``linting``: The linting and code style tools. 42 | - ``testing``: Packages for running the tests. 43 | - ``docs``: For building the project documentation. 44 | 45 | 46 | You can install ``extra``'s by adding them between square brackets during 47 | installation: 48 | 49 | .. code-block:: bash 50 | 51 | pip install odata-query[sqlalchemy] 52 | 53 | 54 | Quickstart 55 | ---------- 56 | 57 | The most common use case is probably parsing an OData query string, and applying 58 | it to a query your ORM understands. For this purpose there is an all-in-one function: 59 | ``apply_odata_query``. 60 | 61 | Example for Django: 62 | 63 | .. code-block:: python 64 | 65 | from odata_query.django import apply_odata_query 66 | 67 | orm_query = MyModel.objects # This can be a Manager or a QuerySet. 68 | odata_query = "name eq 'test'" # This will usually come from a query string parameter. 69 | 70 | query = apply_odata_query(orm_query, odata_query) 71 | results = query.all() 72 | 73 | 74 | Example for SQLAlchemy ORM: 75 | 76 | .. code-block:: python 77 | 78 | from odata_query.sqlalchemy import apply_odata_query 79 | 80 | orm_query = select(MyModel) # This is any form of Query or Selectable. 81 | odata_query = "name eq 'test'" # This will usually come from a query string parameter. 82 | 83 | query = apply_odata_query(orm_query, odata_query) 84 | results = session.execute(query).scalars().all() 85 | 86 | Example for SQLAlchemy Core: 87 | 88 | .. code-block:: python 89 | 90 | from odata_query.sqlalchemy import apply_odata_core 91 | 92 | core_query = select(MyTable) # This is any form of Query or Selectable. 93 | odata_query = "name eq 'test'" # This will usually come from a query string parameter. 94 | 95 | query = apply_odata_core(core_query, odata_query) 96 | results = session.execute(query).scalars().all() 97 | 98 | .. splitinclude-1 99 | 100 | Advanced Usage 101 | -------------- 102 | 103 | Not all use cases are as simple as that. Luckily, ``odata-query`` is modular 104 | and extendable. See the `documentation`_ for advanced usage or extending the 105 | library for other cases. 106 | 107 | .. splitinclude-2 108 | 109 | Contact 110 | ------- 111 | 112 | Got any questions or ideas? We'd love to hear from you. Check out our 113 | `contributing guidelines`_ for ways to offer feedback and 114 | contribute. 115 | 116 | 117 | License 118 | ------- 119 | 120 | Copyright © `Gorillini NV`_. 121 | All rights reserved. 122 | 123 | Licensed under the MIT License. 124 | 125 | 126 | .. _odata v4: https://www.odata.org/ 127 | .. _django queries: https://docs.djangoproject.com/en/3.2/topics/db/queries/ 128 | .. _sqlalchemy queries: https://docs.sqlalchemy.org/en/14/orm/loading_objects.html 129 | .. _documentation: https://odata-query.readthedocs.io/en/latest 130 | .. _Gorillini NV: https://gorilla.co/ 131 | .. _contributing guidelines: ./CONTRIBUTING.rst 132 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | apidoc: 18 | sphinx-apidoc -f -o source/api -e ../odata_query 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | API Docs 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | modules 8 | 9 | -------------------------------------------------------------------------------- /docs/source/api/modules.rst: -------------------------------------------------------------------------------- 1 | odata_query 2 | =========== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | odata_query 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.ast.rst: -------------------------------------------------------------------------------- 1 | odata\_query.ast module 2 | ======================= 3 | 4 | .. automodule:: odata_query.ast 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.django.django_q.rst: -------------------------------------------------------------------------------- 1 | odata\_query.django.django\_q module 2 | ==================================== 3 | 4 | .. automodule:: odata_query.django.django_q 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.django.django_q_ext.rst: -------------------------------------------------------------------------------- 1 | odata\_query.django.django\_q\_ext module 2 | ========================================= 3 | 4 | .. automodule:: odata_query.django.django_q_ext 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.django.rst: -------------------------------------------------------------------------------- 1 | odata\_query.django package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | odata_query.django.django_q 11 | odata_query.django.django_q_ext 12 | odata_query.django.shorthand 13 | odata_query.django.utils 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: odata_query.django 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.django.shorthand.rst: -------------------------------------------------------------------------------- 1 | odata\_query.django.shorthand module 2 | ==================================== 3 | 4 | .. automodule:: odata_query.django.shorthand 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.django.utils.rst: -------------------------------------------------------------------------------- 1 | odata\_query.django.utils module 2 | ================================ 3 | 4 | .. automodule:: odata_query.django.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.exceptions.rst: -------------------------------------------------------------------------------- 1 | odata\_query.exceptions module 2 | ============================== 3 | 4 | .. automodule:: odata_query.exceptions 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.grammar.rst: -------------------------------------------------------------------------------- 1 | odata\_query.grammar module 2 | =========================== 3 | 4 | .. automodule:: odata_query.grammar 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.rewrite.rst: -------------------------------------------------------------------------------- 1 | odata\_query.rewrite module 2 | =========================== 3 | 4 | .. automodule:: odata_query.rewrite 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.roundtrip.rst: -------------------------------------------------------------------------------- 1 | odata\_query.roundtrip module 2 | ============================= 3 | 4 | .. automodule:: odata_query.roundtrip 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.rst: -------------------------------------------------------------------------------- 1 | odata\_query package 2 | ==================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | odata_query.django 11 | odata_query.sql 12 | odata_query.sqlalchemy 13 | 14 | Submodules 15 | ---------- 16 | 17 | .. toctree:: 18 | :maxdepth: 4 19 | 20 | odata_query.ast 21 | odata_query.exceptions 22 | odata_query.grammar 23 | odata_query.rewrite 24 | odata_query.roundtrip 25 | odata_query.typing 26 | odata_query.utils 27 | odata_query.visitor 28 | 29 | Module contents 30 | --------------- 31 | 32 | .. automodule:: odata_query 33 | :members: 34 | :undoc-members: 35 | :show-inheritance: 36 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sql.athena.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sql.athena module 2 | ============================== 3 | 4 | .. automodule:: odata_query.sql.athena 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sql.base.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sql.base module 2 | ============================ 3 | 4 | .. automodule:: odata_query.sql.base 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sql.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sql package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | odata_query.sql.athena 11 | odata_query.sql.base 12 | odata_query.sql.sqlite 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: odata_query.sql 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sql.sqlite.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sql.sqlite module 2 | ============================== 3 | 4 | .. automodule:: odata_query.sql.sqlite 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sqlalchemy.common.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sqlalchemy.common module 2 | ===================================== 3 | 4 | .. automodule:: odata_query.sqlalchemy.common 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sqlalchemy.core.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sqlalchemy.core module 2 | =================================== 3 | 4 | .. automodule:: odata_query.sqlalchemy.core 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sqlalchemy.functions_ext.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sqlalchemy.functions\_ext module 2 | ============================================= 3 | 4 | .. automodule:: odata_query.sqlalchemy.functions_ext 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sqlalchemy.orm.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sqlalchemy.orm module 2 | ================================== 3 | 4 | .. automodule:: odata_query.sqlalchemy.orm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sqlalchemy.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sqlalchemy package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | odata_query.sqlalchemy.common 11 | odata_query.sqlalchemy.core 12 | odata_query.sqlalchemy.functions_ext 13 | odata_query.sqlalchemy.orm 14 | odata_query.sqlalchemy.shorthand 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: odata_query.sqlalchemy 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.sqlalchemy.shorthand.rst: -------------------------------------------------------------------------------- 1 | odata\_query.sqlalchemy.shorthand module 2 | ======================================== 3 | 4 | .. automodule:: odata_query.sqlalchemy.shorthand 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.typing.rst: -------------------------------------------------------------------------------- 1 | odata\_query.typing module 2 | ========================== 3 | 4 | .. automodule:: odata_query.typing 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.utils.rst: -------------------------------------------------------------------------------- 1 | odata\_query.utils module 2 | ========================= 3 | 4 | .. automodule:: odata_query.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/odata_query.visitor.rst: -------------------------------------------------------------------------------- 1 | odata\_query.visitor module 2 | =========================== 3 | 4 | .. automodule:: odata_query.visitor 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changelog: 2 | 3 | .. include:: ../../CHANGELOG.rst 4 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "OData Query" 21 | copyright = "2021, Gorillini NV" 22 | author = "Oliver Hofkens" 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = "0.10.0" 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.autodoc", 35 | "sphinx.ext.napoleon", 36 | "sphinx.ext.doctest", 37 | "sphinx.ext.graphviz", 38 | "sphinx.ext.viewcode", 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ["_templates"] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = [] 48 | 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | html_theme = "sphinx_rtd_theme" 56 | 57 | # Add any paths that contain custom static files (such as style sheets) here, 58 | # relative to this directory. They are copied after the builtin static files, 59 | # so a file named "default.css" will overwrite the builtin "default.css". 60 | html_static_path = ["_static"] 61 | -------------------------------------------------------------------------------- /docs/source/contributing.rst: -------------------------------------------------------------------------------- 1 | .. _contributing: 2 | 3 | .. include:: ../../CONTRIBUTING.rst 4 | -------------------------------------------------------------------------------- /docs/source/deviations-and-extensions.rst: -------------------------------------------------------------------------------- 1 | Deviations and Extensions to the OData Spec 2 | =========================================== 3 | 4 | There are some minor cases where this library deviates from the official `spec`_, 5 | e.g. when the spec is ambiguous, or to add extra non-breaking functionality. 6 | 7 | 8 | Lists with a single item need a trailing comma 9 | ---------------------------------------------- 10 | 11 | This library needs a trailing comma to represent a list with a single item, 12 | whereas the spec does not describe this trailing comma (e.g. 13 | ``(item,)`` instead of ``(item)``). 14 | 15 | The reason is that the spec doesn't seem to differentiate between a list of a 16 | single item and any other parenthesized expression. This can lead to parsing 17 | conflicts. Consider the following expression: 18 | 19 | .. code-block:: 20 | 21 | concat(('a'), ('b')) 22 | 23 | 24 | The grammar in the official spec could parse this as either: 25 | 26 | * Concatenate 2 lists, the first one containing the single string 'a', the second 27 | one containing the single string 'b'. The result would be ``('a', 'b')``. 28 | * Concatenate 2 expressions in parentheses. The first one is the string literal 29 | 'a', the second one the string literal 'b'. The result would be ``'ab'`` 30 | 31 | 32 | To make this difference explicit, this library requires a trailing comma to 33 | signify a list. The same behavior is present in Python: 34 | 35 | >>> ("a") 36 | 'a' 37 | >>> ("a",) 38 | ('a',) 39 | 40 | .. >>> from odata_query.grammar import ODataLexer, ODataParser 41 | .. >>> lexer = ODataLexer() 42 | .. >>> parser = ODataParser() 43 | >>> parser.parse(lexer.tokenize("('a')")) 44 | String(val='a') 45 | >>> parser.parse(lexer.tokenize("('a',)")) 46 | List(val=[String(val='a')]) 47 | 48 | 49 | Durations expressed in years and months 50 | --------------------------------------- 51 | 52 | The official `spec`_ defines a duration with a number of days, hours, minutes, 53 | and seconds. E.g. ``duration'P1DT2H'`` is a duration of 1 **D**\ ay and 2 **H**\ ours. 54 | 55 | This library adds the ability to express durations containing **Y**\ ears and 56 | **M**\ onths. E.g. ``duration'P1Y2M3DT4H'`` would express a duration of 1 **Y**\ ear, 57 | 2 **M**\ onths, 3 **D**\ ays, and 4 **H**\ ours. 58 | 59 | It's important to note that the final internal value is still expressed in days, 60 | based on average durations. 61 | Thus, a month simply represents 30.44 days, while a year represents 365.25 days. 62 | 63 | 64 | .. _spec: https://www.odata.org/documentation/ 65 | -------------------------------------------------------------------------------- /docs/source/django.rst: -------------------------------------------------------------------------------- 1 | Using OData with Django 2 | ======================= 3 | 4 | Basic Usage 5 | ----------- 6 | 7 | The easiest way to add OData filtering to a Django QuerySet is with the shorthand: 8 | 9 | .. code-block:: python 10 | 11 | from odata_query.django import apply_odata_query 12 | 13 | orm_query = MyModel.objects # This can be a Manager or a QuerySet. 14 | odata_query = "name eq 'test'" # This will usually come from a query string parameter. 15 | 16 | query = apply_odata_query(orm_query, odata_query) 17 | results = query.all() 18 | 19 | 20 | Advanced Usage 21 | -------------- 22 | 23 | If you need some more flexibility or advanced features, the implementation of the 24 | shorthand is a nice starting point: :py:mod:`odata_query.django.shorthand` 25 | 26 | Let's break it down real quick: 27 | 28 | 29 | Parsing the OData Query 30 | ^^^^^^^^^^^^^^^^^^^^^^^ 31 | 32 | .. include:: snippets/parsing.rst 33 | 34 | 35 | Optional: Modifying the AST 36 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 37 | 38 | .. include:: snippets/modifying.rst 39 | 40 | 41 | Building a Query Filter 42 | ^^^^^^^^^^^^^^^^^^^^^^^ 43 | 44 | To get from an :term:`AST` to something Django can use, you'll need to use the 45 | :py:class:`odata_query.django.django_q.AstToDjangoQVisitor`. It needs to know 46 | about the 'root model' of your query in order to build relationships if necessary. 47 | In most cases, this will be ``queryset.model``. 48 | 49 | 50 | .. code-block:: python 51 | 52 | from odata_query.django.django_q import AstToDjangoQVisitor 53 | 54 | visitor = AstToDjangoQVisitor(MyModel) 55 | query_filter = visitor.visit(ast) 56 | 57 | 58 | Optional: QuerySet Annotations 59 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 60 | 61 | For some queries using advanced expressions, the ``AstToDjangoQVisitor`` will 62 | generate `queryset annotations`_. For the query to work, these will need to be 63 | applied: 64 | 65 | .. code-block:: python 66 | 67 | if visitor.queryset_annotations: 68 | queryset = queryset.annotate(**visitor.queryset_annotations) 69 | 70 | 71 | Running the query 72 | ^^^^^^^^^^^^^^^^^ 73 | 74 | Finally, we're ready to run the query: 75 | 76 | .. code-block:: python 77 | 78 | results = queryset.filter(query_filter).all() 79 | 80 | 81 | .. _queryset annotations: https://docs.djangoproject.com/en/3.2/ref/models/querysets/#annotate 82 | -------------------------------------------------------------------------------- /docs/source/glossary.rst: -------------------------------------------------------------------------------- 1 | Glossary 2 | ======== 3 | 4 | .. glossary:: 5 | 6 | AST 7 | Abstract Syntax Tree. A tree data structure representing the syntactic 8 | structure of some text. For example, in OData we might write 9 | ``name eq 'Bobby'``, which could be represented as the tree: 10 | ``Compare(Eq, Identifier('name'), String('Bobby'))``. For more, 11 | see `Wikipedia `_ 12 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../../README.rst 2 | :end-before: .. splitinclude-1 3 | 4 | Contents 5 | ======== 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | Getting Started 11 | django 12 | sqlalchemy 13 | sql 14 | parsing-odata 15 | working-with-ast 16 | deviations-and-extensions 17 | glossary 18 | 19 | contributing 20 | changelog 21 | 22 | api/index 23 | 24 | .. include:: ../../README.rst 25 | :start-after: .. splitinclude-2 26 | 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/source/parsing-odata.rst: -------------------------------------------------------------------------------- 1 | .. _ref-parsing-odata: 2 | 3 | Parsing OData 4 | ============= 5 | 6 | ``odata-query`` includes a parser that tries to cover as much as possible of the `OData v4 filter spec`_. 7 | This parser is built with `SLY`_ and consists of a :ref:`ref-lexer` and a :ref:`ref-parser`. 8 | 9 | 10 | .. _ref-lexer: 11 | 12 | Lexer 13 | ----- 14 | 15 | The lexer's job is to take an OData query string and break it up into a series of 16 | meaningful tokens, where each token has a type and a value. These are like the 17 | words of a sentence. At this stage we're not looking at the structure of the 18 | entire sentence yet, just if the individual words make sense. 19 | 20 | The OData lexer this library uses is defined in :py:class:`odata_query.grammar.ODataLexer`. Here's 21 | an example of what it does: 22 | 23 | .. doctest:: 24 | 25 | >>> from odata_query.grammar import ODataLexer 26 | >>> lexer = ODataLexer() 27 | >>> list(lexer.tokenize("name eq 'Hello World'")) 28 | [Token(type='ODATA_IDENTIFIER', value=Identifier(name='name'), lineno=1, index=0), Token(type='EQ', value=Eq(), lineno=1, index=4), Token(type='STRING', value=String(val='Hello World'), lineno=1, index=8)] 29 | 30 | 31 | 32 | .. _ref-parser: 33 | 34 | Parser 35 | ------ 36 | 37 | The parser's job is to take the tokens as produced by the :ref:`ref-lexer` 38 | and find the language structure in them, according to the grammar rules defined 39 | by the OData standard. In our case, the parser tries to build an :term:`AST` that 40 | represents the entire query. This :term:`AST` is a tree structure that consists 41 | of the nodes found in :py:mod:`odata_query.ast`. 42 | 43 | As an example, the following OData query:: 44 | 45 | name eq 'Hello World' 46 | 47 | can be represented in the following :term:`AST`: 48 | 49 | .. graphviz:: 50 | 51 | digraph { 52 | "Compare()" -> "Identifier('name')" [label = "left"]; 53 | "Compare()" -> "Eq()" [label = "comparator"]; 54 | "Compare()" -> "String('Hello World')" [label = "right"]; 55 | } 56 | 57 | 58 | The OData parser this library uses is defined in :py:class:`odata_query.grammar.ODataParser`. 59 | Here's an example of what it does: 60 | 61 | .. doctest:: 62 | 63 | >>> from odata_query.grammar import ODataParser 64 | >>> parser = ODataParser() 65 | >>> parser.parse(lexer.tokenize("name eq 'Hello World'")) 66 | Compare(comparator=Eq(), left=Identifier(name='name'), right=String(val='Hello World')) 67 | 68 | 69 | 70 | .. _OData v4 filter spec: https://docs.oasis-open.org/odata/odata/v4.01/cs01/abnf/odata-abnf-construction-rules.txt 71 | .. _SLY: https://github.com/dabeaz/sly 72 | 73 | -------------------------------------------------------------------------------- /docs/source/snippets/modifying.rst: -------------------------------------------------------------------------------- 1 | There are cases where you'll want to modify the query before executing it. That's 2 | what :ref:`ref-node-transformer`'s are for! 3 | 4 | One example might be that certain fields are exposed to end users under a different 5 | name than the one in the database. In this case, the 6 | :py:class:`odata_query.rewrite.AliasRewriter` will come in handy. Just pass it a 7 | mapping of aliases to their full name and let it do its job: 8 | 9 | .. code-block:: python 10 | 11 | from odata_query.rewrite import AliasRewriter 12 | 13 | rewriter = AliasRewriter({ 14 | "name": "author/name", 15 | }) 16 | new_ast = rewriter.visit(ast) 17 | 18 | -------------------------------------------------------------------------------- /docs/source/snippets/parsing.rst: -------------------------------------------------------------------------------- 1 | To get from a string representing an OData query to a usable representation, 2 | we need to tokenize and parse it as follows: 3 | 4 | .. code-block:: python 5 | 6 | from odata_query.grammar import ODataParser, ODataLexer 7 | 8 | lexer = ODataLexer() 9 | parser = ODataParser() 10 | ast = parser.parse(lexer.tokenize(my_odata_query)) 11 | 12 | This process is described in more detail in :ref:`ref-parsing-odata`. 13 | -------------------------------------------------------------------------------- /docs/source/sql.rst: -------------------------------------------------------------------------------- 1 | Using OData with raw SQL 2 | ======================== 3 | 4 | Using a raw SQL interface is slightly more involved and less powerful, but 5 | offers a lot of flexibility in return. 6 | 7 | 8 | Parsing the OData Query 9 | ^^^^^^^^^^^^^^^^^^^^^^^ 10 | 11 | .. include:: snippets/parsing.rst 12 | 13 | 14 | Optional: Modifying the AST 15 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 16 | 17 | .. include:: snippets/modifying.rst 18 | 19 | 20 | Building a Query Filter 21 | ^^^^^^^^^^^^^^^^^^^^^^^ 22 | 23 | To get from an :term:`AST` to a SQL clause, you'll need to use 24 | :py:class:`odata_query.sql.base.AstToSqlVisitor` (standard SQL) or one of its 25 | dialect-specific subclasses, such as 26 | :py:class:`odata_query.sql.sqlite.AstToSqliteSqlVisitor` (SQLite). 27 | 28 | .. code-block:: python 29 | 30 | from odata_query.sql import AstToSqlVisitor 31 | 32 | visitor = AstToSqlVisitor() 33 | where_clause = visitor.visit(ast) 34 | 35 | 36 | Running the query 37 | ^^^^^^^^^^^^^^^^^ 38 | 39 | Finally, we're ready to run the query: 40 | 41 | .. code-block:: python 42 | 43 | query = "SELECT * FROM my_table WHERE " + where_clause 44 | results = conn.execute(query).fetchall() 45 | 46 | 47 | Supported dialects 48 | ^^^^^^^^^^^^^^^^^^ 49 | 50 | .. autoclass:: odata_query.sql.base.AstToSqlVisitor 51 | .. autoclass:: odata_query.sql.sqlite.AstToSqliteSqlVisitor 52 | .. autoclass:: odata_query.sql.athena.AstToAthenaSqlVisitor 53 | -------------------------------------------------------------------------------- /docs/source/sqlalchemy.rst: -------------------------------------------------------------------------------- 1 | Using OData with SQLAlchemy 2 | =========================== 3 | 4 | Basic Usage 5 | ----------- 6 | 7 | The easiest way to add OData filtering to a SQLAlchemy query is with the shorthand: 8 | 9 | 10 | SQLAlchemy ORM 11 | ^^^^^^^^^^^^^^ 12 | 13 | .. code-block:: python 14 | 15 | from odata_query.sqlalchemy import apply_odata_query 16 | 17 | orm_query = select(MyModel) # This is any form of Query or Selectable. 18 | odata_query = "name eq 'test'" # This will usually come from a query string parameter. 19 | 20 | query = apply_odata_query(orm_query, odata_query) 21 | results = session.execute(query).scalars().all() 22 | 23 | 24 | SQLAlchemy Core 25 | ^^^^^^^^^^^^^^^ 26 | 27 | .. attention:: 28 | 29 | Basic support for SQLAlchemy Core is new since version 0.7.0. 30 | It currently does not support relationship traversal or ``any``/``all`` 31 | yet. Those operations will raise a ``NotImplementedException``. 32 | 33 | 34 | .. code-block:: python 35 | 36 | from odata_query.sqlalchemy import apply_odata_core 37 | 38 | core_query = select(MyTable) # This is any form of Query or Selectable. 39 | odata_query = "name eq 'test'" # This will usually come from a query string parameter. 40 | 41 | query = apply_odata_query(core_query, odata_query) 42 | results = session.execute(query).scalars().all() 43 | 44 | 45 | Advanced Usage 46 | -------------- 47 | 48 | If you need some more flexibility or advanced features, the implementation of the 49 | shorthand is a nice starting point: :py:mod:`odata_query.sqlalchemy.shorthand` 50 | 51 | Let's break it down real quick: 52 | 53 | 54 | Parsing the OData Query 55 | ^^^^^^^^^^^^^^^^^^^^^^^ 56 | 57 | .. include:: snippets/parsing.rst 58 | 59 | 60 | Optional: Modifying the AST 61 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 62 | 63 | .. include:: snippets/modifying.rst 64 | 65 | 66 | Building a Query Filter 67 | ^^^^^^^^^^^^^^^^^^^^^^^ 68 | 69 | To get from an :term:`AST` to something SQLAlchemy can use, you'll need to use the 70 | :py:class:`odata_query.sqlalchemy.orm.AstToSqlAlchemyOrmVisitor` (ORM mode) or 71 | the :py:class:`odata_query.sqlalchemy.core.AstToSqlAlchemyCoreVisitor` (Core 72 | mode). 73 | It needs to know about the 'root model' or table of your query in order to see 74 | which fields exists and how objects are related. 75 | 76 | 77 | SQLAlchemy ORM 78 | """""""""""""" 79 | 80 | .. code-block:: python 81 | 82 | from odata_query.sqlalchemy.orm import AstToSqlAlchemyOrmVisitor 83 | 84 | visitor = AstToSqlAlchemyOrmVisitor(MyModel) 85 | query_filter = visitor.visit(ast) 86 | 87 | SQLAlchemy Core 88 | """"""""""""""" 89 | 90 | .. code-block:: python 91 | 92 | from odata_query.sqlalchemy.core import AstToSqlAlchemyCoreVisitor 93 | 94 | visitor = AstToSqlAlchemyCoreVisitor(MyTable) 95 | query_filter = visitor.visit(ast) 96 | 97 | 98 | Optional: Joins 99 | ^^^^^^^^^^^^^^^ 100 | 101 | .. attention:: 102 | 103 | Relationship traversal and automatic joins are not yet supported for 104 | SQLAlchemy Core mode. 105 | 106 | 107 | If your query spans relationships, the ``AstToSqlAlchemyClauseVisitor`` will 108 | generate join statements. For the query to work, these will need to be 109 | applied explicitly: 110 | 111 | .. code-block:: python 112 | 113 | for j in visitor.join_relationships: 114 | query = query.join(j) 115 | 116 | 117 | Running the query 118 | ^^^^^^^^^^^^^^^^^ 119 | 120 | Finally, we're ready to run the query: 121 | 122 | .. code-block:: python 123 | 124 | query = query.where(query_filter) 125 | results = s.execute(query).scalars().all() 126 | -------------------------------------------------------------------------------- /docs/source/working-with-ast.rst: -------------------------------------------------------------------------------- 1 | Working with the AST 2 | ==================== 3 | 4 | Now that our :ref:`OData query has been parsed ` to an :term:`AST`, 5 | how do we work with it? `The Visitor Pattern`_ is a popular way to walk tree 6 | structures such as :term:`AST`'s and modify or transform them to another 7 | representation. ``odata-query`` contains the :ref:`ref-node-visitor` and 8 | :ref:`ref-node-transformer` base classes that implement this pattern, as well 9 | as some concrete implementations. 10 | 11 | 12 | .. _ref-node-visitor: 13 | 14 | NodeVisitor 15 | ----------- 16 | 17 | A :py:class:`odata_query.visitor.NodeVisitor` is a class that walks an :term:`AST` 18 | (depth-first by default) and calls a ``visit_{node_type}`` method on each 19 | :py:class:`odata_query.ast._Node` it encounters. These methods can return whatever 20 | they want, making this a very flexible pattern! If no ``visit_`` method is 21 | implemented for the type of the node the visitor will continue with the node's 22 | children if it has any, so you only need to implement what you explicitly need. 23 | A simple :py:class:`odata_query.visitor.NodeVisitor` that counts comparison 24 | expressions for example, might look like this: 25 | 26 | .. code-block:: python 27 | 28 | class ComparisonCounter(NodeVisitor): 29 | def visit_Comparison(self, node: ast.Comparison) -> int: 30 | count_lhs = self.visit(node.left) or 0 31 | count_rhs = self.visit(node.right) or 0 32 | return 1 + count_lhs + count_rhs 33 | 34 | 35 | count = ComparisonCounter().visit(my_ast) 36 | 37 | 38 | This isn't the most useful implementation... For some more realistic examples, 39 | take a look at the :py:class:`odata_query.django.django_q.AstToDjangoQVisitor` or 40 | the :py:class:`odata_query.sqlalchemy.orm.AstToSqlAlchemyOrmVisitor` 41 | implementations. They transform an :term:`AST` to Django and SQLAlchemy ORM queries 42 | respectively. 43 | 44 | 45 | .. _ref-node-transformer: 46 | 47 | NodeTransformer 48 | --------------- 49 | 50 | A :py:class:`odata_query.visitor.NodeTransformer` is very similar to a 51 | :ref:`ref-node-visitor`, with one difference: The ``visit_`` methods should return 52 | an :py:class:`odata_query.ast._Node`, which will replace the node that is being 53 | visited. This allows ``NodeTransformer``'s to modify the :term:`AST` while it's 54 | being traversed. For example, the following 55 | :py:class:`odata_query.visitor.NodeTransformer` would invert all 'less-than' 56 | comparisons to 'greater-than' and vice-versa: 57 | 58 | 59 | .. code-block:: python 60 | 61 | class ComparisonInverter(NodeTransformer): 62 | def visit_Comparison(self, node: ast.Comparison) -> ast.Comparison: 63 | if node.comparator == ast.Lt(): 64 | new_comparator = ast.Gt() 65 | elif node.comparator == ast.Gt(): 66 | new_comparator = ast.Lt() 67 | else: 68 | new_comparator = node.comparator 69 | 70 | return ast.Comparison(new_comparator, node.left, node.right) 71 | 72 | 73 | inverted = ComparisonInverter().visit(my_ast) 74 | 75 | 76 | An interesting concrete implementation in ``odata-query`` is the 77 | :py:class:`odata_query.rewrite.AliasRewriter`. This transformer looks for 78 | aliases in identifiers and attributes, and replaces them with their full names. 79 | 80 | 81 | 82 | Included Visitors 83 | ----------------- 84 | 85 | 86 | .. autoclass:: odata_query.django.django_q.AstToDjangoQVisitor 87 | .. autoclass:: odata_query.sqlalchemy.orm.AstToSqlAlchemyOrmVisitor 88 | .. autoclass:: odata_query.sqlalchemy.core.AstToSqlAlchemyCoreVisitor 89 | .. autoclass:: odata_query.rewrite.AliasRewriter 90 | .. autoclass:: odata_query.roundtrip.AstToODataVisitor 91 | 92 | 93 | .. _The Visitor Pattern: https://en.wikipedia.org/wiki/Visitor_pattern 94 | -------------------------------------------------------------------------------- /odata_query/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.10.0" 2 | -------------------------------------------------------------------------------- /odata_query/ast.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import re 3 | from dataclasses import dataclass, field 4 | from typing import List as ListType, Optional, Tuple 5 | from uuid import UUID 6 | 7 | from dateutil.parser import isoparse 8 | 9 | DURATION_PATTERN = re.compile( 10 | r"([+-])?P(\d+Y)?(\d+M)?(\d+D)?(?:T(\d+H)?(\d+M)?(\d+(?:\.\d+)?S)?)?" 11 | ) 12 | 13 | 14 | @dataclass(frozen=True) 15 | class _Node: 16 | pass 17 | 18 | 19 | @dataclass(frozen=True) 20 | class Identifier(_Node): 21 | name: str 22 | namespace: Tuple[str, ...] = field(default_factory=tuple) 23 | 24 | def full_name(self): 25 | return ".".join(self.namespace + (self.name,)) 26 | 27 | 28 | @dataclass(frozen=True) 29 | class Attribute(_Node): 30 | owner: _Node 31 | attr: str 32 | 33 | 34 | ############################################################################### 35 | # Literals 36 | ############################################################################### 37 | @dataclass(frozen=True) 38 | class _Literal(_Node): 39 | pass 40 | 41 | @property 42 | def py_val(self): 43 | raise NotImplementedError() 44 | 45 | 46 | @dataclass(frozen=True) 47 | class Null(_Literal): 48 | @property 49 | def py_val(self) -> None: 50 | return None 51 | 52 | 53 | @dataclass(frozen=True) 54 | class Integer(_Literal): 55 | val: str 56 | 57 | @property 58 | def py_val(self) -> int: 59 | return int(self.val) 60 | 61 | 62 | @dataclass(frozen=True) 63 | class Float(_Literal): 64 | val: str 65 | 66 | @property 67 | def py_val(self) -> float: 68 | return float(self.val) 69 | 70 | 71 | @dataclass(frozen=True) 72 | class Boolean(_Literal): 73 | val: str 74 | 75 | @property 76 | def py_val(self) -> bool: 77 | return self.val.lower() == "true" 78 | 79 | 80 | @dataclass(frozen=True) 81 | class String(_Literal): 82 | val: str 83 | 84 | @property 85 | def py_val(self) -> str: 86 | return self.val 87 | 88 | 89 | @dataclass(frozen=True) 90 | class Geography(_Literal): 91 | val: str 92 | 93 | def wkt(self): 94 | return self.val 95 | 96 | 97 | @dataclass(frozen=True) 98 | class Date(_Literal): 99 | val: str 100 | 101 | @property 102 | def py_val(self) -> dt.date: 103 | return dt.date.fromisoformat(self.val) 104 | 105 | 106 | @dataclass(frozen=True) 107 | class Time(_Literal): 108 | val: str 109 | 110 | @property 111 | def py_val(self) -> dt.time: 112 | return dt.time.fromisoformat(self.val) 113 | 114 | 115 | @dataclass(frozen=True) 116 | class DateTime(_Literal): 117 | val: str 118 | 119 | @property 120 | def py_val(self) -> dt.datetime: 121 | return isoparse(self.val) 122 | 123 | 124 | @dataclass(frozen=True) 125 | class Duration(_Literal): 126 | val: str 127 | 128 | @property 129 | def py_val(self) -> dt.timedelta: 130 | sign, years, months, days, hours, minutes, seconds = self.unpack() 131 | 132 | # Initialize days to 0 if None 133 | num_days = float(days or 0) 134 | 135 | # Approximate conversion, adjust as necessary for more precision 136 | num_days += float(years or 0) * 365.25 # Average including leap years 137 | num_days += float(months or 0) * 30.44 # Average month length 138 | 139 | delta = dt.timedelta( 140 | days=num_days, 141 | hours=float(hours or 0), 142 | minutes=float(minutes or 0), 143 | seconds=float(seconds or 0), 144 | ) 145 | if sign and sign == "-": 146 | delta = -1 * delta 147 | return delta 148 | 149 | def unpack( 150 | self, 151 | ) -> Tuple[ 152 | Optional[str], 153 | Optional[str], 154 | Optional[str], 155 | Optional[str], 156 | Optional[str], 157 | Optional[str], 158 | Optional[str], 159 | ]: 160 | """ 161 | Returns: 162 | ``(sign, years, months, days, hours, minutes, seconds)`` 163 | """ 164 | 165 | match = DURATION_PATTERN.fullmatch(self.val) 166 | if not match: 167 | raise ValueError(f"Could not unpack Duration with value {self.val}") 168 | 169 | sign, years, months, days, hours, minutes, seconds = match.groups() 170 | 171 | _years = years[:-1] if years else None 172 | _months = months[:-1] if months else None 173 | _days = days[:-1] if days else None 174 | _hours = hours[:-1] if hours else None 175 | _minutes = minutes[:-1] if minutes else None 176 | _seconds = seconds[:-1] if seconds else None 177 | 178 | return sign, _years, _months, _days, _hours, _minutes, _seconds 179 | 180 | 181 | @dataclass(frozen=True) 182 | class GUID(_Literal): 183 | val: str 184 | 185 | @property 186 | def py_val(self) -> UUID: 187 | return UUID(self.val) 188 | 189 | 190 | @dataclass(frozen=True) 191 | class List(_Literal): 192 | val: ListType[_Literal] 193 | 194 | @property 195 | def py_val(self) -> list: 196 | return [v.py_val for v in self.val] 197 | 198 | 199 | ############################################################################### 200 | # Arithmetic 201 | ############################################################################### 202 | @dataclass(frozen=True) 203 | class _BinOpToken(_Node): 204 | pass 205 | 206 | 207 | @dataclass(frozen=True) 208 | class Add(_BinOpToken): 209 | pass 210 | 211 | 212 | @dataclass(frozen=True) 213 | class Sub(_BinOpToken): 214 | pass 215 | 216 | 217 | @dataclass(frozen=True) 218 | class Mult(_BinOpToken): 219 | pass 220 | 221 | 222 | @dataclass(frozen=True) 223 | class Div(_BinOpToken): 224 | pass 225 | 226 | 227 | @dataclass(frozen=True) 228 | class Mod(_BinOpToken): 229 | pass 230 | 231 | 232 | @dataclass(frozen=True) 233 | class BinOp(_Node): 234 | op: _BinOpToken 235 | left: _Node 236 | right: _Node 237 | 238 | 239 | ############################################################################### 240 | # Comparison 241 | ############################################################################### 242 | @dataclass(frozen=True) 243 | class _Comparator(_Node): 244 | pass 245 | 246 | 247 | @dataclass(frozen=True) 248 | class Eq(_Comparator): 249 | pass 250 | 251 | 252 | @dataclass(frozen=True) 253 | class NotEq(_Comparator): 254 | pass 255 | 256 | 257 | @dataclass(frozen=True) 258 | class Lt(_Comparator): 259 | pass 260 | 261 | 262 | @dataclass(frozen=True) 263 | class LtE(_Comparator): 264 | pass 265 | 266 | 267 | @dataclass(frozen=True) 268 | class Gt(_Comparator): 269 | pass 270 | 271 | 272 | @dataclass(frozen=True) 273 | class GtE(_Comparator): 274 | pass 275 | 276 | 277 | @dataclass(frozen=True) 278 | class In(_Comparator): 279 | pass 280 | 281 | 282 | @dataclass(frozen=True) 283 | class Compare(_Node): 284 | comparator: _Comparator 285 | left: _Node 286 | right: _Node 287 | 288 | 289 | ############################################################################### 290 | # Boolean ops 291 | ############################################################################### 292 | @dataclass(frozen=True) 293 | class _BoolOpToken(_Node): 294 | pass 295 | 296 | 297 | @dataclass(frozen=True) 298 | class And(_BoolOpToken): 299 | pass 300 | 301 | 302 | @dataclass(frozen=True) 303 | class Or(_BoolOpToken): 304 | pass 305 | 306 | 307 | @dataclass(frozen=True) 308 | class BoolOp(_Node): 309 | op: _BoolOpToken 310 | left: _Node 311 | right: _Node 312 | 313 | 314 | ############################################################################### 315 | # Unary ops 316 | ############################################################################### 317 | @dataclass(frozen=True) 318 | class _UnaryOpToken(_Node): 319 | pass 320 | 321 | 322 | @dataclass(frozen=True) 323 | class Not(_UnaryOpToken): 324 | pass 325 | 326 | 327 | @dataclass(frozen=True) 328 | class USub(_UnaryOpToken): 329 | pass 330 | 331 | 332 | @dataclass(frozen=True) 333 | class UnaryOp(_Node): 334 | op: _UnaryOpToken 335 | operand: _Node 336 | 337 | 338 | ############################################################################### 339 | # Function calls 340 | ############################################################################### 341 | @dataclass(frozen=True) 342 | class NamedParam(_Node): 343 | name: Identifier 344 | param: _Node 345 | 346 | 347 | @dataclass(frozen=True) 348 | class Call(_Node): 349 | func: Identifier 350 | args: ListType[_Node] 351 | 352 | 353 | ############################################################################### 354 | # Collections 355 | ############################################################################### 356 | @dataclass(frozen=True) 357 | class _CollectionOperator(_Node): 358 | pass 359 | 360 | 361 | @dataclass(frozen=True) 362 | class Any(_CollectionOperator): 363 | pass 364 | 365 | 366 | @dataclass(frozen=True) 367 | class All(_CollectionOperator): 368 | pass 369 | 370 | 371 | @dataclass(frozen=True) 372 | class Lambda(_Node): 373 | identifier: Identifier 374 | expression: _Node 375 | 376 | 377 | @dataclass(frozen=True) 378 | class CollectionLambda(_Node): 379 | owner: _Node 380 | operator: _CollectionOperator 381 | lambda_: Optional[Lambda] 382 | -------------------------------------------------------------------------------- /odata_query/django/__init__.py: -------------------------------------------------------------------------------- 1 | from .django_q import AstToDjangoQVisitor 2 | from .django_q_ext import * 3 | from .shorthand import apply_odata_query 4 | -------------------------------------------------------------------------------- /odata_query/django/django_q_ext.py: -------------------------------------------------------------------------------- 1 | from django.db.models import CharField, Lookup, Subquery, fields, functions 2 | from django.db.models.query import QuerySet 3 | 4 | 5 | @fields.Field.register_lookup 6 | class NotEqual(Lookup): 7 | """https://docs.djangoproject.com/en/2.2/howto/custom-lookups/""" 8 | 9 | lookup_name = "ne" 10 | 11 | def as_sql(self, compiler, connection): # type: ignore 12 | lhs, lhs_params = self.process_lhs(compiler, connection) 13 | rhs, rhs_params = self.process_rhs(compiler, connection) 14 | params = lhs_params + rhs_params 15 | return "%s <> %s" % (lhs, rhs), params 16 | 17 | 18 | CharField.register_lookup(functions.Length) 19 | CharField.register_lookup(functions.Upper) 20 | CharField.register_lookup(functions.Lower) 21 | CharField.register_lookup(functions.Trim) 22 | 23 | 24 | class _AnyAll(Subquery): 25 | template = "" 26 | 27 | def __init__(self, queryset: QuerySet, negated: bool = False, **kwargs): 28 | # As a performance optimization, remove ordering since ~ doesn't 29 | # care about it, just whether or not a row matches. 30 | queryset = queryset.order_by() 31 | self.negated = negated 32 | super().__init__(queryset, **kwargs) 33 | 34 | def __invert__(self) -> Subquery: 35 | clone = self.copy() 36 | clone.negated = not self.negated 37 | return clone 38 | 39 | def __repr__(self) -> str: 40 | return self.template % {"subquery": self.queryset.query} 41 | 42 | def as_sql( # type: ignore 43 | self, compiler, connection, template=None, **extra_context 44 | ): 45 | sql, params = super().as_sql(compiler, connection, template, **extra_context) 46 | # if self.negated: 47 | # sql = "NOT {}".format(sql) 48 | return sql, params 49 | 50 | def select_format(self, compiler, sql, params): # type:ignore 51 | return sql, params 52 | 53 | 54 | class Any(_AnyAll): 55 | template = "ANY(%(subquery)s)" 56 | 57 | 58 | class All(_AnyAll): 59 | template = "ALL(%(subquery)s)" 60 | -------------------------------------------------------------------------------- /odata_query/django/shorthand.py: -------------------------------------------------------------------------------- 1 | from django.db.models.query import QuerySet 2 | 3 | from odata_query.grammar import ODataLexer, ODataParser # type: ignore 4 | 5 | from .django_q import AstToDjangoQVisitor 6 | 7 | 8 | def apply_odata_query(queryset: QuerySet, odata_query: str) -> QuerySet: 9 | """ 10 | Shorthand for applying an OData query to a Django QuerySet. 11 | 12 | Args: 13 | queryset: Django QuerySet to apply the OData query to. 14 | odata_query: OData query string. 15 | Returns: 16 | QuerySet: The modified QuerySet 17 | """ 18 | lexer = ODataLexer() 19 | parser = ODataParser() 20 | model = queryset.model 21 | 22 | ast = parser.parse(lexer.tokenize(odata_query)) 23 | transformer = AstToDjangoQVisitor(model) 24 | where_clause = transformer.visit(ast) 25 | 26 | if transformer.queryset_annotations: 27 | queryset = queryset.annotate(**transformer.queryset_annotations) 28 | 29 | return queryset.filter(where_clause) 30 | -------------------------------------------------------------------------------- /odata_query/django/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Type 2 | 3 | from django.db.models import Model 4 | 5 | 6 | def reverse_relationship( 7 | relationship_expr: str, root_model: Type[Model] 8 | ) -> Tuple[str, Type[Model]]: 9 | """ 10 | Reverses a relationship expression relative to root_model. 11 | 12 | Args: 13 | relationship_expr: The Django relationship string, with underscores to 14 | represent relationship traversal. 15 | root_model: The model to which relationship_expr is relative. 16 | 17 | Returns: 18 | str: The django relationship string in reverse, so from the last joined 19 | relationship back to the root model. 20 | Type[Model]: The model to which the returned expression is relative. 21 | """ 22 | relation_steps = relationship_expr.split("__") 23 | 24 | related_model = root_model 25 | path_to_outerref_parts = [] 26 | for step in relation_steps: 27 | related_field = related_model._meta.get_field(step) 28 | related_model = related_field.related_model 29 | path_to_outerref_parts.append(related_field.remote_field.name) 30 | 31 | path_to_outerref = "__".join(reversed(path_to_outerref_parts)) 32 | 33 | return (path_to_outerref, related_model) 34 | -------------------------------------------------------------------------------- /odata_query/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from sly.lex import Token 4 | 5 | 6 | class ODataException(Exception): 7 | """ 8 | Base class for all exceptions in this library. 9 | """ 10 | 11 | pass 12 | 13 | 14 | class ODataSyntaxError(ODataException): 15 | """ 16 | Base class for syntax errors. 17 | """ 18 | 19 | pass 20 | 21 | 22 | class TokenizingException(ODataSyntaxError): 23 | """ 24 | Thrown when the lexer cannot tokenize the query. 25 | """ 26 | 27 | def __init__(self, token: Token): 28 | self.token = token 29 | super().__init__(f"Failed to tokenize at: {token}") 30 | 31 | 32 | class ParsingException(ODataSyntaxError): 33 | """ 34 | Thrown when the parser cannot parse the query. 35 | """ 36 | 37 | def __init__(self, token: Optional[Token], eof: bool = False): 38 | self.token = token 39 | self.eof = eof 40 | super().__init__(f"Failed to parse at: {token}") 41 | 42 | 43 | class FunctionCallException(ODataException): 44 | """ 45 | Base class for errors in function calls. 46 | """ 47 | 48 | pass 49 | 50 | 51 | class UnknownFunctionException(FunctionCallException): 52 | """ 53 | Thrown when the parser encounters an undefined function call. 54 | """ 55 | 56 | def __init__(self, function_name: str): 57 | self.function_name = function_name 58 | super().__init__(f"Unknown function: '{function_name}'") 59 | 60 | 61 | class ArgumentCountException(FunctionCallException): 62 | """ 63 | Thrown when the parser encounters a function called with a wrong number 64 | of arguments. 65 | """ 66 | 67 | def __init__( 68 | self, function_name: str, exp_min_args: int, exp_max_args: int, given_args: int 69 | ): 70 | self.function_name = function_name 71 | self.exp_min_args = exp_min_args 72 | self.exp_max_args = exp_max_args 73 | self.n_args_given = given_args 74 | if exp_min_args != exp_max_args: 75 | super().__init__( 76 | f"Function '{function_name}' takes between {exp_min_args} and " 77 | f"{exp_max_args} arguments. {given_args} given." 78 | ) 79 | else: 80 | super().__init__( 81 | f"Function '{function_name}' takes {exp_min_args} arguments. " 82 | f"{given_args} given." 83 | ) 84 | 85 | 86 | class UnsupportedFunctionException(FunctionCallException): 87 | """ 88 | Thrown when a function is used that is not implemented yet. 89 | """ 90 | 91 | def __init__(self, function_name: str): 92 | self.function_name = function_name 93 | super().__init__(f"Function '{function_name}' is not implemented yet.") 94 | 95 | 96 | class ArgumentTypeException(FunctionCallException): 97 | """ 98 | Thrown when a function is called with argument of the wrong type. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | function_name: Optional[str] = None, 104 | expected_type: Optional[str] = None, 105 | actual_type: Optional[str] = None, 106 | ): 107 | self.function_name = function_name 108 | self.expected_type = expected_type 109 | self.actual_type = actual_type 110 | 111 | if function_name: 112 | message = f"Unsupported or invalid type for function or operator '{function_name}'" 113 | else: 114 | message = "Invalid argument type for function or operator." 115 | if expected_type: 116 | message += f" Expected {expected_type}" 117 | if actual_type: 118 | message += f", got {actual_type}" 119 | 120 | super().__init__(message) 121 | 122 | 123 | class TypeException(ODataException): 124 | """ 125 | Thrown when doing an invalid operation on a value. 126 | E.g. `10 gt null` or `~date()` 127 | """ 128 | 129 | def __init__(self, operation: str, value: str): 130 | self.operation = operation 131 | self.value = value 132 | super().__init__(f"Cannot apply '{operation}' to '{value}'") 133 | 134 | 135 | class ValueException(ODataException): 136 | """ 137 | Thrown when a value has an invalid value, such as an invalid datetime. 138 | """ 139 | 140 | def __init__(self, value: Any): 141 | self.value = value 142 | super().__init__(f"Invalid value: {value}") 143 | 144 | 145 | class InvalidFieldException(ODataException): 146 | """ 147 | Thrown when a field mentioned in a query does not exist. 148 | """ 149 | 150 | def __init__(self, field_name: str): 151 | self.field_name = field_name 152 | super().__init__(f"Invalid field: {field_name}") 153 | -------------------------------------------------------------------------------- /odata_query/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/odata_query/py.typed -------------------------------------------------------------------------------- /odata_query/rewrite.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from . import ast 4 | from .grammar import ODataLexer, ODataParser # type: ignore 5 | from .visitor import NodeTransformer 6 | 7 | 8 | class AliasRewriter(NodeTransformer): 9 | """ 10 | A :class:`NodeTransformer` that replaces aliases in the :term:`AST` with their 11 | aliased identifiers or attributes. 12 | 13 | Args: 14 | field_aliases: A mapping of aliases to their full name. These can 15 | be identifiers, attributes, and even function calls in odata 16 | syntax. 17 | lexer: Optional lexer instance to use. If not passed, will construct 18 | the default one. 19 | parser: Optional parser instance to use. If not passed, will construct 20 | the default one. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | field_aliases: Dict[str, str], 26 | lexer: Optional[ODataLexer] = None, 27 | parser: Optional[ODataParser] = None, 28 | ): 29 | self.field_aliases = field_aliases 30 | 31 | if not lexer: 32 | lexer = ODataLexer() 33 | if not parser: 34 | parser = ODataParser() 35 | 36 | self.replacements = { 37 | parser.parse(lexer.tokenize(k)): parser.parse(lexer.tokenize(v)) 38 | for k, v in self.field_aliases.items() 39 | } 40 | 41 | def visit_Identifier(self, node: ast.Identifier) -> ast._Node: 42 | """:meta private:""" 43 | if node in self.replacements: 44 | return self.replacements[node] 45 | return node 46 | 47 | def visit_Attribute(self, node: ast.Attribute) -> ast._Node: 48 | """:meta private:""" 49 | if node in self.replacements: 50 | return self.replacements[node] 51 | else: 52 | new_owner = self.visit(node.owner) 53 | return ast.Attribute(new_owner, node.attr) 54 | 55 | 56 | class IdentifierStripper(NodeTransformer): 57 | """ 58 | A :class:`NodeTransformer` that strips the given identifier off of 59 | attributes. E.g. ``author/name`` -> ``name``. 60 | 61 | Args: 62 | strip: The identifier to strip off of all attributes in the :term:`AST` 63 | """ 64 | 65 | def __init__(self, strip: ast.Identifier): 66 | self.strip = strip 67 | 68 | def visit_Attribute(self, node: ast.Attribute) -> ast._Node: 69 | """:meta private:""" 70 | if node.owner == self.strip: 71 | return ast.Identifier(node.attr) 72 | elif isinstance(node.owner, ast.Attribute): 73 | return ast.Attribute(self.visit(node.owner), node.attr) 74 | 75 | return node 76 | -------------------------------------------------------------------------------- /odata_query/roundtrip.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Type 3 | 4 | if sys.version_info >= (3, 8): 5 | from typing import Protocol 6 | else: 7 | from typing_extensions import Protocol 8 | 9 | from odata_query import ast, visitor 10 | 11 | 12 | class LiteralValNode(Protocol): 13 | """:meta private:""" 14 | 15 | val: str 16 | 17 | 18 | PRECEDENCE = { 19 | ast.Attribute: 10, 20 | ast.Call: 10, 21 | ast.Not: 9, 22 | ast.USub: 9, 23 | ast.Mult: 8, 24 | ast.Div: 8, 25 | ast.Mod: 8, 26 | ast.Add: 7, 27 | ast.Sub: 7, 28 | ast.Gt: 6, 29 | ast.GtE: 6, 30 | ast.Lt: 6, 31 | ast.LtE: 6, 32 | ast.Eq: 5, 33 | ast.NotEq: 5, 34 | ast.And: 4, 35 | ast.Or: 3, 36 | } 37 | 38 | 39 | class AstToODataVisitor(visitor.NodeVisitor): 40 | """ 41 | :class:`NodeVisitor` that transforms an :term:`AST` back into an OData 42 | query string. 43 | """ 44 | 45 | def visit_Identifier(self, node: ast.Identifier) -> str: 46 | """:meta private:""" 47 | if node.namespace: 48 | prefix = ".".join(node.namespace) 49 | return prefix + "." + node.name 50 | return node.name 51 | 52 | def visit_Attribute(self, node: ast.Attribute) -> str: 53 | """:meta private:""" 54 | return self.visit(node.owner) + "/" + node.attr 55 | 56 | def visit_Null(self, node: ast.Null) -> str: 57 | """:meta private:""" 58 | return "null" 59 | 60 | def visit_String(self, node: ast.String) -> str: 61 | """:meta private:""" 62 | return "'" + node.val + "'" 63 | 64 | def visit_Duration(self, node: ast.Duration) -> str: 65 | """:meta private:""" 66 | return "duration'" + node.val + "'" 67 | 68 | def _visit_Literal(self, node: LiteralValNode) -> str: 69 | """:meta private:""" 70 | return node.val 71 | 72 | visit_Integer = _visit_Literal 73 | visit_Float = _visit_Literal 74 | visit_Boolean = _visit_Literal 75 | visit_Date = _visit_Literal 76 | visit_Time = _visit_Literal 77 | visit_DateTime = _visit_Literal 78 | visit_GUID = _visit_Literal 79 | 80 | def visit_List(self, node: ast.List) -> str: 81 | """:meta private:""" 82 | return "(" + ", ".join(self.visit(v) for v in node.val) + ")" 83 | 84 | def visit_Add(self, node: ast.Add) -> str: 85 | """:meta private:""" 86 | return "add" 87 | 88 | def visit_Sub(self, node: ast.Sub) -> str: 89 | """:meta private:""" 90 | return "sub" 91 | 92 | def visit_Mult(self, node: ast.Mult) -> str: 93 | """:meta private:""" 94 | return "mul" 95 | 96 | def visit_Div(self, node: ast.Div) -> str: 97 | """:meta private:""" 98 | return "div" 99 | 100 | def visit_Mod(self, node: ast.Mod) -> str: 101 | """:meta private:""" 102 | return "mod" 103 | 104 | def visit_BinOp(self, node: ast.BinOp) -> str: 105 | """:meta private:""" 106 | left = self._visit_and_paren_if_precedence_lower(node.left, type(node.op)) 107 | right = self._visit_and_paren_if_precedence_lower(node.right, type(node.op)) 108 | return left + " " + self.visit(node.op) + " " + right 109 | 110 | def visit_Eq(self, node: ast.Eq) -> str: 111 | """:meta private:""" 112 | return "eq" 113 | 114 | def visit_NotEq(self, node: ast.NotEq) -> str: 115 | """:meta private:""" 116 | return "ne" 117 | 118 | def visit_Lt(self, node: ast.Lt) -> str: 119 | """:meta private:""" 120 | return "lt" 121 | 122 | def visit_LtE(self, node: ast.LtE) -> str: 123 | """:meta private:""" 124 | return "le" 125 | 126 | def visit_Gt(self, node: ast.Gt) -> str: 127 | """:meta private:""" 128 | return "gt" 129 | 130 | def visit_GtE(self, node: ast.GtE) -> str: 131 | """:meta private:""" 132 | return "ge" 133 | 134 | def visit_In(self, node: ast.In) -> str: 135 | """:meta private:""" 136 | return "in" 137 | 138 | def visit_Compare(self, node: ast.Compare) -> str: 139 | """:meta private:""" 140 | left = self._visit_and_paren_if_precedence_lower( 141 | node.left, type(node.comparator) 142 | ) 143 | right = self._visit_and_paren_if_precedence_lower( 144 | node.right, type(node.comparator) 145 | ) 146 | return left + " " + self.visit(node.comparator) + " " + right 147 | 148 | def visit_And(self, node: ast.And) -> str: 149 | """:meta private:""" 150 | return "and" 151 | 152 | def visit_Or(self, node: ast.Or) -> str: 153 | """:meta private:""" 154 | return "or" 155 | 156 | def visit_BoolOp(self, node: ast.BoolOp) -> str: 157 | """:meta private:""" 158 | left = self._visit_and_paren_if_precedence_lower(node.left, type(node.op)) 159 | right = self._visit_and_paren_if_precedence_lower(node.right, type(node.op)) 160 | return left + " " + self.visit(node.op) + " " + right 161 | 162 | def visit_Not(self, node: ast.Not) -> str: 163 | """:meta private:""" 164 | return "not" 165 | 166 | def visit_USub(self, node: ast.USub) -> str: 167 | """:meta private:""" 168 | return "-" 169 | 170 | def visit_UnaryOp(self, node: ast.UnaryOp) -> str: 171 | """:meta private:""" 172 | operand = self._visit_and_paren_if_precedence_lower(node.operand, type(node.op)) 173 | return self.visit(node.op) + " " + operand 174 | 175 | def visit_Call(self, node: ast.Call) -> str: 176 | """:meta private:""" 177 | return ( 178 | self.visit(node.func) 179 | + "(" 180 | + ", ".join(self.visit(n) for n in node.args) 181 | + ")" 182 | ) 183 | 184 | def visit_Any(self, node: ast.Any) -> str: 185 | """:meta private:""" 186 | return "any" 187 | 188 | def visit_All(self, node: ast.All) -> str: 189 | """:meta private:""" 190 | return "all" 191 | 192 | def visit_Lambda(self, node: ast.Lambda) -> str: 193 | """:meta private:""" 194 | return self.visit(node.identifier) + ": " + self.visit(node.expression) 195 | 196 | def visit_CollectionLambda(self, node: ast.CollectionLambda) -> str: 197 | """:meta private:""" 198 | return ( 199 | self.visit(node.owner) 200 | + "/" 201 | + self.visit(node.operator) 202 | + "(" 203 | + (self.visit(node.lambda_) if node.lambda_ else "") 204 | + ")" 205 | ) 206 | 207 | def _visit_and_paren_if_precedence_lower( 208 | self, node: ast._Node, precedence: Type[ast._Node] 209 | ) -> str: 210 | """ 211 | Transform `node` by visiting it, then wrap the result in parentheses if 212 | the expressions precedence is lower than that of `precedence`. 213 | 214 | :meta private: 215 | """ 216 | res = self.visit(node) 217 | 218 | if hasattr(node, "op"): 219 | node_op = type(node.op) # type: ignore 220 | elif hasattr(node, "comparator"): 221 | node_op = type(node.comparator) # type: ignore 222 | else: 223 | node_op = type(node) 224 | 225 | node_prec = PRECEDENCE.get(node_op, 100) 226 | check_prec = PRECEDENCE.get(precedence, 100) 227 | 228 | if node_prec < check_prec: 229 | res = "(" + res + ")" 230 | 231 | return res 232 | -------------------------------------------------------------------------------- /odata_query/sql/__init__.py: -------------------------------------------------------------------------------- 1 | from .athena import AstToAthenaSqlVisitor 2 | from .base import AstToSqlVisitor 3 | from .sqlite import AstToSqliteSqlVisitor 4 | -------------------------------------------------------------------------------- /odata_query/sql/athena.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from odata_query import ast, exceptions, typing 4 | 5 | from .base import AstToSqlVisitor 6 | 7 | UNSAFE_CHARS = re.compile(r"[^a-zA-Z0-9_]") 8 | 9 | 10 | def clean_athena_identifier(identifier: str) -> str: 11 | """ 12 | Cleans an Athena identifier so it passes the following validation rules: 13 | 14 | - Table names and table column names in Athena must be lowercase 15 | - Athena table, view, database, and column names allow only underscore special characters 16 | - Names should be quoted or backticked when starting with a number or underscore 17 | 18 | Source: https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html 19 | """ 20 | id_new = identifier.lower() 21 | id_new = UNSAFE_CHARS.sub("_", id_new) 22 | return id_new 23 | 24 | 25 | class AstToAthenaSqlVisitor(AstToSqlVisitor): 26 | """ 27 | :class:`NodeVisitor` that transforms an :term:`AST` into an Athena SQL 28 | ``WHERE`` clause. 29 | 30 | Args: 31 | table_alias: Optional alias for the root table. 32 | """ 33 | 34 | def visit_Identifier(self, node: ast.Identifier) -> str: 35 | ":meta private:" 36 | # Double quotes for column names acc SQL Standard 37 | sql_id = f'"{clean_athena_identifier(node.name)}"' 38 | 39 | if self.table_alias: 40 | sql_id = f'"{self.table_alias}".' + sql_id 41 | 42 | return sql_id 43 | 44 | def visit_DateTime(self, node: ast.DateTime) -> str: 45 | ":meta private:" 46 | return f"FROM_ISO8601_TIMESTAMP('{node.val}')" 47 | 48 | def sqlfunc_length(self, arg: ast._Node) -> str: 49 | ":meta private:" 50 | arg_sql = self.visit(arg) 51 | inferred_type = typing.infer_type(arg) 52 | 53 | # If the input is a string or default, assume str-length: 54 | if inferred_type is ast.String or inferred_type is None: 55 | return f"LENGTH({arg_sql})" 56 | 57 | # If the input is a list, assume list-length: 58 | if inferred_type is ast.List: 59 | return f"CARDINALITY({arg_sql})" 60 | 61 | raise exceptions.ArgumentTypeException("length") 62 | 63 | def sqlfunc_substring(self, *args: ast._Node) -> str: 64 | ":meta private:" 65 | args_sql = [self.visit(arg) for arg in args] 66 | inferred_type = typing.infer_type(args[0]) 67 | 68 | # If the first input is a string or default, assume str-substr: 69 | if inferred_type is ast.String or inferred_type is None: 70 | if len(args) == 2: 71 | return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1)" 72 | if len(args) == 3: 73 | return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1, {args_sql[2]})" 74 | 75 | # If the first input is a list, assume list-substr: 76 | if inferred_type is ast.List: 77 | if len(args) == 2: 78 | return f"SLICE({args_sql[0]}, {args_sql[1]})" 79 | if len(args) == 3: 80 | return f"SLICE({args_sql[0]}, {args_sql[1]}, {args_sql[2]})" 81 | 82 | raise exceptions.ArgumentTypeException("substring") 83 | 84 | def sqlfunc_round(self, arg: ast._Node) -> str: 85 | ":meta private:" 86 | arg_sql = self.visit(arg) 87 | return f"ROUND({arg_sql})" 88 | 89 | def sqlfunc_floor(self, arg: ast._Node) -> str: 90 | ":meta private:" 91 | arg_sql = self.visit(arg) 92 | return f"FLOOR({arg_sql})" 93 | 94 | def sqlfunc_ceiling(self, arg: ast._Node) -> str: 95 | ":meta private:" 96 | arg_sql = self.visit(arg) 97 | return f"CEILING({arg_sql})" 98 | 99 | def sqlfunc_hassubset(self, *args: ast._Node) -> str: 100 | ":meta private:" 101 | args_sql = [self.visit(arg) for arg in args] 102 | return f"CARDINALITY(ARRAY_INTERSECT({args_sql[0]}, {args_sql[1]})) = CARDINALITY({args_sql[1]})" 103 | -------------------------------------------------------------------------------- /odata_query/sql/sqlite.py: -------------------------------------------------------------------------------- 1 | from odata_query import ast, exceptions, typing 2 | 3 | from .base import AstToSqlVisitor 4 | 5 | 6 | class AstToSqliteSqlVisitor(AstToSqlVisitor): 7 | """ 8 | :class:`NodeVisitor` that transforms an :term:`AST` into a SQLite SQL 9 | ``WHERE`` clause. 10 | 11 | Args: 12 | table_alias: Optional alias for the root table. 13 | """ 14 | 15 | def visit_Boolean(self, node: ast.Boolean) -> str: 16 | """:meta private:""" 17 | if node.py_val: 18 | return "1" 19 | return "0" 20 | 21 | def visit_Date(self, node: ast.Date) -> str: 22 | """:meta private:""" 23 | return f"DATE('{node.val}')" 24 | 25 | def visit_DateTime(self, node: ast.DateTime) -> str: 26 | """:meta private:""" 27 | return f"DATETIME('{node.val}')" 28 | 29 | def sqlfunc_indexof(self, *args: ast._Node) -> str: 30 | """:meta private:""" 31 | args_sql = [self.visit(arg) for arg in args] 32 | inferred_type = [typing.infer_type(arg) for arg in args] 33 | 34 | # If any of the inputs is a string, assume str-indexof: 35 | if any(typ is ast.String for typ in inferred_type) or all( 36 | typ is None for typ in inferred_type 37 | ): 38 | return f"INSTR({args_sql[0]}, {args_sql[1]}) - 1" 39 | 40 | # If any of the inputs is a list, assume list-indexof 41 | # which isn't easily doable at the moment: 42 | if any(typ is ast.List for typ in inferred_type): 43 | raise exceptions.UnsupportedFunctionException("indexof") 44 | 45 | raise exceptions.ArgumentTypeException("indexof") 46 | 47 | def sqlfunc_length(self, arg: ast._Node) -> str: 48 | """:meta private:""" 49 | arg_sql = self.visit(arg) 50 | return f"LENGTH({arg_sql})" 51 | 52 | def sqlfunc_substring(self, *args: ast._Node) -> str: 53 | """:meta private:""" 54 | args_sql = [self.visit(arg) for arg in args] 55 | inferred_type = typing.infer_type(args[0]) 56 | 57 | # If the first input is a string or default, assume str-substr: 58 | if inferred_type is ast.String or inferred_type is None: 59 | if len(args) == 2: 60 | return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1)" 61 | if len(args) == 3: 62 | return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1, {args_sql[2]})" 63 | 64 | # If the first input is a list, assume list-substr: 65 | if inferred_type is ast.List: 66 | raise exceptions.UnsupportedFunctionException("substring") 67 | 68 | raise exceptions.ArgumentTypeException("substring") 69 | 70 | def sqlfunc_year(self, arg: ast._Node) -> str: 71 | """:meta private:""" 72 | arg_sql = self.visit(arg) 73 | return f"CAST(STRFTIME('%Y', {arg_sql}) AS INTEGER)" 74 | 75 | def sqlfunc_month(self, arg: ast._Node) -> str: 76 | """:meta private:""" 77 | arg_sql = self.visit(arg) 78 | return f"CAST(STRFTIME('%m', {arg_sql}) AS INTEGER)" 79 | 80 | def sqlfunc_day(self, arg: ast._Node) -> str: 81 | """:meta private:""" 82 | arg_sql = self.visit(arg) 83 | return f"CAST(STRFTIME('%d', {arg_sql}) AS INTEGER)" 84 | 85 | def sqlfunc_hour(self, arg: ast._Node) -> str: 86 | """:meta private:""" 87 | arg_sql = self.visit(arg) 88 | return f"CAST(STRFTIME('%H', {arg_sql}) AS INTEGER)" 89 | 90 | def sqlfunc_minute(self, arg: ast._Node) -> str: 91 | """:meta private:""" 92 | arg_sql = self.visit(arg) 93 | return f"CAST(STRFTIME('%M', {arg_sql}) AS INTEGER)" 94 | 95 | def sqlfunc_date(self, arg: ast._Node) -> str: 96 | """:meta private:""" 97 | arg_sql = self.visit(arg) 98 | return f"DATE({arg_sql})" 99 | 100 | def sqlfunc_now(self) -> str: 101 | """:meta private:""" 102 | return "DATETIME('now')" 103 | 104 | def sqlfunc_round(self, arg: ast._Node) -> str: 105 | """:meta private:""" 106 | arg_sql = self.visit(arg) 107 | return f"TRUNC({arg_sql} + 0.5)" 108 | 109 | def sqlfunc_floor(self, arg: ast._Node) -> str: 110 | """:meta private:""" 111 | arg_sql = self.visit(arg) 112 | return f"FLOOR({arg_sql})" 113 | 114 | def sqlfunc_ceiling(self, arg: ast._Node) -> str: 115 | """:meta private:""" 116 | arg_sql = self.visit(arg) 117 | return f"CEILING({arg_sql})" 118 | -------------------------------------------------------------------------------- /odata_query/sqlalchemy/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import AstToSqlAlchemyCoreVisitor 2 | from .orm import AstToSqlAlchemyOrmVisitor 3 | from .shorthand import apply_odata_core, apply_odata_query 4 | -------------------------------------------------------------------------------- /odata_query/sqlalchemy/common.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from typing import Any, Callable, Optional, Union 3 | 4 | from sqlalchemy.sql import functions 5 | from sqlalchemy.sql.expression import ( 6 | BinaryExpression, 7 | BindParameter, 8 | BooleanClauseList, 9 | ClauseElement, 10 | False_, 11 | Null, 12 | True_, 13 | and_, 14 | cast, 15 | extract, 16 | false, 17 | literal, 18 | null, 19 | or_, 20 | true, 21 | ) 22 | from sqlalchemy.types import Date, Time 23 | 24 | from odata_query import ast, exceptions as ex, typing, visitor 25 | 26 | from . import functions_ext 27 | 28 | 29 | class _CommonVisitors(visitor.NodeVisitor): 30 | """ 31 | Contains the visitor methods that are equal between SQLAlchemy Core and ORM. 32 | """ 33 | 34 | def visit_Null(self, node: ast.Null) -> Null: 35 | ":meta private:" 36 | return null() 37 | 38 | def visit_Integer(self, node: ast.Integer) -> BindParameter: 39 | ":meta private:" 40 | return literal(node.py_val) 41 | 42 | def visit_Float(self, node: ast.Float) -> BindParameter: 43 | ":meta private:" 44 | return literal(node.py_val) 45 | 46 | def visit_Boolean(self, node: ast.Boolean) -> Union[True_, False_]: 47 | ":meta private:" 48 | if node.val == "true": 49 | return true() 50 | else: 51 | return false() 52 | 53 | def visit_String(self, node: ast.String) -> BindParameter: 54 | ":meta private:" 55 | return literal(node.py_val) 56 | 57 | def visit_Date(self, node: ast.Date) -> BindParameter: 58 | ":meta private:" 59 | try: 60 | return literal(node.py_val) 61 | except ValueError: 62 | raise ex.ValueException(node.val) 63 | 64 | def visit_DateTime(self, node: ast.DateTime) -> BindParameter: 65 | ":meta private:" 66 | try: 67 | return literal(node.py_val) 68 | except ValueError: 69 | raise ex.ValueException(node.val) 70 | 71 | def visit_Time(self, node: ast.Time) -> BindParameter: 72 | ":meta private:" 73 | try: 74 | return literal(node.py_val) 75 | except ValueError: 76 | raise ex.ValueException(node.val) 77 | 78 | def visit_Duration(self, node: ast.Duration) -> BindParameter: 79 | ":meta private:" 80 | return literal(node.py_val) 81 | 82 | def visit_GUID(self, node: ast.GUID) -> BindParameter: 83 | ":meta private:" 84 | return literal(node.val) 85 | 86 | def visit_List(self, node: ast.List) -> list: 87 | ":meta private:" 88 | return [self.visit(n) for n in node.val] 89 | 90 | def visit_Add(self, node: ast.Add) -> Callable[[Any, Any], Any]: 91 | ":meta private:" 92 | return operator.add 93 | 94 | def visit_Sub(self, node: ast.Sub) -> Callable[[Any, Any], Any]: 95 | ":meta private:" 96 | return operator.sub 97 | 98 | def visit_Mult(self, node: ast.Mult) -> Callable[[Any, Any], Any]: 99 | ":meta private:" 100 | return operator.mul 101 | 102 | def visit_Div(self, node: ast.Div) -> Callable[[Any, Any], Any]: 103 | ":meta private:" 104 | return operator.truediv 105 | 106 | def visit_Mod(self, node: ast.Mod) -> Callable[[Any, Any], Any]: 107 | ":meta private:" 108 | return operator.mod 109 | 110 | def visit_BinOp(self, node: ast.BinOp) -> Any: 111 | ":meta private:" 112 | left = self.visit(node.left) 113 | right = self.visit(node.right) 114 | op = self.visit(node.op) 115 | 116 | return op(left, right) 117 | 118 | def visit_Eq( 119 | self, node: ast.Eq 120 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 121 | ":meta private:" 122 | return operator.eq 123 | 124 | def visit_NotEq( 125 | self, node: ast.NotEq 126 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 127 | ":meta private:" 128 | return operator.ne 129 | 130 | def visit_Lt( 131 | self, node: ast.Lt 132 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 133 | ":meta private:" 134 | return operator.lt 135 | 136 | def visit_LtE( 137 | self, node: ast.LtE 138 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 139 | ":meta private:" 140 | return operator.le 141 | 142 | def visit_Gt( 143 | self, node: ast.Gt 144 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 145 | ":meta private:" 146 | return operator.gt 147 | 148 | def visit_GtE( 149 | self, node: ast.GtE 150 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 151 | ":meta private:" 152 | return operator.ge 153 | 154 | def visit_In( 155 | self, node: ast.In 156 | ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]: 157 | ":meta private:" 158 | return lambda a, b: a.in_(b) 159 | 160 | def visit_And( 161 | self, node: ast.And 162 | ) -> Callable[[ClauseElement, ClauseElement], BooleanClauseList]: 163 | ":meta private:" 164 | return and_ 165 | 166 | def visit_Or( 167 | self, node: ast.Or 168 | ) -> Callable[[ClauseElement, ClauseElement], BooleanClauseList]: 169 | ":meta private:" 170 | return or_ 171 | 172 | def visit_BoolOp(self, node: ast.BoolOp) -> BooleanClauseList: 173 | ":meta private:" 174 | left = self.visit(node.left) 175 | right = self.visit(node.right) 176 | op = self.visit(node.op) 177 | return op(left, right) 178 | 179 | def visit_Not(self, node: ast.Not) -> Callable[[ClauseElement], ClauseElement]: 180 | ":meta private:" 181 | return operator.invert 182 | 183 | def visit_UnaryOp(self, node: ast.UnaryOp) -> ClauseElement: 184 | ":meta private:" 185 | mod = self.visit(node.op) 186 | val = self.visit(node.operand) 187 | 188 | try: 189 | return mod(val) 190 | except TypeError: 191 | raise ex.TypeException(node.op.__class__.__name__, val) 192 | 193 | def visit_Call(self, node: ast.Call) -> ClauseElement: 194 | ":meta private:" 195 | try: 196 | handler = getattr(self, "func_" + node.func.name.lower()) 197 | except AttributeError: 198 | raise ex.UnsupportedFunctionException(node.func.name) 199 | 200 | return handler(*node.args) 201 | 202 | def func_contains(self, field: ast._Node, substr: ast._Node) -> ClauseElement: 203 | ":meta private:" 204 | return self._substr_function(field, substr, "contains") 205 | 206 | def func_startswith(self, field: ast._Node, substr: ast._Node) -> ClauseElement: 207 | ":meta private:" 208 | return self._substr_function(field, substr, "startswith") 209 | 210 | def func_endswith(self, field: ast._Node, substr: ast._Node) -> ClauseElement: 211 | ":meta private:" 212 | return self._substr_function(field, substr, "endswith") 213 | 214 | def func_length(self, arg: ast._Node) -> functions.Function: 215 | ":meta private:" 216 | return functions.char_length(self.visit(arg)) 217 | 218 | def func_concat(self, *args: ast._Node) -> functions.Function: 219 | ":meta private:" 220 | return functions.concat(*[self.visit(arg) for arg in args]) 221 | 222 | def func_indexof(self, first: ast._Node, second: ast._Node) -> functions.Function: 223 | ":meta private:" 224 | # TODO: Highly dialect dependent, might want to implement in GenericFunction: 225 | # Subtract 1 because OData is 0-indexed while SQL is 1-indexed 226 | return functions_ext.strpos(self.visit(first), self.visit(second)) - 1 227 | 228 | def func_substring( 229 | self, fullstr: ast._Node, index: ast._Node, nchars: Optional[ast._Node] = None 230 | ) -> functions.Function: 231 | ":meta private:" 232 | # Add 1 because OData is 0-indexed while SQL is 1-indexed 233 | if nchars: 234 | return functions_ext.substr( 235 | self.visit(fullstr), 236 | self.visit(index) + 1, 237 | self.visit(nchars), 238 | ) 239 | else: 240 | return functions_ext.substr(self.visit(fullstr), self.visit(index) + 1) 241 | 242 | def func_matchespattern( 243 | self, field: ast._Node, pattern: ast._Node 244 | ) -> functions.Function: 245 | ":meta private:" 246 | identifier = self.visit(field) 247 | return identifier.regexp_match(self.visit(pattern)) 248 | 249 | def func_tolower(self, field: ast._Node) -> functions.Function: 250 | ":meta private:" 251 | return functions_ext.lower(self.visit(field)) 252 | 253 | def func_toupper(self, field: ast._Node) -> functions.Function: 254 | ":meta private:" 255 | return functions_ext.upper(self.visit(field)) 256 | 257 | def func_trim(self, field: ast._Node) -> functions.Function: 258 | ":meta private:" 259 | return functions_ext.ltrim(functions_ext.rtrim(self.visit(field))) 260 | 261 | def func_date(self, field: ast._Node) -> ClauseElement: 262 | ":meta private:" 263 | return cast(self.visit(field), Date) 264 | 265 | def func_day(self, field: ast._Node) -> functions.Function: 266 | ":meta private:" 267 | return extract("day", self.visit(field)) 268 | 269 | def func_hour(self, field: ast._Node) -> functions.Function: 270 | ":meta private:" 271 | return extract("hour", self.visit(field)) 272 | 273 | def func_minute(self, field: ast._Node) -> functions.Function: 274 | ":meta private:" 275 | return extract("minute", self.visit(field)) 276 | 277 | def func_month(self, field: ast._Node) -> functions.Function: 278 | ":meta private:" 279 | return extract("month", self.visit(field)) 280 | 281 | def func_now(self) -> functions.Function: 282 | ":meta private:" 283 | return functions.now() 284 | 285 | def func_second(self, field: ast._Node) -> functions.Function: 286 | ":meta private:" 287 | return extract("second", self.visit(field)) 288 | 289 | def func_time(self, field: ast._Node) -> functions.Function: 290 | ":meta private:" 291 | return cast(self.visit(field), Time) 292 | 293 | def func_year(self, field: ast._Node) -> functions.Function: 294 | ":meta private:" 295 | return extract("year", self.visit(field)) 296 | 297 | def func_ceiling(self, field: ast._Node) -> functions.Function: 298 | ":meta private:" 299 | return functions_ext.ceil(self.visit(field)) 300 | 301 | def func_floor(self, field: ast._Node) -> functions.Function: 302 | ":meta private:" 303 | return functions_ext.floor(self.visit(field)) 304 | 305 | def func_round(self, field: ast._Node) -> functions.Function: 306 | ":meta private:" 307 | return functions_ext.round(self.visit(field)) 308 | 309 | def _substr_function( 310 | self, field: ast._Node, substr: ast._Node, func: str 311 | ) -> ClauseElement: 312 | ":meta private:" 313 | typing.typecheck(field, (ast.Identifier, ast.String), "field") 314 | typing.typecheck(substr, ast.String, "substring") 315 | 316 | identifier = self.visit(field) 317 | substring = self.visit(substr) 318 | op = getattr(identifier, func) 319 | 320 | return op(substring) 321 | -------------------------------------------------------------------------------- /odata_query/sqlalchemy/core.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | import sqlalchemy as sa 4 | from sqlalchemy.sql.expression import BinaryExpression, ClauseElement, ColumnClause 5 | 6 | from odata_query import ast, exceptions as ex, visitor 7 | 8 | from . import common 9 | 10 | 11 | class AstToSqlAlchemyCoreVisitor(common._CommonVisitors, visitor.NodeVisitor): 12 | """ 13 | :class:`NodeVisitor` that transforms an :term:`AST` into a SQLAlchemy where 14 | clause using Core features. 15 | 16 | Args: 17 | table: A SQLalchemy table 18 | """ 19 | 20 | def __init__(self, table: Type[sa.Table]): 21 | self.table = table 22 | 23 | def visit_Identifier(self, node: ast.Identifier) -> ColumnClause: 24 | """:meta private:""" 25 | try: 26 | return self.table.c[node.name] 27 | except KeyError: 28 | raise ex.InvalidFieldException(node.name) 29 | 30 | def visit_Attribute(self, node: ast.Attribute) -> ColumnClause: 31 | """:meta private:""" 32 | raise NotImplementedError( 33 | "Relationship traversal is not yet implemented for the SQLAlchemy Core visitor." 34 | ) 35 | 36 | def visit_Compare(self, node: ast.Compare) -> BinaryExpression: 37 | """:meta private:""" 38 | left = self.visit(node.left) 39 | right = self.visit(node.right) 40 | op = self.visit(node.comparator) 41 | return op(left, right) 42 | 43 | def visit_CollectionLambda(self, node: ast.CollectionLambda) -> ClauseElement: 44 | """:meta private:""" 45 | raise NotImplementedError( 46 | "Collection lambda is not yet implemented for the SQLAlchemy Core visitor." 47 | ) 48 | -------------------------------------------------------------------------------- /odata_query/sqlalchemy/functions_ext.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.sql.functions import GenericFunction 2 | from sqlalchemy.types import Integer, String 3 | 4 | 5 | class strpos(GenericFunction): 6 | type = Integer 7 | package = "odata" 8 | inherit_cache = True 9 | 10 | 11 | class substr(GenericFunction): 12 | type = String 13 | package = "odata" 14 | inherit_cache = True 15 | 16 | 17 | class lower(GenericFunction): 18 | type = String 19 | package = "odata" 20 | inherit_cache = True 21 | 22 | 23 | class upper(GenericFunction): 24 | type = String 25 | package = "odata" 26 | inherit_cache = True 27 | 28 | 29 | class ltrim(GenericFunction): 30 | type = String 31 | package = "odata" 32 | inherit_cache = True 33 | 34 | 35 | class rtrim(GenericFunction): 36 | type = String 37 | package = "odata" 38 | inherit_cache = True 39 | 40 | 41 | class ceil(GenericFunction): 42 | type = Integer 43 | package = "odata" 44 | inherit_cache = True 45 | 46 | 47 | class floor(GenericFunction): 48 | type = Integer 49 | package = "odata" 50 | inherit_cache = True 51 | 52 | 53 | class round(GenericFunction): 54 | type = Integer 55 | package = "odata" 56 | inherit_cache = True 57 | -------------------------------------------------------------------------------- /odata_query/sqlalchemy/orm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type 2 | 3 | from sqlalchemy.inspection import inspect 4 | from sqlalchemy.orm.attributes import InstrumentedAttribute 5 | from sqlalchemy.orm.decl_api import DeclarativeMeta 6 | from sqlalchemy.orm.relationships import RelationshipProperty 7 | from sqlalchemy.sql.expression import BinaryExpression, ClauseElement, ColumnClause 8 | 9 | from odata_query import ast, exceptions as ex, utils, visitor 10 | 11 | from . import common 12 | 13 | 14 | class AstToSqlAlchemyOrmVisitor(common._CommonVisitors, visitor.NodeVisitor): 15 | """ 16 | :class:`NodeVisitor` that transforms an :term:`AST` into a SQLAlchemy query 17 | filter clause using ORM features. 18 | 19 | Args: 20 | root_model: The root model of the query. 21 | """ 22 | 23 | def __init__(self, root_model: Type[DeclarativeMeta]): 24 | self.root_model = root_model 25 | self.join_relationships: List[InstrumentedAttribute] = [] 26 | 27 | def visit_Identifier(self, node: ast.Identifier) -> ColumnClause: 28 | ":meta private:" 29 | try: 30 | return getattr(self.root_model, node.name) 31 | except AttributeError: 32 | raise ex.InvalidFieldException(node.name) 33 | 34 | def visit_Attribute(self, node: ast.Attribute) -> ColumnClause: 35 | ":meta private:" 36 | rel_attr = self.visit(node.owner) 37 | # Owner is an InstrumentedAttribute, hopefully of a relationship. 38 | # But we need the model pointed to by the relationship. 39 | prop_inspect = inspect(rel_attr).property 40 | if not isinstance(prop_inspect, RelationshipProperty): 41 | # TODO: new exception: 42 | raise ValueError(f"Not a relationship: {node.owner}") 43 | self.join_relationships.append(rel_attr) 44 | 45 | # We'd like to reference the column on the related class: 46 | owner_cls = prop_inspect.entity.class_ 47 | try: 48 | return getattr(owner_cls, node.attr) 49 | except AttributeError: 50 | raise ex.InvalidFieldException(node.attr) 51 | 52 | def visit_Compare(self, node: ast.Compare) -> BinaryExpression: 53 | ":meta private:" 54 | left = self.visit(node.left) 55 | right = self.visit(node.right) 56 | op = self.visit(node.comparator) 57 | 58 | # If a node is a `relationship` representing a single foreign key, 59 | # the client meant to compare the foreign key, not the related object. 60 | # E.g. In "blogpost/author eq 1", left should be "blogpost/author_id" 61 | left = self._maybe_sub_relationship_with_foreign_key(left) 62 | right = self._maybe_sub_relationship_with_foreign_key(right) 63 | 64 | return op(left, right) 65 | 66 | def visit_CollectionLambda(self, node: ast.CollectionLambda) -> ClauseElement: 67 | ":meta private:" 68 | owner_prop = self.visit(node.owner) 69 | collection_model = inspect(owner_prop).property.entity.class_ 70 | 71 | if node.lambda_: 72 | # For the lambda, we want to strip the identifier off, because 73 | # we will execute this as a subquery in the wanted model's context. 74 | subq_ast = utils.expression_relative_to_identifier( 75 | node.lambda_.identifier, node.lambda_.expression 76 | ) 77 | subq_transformer = self.__class__(collection_model) 78 | subquery_filter = subq_transformer.visit(subq_ast) 79 | else: 80 | subquery_filter = None 81 | 82 | if isinstance(node.operator, ast.Any): 83 | return owner_prop.any(subquery_filter) 84 | else: 85 | # For an ALL query, invert both the filter and the EXISTS: 86 | if node.lambda_: 87 | subquery_filter = ~subquery_filter 88 | return ~owner_prop.any(subquery_filter) 89 | 90 | def _maybe_sub_relationship_with_foreign_key( 91 | self, elem: ClauseElement 92 | ) -> ClauseElement: 93 | """ 94 | If the given ClauseElement is a `relationship` with a single ForeignKey, 95 | replace it with the `ForeignKey` itself. 96 | 97 | :meta private: 98 | """ 99 | try: 100 | prop_inspect = inspect(elem).property 101 | if isinstance(prop_inspect, RelationshipProperty): 102 | foreign_key = prop_inspect._calculated_foreign_keys 103 | if len(foreign_key) == 1: 104 | return next(iter(foreign_key)) 105 | except Exception: 106 | pass 107 | 108 | return elem 109 | -------------------------------------------------------------------------------- /odata_query/sqlalchemy/shorthand.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from sqlalchemy.orm.query import Query 4 | from sqlalchemy.sql.expression import ClauseElement, Select 5 | 6 | from odata_query.grammar import ODataLexer, ODataParser # type: ignore 7 | 8 | from .core import AstToSqlAlchemyCoreVisitor 9 | from .orm import AstToSqlAlchemyOrmVisitor 10 | 11 | 12 | def _get_joined_attrs(query: Select) -> List[str]: 13 | # use _legacy_setup_joins for legacy Query objects 14 | setup_joins = ( 15 | getattr(query, "_legacy_setup_joins", query._setup_joins) or query._setup_joins 16 | ) 17 | return [str(join[0]) for join in setup_joins] 18 | 19 | 20 | def apply_odata_query(query: ClauseElement, odata_query: str) -> ClauseElement: 21 | """ 22 | Shorthand for applying an OData query to a SQLAlchemy query. 23 | 24 | Args: 25 | query: SQLAlchemy query to apply the OData query to. 26 | odata_query: OData query string. 27 | Returns: 28 | ClauseElement: The modified query 29 | """ 30 | lexer = ODataLexer() 31 | parser = ODataParser() 32 | 33 | clause_elem: ClauseElement 34 | if isinstance(query, Query): 35 | # For now, we keep supporting the 1.x style of queries unofficially. 36 | # GITHUB-34 37 | clause_elem = query.__clause_element__() 38 | else: 39 | clause_elem = query 40 | 41 | model = clause_elem.columns_clause_froms[0].entity_namespace 42 | 43 | ast = parser.parse(lexer.tokenize(odata_query)) 44 | transformer = AstToSqlAlchemyOrmVisitor(model) 45 | where_clause = transformer.visit(ast) 46 | 47 | existing_joins = _get_joined_attrs(query) 48 | for required_join in transformer.join_relationships: 49 | if ( 50 | str(required_join) not in existing_joins 51 | and str(required_join.key) not in existing_joins 52 | ): 53 | query = query.join(required_join) 54 | 55 | return query.filter(where_clause) 56 | 57 | 58 | def apply_odata_core(query: ClauseElement, odata_query: str) -> ClauseElement: 59 | """ 60 | Shorthand for applying an OData query to a SQLAlchemy core. 61 | 62 | Args: 63 | query: SQLAlchemy query to apply the OData query to. 64 | odata_query: OData query string. 65 | Returns: 66 | ClauseElement: The modified query 67 | """ 68 | lexer = ODataLexer() 69 | parser = ODataParser() 70 | table = query.columns_clause_froms[0] 71 | 72 | ast = parser.parse(lexer.tokenize(odata_query)) 73 | transformer = AstToSqlAlchemyCoreVisitor(table) 74 | where_clause = transformer.visit(ast) 75 | return query.filter(where_clause) 76 | -------------------------------------------------------------------------------- /odata_query/typing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import operator 3 | from typing import Optional, Tuple, Type, Union 4 | 5 | from . import ast, exceptions as ex 6 | 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def typecheck( 11 | node: ast._Node, expected_type: Union[Type, Tuple[Type, ...]], field_name: str 12 | ) -> None: 13 | """ 14 | Checks that the inferred type of ``node`` is (one) of ``expected_type``, and 15 | raises :class:`ArgumentTypeException` if not. 16 | 17 | Args: 18 | node: The node to type check. 19 | expected_type: The allowed type(s) the node can have. 20 | field_name: The name of the field you're typechecking. Only used in the 21 | exception. 22 | Raises: 23 | ArgumentTypeException 24 | """ 25 | actual_type = infer_type(node) 26 | compare = operator.contains if isinstance(expected_type, tuple) else operator.eq 27 | if actual_type and not compare(expected_type, actual_type): 28 | allowed = ( 29 | [t.__name__ for t in expected_type] 30 | if isinstance(expected_type, tuple) 31 | else expected_type.__name__ 32 | ) 33 | raise ex.ArgumentTypeException(field_name, str(allowed), actual_type.__name__) 34 | 35 | 36 | def infer_type(node: ast._Node) -> Optional[Type[ast._Node]]: 37 | """ 38 | Tries to infer the type of ``node``. 39 | 40 | Args: 41 | node: The node to infer the type for. 42 | Returns: 43 | The inferred type or ``None`` if unable to infer. 44 | """ 45 | if isinstance(node, (ast._Literal)): 46 | return type(node) 47 | 48 | if isinstance(node, (ast.Compare, ast.BoolOp)): 49 | return ast.Boolean 50 | 51 | if isinstance(node, ast.Call): 52 | return infer_return_type(node) 53 | 54 | log.debug("Failed to infer type for %s", node) 55 | return None 56 | 57 | 58 | def infer_return_type(node: ast.Call) -> Optional[Type[ast._Node]]: 59 | """ 60 | Tries to infer the type of a function call ``node``. 61 | 62 | Args: 63 | node: The node to infer the type for. 64 | Returns: 65 | The inferred type or ``None`` if unable to infer. 66 | """ 67 | func = node.func.full_name() 68 | 69 | if func in ( 70 | "contains", 71 | "endswith", 72 | "startswith", 73 | "hassubset", 74 | "hassubsequence", 75 | "geo.intersects", 76 | ): 77 | return ast.Boolean 78 | 79 | if func in ( 80 | "indexof", 81 | "length", 82 | "year", 83 | "month", 84 | "day", 85 | "hour", 86 | "minute", 87 | "second", 88 | "totaloffsetminutes", 89 | ): 90 | return ast.Integer 91 | 92 | if func in ( 93 | "fractionalseconds", 94 | "totalseconds", 95 | "ceiling", 96 | "floor", 97 | "round", 98 | "geo.distance", 99 | "geo.length", 100 | ): 101 | return ast.Float 102 | 103 | if func in ("tolower", "toupper", "trim"): 104 | return ast.String 105 | 106 | if func == "date": 107 | return ast.Date 108 | 109 | if func in ("maxdatetime", "mindatetime", "now"): 110 | return ast.DateTime 111 | 112 | if func == "concat": 113 | return infer_type(node.args[0]) or infer_type(node.args[1]) 114 | 115 | if func == "substring": 116 | return infer_type(node.args[0]) 117 | 118 | return None 119 | -------------------------------------------------------------------------------- /odata_query/utils.py: -------------------------------------------------------------------------------- 1 | from . import ast 2 | from .rewrite import IdentifierStripper 3 | 4 | 5 | def expression_relative_to_identifier( 6 | identifier: ast.Identifier, expression: ast._Node 7 | ) -> ast._Node: 8 | """ 9 | Shorthand for the :class:`IdentifierStripper`. 10 | 11 | Args: 12 | identifier: Identifier to strip from ``expression``. 13 | expression: Expression to strip the ``identifier`` from. 14 | 15 | Returns: 16 | The ``expression`` relative to the ``identifier``. 17 | """ 18 | stripper = IdentifierStripper(identifier) 19 | result = stripper.visit(expression) 20 | return result 21 | -------------------------------------------------------------------------------- /odata_query/visitor.py: -------------------------------------------------------------------------------- 1 | from dataclasses import fields 2 | from typing import Any, Iterator, Tuple 3 | 4 | from . import ast 5 | 6 | 7 | def iter_dataclass_fields(node: ast._Node) -> Iterator[Tuple[str, Any]]: 8 | """ 9 | Loops over all fields of the given node, yielding the field's name and 10 | the current value. 11 | 12 | Yields: 13 | Tuples of ``(fieldname, value)`` for each field in ``node._fields``. 14 | """ 15 | for field in fields(node): 16 | yield field.name, getattr(node, field.name) 17 | 18 | 19 | class NodeVisitor: 20 | """ 21 | Base class for visitors that walk the :term:`AST` and calls a visitor 22 | method for every node found. This method may return a value 23 | which is forwarded by the :func:`visit` method. 24 | 25 | This class is meant to be subclassed, with the subclass adding visitor 26 | methods. 27 | By default the visitor methods for the nodes are named ``'visit_'`` + 28 | class name of the node (e.g. ``visit_Identifier(self, identifier)``). 29 | If no visitor method exists for a node, the :func:`generic_visit` visitor is 30 | used instead. 31 | """ 32 | 33 | def visit(self, node: ast._Node) -> Any: 34 | """ 35 | Looks for an explicit node visiting method on ``self``, 36 | otherwise calls :func:`generic_visit`. 37 | 38 | Returns: 39 | Whatever the called method returned. The user is free to choose what 40 | the :class:`NodeVisitor` should return. 41 | """ 42 | method = "visit_" + node.__class__.__name__ 43 | visitor = getattr(self, method, self.generic_visit) 44 | return visitor(node) 45 | 46 | def generic_visit(self, node: ast._Node): 47 | """ 48 | Visits all fields on ``node`` recursively. 49 | Called if no explicit visitor method exists for a node. 50 | """ 51 | for field, value in iter_dataclass_fields(node): 52 | if isinstance(value, list): 53 | for item in value: 54 | if isinstance(item, ast._Node): 55 | self.visit(item) 56 | elif isinstance(value, ast._Node): 57 | self.visit(value) 58 | 59 | 60 | class NodeTransformer(NodeVisitor): 61 | """ 62 | A subclass of :class:`NodeVisitor` that allows replacing of nodes in the 63 | :term:`AST` as it passes over it. The visitor methods should return instances 64 | of :class:`_Node` that replace the passed node. 65 | """ 66 | 67 | def generic_visit(self, node: ast._Node) -> ast._Node: 68 | new_kwargs = {} 69 | 70 | for field, value in iter_dataclass_fields(node): 71 | if isinstance(value, list): 72 | new_val = [] 73 | for item in value: 74 | if isinstance(item, ast._Node): 75 | new_val.append(self.visit(item)) 76 | else: 77 | new_val.append(item) 78 | elif isinstance(value, ast._Node): 79 | new_val = self.visit(value) 80 | else: 81 | new_val = value 82 | 83 | new_kwargs[field] = new_val 84 | 85 | return type(node)(**new_kwargs) 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "odata-query" 3 | version = "0.10.0" 4 | description = "An OData query parser and transpiler." 5 | authors = ["Oliver Hofkens "] 6 | readme = "README.rst" 7 | license = "MIT" 8 | keywords = ["OData", "Query", "Parser"] 9 | classifiers = [ 10 | "Development Status :: 5 - Production/Stable", 11 | "Environment :: Web Environment", 12 | "Framework :: Django", 13 | "Intended Audience :: Developers", 14 | "License :: OSI Approved :: MIT License", 15 | "Operating System :: OS Independent", 16 | "Programming Language :: Python :: 3 :: Only", 17 | "Programming Language :: SQL", 18 | "Topic :: Database", 19 | "Topic :: Internet :: WWW/HTTP", 20 | "Topic :: Internet :: WWW/HTTP :: Indexing/Search", 21 | "Topic :: Software Development :: Compilers" 22 | ] 23 | include = ["odata_query/py.typed"] 24 | 25 | [tool.poetry.dependencies] 26 | python = "^3.7" 27 | 28 | python-dateutil = "^2.8.1" 29 | sly = "^0.4" 30 | 31 | django = { version = ">=2.2", optional = true } 32 | sqlalchemy = { version = "^1.4", optional = true } 33 | 34 | black = { version = "^22.1", optional = true } 35 | bump2version = { version = "^1.0", optional = true } 36 | flake8 = { version = "^3.8", optional = true } 37 | isort = { version = "^5.7", optional = true } 38 | mypy = { version = "^0.931", optional = true } 39 | types-python-dateutil = { version = "^2.8.1", optional = true } 40 | pytest = { version = "^6.2 || ^7.0", optional = true } 41 | pytest-cov = { version = "*", optional = true } 42 | sphinx = { version = "^5.3", optional = true } 43 | sphinx-rtd-theme = { version = "^2.0", optional = true } 44 | vulture = { version = "^2.3", optional = true } 45 | 46 | [tool.poetry.extras] 47 | dev = ["bump2version"] 48 | django = ["django"] 49 | docs = ["sphinx", "sphinx-rtd-theme"] 50 | linting = ["flake8", "black", "isort", "mypy", "types-python-dateutil", "yamllint", "vulture"] 51 | sqlalchemy = ["sqlalchemy"] 52 | testing = ["pytest", "pytest-cov", "pytest-xdist"] 53 | 54 | [build-system] 55 | requires = ["poetry-core>=1.0.0"] 56 | build-backend = "poetry.core.masonry.api" 57 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.10.0 3 | commit = True 4 | tag = True 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(?:(?Pa|b|rc)(?P\d+))? 6 | serialize = 7 | {major}.{minor}.{patch}{release}{build} 8 | {major}.{minor}.{patch} 9 | 10 | [bumpversion:part:release] 11 | values = 12 | a 13 | b 14 | rc 15 | 16 | [bumpversion:part:build] 17 | first_value = 0 18 | 19 | [flake8] 20 | ignore = E203, E266, E501, W503 21 | max-line-length = 80 22 | max-complexity = 18 23 | select = B,C,E,F,W,T4,B9 24 | show_source = True 25 | per_file_ignores = 26 | __init__.py: F401,F403 27 | odata_query/grammar.py:F821,F811 28 | 29 | [tool:isort] 30 | multi_line_output = 3 31 | include_trailing_comma = True 32 | force_grid_wrap = 0 33 | combine_as_imports = True 34 | line_length = 88 35 | default_section = THIRDPARTY 36 | known_first_party = odata_query, tests 37 | 38 | [tool:pytest] 39 | addopts = 40 | --cov . 41 | --cov-branch 42 | --cov-config setup.cfg 43 | --cov-report term-missing 44 | --cov-report xml:coverage.xml 45 | testpaths = tests 46 | markers = 47 | slow: marks tests as slow (deselect with '-m "not slow"') 48 | 49 | [coverage:run] 50 | omit = ./tests/*, ./.tox/* 51 | 52 | [bumpversion:file:pyproject.toml] 53 | search = version = "{current_version}" 54 | replace = version = "{new_version}" 55 | 56 | [bumpversion:file:sonar-project.properties] 57 | search = sonar.projectVersion={current_version} 58 | replace = sonar.projectVersion={new_version} 59 | 60 | [bumpversion:file:odata_query/__init__.py] 61 | search = __version__ = "{current_version}" 62 | replace = __version__ = "{new_version}" 63 | 64 | [bumpversion:file:docs/source/conf.py] 65 | search = release = "{current_version}" 66 | replace = release = "{new_version}" 67 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.projectKey=gorillaco_odata-query 2 | sonar.projectName=OData Query 3 | sonar.projectVersion=0.10.0 4 | 5 | sonar.sources=./odata_query 6 | sonar.tests=./tests 7 | 8 | sonar.python.coverage.reportPaths=coverage.xml 9 | 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from odata_query.grammar import ODataLexer, ODataParser 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def lexer(): 10 | return ODataLexer() 11 | 12 | 13 | @pytest.fixture 14 | def parser(): 15 | return ODataParser() 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def data_dir(): 20 | data_dir_path = Path(__file__).parent / "data" 21 | data_dir_path.mkdir(exist_ok=True) 22 | 23 | return data_dir_path 24 | -------------------------------------------------------------------------------- /tests/data/world_borders.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/data/world_borders.zip -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/django/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/django/__init__.py -------------------------------------------------------------------------------- /tests/integration/django/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ODataQueryConfig(AppConfig): 5 | name = "tests.integration.django" 6 | verbose_name = "OData Query Django test app" 7 | default = True 8 | 9 | 10 | class DbRouter: 11 | """ 12 | Ensure that GeoDjango models go to the SpatiaLite database, while other 13 | models use the default SQLite database. 14 | """ 15 | 16 | GEO_APP = "django_geo" 17 | 18 | def db_for_read(self, model, **hints): 19 | if model._meta.app_label == self.GEO_APP: 20 | return "geo" 21 | return None 22 | 23 | def db_for_write(self, model, **hints): 24 | if model._meta.app_label == self.GEO_APP: 25 | return "geo" 26 | return None 27 | 28 | def allow_relation(self, obj1, obj2, **hints): 29 | return obj1._meta.app_label == obj2._meta.app_label 30 | 31 | def allow_migrate(self, db: str, app_label: str, model_name=None, **hints): 32 | if app_label != self.GEO_APP and db == "default": 33 | return True 34 | if app_label == self.GEO_APP and db == "geo": 35 | return True 36 | return False 37 | -------------------------------------------------------------------------------- /tests/integration/django/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from django.core import management 3 | 4 | 5 | @pytest.fixture(scope="session") 6 | def django_db(): 7 | management.call_command("migrate", "--run-syncdb") 8 | -------------------------------------------------------------------------------- /tests/integration/django/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import django 4 | from django.db import models 5 | 6 | django.setup() 7 | 8 | 9 | class Author(models.Model): 10 | id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) 11 | name = models.CharField(max_length=128) 12 | 13 | 14 | class BlogPost(models.Model): 15 | published_at = models.DateTimeField() 16 | title = models.CharField(max_length=128) 17 | content = models.TextField() 18 | 19 | authors = models.ManyToManyField(Author, related_name="blogposts") 20 | 21 | 22 | class Comment(models.Model): 23 | content = models.TextField() 24 | 25 | author = models.ForeignKey( 26 | Author, on_delete=models.CASCADE, related_name="comments" 27 | ) 28 | blogpost = models.ForeignKey( 29 | BlogPost, on_delete=models.CASCADE, related_name="comments" 30 | ) 31 | -------------------------------------------------------------------------------- /tests/integration/django/settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | DB_DIR = Path(__file__).parent / "db" 4 | DB_DIR.mkdir(exist_ok=True) 5 | 6 | DATABASES = { 7 | "default": { 8 | "ENGINE": "django.db.backends.sqlite3", 9 | "NAME": str(DB_DIR / "odata-query"), 10 | }, 11 | "geo": { 12 | "ENGINE": "django.contrib.gis.db.backends.spatialite", 13 | "NAME": str(DB_DIR / "odata-query-geo"), 14 | }, 15 | } 16 | DATABASE_ROUTERS = ["tests.integration.django.apps.DbRouter"] 17 | DEBUG = True 18 | INSTALLED_APPS = [ 19 | "tests.integration.django.apps.ODataQueryConfig", 20 | # GEO: 21 | "django.contrib.gis", 22 | "tests.integration.django_geo.apps.ODataQueryConfig", 23 | ] 24 | -------------------------------------------------------------------------------- /tests/integration/django/test_querying.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from typing import Type 3 | 4 | import pytest 5 | from django.db import models, transaction 6 | 7 | from odata_query.django import apply_odata_query 8 | from odata_query.django.django_q import DJANGO_LT_4 9 | 10 | from .models import Author, BlogPost, Comment 11 | 12 | 13 | @pytest.fixture 14 | def sample_data_sess(django_db): 15 | # https://docs.djangoproject.com/en/3.1/topics/db/transactions/#managing-autocommit 16 | transaction.set_autocommit(False) 17 | a1 = Author(name="Gorilla") 18 | a1.save() 19 | a2 = Author(name="Baboon") 20 | a2.save() 21 | a3 = Author(name="Saki") 22 | a3.save() 23 | 24 | bp1 = BlogPost( 25 | title="Querying Data", 26 | published_at=dt.datetime(2020, 1, 1), 27 | content="How 2 query data...", 28 | ) 29 | bp1.save() 30 | bp2 = BlogPost( 31 | title="Automating Monkey Jobs", 32 | published_at=dt.datetime(2019, 1, 1), 33 | content="How 2 automate monkey jobs...", 34 | ) 35 | bp2.save() 36 | bp1.authors.set([a1]) 37 | bp2.authors.set([a2, a3]) 38 | 39 | c1 = Comment(content="Dope!", author=a2, blogpost=bp1) 40 | c1.save() 41 | c2 = Comment(content="Cool!", author=a1, blogpost=bp2) 42 | c2.save() 43 | yield 44 | transaction.rollback() 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "model, query, exp_results", 49 | [ 50 | (Author, "name eq 'Baboon'", 1), 51 | (Author, "startswith(name, 'Gori')", 1), 52 | (BlogPost, "contains(content, 'How')", 2), 53 | (BlogPost, "published_at gt 2019-06-01", 1), 54 | (Author, "contains(blogposts/title, 'Monkey')", 2), 55 | (Author, "startswith(blogposts/comments/content, 'Cool')", 2), 56 | (Author, "comments/any()", 2), 57 | (BlogPost, "authors/any(a: contains(a/name, 'o'))", 2), 58 | (BlogPost, "authors/all(a: contains(a/name, 'o'))", 1), 59 | (Author, "blogposts/comments/any(c: contains(c/content, 'Cool'))", 2), 60 | (Author, "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 0), 61 | pytest.param( 62 | Author, 63 | "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)", 64 | 0, 65 | marks=pytest.mark.xfail( 66 | not DJANGO_LT_4, 67 | reason="Bug related to https://code.djangoproject.com/ticket/33705", 68 | ), 69 | ), 70 | (BlogPost, "comments/author eq 0", 0), 71 | (BlogPost, "substring(content, 0) eq 'test'", 0), 72 | (BlogPost, "year(published_at) eq 2019", 1), 73 | # GITHUB-19 74 | (BlogPost, "contains(title, 'Query') eq true", 1), 75 | (BlogPost, "contains(title, 'Query') eq false", 1), 76 | ], 77 | ) 78 | def test_query_with_odata( 79 | model: Type[models.Model], 80 | query: str, 81 | exp_results: int, 82 | sample_data_sess, 83 | ): 84 | q = apply_odata_query(model.objects, query) 85 | results = q.all() 86 | assert len(results) == exp_results 87 | 88 | 89 | @pytest.mark.xfail( 90 | not DJANGO_LT_4, 91 | reason="Bug fixed in unreleased version: https://code.djangoproject.com/ticket/33705", 92 | ) 93 | @pytest.mark.parametrize( 94 | "odata_query, expected_sql", 95 | [ 96 | ( 97 | "author eq null", 98 | ( 99 | 'SELECT DISTINCT "django_comment"."id", "django_comment"."content", "django_comment"."author_id", "django_comment"."blogpost_id" ' 100 | 'FROM "django_comment" ' 101 | 'WHERE "django_comment"."author_id" IS NULL' 102 | ), 103 | ), 104 | ( 105 | "author ne null", 106 | ( 107 | 'SELECT DISTINCT "django_comment"."id", "django_comment"."content", "django_comment"."author_id", "django_comment"."blogpost_id" ' 108 | 'FROM "django_comment" ' 109 | 'WHERE "django_comment"."author_id" IS NOT NULL' 110 | ), 111 | ), 112 | ], 113 | ) 114 | def test_odata_filter_to_sql_query(odata_query: str, expected_sql: str): 115 | q = apply_odata_query(Comment.objects, odata_query) 116 | queryset = q.distinct() 117 | sql = str(queryset.query) 118 | assert sql == expected_sql 119 | -------------------------------------------------------------------------------- /tests/integration/django/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from odata_query.django import utils 4 | 5 | from .models import Author, BlogPost 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "root_model, rel, exp_model, exp_rel", 10 | [ 11 | (Author, "blogposts", BlogPost, "authors"), 12 | (BlogPost, "comments__author", Author, "comments__blogpost"), 13 | ], 14 | ) 15 | def test_reverse_relationship(root_model, rel, exp_model, exp_rel): 16 | res_rel, res_model = utils.reverse_relationship(rel, root_model) 17 | 18 | assert res_rel == exp_rel 19 | assert res_model is exp_model 20 | -------------------------------------------------------------------------------- /tests/integration/django_geo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/django_geo/__init__.py -------------------------------------------------------------------------------- /tests/integration/django_geo/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ODataQueryConfig(AppConfig): 5 | name = "tests.integration.django_geo" 6 | verbose_name = "OData Query GeoDjango test app" 7 | default = True 8 | -------------------------------------------------------------------------------- /tests/integration/django_geo/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from zipfile import ZipFile 3 | 4 | import pytest 5 | from django.core import management 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def django_db(): 10 | management.call_command("migrate", "--run-syncdb", "--database", "geo") 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def world_borders_dataset(data_dir: Path): 15 | target_dir = data_dir / "world_borders" 16 | 17 | if target_dir.exists(): 18 | return target_dir 19 | 20 | filename_zip = target_dir.with_suffix(".zip") 21 | with ZipFile(filename_zip, "r") as z: 22 | z.extractall(target_dir) 23 | 24 | return target_dir 25 | -------------------------------------------------------------------------------- /tests/integration/django_geo/models.py: -------------------------------------------------------------------------------- 1 | # https://docs.djangoproject.com/en/4.2/ref/contrib/gis/tutorial/#geographic-models 2 | 3 | from django.core.exceptions import ImproperlyConfigured 4 | from django.db import models 5 | 6 | # This file needs to be importable even without Geo system libraries installed. 7 | # Tests using these libraries will be skipped using pytest.skip 8 | try: 9 | from django.contrib.gis.db.models import MultiPolygonField 10 | except (ImportError, ImproperlyConfigured): 11 | MultiPolygonField = models.CharField 12 | 13 | 14 | class WorldBorder(models.Model): 15 | # Regular Django fields corresponding to the attributes in the 16 | # world borders shapefile. 17 | name = models.CharField(max_length=50) 18 | area = models.IntegerField() 19 | pop2005 = models.IntegerField("Population 2005") 20 | fips = models.CharField("FIPS Code", max_length=2, null=True) 21 | iso2 = models.CharField("2 Digit ISO", max_length=2) 22 | iso3 = models.CharField("3 Digit ISO", max_length=3) 23 | un = models.IntegerField("United Nations Code") 24 | region = models.IntegerField("Region Code") 25 | subregion = models.IntegerField("Sub-Region Code") 26 | lon = models.FloatField() 27 | lat = models.FloatField() 28 | 29 | # GeoDjango-specific: a geometry field (MultiPolygonField) 30 | mpoly = MultiPolygonField() 31 | 32 | def __str__(self): 33 | return self.name 34 | -------------------------------------------------------------------------------- /tests/integration/django_geo/test_querying.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Type 3 | 4 | import pytest 5 | from django.core.exceptions import ImproperlyConfigured 6 | 7 | from odata_query.django import apply_odata_query 8 | 9 | from .models import WorldBorder 10 | 11 | try: 12 | from django.contrib.gis.db import models 13 | from django.contrib.gis.utils import LayerMapping 14 | except (ImportError, ImproperlyConfigured): 15 | pytest.skip(allow_module_level=True, reason="Could not load GIS libraries") 16 | 17 | # The default spatial reference system for geometry fields is WGS84 18 | # (meaning the SRID is 4326) 19 | SRID = "SRID=4326" 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def sample_data_sess(django_db, world_borders_dataset: Path): 24 | world_mapping = { 25 | "fips": "FIPS", 26 | "iso2": "ISO2", 27 | "iso3": "ISO3", 28 | "un": "UN", 29 | "name": "NAME", 30 | "area": "AREA", 31 | "pop2005": "POP2005", 32 | "region": "REGION", 33 | "subregion": "SUBREGION", 34 | "lon": "LON", 35 | "lat": "LAT", 36 | "mpoly": "MULTIPOLYGON", 37 | } 38 | 39 | world_shp = world_borders_dataset / "TM_WORLD_BORDERS-0.3.shp" 40 | lm = LayerMapping(WorldBorder, world_shp, world_mapping, transform=False) 41 | lm.save(strict=True, verbose=True) 42 | yield 43 | WorldBorder.objects.all().delete() 44 | 45 | 46 | @pytest.mark.parametrize( 47 | "model, query, exp_results", 48 | [ 49 | ( 50 | WorldBorder, 51 | "geo.length(mpoly) gt 1000000", 52 | 154, 53 | ), 54 | ( 55 | WorldBorder, 56 | f"geo.intersects(mpoly, geography'{SRID};Point(-95.3385 29.7245)')", 57 | 1, 58 | ), 59 | ], 60 | ) 61 | def test_query_with_odata( 62 | model: Type[models.Model], 63 | query: str, 64 | exp_results: int, 65 | sample_data_sess, 66 | ): 67 | q = apply_odata_query(model.objects, query) 68 | results = q.all() 69 | assert len(results) == exp_results 70 | -------------------------------------------------------------------------------- /tests/integration/sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/sql/__init__.py -------------------------------------------------------------------------------- /tests/integration/sql/conftest.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def db_conn(): 8 | conn = sqlite3.connect(":memory:", isolation_level=None) 9 | yield conn 10 | conn.close() 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def db_schema(db_conn): 15 | cur = db_conn.cursor() 16 | 17 | cur.execute("CREATE TABLE author (id INTEGER PRIMARY KEY, name TEXT);") 18 | cur.execute( 19 | "CREATE TABLE blogpost (" 20 | "id INTEGER PRIMARY KEY, " 21 | "published_at TEXT, " 22 | "title TEXT, " 23 | "content TEXT);" 24 | ) 25 | cur.execute( 26 | "CREATE TABLE comment (" 27 | "id INTEGER PRIMARY KEY, " 28 | "content TEXT, " 29 | "author_id INTEGER, " 30 | "blogpost_id INTEGER, " 31 | "FOREIGN KEY (author_id) REFERENCES author(id), " 32 | "FOREIGN KEY (blogpost_id) REFERENCES blogpost(id)" 33 | ");" 34 | ) 35 | cur.execute( 36 | "CREATE TABLE author_blogpost (" 37 | "author_id INTEGER, " 38 | "blogpost_id INTEGER, " 39 | "FOREIGN KEY (author_id) REFERENCES author(id), " 40 | "FOREIGN KEY (blogpost_id) REFERENCES blogpost(id)" 41 | ");" 42 | ) 43 | db_conn.commit() 44 | -------------------------------------------------------------------------------- /tests/integration/sql/test_odata_to_athena_sql.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from odata_query import sql 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "odata_query, expected", 8 | [ 9 | ("meter_id eq '1'", "\"meter_id\" = '1'"), 10 | ("meter_id ne '1'", "\"meter_id\" != '1'"), 11 | ("meter_id eq 'o''reilly'''", "\"meter_id\" = 'o''reilly'''"), 12 | ( 13 | "meter_id eq 6c0e37e3-e856-45ee-bd58-484b11882c67", 14 | "\"meter_id\" = '6c0e37e3-e856-45ee-bd58-484b11882c67'", 15 | ), 16 | ("meter_id in ('1',)", "\"meter_id\" IN ('1')"), 17 | ("meter_id in ('1', '2')", "\"meter_id\" IN ('1', '2')"), 18 | ("not (meter_id in ('1', '2'))", "NOT \"meter_id\" IN ('1', '2')"), 19 | ("meter_id eq null", '"meter_id" IS NULL'), 20 | ("meter_id ne null", '"meter_id" IS NOT NULL'), 21 | ("eac gt 10", '"eac" > 10'), 22 | ("eac ge 10", '"eac" >= 10'), 23 | ("eac lt 10", '"eac" < 10'), 24 | ("eac le 10", '"eac" <= 10'), 25 | ("eac gt 1.0 and eac lt 10.0", '"eac" > 1.0 AND "eac" < 10.0'), 26 | ("eac ge 1.0 and eac le 10.0", '"eac" >= 1.0 AND "eac" <= 10.0'), 27 | ( 28 | "eac gt 1 and eac lt 1 and meter_id eq '1'", 29 | '"eac" > 1 AND "eac" < 1 AND "meter_id" = \'1\'', 30 | ), 31 | # OData spec defines AND with higher precedence than OR: 32 | ( 33 | "eac gt 1 and eac lt 10 or eac eq 5 and eac ne 10", 34 | '("eac" > 1 AND "eac" < 10) OR ("eac" = 5 AND "eac" != 10)', 35 | ), 36 | # Unless overridden by parentheses: 37 | ( 38 | "eac gt 1 and (eac lt 10 or eac eq 5) and eac ne 10", 39 | '"eac" > 1 AND ("eac" < 10 OR "eac" = 5) AND "eac" != 10', 40 | ), 41 | ("not (eac gt 10 and eac lt 20)", 'NOT ("eac" > 10 AND "eac" < 20)'), 42 | ("eac gt 1 eq true", '("eac" > 1) = TRUE'), 43 | ("true eq eac gt 1", 'TRUE = ("eac" > 1)'), 44 | ("eac add 10 gt 1000", '"eac" + 10 > 1000'), 45 | ("eac add 10 gt eac sub 10", '"eac" + 10 > "eac" - 10'), 46 | ("eac mul 10 div 10 eq eac", '"eac" * 10 / 10 = "eac"'), 47 | ("eac mod 10 add -1 le eac", '"eac" % 10 + -1 <= "eac"'), 48 | ( 49 | "period_start gt 2020-01-01T00:00:00", 50 | "\"period_start\" > FROM_ISO8601_TIMESTAMP('2020-01-01T00:00:00')", 51 | ), 52 | ( 53 | "period_start add duration'P365D' ge period_end", 54 | '"period_start" + INTERVAL \'365\' DAY >= "period_end"', 55 | ), 56 | ( 57 | "period_start add duration'P1Y' ge period_end", 58 | '"period_start" + INTERVAL \'1\' YEAR >= "period_end"', 59 | ), 60 | ( 61 | "period_start add duration'P2M' ge period_end", 62 | '"period_start" + INTERVAL \'2\' MONTH >= "period_end"', 63 | ), 64 | ( 65 | "period_start add duration'P365DT12H1M1.1S' ge period_end", 66 | "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"", 67 | ), 68 | ( 69 | "period_start add duration'PT1S' ge period_end", 70 | '"period_start" + INTERVAL \'1\' SECOND >= "period_end"', 71 | ), 72 | ("year(period_start) eq 2019", 'EXTRACT (YEAR FROM "period_start") = 2019'), 73 | ( 74 | "period_end lt now() sub duration'P365D'", 75 | "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '365' DAY", 76 | ), 77 | ( 78 | "startswith(trim(meter_id), '999')", 79 | "TRIM(\"meter_id\") LIKE '999%'", 80 | ), 81 | ( 82 | "year(date(now())) eq 2020", 83 | "EXTRACT (YEAR FROM CAST (CURRENT_TIMESTAMP AS DATE)) = 2020", 84 | ), 85 | ("length(concat('abc', 'def')) lt 10", "LENGTH('abc' || 'def') < 10"), 86 | ( 87 | "length(concat(('1', '2'), ('3', '4'))) eq 4", 88 | "CARDINALITY(('1', '2') || ('3', '4')) = 4", 89 | ), 90 | ( 91 | "indexof(substring('abcdefghi', 3), 'hi') gt 1", 92 | "POSITION('hi' IN SUBSTR('abcdefghi', 3 + 1)) - 1 > 1", 93 | ), 94 | ( 95 | "substring('hello', 1, 3) eq 'ell'", 96 | "SUBSTR('hello', 1 + 1, 3) = 'ell'", 97 | ), 98 | ("substring((1, 2, 3), 1)", "SLICE((1, 2, 3), 1)"), 99 | ("substring((1, 2, 3), 1, 1)", "SLICE((1, 2, 3), 1, 1)"), 100 | ( 101 | "contains(meter_id, sub_meter_id)", 102 | "\"meter_id\" LIKE '%' || \"sub_meter_id\" || '%'", 103 | ), 104 | ( 105 | "year(supply_start_date) eq (year(now()) sub 1)", 106 | 'EXTRACT (YEAR FROM "supply_start_date") = EXTRACT (YEAR FROM CURRENT_TIMESTAMP) - 1', 107 | ), 108 | ( 109 | "measurement_class eq 'C' and endswith(data_collector, 'rie')", 110 | "\"measurement_class\" = 'C' AND \"data_collector\" LIKE '%rie'", 111 | ), 112 | # GITHUB-47 113 | ( 114 | "contains(tolower(name), tolower('A'))", 115 | "LOWER(\"name\") LIKE '%' || LOWER('A') || '%'", 116 | ), 117 | ], 118 | ) 119 | def test_odata_filter_to_sql(odata_query: str, expected: str, lexer, parser): 120 | ast = parser.parse(lexer.tokenize(odata_query)) 121 | visitor = sql.AstToAthenaSqlVisitor() 122 | 123 | if isinstance(expected, str): 124 | res = visitor.visit(ast) 125 | assert res == expected 126 | else: 127 | with pytest.raises(expected): 128 | res = visitor.visit(ast) 129 | -------------------------------------------------------------------------------- /tests/integration/sql/test_odata_to_sql.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import pytest 4 | 5 | from odata_query import exceptions as ex, sql 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "odata_query, expected", 10 | [ 11 | ("meter_id eq '1'", "\"meter_id\" = '1'"), 12 | ("meter_id ne '1'", "\"meter_id\" != '1'"), 13 | ("meter_id eq 'o''reilly'''", "\"meter_id\" = 'o''reilly'''"), 14 | ( 15 | "meter_id eq 6c0e37e3-e856-45ee-bd58-484b11882c67", 16 | "\"meter_id\" = '6c0e37e3-e856-45ee-bd58-484b11882c67'", 17 | ), 18 | ("meter_id in ('1',)", "\"meter_id\" IN ('1')"), 19 | ("meter_id in ('1', '2')", "\"meter_id\" IN ('1', '2')"), 20 | ("not (meter_id in ('1', '2'))", "NOT \"meter_id\" IN ('1', '2')"), 21 | ("meter_id eq null", '"meter_id" IS NULL'), 22 | ("meter_id ne null", '"meter_id" IS NOT NULL'), 23 | ("eac gt 10", '"eac" > 10'), 24 | ("eac ge 10", '"eac" >= 10'), 25 | ("eac lt 10", '"eac" < 10'), 26 | ("eac le 10", '"eac" <= 10'), 27 | ("eac gt 1.0 and eac lt 10.0", '"eac" > 1.0 AND "eac" < 10.0'), 28 | ("eac ge 1.0 and eac le 10.0", '"eac" >= 1.0 AND "eac" <= 10.0'), 29 | ( 30 | "eac gt 1 and eac lt 1 and meter_id eq '1'", 31 | '"eac" > 1 AND "eac" < 1 AND "meter_id" = \'1\'', 32 | ), 33 | # OData spec defines AND with higher precedence than OR: 34 | ( 35 | "eac gt 1 and eac lt 10 or eac eq 5 and eac ne 10", 36 | '("eac" > 1 AND "eac" < 10) OR ("eac" = 5 AND "eac" != 10)', 37 | ), 38 | # Unless overridden by parentheses: 39 | ( 40 | "eac gt 1 and (eac lt 10 or eac eq 5) and eac ne 10", 41 | '"eac" > 1 AND ("eac" < 10 OR "eac" = 5) AND "eac" != 10', 42 | ), 43 | ("not (eac gt 10 and eac lt 20)", 'NOT ("eac" > 10 AND "eac" < 20)'), 44 | ("eac gt 1 eq true", '("eac" > 1) = TRUE'), 45 | ("true eq eac gt 1", 'TRUE = ("eac" > 1)'), 46 | ("eac add 10 gt 1000", '"eac" + 10 > 1000'), 47 | ("eac add 10 gt eac sub 10", '"eac" + 10 > "eac" - 10'), 48 | ("eac mul 10 div 10 eq eac", '"eac" * 10 / 10 = "eac"'), 49 | ("eac mod 10 add -1 le eac", '"eac" % 10 + -1 <= "eac"'), 50 | ( 51 | "period_start gt 2020-01-01T00:00:00", 52 | "\"period_start\" > TIMESTAMP '2020-01-01 00:00:00'", 53 | ), 54 | ( 55 | "period_start add duration'P365D' ge period_end", 56 | '"period_start" + INTERVAL \'365\' DAY >= "period_end"', 57 | ), 58 | ( 59 | "period_start add duration'P365DT12H1M1.1S' ge period_end", 60 | "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"", 61 | ), 62 | ( 63 | "period_start add duration'PT1S' ge period_end", 64 | '"period_start" + INTERVAL \'1\' SECOND >= "period_end"', 65 | ), 66 | ( 67 | "period_start add duration'P1Y' ge period_end", 68 | '"period_start" + INTERVAL \'1\' YEAR >= "period_end"', 69 | ), 70 | ( 71 | "period_start add duration'P2M' ge period_end", 72 | '"period_start" + INTERVAL \'2\' MONTH >= "period_end"', 73 | ), 74 | ("year(period_start) eq 2019", 'EXTRACT (YEAR FROM "period_start") = 2019'), 75 | ( 76 | "period_end lt now() sub duration'P365D'", 77 | "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '365' DAY", 78 | ), 79 | ( 80 | "period_end lt now() sub duration'P1Y'", 81 | "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '1' YEAR", 82 | ), 83 | ( 84 | "period_end lt now() sub duration'P2M'", 85 | "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '2' MONTH", 86 | ), 87 | ( 88 | "startswith(trim(meter_id), '999')", 89 | "TRIM(\"meter_id\") LIKE '999%'", 90 | ), 91 | ( 92 | "year(date(now())) eq 2020", 93 | "EXTRACT (YEAR FROM CAST (CURRENT_TIMESTAMP AS DATE)) = 2020", 94 | ), 95 | ("length(concat('abc', 'def')) lt 10", "CHAR_LENGTH('abc' || 'def') < 10"), 96 | ( 97 | "length(concat(('1', '2'), ('3', '4'))) eq 4", 98 | "CARDINALITY(('1', '2') || ('3', '4')) = 4", 99 | ), 100 | ( 101 | "indexof(substring('abcdefghi', 3), 'hi') gt 1", 102 | "POSITION('hi' IN SUBSTRING('abcdefghi' FROM 3 + 1)) - 1 > 1", 103 | ), 104 | ( 105 | "substring('hello', 1, 3) eq 'ell'", 106 | "SUBSTRING('hello' FROM 1 + 1 FOR 3) = 'ell'", 107 | ), 108 | ("substring((1, 2, 3), 1)", ex.UnsupportedFunctionException), 109 | ("substring((1, 2, 3), 1, 1)", ex.UnsupportedFunctionException), 110 | ( 111 | "contains(meter_id, sub_meter_id)", 112 | "\"meter_id\" LIKE '%' || \"sub_meter_id\" || '%'", 113 | ), 114 | ( 115 | "year(supply_start_date) eq (year(now()) sub 1)", 116 | 'EXTRACT (YEAR FROM "supply_start_date") = EXTRACT (YEAR FROM CURRENT_TIMESTAMP) - 1', 117 | ), 118 | ( 119 | "measurement_class eq 'C' and endswith(data_collector, 'rie')", 120 | "\"measurement_class\" = 'C' AND \"data_collector\" LIKE '%rie'", 121 | ), 122 | # GITHUB-47 123 | ( 124 | "contains(tolower(name), tolower('A'))", 125 | "LOWER(\"name\") LIKE '%' || LOWER('A') || '%'", 126 | ), 127 | ], 128 | ) 129 | def test_odata_filter_to_sql( 130 | odata_query: str, expected: Union[str, type], lexer, parser 131 | ): 132 | ast = parser.parse(lexer.tokenize(odata_query)) 133 | visitor = sql.AstToSqlVisitor() 134 | 135 | if isinstance(expected, str): 136 | res = visitor.visit(ast) 137 | assert res == expected 138 | else: 139 | with pytest.raises(expected): 140 | res = visitor.visit(ast) 141 | -------------------------------------------------------------------------------- /tests/integration/sql/test_odata_to_sqlite.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from odata_query import exceptions as ex, sql 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "odata_query, expected", 8 | [ 9 | ("meter_id eq '1'", "\"meter_id\" = '1'"), 10 | ("meter_id ne '1'", "\"meter_id\" != '1'"), 11 | ("meter_id eq 'o''reilly'''", "\"meter_id\" = 'o''reilly'''"), 12 | ( 13 | "meter_id eq 6c0e37e3-e856-45ee-bd58-484b11882c67", 14 | "\"meter_id\" = '6c0e37e3-e856-45ee-bd58-484b11882c67'", 15 | ), 16 | ("meter_id in ('1',)", "\"meter_id\" IN ('1')"), 17 | ("meter_id in ('1', '2')", "\"meter_id\" IN ('1', '2')"), 18 | ("not (meter_id in ('1', '2'))", "NOT \"meter_id\" IN ('1', '2')"), 19 | ("meter_id eq null", '"meter_id" IS NULL'), 20 | ("meter_id ne null", '"meter_id" IS NOT NULL'), 21 | ("eac gt 10", '"eac" > 10'), 22 | ("eac ge 10", '"eac" >= 10'), 23 | ("eac lt 10", '"eac" < 10'), 24 | ("eac le 10", '"eac" <= 10'), 25 | ("eac gt 1.0 and eac lt 10.0", '"eac" > 1.0 AND "eac" < 10.0'), 26 | ("eac ge 1.0 and eac le 10.0", '"eac" >= 1.0 AND "eac" <= 10.0'), 27 | ( 28 | "eac gt 1 and eac lt 1 and meter_id eq '1'", 29 | '"eac" > 1 AND "eac" < 1 AND "meter_id" = \'1\'', 30 | ), 31 | # OData spec defines AND with higher precedence than OR: 32 | ( 33 | "eac gt 1 and eac lt 10 or eac eq 5 and eac ne 10", 34 | '("eac" > 1 AND "eac" < 10) OR ("eac" = 5 AND "eac" != 10)', 35 | ), 36 | # Unless overridden by parentheses: 37 | ( 38 | "eac gt 1 and (eac lt 10 or eac eq 5) and eac ne 10", 39 | '"eac" > 1 AND ("eac" < 10 OR "eac" = 5) AND "eac" != 10', 40 | ), 41 | ("not (eac gt 10 and eac lt 20)", 'NOT ("eac" > 10 AND "eac" < 20)'), 42 | ("eac gt 1 eq true", '("eac" > 1) = 1'), 43 | ("true eq eac gt 1", '1 = ("eac" > 1)'), 44 | ("eac add 10 gt 1000", '"eac" + 10 > 1000'), 45 | ("eac add 10 gt eac sub 10", '"eac" + 10 > "eac" - 10'), 46 | ("eac mul 10 div 10 eq eac", '"eac" * 10 / 10 = "eac"'), 47 | ("eac mod 10 add -1 le eac", '"eac" % 10 + -1 <= "eac"'), 48 | ( 49 | "period_start gt 2020-01-01T00:00:00", 50 | "\"period_start\" > DATETIME('2020-01-01T00:00:00')", 51 | ), 52 | ( 53 | "period_start add duration'P365D' ge period_end", 54 | '"period_start" + INTERVAL \'365\' DAY >= "period_end"', 55 | ), 56 | ( 57 | "period_start add duration'P365DT12H1M1.1S' ge period_end", 58 | "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"", 59 | ), 60 | ( 61 | "period_start add duration'PT1S' ge period_end", 62 | '"period_start" + INTERVAL \'1\' SECOND >= "period_end"', 63 | ), 64 | ( 65 | "period_start add duration'P1Y' ge period_end", 66 | '"period_start" + INTERVAL \'1\' YEAR >= "period_end"', 67 | ), 68 | ( 69 | "period_start add duration'P2M' ge period_end", 70 | '"period_start" + INTERVAL \'2\' MONTH >= "period_end"', 71 | ), 72 | ( 73 | "year(period_start) eq 2019", 74 | "CAST(STRFTIME('%Y', \"period_start\") AS INTEGER) = 2019", 75 | ), 76 | ( 77 | "period_end lt now() sub duration'P365D'", 78 | "\"period_end\" < DATETIME('now') - INTERVAL '365' DAY", 79 | ), 80 | ( 81 | "startswith(trim(meter_id), '999')", 82 | "TRIM(\"meter_id\") LIKE '999%'", 83 | ), 84 | ( 85 | "year(date(now())) eq 2020", 86 | "CAST(STRFTIME('%Y', DATE(DATETIME('now'))) AS INTEGER) = 2020", 87 | ), 88 | ("length(concat('abc', 'def')) lt 10", "LENGTH('abc' || 'def') < 10"), 89 | ( 90 | "length(concat(('1', '2'), ('3', '4'))) eq 4", 91 | "LENGTH(('1', '2') || ('3', '4')) = 4", 92 | ), 93 | ( 94 | "indexof(substring('abcdefghi', 3), 'hi') gt 1", 95 | "INSTR(SUBSTR('abcdefghi', 3 + 1), 'hi') - 1 > 1", 96 | ), 97 | ( 98 | "substring('hello', 1, 3) eq 'ell'", 99 | "SUBSTR('hello', 1 + 1, 3) = 'ell'", 100 | ), 101 | ("substring((1, 2, 3), 1)", ex.UnsupportedFunctionException), 102 | ("substring((1, 2, 3), 1, 1)", ex.UnsupportedFunctionException), 103 | ( 104 | "contains(meter_id, sub_meter_id)", 105 | "\"meter_id\" LIKE '%' || \"sub_meter_id\" || '%'", 106 | ), 107 | ( 108 | "year(supply_start_date) eq (year(now()) sub 1)", 109 | "CAST(STRFTIME('%Y', \"supply_start_date\") AS INTEGER) = CAST(STRFTIME('%Y', DATETIME('now')) AS INTEGER) - 1", 110 | ), 111 | ( 112 | "measurement_class eq 'C' and endswith(data_collector, 'rie')", 113 | "\"measurement_class\" = 'C' AND \"data_collector\" LIKE '%rie'", 114 | ), 115 | # GITHUB-47 116 | ( 117 | "contains(tolower(name), tolower('A'))", 118 | "LOWER(\"name\") LIKE '%' || LOWER('A') || '%'", 119 | ), 120 | ], 121 | ) 122 | def test_odata_filter_to_sql(odata_query: str, expected: str, lexer, parser): 123 | ast = parser.parse(lexer.tokenize(odata_query)) 124 | visitor = sql.AstToSqliteSqlVisitor() 125 | 126 | if isinstance(expected, str): 127 | res = visitor.visit(ast) 128 | assert res == expected 129 | else: 130 | with pytest.raises(expected): 131 | res = visitor.visit(ast) 132 | -------------------------------------------------------------------------------- /tests/integration/sql/test_querying.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from odata_query import sql 4 | 5 | 6 | @pytest.fixture 7 | def sample_data_sess(db_conn, db_schema): 8 | cur = db_conn.cursor() 9 | 10 | cur.execute("BEGIN") 11 | cur.execute("INSERT INTO author (name) VALUES ('Gorilla'), ('Baboon'), ('Saki')") 12 | cur.execute( 13 | "INSERT INTO blogpost (title, published_at, content) VALUES " 14 | "('Querying Data', '2020-01-01T00:00:00', 'How 2 query data...'), " 15 | "('Automating Monkey Jobs', '2019-01-01T00:00:00', 'How 2 automate monkey jobs...')" 16 | ) 17 | cur.execute( 18 | "INSERT INTO author_blogpost (author_id, blogpost_id) VALUES " 19 | "(1, 1), " 20 | "(2, 2), " 21 | "(3, 2)" 22 | ) 23 | cur.execute( 24 | "INSERT INTO comment (content, author_id, blogpost_id) VALUES " 25 | "('Dope!', 2, 1), " 26 | "('Cool!', 1, 2)" 27 | ) 28 | 29 | yield cur 30 | 31 | cur.execute("ROLLBACK") 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "table, query, exp_results", 36 | [ 37 | ("author", "name eq 'Baboon'", 1), 38 | ("author", "startswith(name, 'Gori')", 1), 39 | ("blogpost", "contains(content, 'How')", 2), 40 | ("blogpost", "published_at gt 2019-06-01", 1), 41 | # (Author, "contains(blogposts/title, 'Monkey')", 2), 42 | # (Author, "startswith(blogposts/comments/content, 'Cool')", 2), 43 | # (Author, "comments/any()", 2), 44 | # (BlogPost, "authors/any(a: contains(a/name, 'o'))", 2), 45 | # (BlogPost, "authors/all(a: contains(a/name, 'o'))", 1), 46 | # (Author, "blogposts/comments/any(c: contains(c/content, 'Cool'))", 2), 47 | ("author", "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 0), 48 | ( 49 | "author", 50 | "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)", 51 | 0, 52 | ), 53 | # (BlogPost, "comments/author eq 0", 0), 54 | ("blogpost", "substring(content, 0) eq 'test'", 0), 55 | ("blogpost", "year(published_at) eq 2019", 1), 56 | # GITHUB-19 57 | ("blogpost", "contains(title, 'Query') eq true", 1), 58 | ("blogpost", "contains(tolower(title), tolower('Query'))", 1), 59 | ], 60 | ) 61 | def test_query_with_odata( 62 | table: str, query: str, exp_results: int, lexer, parser, sample_data_sess 63 | ): 64 | ast = parser.parse(lexer.tokenize(query)) 65 | visitor = sql.AstToSqliteSqlVisitor() 66 | where_clause = visitor.visit(ast) 67 | 68 | results = sample_data_sess.execute( 69 | f'SELECT * FROM "{table}" WHERE {where_clause}' 70 | ).fetchall() 71 | assert len(results) == exp_results 72 | -------------------------------------------------------------------------------- /tests/integration/sqlalchemy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/sqlalchemy/__init__.py -------------------------------------------------------------------------------- /tests/integration/sqlalchemy/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import create_engine 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | from .models import Base 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def db_engine(): 10 | engine = create_engine("sqlite://", future=True) 11 | Base.metadata.create_all(engine) 12 | yield engine 13 | engine.dispose() 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def db_session(db_engine): 18 | session = sessionmaker(bind=db_engine) 19 | return session 20 | -------------------------------------------------------------------------------- /tests/integration/sqlalchemy/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text 2 | from sqlalchemy.orm import declarative_base, relationship 3 | 4 | Base = declarative_base() 5 | 6 | 7 | author_blogpost = Table( 8 | "author_blogpost", 9 | Base.metadata, 10 | Column("author_id", Integer, ForeignKey("author.id")), 11 | Column("blogpost_id", Integer, ForeignKey("blogpost.id")), 12 | ) 13 | 14 | 15 | class Author(Base): 16 | __tablename__ = "author" 17 | 18 | id = Column(Integer, primary_key=True) 19 | name = Column(String, nullable=False) 20 | 21 | blogposts = relationship( 22 | "BlogPost", back_populates="authors", secondary=author_blogpost 23 | ) 24 | comments = relationship("Comment", back_populates="author") 25 | 26 | 27 | class BlogPost(Base): 28 | __tablename__ = "blogpost" 29 | 30 | id = Column(Integer, primary_key=True) 31 | published_at = Column(DateTime, nullable=False) 32 | title = Column(String, nullable=False) 33 | content = Column(Text) 34 | 35 | authors = relationship( 36 | "Author", back_populates="blogposts", secondary=author_blogpost 37 | ) 38 | comments = relationship("Comment", back_populates="blogpost") 39 | 40 | 41 | class Comment(Base): 42 | __tablename__ = "comment" 43 | 44 | id = Column(Integer, primary_key=True) 45 | content = Column(Text) 46 | 47 | author_id = Column(Integer, ForeignKey("author.id")) 48 | author = relationship("Author", back_populates="comments") 49 | blogpost_id = Column(Integer, ForeignKey("blogpost.id")) 50 | blogpost = relationship("BlogPost", back_populates="comments") 51 | -------------------------------------------------------------------------------- /tests/integration/sqlalchemy/test_odata_to_sqlalchemy_core.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pytest 4 | from sqlalchemy.sql import functions 5 | from sqlalchemy.sql.expression import cast, extract, literal 6 | from sqlalchemy.types import Date, Time 7 | 8 | from odata_query.sqlalchemy import AstToSqlAlchemyCoreVisitor, functions_ext 9 | 10 | from . import models 11 | 12 | BlogPost = models.Base.metadata.tables["blogpost"] 13 | Author = models.Base.metadata.tables["author"] 14 | Comment = models.Base.metadata.tables["comment"] 15 | 16 | 17 | def tz(offset: int) -> dt.tzinfo: 18 | return dt.timezone(dt.timedelta(hours=offset)) 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "odata_query, expected_q", 23 | [ 24 | ( 25 | "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 26 | BlogPost.c.id == "a7af27e6-f5a0-11e9-9649-0a252986adba", 27 | ), 28 | ("my_app.c.id eq 1", BlogPost.c.id == 1), 29 | ( 30 | "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)", 31 | BlogPost.c.id.in_( 32 | [ 33 | literal("a7af27e6-f5a0-11e9-9649-0a252986adba"), 34 | literal("800c56e4-354d-11eb-be38-3af9d323e83c"), 35 | ] 36 | ), 37 | ), 38 | ("id eq 4", BlogPost.c.id == 4), 39 | ("id ne 4", BlogPost.c.id != 4), 40 | ("4 eq id", 4 == BlogPost.c.id), 41 | ("4 ne id", 4 != BlogPost.c.id), 42 | ( 43 | "published_at gt 2018-01-01", 44 | BlogPost.c.published_at > dt.date(2018, 1, 1), 45 | ), 46 | ( 47 | "published_at ge 2018-01-01", 48 | BlogPost.c.published_at >= dt.date(2018, 1, 1), 49 | ), 50 | ( 51 | "published_at lt 2018-01-01", 52 | BlogPost.c.published_at < dt.date(2018, 1, 1), 53 | ), 54 | ( 55 | "published_at le 2018-01-01", 56 | BlogPost.c.published_at <= dt.date(2018, 1, 1), 57 | ), 58 | ( 59 | "2018-01-01 gt published_at", 60 | literal(dt.date(2018, 1, 1)) > BlogPost.c.published_at, 61 | ), 62 | ( 63 | "2018-01-01 ge published_at", 64 | literal(dt.date(2018, 1, 1)) >= BlogPost.c.published_at, 65 | ), 66 | ( 67 | "published_at gt 2018-01-01T01:02", 68 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2), 69 | ), 70 | ( 71 | "published_at gt 2018-01-01T01:02:03", 72 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3), 73 | ), 74 | ( 75 | "published_at gt 2018-01-01T01:02:03.123", 76 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000), 77 | ), 78 | ( 79 | "published_at gt 2018-01-01T01:02:03.123456", 80 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_456), 81 | ), 82 | ( 83 | "published_at gt 2018-01-01T01:02Z", 84 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(0)), 85 | ), 86 | ( 87 | "published_at gt 2018-01-01T01:02:03Z", 88 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, tzinfo=tz(0)), 89 | ), 90 | ( 91 | "published_at gt 2018-01-01T01:02:03.123Z", 92 | BlogPost.c.published_at 93 | > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000, tzinfo=tz(0)), 94 | ), 95 | ( 96 | "published_at gt 2018-01-01T01:02+02:00", 97 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(+2)), 98 | ), 99 | ( 100 | "published_at gt 2018-01-01T01:02-02:00", 101 | BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(-2)), 102 | ), 103 | ( 104 | "id in (1, 2, 3)", 105 | BlogPost.c.id.in_([literal(1), literal(2), literal(3)]), 106 | ), 107 | ("id eq null", BlogPost.c.id == None), # noqa:E711 108 | ("id ne null", BlogPost.c.id != None), # noqa:E711 109 | ("not (id eq 1)", ~(BlogPost.c.id == 1)), 110 | ("id eq 1 or id eq 2", (BlogPost.c.id == 1) | (BlogPost.c.id == 2)), 111 | ( 112 | "id eq 1 and content eq 'executing'", 113 | (BlogPost.c.id == 1) & (BlogPost.c.content == "executing"), 114 | ), 115 | ( 116 | "id eq 1 and (content eq 'executing' or content eq 'failed')", 117 | (BlogPost.c.id == 1) 118 | & ((BlogPost.c.content == "executing") | (BlogPost.c.content == "failed")), 119 | ), 120 | ("id eq 1 add 1", BlogPost.c.id == literal(1) + literal(1)), 121 | ("id eq 2 sub 1", BlogPost.c.id == literal(2) - literal(1)), 122 | ("id eq 2 mul 2", BlogPost.c.id == literal(2) * literal(2)), 123 | ("id eq 2 div 2", BlogPost.c.id == literal(2) / literal(2)), 124 | ("id eq 5 mod 4", BlogPost.c.id == literal(5) % literal(4)), 125 | ("id eq 2 add -1", BlogPost.c.id == literal(2) + literal(-1)), 126 | ("id eq id sub 1", BlogPost.c.id == BlogPost.c.id - literal(1)), 127 | ( 128 | "title eq 'donut' add 'tello'", 129 | BlogPost.c.title == literal("donut") + literal("tello"), 130 | ), 131 | ( 132 | "title eq content add content", 133 | BlogPost.c.title == BlogPost.c.content + BlogPost.c.content, 134 | ), 135 | ( 136 | "published_at eq 2019-01-01T00:00:00 add duration'P1DT1H1M1S'", 137 | BlogPost.c.published_at 138 | == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) 139 | + dt.timedelta(days=1, hours=1, minutes=1, seconds=1), 140 | ), 141 | ( 142 | "published_at eq 2019-01-01T00:00:00 add duration'P1Y'", 143 | BlogPost.c.published_at 144 | == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) 145 | + dt.timedelta(days=365.25), # 1 times 365.25 (average year in days) 146 | ), 147 | ( 148 | "published_at eq 2019-01-01T00:00:00 add duration'P2M'", 149 | BlogPost.c.published_at 150 | == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) 151 | + dt.timedelta(days=60.88), # 2 times 30.44 (average month in days) 152 | ), 153 | ("contains(title, 'copy')", BlogPost.c.title.contains("copy")), 154 | ("startswith(title, 'copy')", BlogPost.c.title.startswith("copy")), 155 | ("endswith(title, 'bla')", BlogPost.c.title.endswith("bla")), 156 | ( 157 | "id eq length(title)", 158 | BlogPost.c.id == functions.char_length(BlogPost.c.title), 159 | ), 160 | ("length(title) eq 10", functions.char_length(BlogPost.c.title) == 10), 161 | ("10 eq length(title)", 10 == functions.char_length(BlogPost.c.title)), 162 | ( 163 | "length(title) eq length('flippot')", 164 | functions.char_length(BlogPost.c.title) == functions.char_length("flippot"), 165 | ), 166 | ( 167 | "title eq concat('a', 'b')", 168 | BlogPost.c.title == functions.concat("a", "b"), 169 | ), 170 | ( 171 | "title eq concat('test', id)", 172 | BlogPost.c.title == functions.concat("test", BlogPost.c.id), 173 | ), 174 | ( 175 | "title eq concat(concat('a', 'b'), 'c')", 176 | BlogPost.c.title == functions.concat(functions.concat("a", "b"), "c"), 177 | ), 178 | ( 179 | "concat(title, 'a') eq 'testa'", 180 | functions.concat(BlogPost.c.title, "a") == "testa", 181 | ), 182 | ( 183 | "indexof(title, 'Copy') eq 6", 184 | functions_ext.strpos(BlogPost.c.title, "Copy") - 1 == 6, 185 | ), 186 | ( 187 | "substring(title, 0) eq 'Copy'", 188 | functions_ext.substr(BlogPost.c.title, literal(0) + 1) == "Copy", 189 | ), 190 | ( 191 | "substring(title, 0, 4) eq 'Copy'", 192 | functions_ext.substr(BlogPost.c.title, literal(0) + 1, 4) == "Copy", 193 | ), 194 | ( 195 | "matchesPattern(title, 'C.py')", 196 | BlogPost.c.title.regexp_match("C.py"), 197 | ), 198 | ( 199 | "tolower(title) eq 'copy'", 200 | functions_ext.lower(BlogPost.c.title) == "copy", 201 | ), 202 | ( 203 | "toupper(title) eq 'COPY'", 204 | functions_ext.upper(BlogPost.c.title) == "COPY", 205 | ), 206 | ( 207 | "trim(title) eq 'copy'", 208 | functions_ext.ltrim(functions_ext.rtrim(BlogPost.c.title)) == "copy", 209 | ), 210 | ( 211 | "date(published_at) eq 2019-01-01", 212 | cast(BlogPost.c.published_at, Date) == dt.date(2019, 1, 1), 213 | ), 214 | ( 215 | "day(published_at) eq 1", 216 | extract("day", BlogPost.c.published_at) == 1, 217 | ), 218 | ( 219 | "hour(published_at) eq 1", 220 | extract("hour", BlogPost.c.published_at) == 1, 221 | ), 222 | ( 223 | "minute(published_at) eq 1", 224 | extract("minute", BlogPost.c.published_at) == 1, 225 | ), 226 | ( 227 | "month(published_at) eq 1", 228 | extract("month", BlogPost.c.published_at) == 1, 229 | ), 230 | ("published_at eq now()", BlogPost.c.published_at == functions.now()), 231 | ( 232 | "second(published_at) eq 1", 233 | extract("second", BlogPost.c.published_at) == 1, 234 | ), 235 | ( 236 | "time(published_at) eq 14:00:00", 237 | cast(BlogPost.c.published_at, Time) == dt.time(14, 0, 0), 238 | ), 239 | ( 240 | "year(published_at) eq 2019", 241 | extract("year", BlogPost.c.published_at) == 2019, 242 | ), 243 | ("ceiling(id) eq 1", functions_ext.ceil(BlogPost.c.id) == 1), 244 | ("floor(id) eq 1", functions_ext.floor(BlogPost.c.id) == 1), 245 | ("round(id) eq 1", functions_ext.round(BlogPost.c.id) == 1), 246 | ( 247 | "date(published_at) eq 2019-01-01 add duration'P1D'", 248 | cast(BlogPost.c.published_at, Date) 249 | == literal(dt.date(2019, 1, 1)) + dt.timedelta(days=1), 250 | ), 251 | ( 252 | "date(published_at) eq 2019-01-01 add duration'-P1D'", 253 | cast(BlogPost.c.published_at, Date) 254 | == literal(dt.date(2019, 1, 1)) + -1 * dt.timedelta(days=1), 255 | ), 256 | pytest.param( 257 | "authors/name eq 'Ruben'", 258 | Author.c.name == "Ruben", 259 | marks=pytest.mark.xfail(reason="Not implemented yet."), 260 | ), 261 | pytest.param( 262 | "authors/comments/content eq 'Cool!'", 263 | Comment.c.content == "Cool!", 264 | marks=pytest.mark.xfail(reason="Not implemented yet."), 265 | ), 266 | pytest.param( 267 | "contains(comments/content, 'Cool')", 268 | Comment.c.content.contains("Cool"), 269 | marks=pytest.mark.xfail(reason="Not implemented yet."), 270 | ), 271 | # GITHUB-19 272 | ( 273 | "contains(title, 'TEST') eq true", 274 | BlogPost.c.title.contains("TEST") == True, # noqa:E712 275 | ), 276 | # GITHUB-47 277 | ( 278 | "contains(tolower(title), tolower('A'))", 279 | functions_ext.lower(BlogPost.c.title).contains(functions_ext.lower("A")), 280 | ), 281 | ], 282 | ) 283 | def test_odata_filter_to_sqlalchemy_query( 284 | odata_query: str, expected_q: str, lexer, parser 285 | ): 286 | ast = parser.parse(lexer.tokenize(odata_query)) 287 | transformer = AstToSqlAlchemyCoreVisitor(BlogPost) 288 | res_q = transformer.visit(ast) 289 | 290 | assert res_q.compare(expected_q), (str(res_q), str(expected_q)) 291 | -------------------------------------------------------------------------------- /tests/integration/sqlalchemy/test_odata_to_sqlalchemy_orm.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pytest 4 | from sqlalchemy.sql import functions 5 | from sqlalchemy.sql.expression import cast, extract, literal 6 | from sqlalchemy.types import Date, Time 7 | 8 | from odata_query.sqlalchemy import AstToSqlAlchemyOrmVisitor, functions_ext 9 | 10 | from .models import Author, BlogPost, Comment 11 | 12 | 13 | def tz(offset: int) -> dt.tzinfo: 14 | return dt.timezone(dt.timedelta(hours=offset)) 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "odata_query, expected_q", 19 | [ 20 | ( 21 | "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 22 | BlogPost.id == "a7af27e6-f5a0-11e9-9649-0a252986adba", 23 | ), 24 | ("my_app.id eq 1", BlogPost.id == 1), 25 | ( 26 | "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)", 27 | BlogPost.id.in_( 28 | [ 29 | literal("a7af27e6-f5a0-11e9-9649-0a252986adba"), 30 | literal("800c56e4-354d-11eb-be38-3af9d323e83c"), 31 | ] 32 | ), 33 | ), 34 | ("id eq 4", BlogPost.id == 4), 35 | ("id ne 4", BlogPost.id != 4), 36 | ("4 eq id", 4 == BlogPost.id), 37 | ("4 ne id", 4 != BlogPost.id), 38 | ("published_at gt 2018-01-01", BlogPost.published_at > dt.date(2018, 1, 1)), 39 | ("published_at ge 2018-01-01", BlogPost.published_at >= dt.date(2018, 1, 1)), 40 | ("published_at lt 2018-01-01", BlogPost.published_at < dt.date(2018, 1, 1)), 41 | ("published_at le 2018-01-01", BlogPost.published_at <= dt.date(2018, 1, 1)), 42 | ( 43 | "2018-01-01 gt published_at", 44 | literal(dt.date(2018, 1, 1)) > BlogPost.published_at, 45 | ), 46 | ( 47 | "2018-01-01 ge published_at", 48 | literal(dt.date(2018, 1, 1)) >= BlogPost.published_at, 49 | ), 50 | ( 51 | "published_at gt 2018-01-01T01:02", 52 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2), 53 | ), 54 | ( 55 | "published_at gt 2018-01-01T01:02:03", 56 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3), 57 | ), 58 | ( 59 | "published_at gt 2018-01-01T01:02:03.123", 60 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000), 61 | ), 62 | ( 63 | "published_at gt 2018-01-01T01:02:03.123456", 64 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_456), 65 | ), 66 | ( 67 | "published_at gt 2018-01-01T01:02Z", 68 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(0)), 69 | ), 70 | ( 71 | "published_at gt 2018-01-01T01:02:03Z", 72 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, tzinfo=tz(0)), 73 | ), 74 | ( 75 | "published_at gt 2018-01-01T01:02:03.123Z", 76 | BlogPost.published_at 77 | > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000, tzinfo=tz(0)), 78 | ), 79 | ( 80 | "published_at gt 2018-01-01T01:02+02:00", 81 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(+2)), 82 | ), 83 | ( 84 | "published_at gt 2018-01-01T01:02-02:00", 85 | BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(-2)), 86 | ), 87 | ("id in (1, 2, 3)", BlogPost.id.in_([literal(1), literal(2), literal(3)])), 88 | ("id eq null", BlogPost.id == None), # noqa:E711 89 | ("id ne null", BlogPost.id != None), # noqa:E711 90 | ("not (id eq 1)", ~(BlogPost.id == 1)), 91 | ("id eq 1 or id eq 2", (BlogPost.id == 1) | (BlogPost.id == 2)), 92 | ( 93 | "id eq 1 and content eq 'executing'", 94 | (BlogPost.id == 1) & (BlogPost.content == "executing"), 95 | ), 96 | ( 97 | "id eq 1 and (content eq 'executing' or content eq 'failed')", 98 | (BlogPost.id == 1) 99 | & ((BlogPost.content == "executing") | (BlogPost.content == "failed")), 100 | ), 101 | ("id eq 1 add 1", BlogPost.id == literal(1) + literal(1)), 102 | ("id eq 2 sub 1", BlogPost.id == literal(2) - literal(1)), 103 | ("id eq 2 mul 2", BlogPost.id == literal(2) * literal(2)), 104 | ("id eq 2 div 2", BlogPost.id == literal(2) / literal(2)), 105 | ("id eq 5 mod 4", BlogPost.id == literal(5) % literal(4)), 106 | ("id eq 2 add -1", BlogPost.id == literal(2) + literal(-1)), 107 | ("id eq id sub 1", BlogPost.id == BlogPost.id - literal(1)), 108 | ( 109 | "title eq 'donut' add 'tello'", 110 | BlogPost.title == literal("donut") + literal("tello"), 111 | ), 112 | ( 113 | "title eq content add content", 114 | BlogPost.title == BlogPost.content + BlogPost.content, 115 | ), 116 | ( 117 | "published_at eq 2019-01-01T00:00:00 add duration'P1DT1H1M1S'", 118 | BlogPost.published_at 119 | == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) 120 | + dt.timedelta(days=1, hours=1, minutes=1, seconds=1), 121 | ), 122 | ( 123 | "published_at eq 2019-01-01T00:00:00 add duration'P1Y'", 124 | BlogPost.published_at 125 | == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) 126 | + dt.timedelta(days=365.25), # 1 times 365.25 (average year in days) 127 | ), 128 | ( 129 | "published_at eq 2019-01-01T00:00:00 add duration'P2M'", 130 | BlogPost.published_at 131 | == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) 132 | + dt.timedelta(days=60.88), # 2 times 30.44 (average month in days) 133 | ), 134 | ("contains(title, 'copy')", BlogPost.title.contains("copy")), 135 | ("startswith(title, 'copy')", BlogPost.title.startswith("copy")), 136 | ("endswith(title, 'bla')", BlogPost.title.endswith("bla")), 137 | ("id eq length(title)", BlogPost.id == functions.char_length(BlogPost.title)), 138 | ("length(title) eq 10", functions.char_length(BlogPost.title) == 10), 139 | ("10 eq length(title)", 10 == functions.char_length(BlogPost.title)), 140 | ( 141 | "length(title) eq length('flippot')", 142 | functions.char_length(BlogPost.title) == functions.char_length("flippot"), 143 | ), 144 | ("title eq concat('a', 'b')", BlogPost.title == functions.concat("a", "b")), 145 | ( 146 | "title eq concat('test', id)", 147 | BlogPost.title == functions.concat("test", BlogPost.id), 148 | ), 149 | ( 150 | "title eq concat(concat('a', 'b'), 'c')", 151 | BlogPost.title == functions.concat(functions.concat("a", "b"), "c"), 152 | ), 153 | ( 154 | "concat(title, 'a') eq 'testa'", 155 | functions.concat(BlogPost.title, "a") == "testa", 156 | ), 157 | ( 158 | "indexof(title, 'Copy') eq 6", 159 | functions_ext.strpos(BlogPost.title, "Copy") - 1 == 6, 160 | ), 161 | ( 162 | "substring(title, 0) eq 'Copy'", 163 | functions_ext.substr(BlogPost.title, literal(0) + 1) == "Copy", 164 | ), 165 | ( 166 | "substring(title, 0, 4) eq 'Copy'", 167 | functions_ext.substr(BlogPost.title, literal(0) + 1, 4) == "Copy", 168 | ), 169 | ("matchesPattern(title, 'C.py')", BlogPost.title.regexp_match("C.py")), 170 | ("tolower(title) eq 'copy'", functions_ext.lower(BlogPost.title) == "copy"), 171 | ("toupper(title) eq 'COPY'", functions_ext.upper(BlogPost.title) == "COPY"), 172 | ( 173 | "trim(title) eq 'copy'", 174 | functions_ext.ltrim(functions_ext.rtrim(BlogPost.title)) == "copy", 175 | ), 176 | ( 177 | "date(published_at) eq 2019-01-01", 178 | cast(BlogPost.published_at, Date) == dt.date(2019, 1, 1), 179 | ), 180 | ("day(published_at) eq 1", extract("day", BlogPost.published_at) == 1), 181 | ("hour(published_at) eq 1", extract("hour", BlogPost.published_at) == 1), 182 | ("minute(published_at) eq 1", extract("minute", BlogPost.published_at) == 1), 183 | ("month(published_at) eq 1", extract("month", BlogPost.published_at) == 1), 184 | ("published_at eq now()", BlogPost.published_at == functions.now()), 185 | ("second(published_at) eq 1", extract("second", BlogPost.published_at) == 1), 186 | ( 187 | "time(published_at) eq 14:00:00", 188 | cast(BlogPost.published_at, Time) == dt.time(14, 0, 0), 189 | ), 190 | ("year(published_at) eq 2019", extract("year", BlogPost.published_at) == 2019), 191 | ("ceiling(id) eq 1", functions_ext.ceil(BlogPost.id) == 1), 192 | ("floor(id) eq 1", functions_ext.floor(BlogPost.id) == 1), 193 | ("round(id) eq 1", functions_ext.round(BlogPost.id) == 1), 194 | ( 195 | "date(published_at) eq 2019-01-01 add duration'P1D'", 196 | cast(BlogPost.published_at, Date) 197 | == literal(dt.date(2019, 1, 1)) + dt.timedelta(days=1), 198 | ), 199 | ( 200 | "date(published_at) eq 2019-01-01 add duration'-P1D'", 201 | cast(BlogPost.published_at, Date) 202 | == literal(dt.date(2019, 1, 1)) + -1 * dt.timedelta(days=1), 203 | ), 204 | ("authors/name eq 'Ruben'", Author.name == "Ruben"), 205 | ("authors/comments/content eq 'Cool!'", Comment.content == "Cool!"), 206 | ("contains(comments/content, 'Cool')", Comment.content.contains("Cool")), 207 | # GITHUB-19 208 | ( 209 | "contains(title, 'TEST') eq true", 210 | BlogPost.title.contains("TEST") == True, # noqa:E712 211 | ), 212 | # GITHUB-47 213 | ( 214 | "contains(tolower(title), tolower('A'))", 215 | functions_ext.lower(BlogPost.title).contains(functions_ext.lower("A")), 216 | ), 217 | ], 218 | ) 219 | def test_odata_filter_to_sqlalchemy_query( 220 | odata_query: str, expected_q: str, lexer, parser 221 | ): 222 | ast = parser.parse(lexer.tokenize(odata_query)) 223 | transformer = AstToSqlAlchemyOrmVisitor(BlogPost) 224 | res_q = transformer.visit(ast) 225 | 226 | assert res_q.compare(expected_q), (str(res_q), str(expected_q)) 227 | -------------------------------------------------------------------------------- /tests/integration/sqlalchemy/test_querying.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from typing import Callable, Type 3 | 4 | import pytest 5 | from sqlalchemy import select 6 | 7 | from odata_query.sqlalchemy import apply_odata_core, apply_odata_query 8 | 9 | from .models import Author, Base, BlogPost, Comment 10 | 11 | 12 | @pytest.fixture 13 | def sample_data_sess(db_session): 14 | s = db_session() 15 | a1 = Author(name="Gorilla") 16 | a2 = Author(name="Baboon") 17 | a3 = Author(name="Saki") 18 | bp1 = BlogPost( 19 | title="Querying Data", 20 | published_at=dt.datetime(2020, 1, 1), 21 | content="How 2 query data...", 22 | authors=[a1], 23 | ) 24 | bp2 = BlogPost( 25 | title="Automating Monkey Jobs", 26 | published_at=dt.datetime(2019, 1, 1), 27 | content="How 2 automate monkey jobs...", 28 | authors=[a2, a3], 29 | ) 30 | c1 = Comment(content="Dope!", author=a2, blogpost=bp1) 31 | c2 = Comment(content="Cool!", author=a1, blogpost=bp2) 32 | s.add_all([a1, a2, a3, bp1, bp2, c1, c2]) 33 | s.flush() 34 | yield s 35 | s.rollback() 36 | 37 | 38 | def apply_odata_query_bc_sqla1(*args, **kwargs): 39 | """ 40 | Check for backwards compatibility with the "old style" of ORM querying. 41 | See: GITHUB-34 42 | """ 43 | return apply_odata_query(*args, **kwargs) 44 | 45 | 46 | @pytest.mark.parametrize( 47 | "model, query, exp_results", 48 | [ 49 | (Author, "name eq 'Baboon'", 1), 50 | (Author, "startswith(name, 'Gori')", 1), 51 | (BlogPost, "contains(content, 'How')", 2), 52 | (BlogPost, "published_at gt 2019-06-01", 1), 53 | (Author, "contains(blogposts/title, 'Monkey')", 2), 54 | (Author, "startswith(blogposts/comments/content, 'Cool')", 2), 55 | (Author, "comments/any()", 2), 56 | (BlogPost, "authors/any(a: contains(a/name, 'o'))", 2), 57 | (BlogPost, "authors/all(a: contains(a/name, 'o'))", 1), 58 | (Author, "blogposts/comments/any(c: contains(c/content, 'Cool'))", 2), 59 | (Author, "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 0), 60 | ( 61 | Author, 62 | "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)", 63 | 0, 64 | ), 65 | (BlogPost, "comments/author eq 0", 0), 66 | (BlogPost, "substring(content, 0) eq 'test'", 0), 67 | (BlogPost, "year(published_at) eq 2019", 1), 68 | # GITHUB-19 69 | (BlogPost, "contains(title, 'Query') eq true", 1), 70 | ], 71 | ) 72 | @pytest.mark.parametrize( 73 | "apply_func", 74 | [ 75 | pytest.param(apply_odata_query, id="ORM"), 76 | pytest.param(apply_odata_query_bc_sqla1, id="ORM 1.x"), 77 | pytest.param(apply_odata_core, id="Core"), 78 | ], 79 | ) 80 | def test_query_with_odata( 81 | model: Type[Base], 82 | query: str, 83 | exp_results: int, 84 | apply_func: Callable, 85 | sample_data_sess, 86 | ): 87 | # ORM mode 1.x: 88 | if apply_func is apply_odata_query_bc_sqla1: 89 | base_q = sample_data_sess.query(model) 90 | q = apply_func(base_q, query) 91 | results = q.all() 92 | assert len(results) == exp_results 93 | return 94 | 95 | # ORM mode: 96 | elif apply_func is apply_odata_query: 97 | base_q = select(model) 98 | 99 | # Core mode: 100 | elif apply_func is apply_odata_core: 101 | base_q = select(model.__table__) 102 | 103 | else: 104 | raise ValueError(apply_func) 105 | 106 | try: 107 | q = apply_func(base_q, query) 108 | except NotImplementedError: 109 | pytest.xfail("Not implemented yet.") 110 | 111 | results = sample_data_sess.execute(q).scalars().all() 112 | assert len(results) == exp_results 113 | 114 | 115 | @pytest.mark.parametrize( 116 | "apply_func", 117 | [ 118 | pytest.param(apply_odata_query, id="ORM"), 119 | pytest.param(apply_odata_query_bc_sqla1, id="ORM 1.x"), 120 | ], 121 | ) 122 | def test_query_with_existing_join(apply_func, sample_data_sess): 123 | """ 124 | GITHUB-37 125 | """ 126 | odata_query = "author/name eq 'Gorilla'" 127 | exp_results = 1 128 | 129 | # ORM mode 1.x: 130 | if apply_func is apply_odata_query_bc_sqla1: 131 | base_q = sample_data_sess.query(Comment).join( 132 | Author, Comment.author_id == Author.id 133 | ) 134 | q = apply_func(base_q, odata_query) 135 | results = q.all() 136 | assert len(results) == exp_results 137 | return 138 | 139 | # ORM mode: 140 | elif apply_func is apply_odata_query: 141 | base_q = select(Comment).join(Comment.author) 142 | 143 | else: 144 | raise ValueError(apply_func) 145 | 146 | q = apply_func(base_q, odata_query) 147 | 148 | results = sample_data_sess.execute(q).scalars().all() 149 | assert len(results) == exp_results 150 | -------------------------------------------------------------------------------- /tests/integration/test_roundtrip.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from odata_query.roundtrip import AstToODataVisitor 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "odata_query", 8 | [ 9 | "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 10 | "my_app.id eq 1", 11 | "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)", 12 | "id eq 4", 13 | "id ne 4", 14 | "4 eq id", 15 | "4 ne id", 16 | "published_at gt 2018-01-01", 17 | "published_at ge 2018-01-01", 18 | "published_at lt 2018-01-01", 19 | "published_at le 2018-01-01", 20 | "2018-01-01 gt published_at", 21 | "2018-01-01 ge published_at", 22 | "published_at gt 2018-01-01T01:02", 23 | "published_at gt 2018-01-01T01:02:03", 24 | "published_at gt 2018-01-01T01:02:03.123", 25 | "published_at gt 2018-01-01T01:02:03.123456", 26 | "published_at gt 2018-01-01T01:02Z", 27 | "published_at gt 2018-01-01T01:02:03Z", 28 | "published_at gt 2018-01-01T01:02:03.123Z", 29 | "published_at gt 2018-01-01T01:02+02:00", 30 | "published_at gt 2018-01-01T01:02-02:00", 31 | "id in (1, 2, 3)", 32 | "id eq null", 33 | "id ne null", 34 | "not (id eq 1)", 35 | "id eq 1 or id eq 2", 36 | "id eq 1 and content eq 'executing'", 37 | "id eq 1 and (content eq 'executing' or content eq 'failed')", 38 | "id eq 1 add 1", 39 | "id eq 2 sub 1", 40 | "id eq 2 mul 2", 41 | "id eq 2 div 2", 42 | "id eq 5 mod 4", 43 | "id eq 2 add -1", 44 | "id eq id sub 1", 45 | "title eq 'donut' add 'tello'", 46 | "title eq content add content", 47 | "published_at eq 2019-01-01T00:00:00 add duration'P1DT1H1M1S'", 48 | "contains(title, 'copy')", 49 | "startswith(title, 'copy')", 50 | "endswith(title, 'bla')", 51 | "id eq length(title)", 52 | "length(title) eq 10", 53 | "10 eq length(title)", 54 | "length(title) eq length('flippot')", 55 | "title eq concat('a', 'b')", 56 | "title eq concat('test', id)", 57 | "title eq concat(concat('a', 'b'), 'c')", 58 | "concat(title, 'a') eq 'testa'", 59 | "indexof(title, 'Copy') eq 6", 60 | "substring(title, 0) eq 'Copy'", 61 | "substring(title, 0, 4) eq 'Copy'", 62 | "matchesPattern(title, 'C.py')", 63 | "tolower(title) eq 'copy'", 64 | "toupper(title) eq 'COPY'", 65 | "trim(title) eq 'copy'", 66 | "date(published_at) eq 2019-01-01", 67 | "day(published_at) eq 1", 68 | "hour(published_at) eq 1", 69 | "minute(published_at) eq 1", 70 | "month(published_at) eq 1", 71 | "published_at eq now()", 72 | "second(published_at) eq 1", 73 | "time(published_at) eq 14:00:00", 74 | "year(published_at) eq 2019", 75 | "ceiling(id) eq 1", 76 | "floor(id) eq 1", 77 | "round(id) eq 1", 78 | "date(published_at) eq 2019-01-01 add duration'P1D'", 79 | "date(published_at) eq 2019-01-01 add duration'-P1D'", 80 | "authors/name eq 'Ruben'", 81 | "authors/comments/content eq 'Cool!'", 82 | "contains(comments/content, 'Cool')", 83 | "contains(title, 'TEST') eq true", 84 | # Precedence checks: 85 | "1 mul (2 add -3 sub 4) div 5", 86 | ], 87 | ) 88 | def test_odata_filter_roundtrip(odata_query: str, lexer, parser): 89 | ast = parser.parse(lexer.tokenize(odata_query)) 90 | transformer = AstToODataVisitor() 91 | res = transformer.visit(ast) 92 | 93 | assert res == odata_query 94 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/django/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/unit/django/__init__.py -------------------------------------------------------------------------------- /tests/unit/sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/unit/sql/__init__.py -------------------------------------------------------------------------------- /tests/unit/sql/test_ast_to_sql.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from odata_query import ast, sql 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "ast_input, sql_expected", 10 | [ 11 | (ast.Compare(ast.Eq(), ast.Integer("1"), ast.Integer("1")), "1 = 1"), 12 | ( 13 | ast.Compare(ast.NotEq(), ast.Boolean("true"), ast.Boolean("false")), 14 | "TRUE != FALSE", 15 | ), 16 | ( 17 | ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("123.12")), 18 | '"eac" <= 123.12', 19 | ), 20 | ( 21 | ast.Compare( 22 | ast.Lt(), ast.Identifier("period_start"), ast.Date("2019-01-01") 23 | ), 24 | "\"period_start\" < DATE '2019-01-01'", 25 | ), 26 | ( 27 | ast.BoolOp( 28 | ast.And(), 29 | ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("123.12")), 30 | ast.Compare( 31 | ast.In(), 32 | ast.Identifier("meter_id"), 33 | ast.List([ast.String("1"), ast.String("2"), ast.String("3")]), 34 | ), 35 | ), 36 | "\"eac\" >= 123.12 AND \"meter_id\" IN ('1', '2', '3')", 37 | ), 38 | ( 39 | ast.BoolOp( 40 | ast.And(), 41 | ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("1")), 42 | ast.BoolOp( 43 | ast.Or(), 44 | ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("10.0")), 45 | ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("1.0")), 46 | ), 47 | ), 48 | '"a" = \'1\' AND ("eac" <= 10.0 OR "eac" >= 1.0)', 49 | ), 50 | ], 51 | ) 52 | def test_ast_to_sql(ast_input: ast._Node, sql_expected: str): 53 | visitor = sql.AstToSqlVisitor() 54 | res = visitor.visit(ast_input) 55 | 56 | assert res == sql_expected 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "func_name, args, sql_expected", 61 | [ 62 | ("concat", [ast.String("ab"), ast.String("cd")], "'ab' || 'cd'"), 63 | ( 64 | "contains", 65 | [ast.String("abc"), ast.String("b")], 66 | "'abc' LIKE '%b%'", 67 | ), 68 | ( 69 | "endswith", 70 | [ast.String("abc"), ast.String("bc")], 71 | "'abc' LIKE '%bc'", 72 | ), 73 | ( 74 | "indexof", 75 | [ast.String("abc"), ast.String("bc")], 76 | "POSITION('bc' IN 'abc') - 1", 77 | ), 78 | ("length", [ast.String("abc")], "CHAR_LENGTH('abc')"), 79 | ( 80 | "length", 81 | [ast.List([ast.String("a"), ast.String("b")])], 82 | "CARDINALITY(('a', 'b'))", 83 | ), 84 | ( 85 | "startswith", 86 | [ast.String("abc"), ast.String("ab")], 87 | "'abc' LIKE 'ab%'", 88 | ), 89 | ( 90 | "substring", 91 | [ast.String("abc"), ast.Integer("1")], 92 | "SUBSTRING('abc' FROM 1 + 1)", 93 | ), 94 | ( 95 | "substring", 96 | [ast.String("abcdef"), ast.Integer("1"), ast.Integer("2")], 97 | "SUBSTRING('abcdef' FROM 1 + 1 FOR 2)", 98 | ), 99 | ("tolower", [ast.String("ABC")], "LOWER('ABC')"), 100 | ("toupper", [ast.String("abc")], "UPPER('abc')"), 101 | ("trim", [ast.String(" abc ")], "TRIM(' abc ')"), 102 | ( 103 | "year", 104 | [ast.DateTime("2018-01-01T10:00:00")], 105 | "EXTRACT (YEAR FROM TIMESTAMP '2018-01-01 10:00:00')", 106 | ), 107 | ( 108 | "month", 109 | [ast.DateTime("2018-01-01T10:00:00")], 110 | "EXTRACT (MONTH FROM TIMESTAMP '2018-01-01 10:00:00')", 111 | ), 112 | ( 113 | "day", 114 | [ast.DateTime("2018-01-01T10:00:00")], 115 | "EXTRACT (DAY FROM TIMESTAMP '2018-01-01 10:00:00')", 116 | ), 117 | ( 118 | "hour", 119 | [ast.DateTime("2018-01-01T10:00:00")], 120 | "EXTRACT (HOUR FROM TIMESTAMP '2018-01-01 10:00:00')", 121 | ), 122 | ( 123 | "minute", 124 | [ast.DateTime("2018-01-01T10:00:00")], 125 | "EXTRACT (MINUTE FROM TIMESTAMP '2018-01-01 10:00:00')", 126 | ), 127 | ( 128 | "date", 129 | [ast.DateTime("2018-01-01T10:00:00")], 130 | "CAST (TIMESTAMP '2018-01-01 10:00:00' AS DATE)", 131 | ), 132 | ("now", [], "CURRENT_TIMESTAMP"), 133 | ("round", [ast.Float("123.12")], "CAST (123.12 + 0.5 AS INTEGER)"), 134 | ( 135 | "floor", 136 | [ast.Float("123.12")], 137 | """CASE 123.12 138 | WHEN > 0 CAST (123.12 AS INTEGER) 139 | WHEN < 0 CAST (0 - (ABS(123.12) + 0.5) AS INTEGER)) 140 | ELSE 123.12 141 | END""", 142 | ), 143 | ( 144 | "ceiling", 145 | [ast.Float("123.12")], 146 | """CASE 123.12 - CAST (123.12 AS INTEGER) 147 | WHEN > 0 123.12+1 148 | WHEN < 0 123.12-1 149 | ELSE 123.12 150 | END""", 151 | ), 152 | ], 153 | ) 154 | def test_ast_to_sql_functions(func_name: str, args: List[ast._Node], sql_expected: str): 155 | inp_ast = ast.Call(ast.Identifier(func_name), args) 156 | visitor = sql.AstToSqlVisitor() 157 | res = visitor.visit(inp_ast) 158 | 159 | assert res == sql_expected 160 | -------------------------------------------------------------------------------- /tests/unit/sql/test_ast_to_sqlite.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from odata_query import ast, sql 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "ast_input, sql_expected", 10 | [ 11 | (ast.Compare(ast.Eq(), ast.Integer("1"), ast.Integer("1")), "1 = 1"), 12 | ( 13 | ast.Compare(ast.NotEq(), ast.Boolean("true"), ast.Boolean("false")), 14 | "1 != 0", 15 | ), 16 | ( 17 | ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("123.12")), 18 | '"eac" <= 123.12', 19 | ), 20 | ( 21 | ast.Compare( 22 | ast.Lt(), 23 | ast.Identifier("period_start"), 24 | ast.Date("2019-01-01"), 25 | ), 26 | "\"period_start\" < DATE('2019-01-01')", 27 | ), 28 | ( 29 | ast.BoolOp( 30 | ast.And(), 31 | ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("123.12")), 32 | ast.Compare( 33 | ast.In(), 34 | ast.Identifier("meter_id"), 35 | ast.List([ast.String("1"), ast.String("2"), ast.String("3")]), 36 | ), 37 | ), 38 | "\"eac\" >= 123.12 AND \"meter_id\" IN ('1', '2', '3')", 39 | ), 40 | ( 41 | ast.BoolOp( 42 | ast.And(), 43 | ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("1")), 44 | ast.BoolOp( 45 | ast.Or(), 46 | ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("10.0")), 47 | ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("1.0")), 48 | ), 49 | ), 50 | '"a" = \'1\' AND ("eac" <= 10.0 OR "eac" >= 1.0)', 51 | ), 52 | ], 53 | ) 54 | def test_ast_to_sql(ast_input: ast._Node, sql_expected: str): 55 | visitor = sql.AstToSqliteSqlVisitor() 56 | res = visitor.visit(ast_input) 57 | 58 | assert res == sql_expected 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "func_name, args, sql_expected", 63 | [ 64 | ("concat", [ast.String("ab"), ast.String("cd")], "'ab' || 'cd'"), 65 | ( 66 | "contains", 67 | [ast.String("abc"), ast.String("b")], 68 | "'abc' LIKE '%b%'", 69 | ), 70 | ( 71 | "endswith", 72 | [ast.String("abc"), ast.String("bc")], 73 | "'abc' LIKE '%bc'", 74 | ), 75 | ( 76 | "indexof", 77 | [ast.String("abc"), ast.String("bc")], 78 | "INSTR('abc', 'bc') - 1", 79 | ), 80 | ( 81 | "length", 82 | [ast.String("a")], 83 | "LENGTH('a')", 84 | ), 85 | ( 86 | "length", 87 | [ast.Identifier("a")], 88 | 'LENGTH("a")', 89 | ), 90 | ( 91 | "startswith", 92 | [ast.String("abc"), ast.String("ab")], 93 | "'abc' LIKE 'ab%'", 94 | ), 95 | ( 96 | "substring", 97 | [ast.String("abc"), ast.Integer("1")], 98 | "SUBSTR('abc', 1 + 1)", 99 | ), 100 | ( 101 | "substring", 102 | [ast.String("abcdef"), ast.Integer("1"), ast.Integer("2")], 103 | "SUBSTR('abcdef', 1 + 1, 2)", 104 | ), 105 | ("tolower", [ast.String("ABC")], "LOWER('ABC')"), 106 | ("toupper", [ast.String("abc")], "UPPER('abc')"), 107 | ("trim", [ast.String(" abc ")], "TRIM(' abc ')"), 108 | ( 109 | "year", 110 | [ast.DateTime("2018-01-01T10:00:00")], 111 | "CAST(STRFTIME('%Y', DATETIME('2018-01-01T10:00:00')) AS INTEGER)", 112 | ), 113 | ( 114 | "month", 115 | [ast.DateTime("2018-01-01T10:00:00")], 116 | "CAST(STRFTIME('%m', DATETIME('2018-01-01T10:00:00')) AS INTEGER)", 117 | ), 118 | ( 119 | "day", 120 | [ast.DateTime("2018-01-01T10:00:00")], 121 | "CAST(STRFTIME('%d', DATETIME('2018-01-01T10:00:00')) AS INTEGER)", 122 | ), 123 | ( 124 | "hour", 125 | [ast.DateTime("2018-01-01T10:00:00")], 126 | "CAST(STRFTIME('%H', DATETIME('2018-01-01T10:00:00')) AS INTEGER)", 127 | ), 128 | ( 129 | "minute", 130 | [ast.DateTime("2018-01-01T10:00:00")], 131 | "CAST(STRFTIME('%M', DATETIME('2018-01-01T10:00:00')) AS INTEGER)", 132 | ), 133 | ( 134 | "date", 135 | [ast.DateTime("2018-01-01T10:00:00")], 136 | "DATE(DATETIME('2018-01-01T10:00:00'))", 137 | ), 138 | ("now", [], "DATETIME('now')"), 139 | ("round", [ast.Float("123.12")], "TRUNC(123.12 + 0.5)"), 140 | ("floor", [ast.Float("123.12")], "FLOOR(123.12)"), 141 | ("ceiling", [ast.Float("123.12")], "CEILING(123.12)"), 142 | ], 143 | ) 144 | def test_ast_to_sql_functions(func_name: str, args: List[ast._Node], sql_expected: str): 145 | inp_ast = ast.Call(ast.Identifier(func_name), args) 146 | visitor = sql.AstToSqliteSqlVisitor() 147 | res = visitor.visit(inp_ast) 148 | 149 | assert res == sql_expected 150 | -------------------------------------------------------------------------------- /tests/unit/test_rewrite.py: -------------------------------------------------------------------------------- 1 | from odata_query import ast 2 | from odata_query.rewrite import AliasRewriter 3 | 4 | 5 | def test_identifier_rewrite(): 6 | rewriter = AliasRewriter({"a": "author"}) 7 | _ast = ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("Bobby")) 8 | exp = ast.Compare(ast.Eq(), ast.Identifier("author"), ast.String("Bobby")) 9 | 10 | res = rewriter.visit(_ast) 11 | assert res == exp 12 | 13 | 14 | def test_identifier_to_attribute_rewrite(): 15 | rewriter = AliasRewriter({"a": "author/name"}) 16 | _ast = ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("Bobby")) 17 | exp = ast.Compare( 18 | ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby") 19 | ) 20 | 21 | res = rewriter.visit(_ast) 22 | assert res == exp 23 | 24 | 25 | def test_attribute_to_identifier_rewrite(): 26 | rewriter = AliasRewriter({"author/name": "author_name"}) 27 | _ast = ast.Compare( 28 | ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby") 29 | ) 30 | exp = ast.Compare(ast.Eq(), ast.Identifier("author_name"), ast.String("Bobby")) 31 | 32 | res = rewriter.visit(_ast) 33 | assert res == exp 34 | 35 | 36 | def test_attribute_to_attribute_rewrite(): 37 | rewriter = AliasRewriter({"author/name": "author/info/name"}) 38 | _ast = ast.Compare( 39 | ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby") 40 | ) 41 | exp = ast.Compare( 42 | ast.Eq(), 43 | ast.Attribute(ast.Attribute(ast.Identifier("author"), "info"), "name"), 44 | ast.String("Bobby"), 45 | ) 46 | 47 | res = rewriter.visit(_ast) 48 | assert res == exp 49 | 50 | 51 | def test_identifier_to_function_rewrite(): 52 | rewriter = AliasRewriter({"author_length": "length(author)"}) 53 | _ast = ast.Compare(ast.Eq(), ast.Identifier("author_length"), ast.Integer(10)) 54 | exp = ast.Compare( 55 | ast.Eq(), 56 | ast.Call(ast.Identifier("length"), [ast.Identifier("author")]), 57 | ast.Integer(10), 58 | ) 59 | 60 | res = rewriter.visit(_ast) 61 | assert res == exp 62 | 63 | 64 | def test_rewrite_attribute_owner(): 65 | rewriter = AliasRewriter({"a": "author"}) 66 | _ast = ast.Compare( 67 | ast.Eq(), ast.Attribute(ast.Identifier("a"), "name"), ast.String("Bobby") 68 | ) 69 | exp = ast.Compare( 70 | ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby") 71 | ) 72 | 73 | res = rewriter.visit(_ast) 74 | assert res == exp 75 | -------------------------------------------------------------------------------- /tests/unit/test_typing.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | import pytest 4 | 5 | from odata_query import ast, typing 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "input_node, expected_type", 10 | [ 11 | (ast.String("abc"), ast.String), 12 | (ast.Compare(ast.Eq(), ast.String("abc"), ast.String("def")), ast.Boolean), 13 | ( 14 | ast.Call(ast.Identifier("contains"), [ast.String("abc"), ast.String("b")]), 15 | ast.Boolean, 16 | ), 17 | (ast.Identifier("a"), None), 18 | ], 19 | ) 20 | def test_infer_type_of_node(input_node: ast._Node, expected_type: Type[ast._Node]): 21 | res = typing.infer_type(input_node) 22 | 23 | assert res is expected_type 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "input_node, expected_type", 28 | [ 29 | ( 30 | ast.Call(ast.Identifier("contains"), [ast.String("abc"), ast.String("b")]), 31 | ast.Boolean, 32 | ), 33 | ( 34 | ast.Call(ast.Identifier("indexof"), [ast.String("abc"), ast.String("b")]), 35 | ast.Integer, 36 | ), 37 | ( 38 | ast.Call(ast.Identifier("floor"), [ast.Float("10.32")]), 39 | ast.Float, 40 | ), 41 | ( 42 | ast.Call(ast.Identifier("date"), [ast.DateTime("2020-01-01T10:10:10")]), 43 | ast.Date, 44 | ), 45 | ( 46 | ast.Call(ast.Identifier("maxdatetime"), []), 47 | ast.DateTime, 48 | ), 49 | ( 50 | ast.Call(ast.Identifier("unknown_function"), []), 51 | None, 52 | ), 53 | ], 54 | ) 55 | def test_infer_return_type_of_call( 56 | input_node: ast.Call, expected_type: Type[ast._Node] 57 | ): 58 | res = typing.infer_type(input_node) 59 | 60 | assert res is expected_type 61 | -------------------------------------------------------------------------------- /tests/unit/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from odata_query import ast, utils 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "expression, expected", 8 | [ 9 | (ast.Attribute(ast.Identifier("id"), "name"), ast.Identifier("name")), 10 | ( 11 | ast.Attribute(ast.Attribute(ast.Identifier("id"), "user"), "name"), 12 | ast.Attribute(ast.Identifier("user"), "name"), 13 | ), 14 | ( 15 | ast.Attribute(ast.Identifier("something else"), "whatever"), 16 | ast.Attribute(ast.Identifier("something else"), "whatever"), 17 | ), 18 | ( 19 | ast.Compare( 20 | ast.Eq(), 21 | ast.Attribute(ast.Identifier("id"), "name"), 22 | ast.String("Jozef"), 23 | ), 24 | ast.Compare(ast.Eq(), ast.Identifier("name"), ast.String("Jozef")), 25 | ), 26 | ( 27 | ast.Compare( 28 | ast.In(), 29 | ast.Attribute(ast.Identifier("id"), "name"), 30 | ast.List([ast.String("Jozef")]), 31 | ), 32 | ast.Compare( 33 | ast.In(), ast.Identifier("name"), ast.List([ast.String("Jozef")]) 34 | ), 35 | ), 36 | ], 37 | ) 38 | def test_expression_relative_to_identifier(expression, expected): 39 | res = utils.expression_relative_to_identifier(ast.Identifier("id"), expression) 40 | 41 | assert res == expected 42 | -------------------------------------------------------------------------------- /tests/unit/test_visitor.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from odata_query import ast, visitor 7 | 8 | 9 | @pytest.fixture() 10 | def simple_ast(): 11 | return ast.Compare( 12 | ast.Eq(), 13 | ast.BoolOp( 14 | ast.And(), 15 | ast.Compare(ast.Eq(), ast.Identifier("meter_id"), ast.String("123")), 16 | ast.Compare( 17 | ast.In(), ast.Identifier("category"), ast.List([ast.String("Gas")]) 18 | ), 19 | ), 20 | ast.Boolean("true"), 21 | ) 22 | 23 | 24 | def test_generic_visitor_calls_methods_based_on_node_class(simple_ast): 25 | _visitor = visitor.NodeVisitor() 26 | _visitor.visit_Identifier = mock.MagicMock() 27 | _visitor.visit(simple_ast) 28 | assert _visitor.visit_Identifier.call_count == 2 29 | 30 | 31 | def test_node_transformer_without_methods_doesnt_modify_tree(simple_ast): 32 | transformer = visitor.NodeTransformer() 33 | new_tree = transformer.visit(deepcopy(simple_ast)) 34 | assert simple_ast == new_tree 35 | 36 | 37 | def test_node_transformer_simple(simple_ast): 38 | class Inverter(visitor.NodeTransformer): 39 | """Simple transformer that inverts equality checks""" 40 | 41 | def visit_Compare(self, node: ast.Compare) -> ast.Compare: 42 | left = self.visit(node.left) 43 | right = self.visit(node.right) 44 | if isinstance(node.comparator, ast.Eq): 45 | return ast.Compare(ast.NotEq(), left, right) 46 | elif isinstance(node.comparator, ast.NotEq): 47 | return ast.Compare(ast.Eq(), left, right) 48 | return ast.Compare(node.comparator, left, right) 49 | 50 | transformer = Inverter() 51 | new_tree = transformer.visit(deepcopy(simple_ast)) 52 | 53 | assert simple_ast != new_tree 54 | assert isinstance(new_tree.comparator, ast.NotEq) 55 | assert isinstance(new_tree.left.left.comparator, ast.NotEq) 56 | 57 | roundtrip_tree = transformer.visit(deepcopy(new_tree)) 58 | assert simple_ast == roundtrip_tree 59 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37-django3, py{38,39,310,311}-django{3,4}, linting, docs 3 | skip_missing_interpreters = True 4 | isolated_build = True 5 | 6 | [gh-actions] 7 | python = 8 | 3.7: py37 9 | 3.8: py38, linting, docs 10 | 3.9: py39 11 | 3.10: py310 12 | 3.11: py311 13 | 14 | [testenv:linting] 15 | basepython = python3.8 16 | extras = 17 | linting 18 | commands = 19 | flake8 --show-source odata_query tests 20 | black --check --diff odata_query tests 21 | isort --check-only odata_query tests 22 | mypy --ignore-missing-imports -p odata_query 23 | vulture odata_query/ --min-confidence 80 24 | 25 | [testenv:docs] 26 | basepython = python3.8 27 | extras = 28 | docs 29 | django 30 | sqlalchemy 31 | changedir = 32 | docs/ 33 | commands = 34 | sphinx-build source build 35 | 36 | [testenv] 37 | deps = 38 | django3: Django>=3.2,<4 39 | django4: Django>=4,<5 40 | extras = 41 | testing 42 | django 43 | sqlalchemy 44 | setenv = 45 | DJANGO_SETTINGS_MODULE = tests.integration.django.settings 46 | passenv = 47 | PYTHONBREAKPOINT 48 | commands = 49 | pytest {posargs:tests/unit/ tests/integration/} -r fEs 50 | --------------------------------------------------------------------------------