├── .editorconfig
├── .github
    └── workflows
    │   └── tests.yml
├── .gitignore
├── .readthedocs.yaml
├── CHANGELOG.rst
├── CONTRIBUTING.rst
├── LICENSE
├── README.rst
├── docs
    ├── Makefile
    ├── make.bat
    └── source
    │   ├── api
    │       ├── index.rst
    │       ├── modules.rst
    │       ├── odata_query.ast.rst
    │       ├── odata_query.django.django_q.rst
    │       ├── odata_query.django.django_q_ext.rst
    │       ├── odata_query.django.rst
    │       ├── odata_query.django.shorthand.rst
    │       ├── odata_query.django.utils.rst
    │       ├── odata_query.exceptions.rst
    │       ├── odata_query.grammar.rst
    │       ├── odata_query.rewrite.rst
    │       ├── odata_query.roundtrip.rst
    │       ├── odata_query.rst
    │       ├── odata_query.sql.athena.rst
    │       ├── odata_query.sql.base.rst
    │       ├── odata_query.sql.rst
    │       ├── odata_query.sql.sqlite.rst
    │       ├── odata_query.sqlalchemy.common.rst
    │       ├── odata_query.sqlalchemy.core.rst
    │       ├── odata_query.sqlalchemy.functions_ext.rst
    │       ├── odata_query.sqlalchemy.orm.rst
    │       ├── odata_query.sqlalchemy.rst
    │       ├── odata_query.sqlalchemy.shorthand.rst
    │       ├── odata_query.typing.rst
    │       ├── odata_query.utils.rst
    │       └── odata_query.visitor.rst
    │   ├── changelog.rst
    │   ├── conf.py
    │   ├── contributing.rst
    │   ├── deviations-and-extensions.rst
    │   ├── django.rst
    │   ├── glossary.rst
    │   ├── index.rst
    │   ├── parsing-odata.rst
    │   ├── snippets
    │       ├── modifying.rst
    │       └── parsing.rst
    │   ├── sql.rst
    │   ├── sqlalchemy.rst
    │   └── working-with-ast.rst
├── odata_query
    ├── __init__.py
    ├── ast.py
    ├── django
    │   ├── __init__.py
    │   ├── django_q.py
    │   ├── django_q_ext.py
    │   ├── shorthand.py
    │   └── utils.py
    ├── exceptions.py
    ├── grammar.py
    ├── py.typed
    ├── rewrite.py
    ├── roundtrip.py
    ├── sql
    │   ├── __init__.py
    │   ├── athena.py
    │   ├── base.py
    │   └── sqlite.py
    ├── sqlalchemy
    │   ├── __init__.py
    │   ├── common.py
    │   ├── core.py
    │   ├── functions_ext.py
    │   ├── orm.py
    │   └── shorthand.py
    ├── typing.py
    ├── utils.py
    └── visitor.py
├── poetry.lock
├── pyproject.toml
├── setup.cfg
├── sonar-project.properties
├── tests
    ├── __init__.py
    ├── conftest.py
    ├── data
    │   └── world_borders.zip
    ├── integration
    │   ├── __init__.py
    │   ├── django
    │   │   ├── __init__.py
    │   │   ├── apps.py
    │   │   ├── conftest.py
    │   │   ├── models.py
    │   │   ├── settings.py
    │   │   ├── test_odata_to_django_q.py
    │   │   ├── test_querying.py
    │   │   └── test_utils.py
    │   ├── django_geo
    │   │   ├── __init__.py
    │   │   ├── apps.py
    │   │   ├── conftest.py
    │   │   ├── models.py
    │   │   └── test_querying.py
    │   ├── sql
    │   │   ├── __init__.py
    │   │   ├── conftest.py
    │   │   ├── test_odata_to_athena_sql.py
    │   │   ├── test_odata_to_sql.py
    │   │   ├── test_odata_to_sqlite.py
    │   │   └── test_querying.py
    │   ├── sqlalchemy
    │   │   ├── __init__.py
    │   │   ├── conftest.py
    │   │   ├── models.py
    │   │   ├── test_odata_to_sqlalchemy_core.py
    │   │   ├── test_odata_to_sqlalchemy_orm.py
    │   │   └── test_querying.py
    │   └── test_roundtrip.py
    └── unit
    │   ├── __init__.py
    │   ├── django
    │       └── __init__.py
    │   ├── sql
    │       ├── __init__.py
    │       ├── test_ast_to_sql.py
    │       └── test_ast_to_sqlite.py
    │   ├── test_odata_parser.py
    │   ├── test_rewrite.py
    │   ├── test_typing.py
    │   ├── test_utils.py
    │   └── test_visitor.py
└── tox.ini
/.editorconfig:
--------------------------------------------------------------------------------
 1 | # http://editorconfig.org
 2 | 
 3 | root = true
 4 | 
 5 | [*]
 6 | indent_style = space
 7 | indent_size = 4
 8 | trim_trailing_whitespace = true
 9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 | 
13 | [Makefile]
14 | indent_style = tab
15 | 
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
 1 | name: Run tests & linting
 2 | 
 3 | on: [push, pull_request, workflow_dispatch]
 4 | 
 5 | jobs:
 6 |   build:
 7 |     runs-on: ubuntu-latest
 8 |     strategy:
 9 |       matrix:
10 |         python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
11 | 
12 |     steps:
13 |       - uses: actions/checkout@v4
14 |       - name: Install GIS libraries
15 |         run: sudo apt-get install -y binutils libproj-dev gdal-bin libsqlite3-mod-spatialite
16 |       - name: Set up Python
17 |         uses: actions/setup-python@v5
18 |         with:
19 |           python-version: ${{ matrix.python-version }}
20 |       - name: Install dependencies
21 |         run: |
22 |           python -m pip install --upgrade pip
23 |           pip install tox tox-gh-actions
24 |       - name: Run tests
25 |         run: tox
26 | 
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | env/
 12 | build/
 13 | develop-eggs/
 14 | dist/
 15 | downloads/
 16 | eggs/
 17 | .eggs/
 18 | lib/
 19 | lib64/
 20 | parts/
 21 | sdist/
 22 | var/
 23 | wheels/
 24 | *.egg-info/
 25 | .installed.cfg
 26 | *.egg
 27 | 
 28 | # PyInstaller
 29 | #  Usually these files are written by a python script from a template
 30 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 31 | *.manifest
 32 | *.spec
 33 | 
 34 | # Installer logs
 35 | pip-log.txt
 36 | pip-delete-this-directory.txt
 37 | 
 38 | # Unit test / coverage reports
 39 | htmlcov/
 40 | .tox/
 41 | .coverage
 42 | .coverage.*
 43 | .cache
 44 | nosetests.xml
 45 | *.cover
 46 | .hypothesis/
 47 | .pytest_cache/
 48 | coverage.xml
 49 | 
 50 | # Translations
 51 | *.mo
 52 | *.pot
 53 | 
 54 | # Django stuff:
 55 | *.log
 56 | local_settings.py
 57 | 
 58 | # Flask stuff:
 59 | instance/
 60 | .webassets-cache
 61 | 
 62 | # Scrapy stuff:
 63 | .scrapy
 64 | 
 65 | # Sphinx documentation
 66 | docs/_build/
 67 | 
 68 | # PyBuilder
 69 | target/
 70 | 
 71 | # Jupyter Notebook
 72 | .ipynb_checkpoints
 73 | 
 74 | # pyenv
 75 | .python-version
 76 | 
 77 | # celery beat schedule file
 78 | celerybeat-schedule
 79 | 
 80 | # SageMath parsed files
 81 | *.sage.py
 82 | 
 83 | # dotenv
 84 | .env
 85 | .envrc
 86 | 
 87 | # virtualenv
 88 | .venv
 89 | venv/
 90 | ENV/
 91 | 
 92 | # Spyder project settings
 93 | .spyderproject
 94 | .spyproject
 95 | 
 96 | # Rope project settings
 97 | .ropeproject
 98 | 
 99 | # mkdocs documentation
100 | /site
101 | 
102 | # mypy
103 | .mypy_cache/
104 | 
105 | # Serverless
106 | .serverless/
107 | node_modules/
108 | 
109 | tests/integration/django/db/*
110 | tests/data/world_borders/*
111 | 
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
 1 | # .readthedocs.yaml
 2 | # Read the Docs configuration file
 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
 4 | 
 5 | # Required
 6 | version: 2
 7 | 
 8 | build:
 9 |   os: ubuntu-22.04
10 |   tools:
11 |     python: "3.12"
12 | 
13 | # Build documentation in the docs/ directory with Sphinx
14 | sphinx:
15 |     configuration: docs/source/conf.py
16 | 
17 | # Optionally build your docs in additional formats such as PDF
18 | formats: []
19 | 
20 | # Optionally set the version of Python and requirements required to build your docs
21 | python:
22 |     install:
23 |         - method: pip
24 |           path: .
25 |           extra_requirements:
26 |               - docs
27 | 
--------------------------------------------------------------------------------
/CHANGELOG.rst:
--------------------------------------------------------------------------------
  1 | 
  2 | Changelog
  3 | =========
  4 | 
  5 | All notable changes to this project will be documented in this file.
  6 | 
  7 | The format is based on `Keep a Changelog `_\ ,
  8 | and this project adheres to `Semantic Versioning `_.
  9 | 
 10 | [Unreleased]
 11 | ---------------------
 12 | 
 13 | Changed
 14 | ^^^^^^^
 15 | 
 16 | * Log ``debug`` instead of ``warning`` when type inference fails.
 17 | 
 18 | 
 19 | [0.10.0] - 2024-01-21
 20 | ---------------------
 21 | 
 22 | Added
 23 | ^^^^^
 24 | 
 25 | * Parser,Django: Added support for geo.{intersects, distance, length} functions.
 26 | * Python 3.11 coverage in test matrix.
 27 | * Parser: Added support for durations specified in years and months.
 28 | 
 29 | 
 30 | [0.9.1] - 2023-11-15
 31 | --------------------
 32 | 
 33 | Fixed
 34 | ^^^^^
 35 | 
 36 | * SQL: support function calls on right-hand side of `contains` node.
 37 | 
 38 | 
 39 | [0.9.0] - 2023-08-17
 40 | --------------------
 41 | 
 42 | Fixed
 43 | ^^^^^
 44 | 
 45 | * SQLAlchemy functions defined by ``odata_query``
 46 | 
 47 |     - no longer clash with other functions defined in ``sqlalchemy.func`` with
 48 |        the same name.
 49 |     - inherit cache to prevent SQLAlchemy performance warnings.
 50 | 
 51 | 
 52 | [0.8.1] - 2023-02-17
 53 | --------------------
 54 | 
 55 | Fixed
 56 | ^^^^^
 57 | 
 58 | * SQLAlchemy: Fix detection of pre-existing joins for legacy Query objects.
 59 | 
 60 | 
 61 | [0.8.0] - 2023-01-20
 62 | --------------------
 63 | 
 64 | Added
 65 | ^^^^^
 66 | 
 67 | * SQL: Support for the SQLite dialect.
 68 | 
 69 | Changed
 70 | ^^^^^^^
 71 | 
 72 | * SQL: Refactored SQL support
 73 | 
 74 |     - The ``AstToSqlVisitor`` now implements basic, standard SQL and functions
 75 |       as a base class for dialects.
 76 |     - The Athena dialect now lives in ``AstToAthenaSqlVisitor``.
 77 | 
 78 | 
 79 | [0.7.2] - 2022-12-19
 80 | --------------------
 81 | 
 82 | Fixed
 83 | ^^^^^
 84 | 
 85 | * SQLAlchemy; Return a 1.x ``Query`` if a 1.x ``Query`` is passed in the shorthand.
 86 | 
 87 | 
 88 | [0.7.1] - 2022-12-16
 89 | --------------------
 90 | 
 91 | Fixed
 92 | ^^^^^
 93 | 
 94 | * SQLAlchemy; Fix accidentally dropped support for 1.x style queries.
 95 | 
 96 | 
 97 | [0.7.0] - 2022-10-28
 98 | --------------------
 99 | 
100 | Added
101 | ^^^^^
102 | 
103 | * A new visitor that roundtrips OData AST back to an OData query string.
104 | * SQLAlchemy: Support for SQLAlchemy Core through a new visitor.
105 | 
106 | 
107 | Changed
108 | ^^^^^^^
109 | 
110 | * Some lower-level SQLAlchemy related modules have been reorganized to
111 |   facilitate SQLAlchemy Core support.
112 | 
113 | 
114 | [0.6.0] - 2022-10-28
115 | --------------------
116 | 
117 | Changed
118 | ^^^^^^^
119 | 
120 | * Django; Rework Django query transformer to use Django query nodes more.
121 |   With the release of Django 4, we can now use ``Lookup``, ``Function``, and other
122 |   Django query nodes in queries directly instead of relying on the keyworded
123 |   syntax. For older Django versions, we can still transform those nodes to the
124 |   keyworded syntax.
125 | * Deps; Upgraded Sphinx and the ReadTheDocs theme.
126 | 
127 | 
128 | Fixed
129 | ^^^^^
130 | 
131 | * Django; Support comparisons of boolean funcs to booleans
132 |   (e.g. ``contains(a, 'b') eq true``)
133 | 
134 | 
135 | [0.5.2] - 2022-03-14
136 | --------------------
137 | 
138 | Changed
139 | ^^^^^^^
140 | 
141 | * Deps; Upgraded pytest.
142 | 
143 | Fixed
144 | ^^^^^
145 | 
146 | * SQLAlchemy; Fixed datetime extract functions.
147 | 
148 | 
149 | [0.5.1] - 2022-02-28
150 | --------------------
151 | 
152 | Fixed
153 | ^^^^^
154 | 
155 | * QA; Remove ``type:ignore`` from ``grammar.py`` and fix resulting type issues.
156 | 
157 | 
158 | [0.5.0] - 2022-02-28
159 | --------------------
160 | 
161 | Added
162 | ^^^^^
163 | 
164 | * Parser: Rudimentary OData namespace support.
165 | * AST: Literal nodes now have a `py_val` getter that returns the closest Python
166 |   approximation to the OData value.
167 | * QA: Added full typing support.
168 | 
169 | Changed
170 | ^^^^^^^
171 | 
172 | * QA: Upgraded linting libraries.
173 | 
174 | 
175 | [0.4.2] - 2021-12-19
176 | --------------------
177 | 
178 | Added
179 | ^^^^^
180 | 
181 | * Docs: Include contribution guidelines and changelog in the main documentation.
182 | 
183 | Changed
184 | ^^^^^^^
185 | 
186 | * Docs: Use ReStructuredText instead of markdown where possible, for easier
187 |   interaction with Sphinx.
188 | 
189 | Removed
190 | ^^^^^^^
191 | 
192 | * Docs: Removed the ``Myst`` dependency as we're no longer mixing markdown into
193 |   our docs.
194 | * Dev: Removed the ``moto`` and ``Faker`` dependencies as they weren't used.
195 | 
196 | [0.4.1] - 2021-07-16
197 | --------------------
198 | 
199 | Added
200 | ^^^^^
201 | 
202 | * Added shorthands for the most common use cases: Applying an OData filter
203 |   straight to a Django QuerySet or SQLAlchemy query.
204 | 
205 | Fixed
206 | ^^^^^
207 | 
208 | * Cleared warnings produced in SLY by wrong regex flag placement.
209 | 
210 | [0.4.0] - 2021-05-28
211 | --------------------
212 | 
213 | Changed
214 | ^^^^^^^
215 | 
216 | * Raise a new ``InvalidFieldException`` if a field in a query doesn't exist.
217 | 
218 | Fixed
219 | ^^^^^
220 | 
221 | * Allow ``AliasRewriter`` to recurse into ``Attribute`` nodes, in order to replace
222 |   nodes in the ``Attribute``\ 's ownership chain.
223 | 
224 | [0.3.0] - 2021-05-17
225 | --------------------
226 | 
227 | Added
228 | ^^^^^
229 | 
230 | * Added ``NodeTransformers``\ , which are like ``NodeVisitors`` but replace visited
231 |   nodes with the returned value.
232 | * Initial API documentation.
233 | 
234 | Changed
235 | ^^^^^^^
236 | 
237 | * The AstTo{ORMQuery} visitors for SQLAlchemy and Django now have the same
238 |   interface.
239 | * AstToDjangoQVisitor now builds subqueries for ``any()/all()`` itself, instead
240 |   of relying on ``SubQueryToken``\ s and a seperate visitor.
241 | * Made all AST Nodes ``frozen`` (read-only), so they can be hashed.
242 | * Replaced ``field_mapping`` on the ORM visitors with a more general
243 |   ``AliasRewriter`` based on the new ``NodeTransformers``.
244 | * Refactored ``IdentifierStripper`` to use the new ``NodeTransformers``.
245 | 
246 | [0.2.0] - 2021-05-05
247 | --------------------
248 | 
249 | Added
250 | ^^^^^
251 | 
252 | * Transform OData queries to SQLAlchemy expressions with the new
253 |   AstToSqlAlchemyClauseVisitor.
254 | 
255 | Changed
256 | ^^^^^^^
257 | 
258 | * Don't write a debugfile for the parser by default.
259 | 
260 | [0.1.0] - 2021-03-12
261 | --------------------
262 | 
263 | Added
264 | ^^^^^
265 | 
266 | * Initial split to seperate package.
267 | 
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
  1 | How to contribute
  2 | =================
  3 | 
  4 | Getting started
  5 | ---------------
  6 | 
  7 | So you're interested in contributing to ``odata-query``? That's great! We're
  8 | excited to hear your ideas and experiences.
  9 | 
 10 | This file describes all the different ways in which you can contribute.
 11 | 
 12 | 
 13 | Reporting a bug
 14 | ------------------
 15 | 
 16 | Have you encountered a bug? Please let us know by reporting it.
 17 | 
 18 | Before doing so, take a look at the existing `Issues`_ to make sure the bug you
 19 | encountered hasn't already been reported by someone else. If so, we ask you to
 20 | reply to the existing Issue rather than creating a new one. Bugs with many
 21 | replies will obviously have a higher priority.
 22 | 
 23 | If the bug you encountered has not been reported yet, create a new Issue for it
 24 | and make sure to label it as a 'bug'. To allow us to help you as efficiently as
 25 | possible, always try to include the following:
 26 | 
 27 | - Which version of ``odata-query`` are you using?
 28 | - What went wrong?
 29 | - What did you expect to happen?
 30 | - Detailed steps to reproduce.
 31 | 
 32 | The maintainer of this repository monitors issues on a regular basis and will
 33 | respond to your bug report as soon as possible.
 34 | 
 35 | 
 36 | Requesting an enhancement
 37 | -------------------------
 38 | 
 39 | Do you have a great idea that could make ``odata-query`` even better?
 40 | Feel free to request an enhancement.
 41 | 
 42 | Before doing so, take a look at the existing `Issues`_ to make sure your idea
 43 | hasn't already been requested by someone else. If so, we ask you to reply or
 44 | give a thumbs-up to the existing Issue rather than creating a new one. Requests
 45 | with many replies will obviously have a higher priority.
 46 | 
 47 | If your idea has not been requested yet, create a new Issue for it and make sure
 48 | to label it as an 'enhancement'. Explain what your idea is in detail and how it
 49 | could improve ``odata-query``.
 50 | 
 51 | The maintainer of this repository monitors issues on a regular basis and will
 52 | respond to your request as soon as possible.
 53 | 
 54 | 
 55 | Contributing code
 56 | -----------------
 57 | 
 58 | Would you prefer to contribute directly by writing some code yourself? That's
 59 | great.
 60 | 
 61 | **If your contribution is minor**, such as fixing a bug or typo, we encourage
 62 | you to open a pull request right away.
 63 | 
 64 | **If your contribution is major**, such as a new feature or a breaking change,
 65 | start by opening an issue first. That way, other people can weigh in on the
 66 | discussion before you do any work.
 67 | 
 68 | The workflow for creating a pull request:
 69 | 
 70 | 1. Fork the repository.
 71 | 2. ``clone`` your forked repository.
 72 | 3. Create a new feature branch from the ``master`` branch.
 73 | 4. Make your contributions to the project's code. Please run the tests before
 74 |    committing, and keep the code style conform to the project's style guide
 75 |    (listed below).
 76 | 5. Add documentation where required. Please keep the style conform.
 77 | 6. Add your changes to the :ref:`changelog` in the "Unreleased"
 78 |    section. Include a link to your GitHub profile for some internet fame.
 79 | 7. ``commit`` your changes in logical chunks.
 80 | 8. ``push`` your branch to your fork on GitHub.
 81 | 9. Create a pull request from your feature branch to the original repository's
 82 |    ``master`` branch.
 83 | 
 84 | The maintainer of this repository monitors pull requests on a regular basis and
 85 | will respond as soon as possible. The smaller individual pull requests are, the
 86 | faster the maintainer will be able to respond.
 87 | 
 88 | 
 89 | Local development
 90 | ^^^^^^^^^^^^^^^^^
 91 | 
 92 | ``odata-query`` uses `poetry`_ to manage the package and its dependencies, and
 93 | `tox`_ to run tests and linting in dedicated environments.
 94 | 
 95 | To install the project for development, run:
 96 | 
 97 | .. code-block:: bash
 98 | 
 99 |     poetry install -E testing -E linting -E docs
100 | 
101 | 
102 | To run all tests and linting, run:
103 | 
104 | .. code-block:: bash
105 | 
106 |     tox
107 | 
108 | To ensure your files are configured correctly (linting will bug you if they're not),
109 | you can configure your IDE to use the included linting tools OR run them manually
110 | as follows:
111 | 
112 | 
113 | .. code-block:: bash
114 | 
115 |     poetry run black .  # Format files to adhere to the Black code style
116 |     poetry run isort .  # Ensure the imports are organised correctly
117 | 
118 | 
119 | Contact
120 | -------
121 | 
122 | Discussions about ``odata-query`` take place on this repository's `Issues`_ and
123 | `Pull Requests`_ sections. Anybody is welcome to join the conversation. Wherever
124 | possible, do not take these conversations to private channels, including
125 | contacting the maintainers directly. Keeping communication public means
126 | everybody can benefit and learn.
127 | 
128 | 
129 | .. _Issues: https://github.com/gorilla-co/odata-query/issues
130 | .. _Pull Requests: https://github.com/gorilla-co/odata-query/pulls
131 | .. _poetry: https://python-poetry.org/
132 | .. _tox: https://tox.readthedocs.io/en/latest/index.html
133 | 
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | The MIT License (MIT)
 2 | 
 3 | Copyright (c) 2021 Gorillini NV
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 
23 | 
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
  1 | OData-Query
  2 | ===========
  3 | 
  4 | .. image:: https://readthedocs.org/projects/odata-query/badge/?version=latest
  5 |     :alt: Documentation Status
  6 |     :target: https://odata-query.readthedocs.io/en/latest/?badge=latest
  7 | 
  8 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg
  9 |     :alt: Code style: black
 10 |     :target: https://github.com/psf/black
 11 | 
 12 | 
 13 | ``odata-query`` is a library that parses `OData v4`_ filter strings, and can
 14 | convert them to other forms such as `Django Queries`_, `SQLAlchemy Queries`_,
 15 | or just plain SQL.
 16 | 
 17 | 
 18 | Installation
 19 | ------------
 20 | 
 21 | ``odata-query`` is available on pypi, so can be installed with the package manager
 22 | of your choice:
 23 | 
 24 | .. code-block:: bash
 25 | 
 26 |     pip install odata-query
 27 |     # OR
 28 |     poetry add odata-query
 29 |     # OR
 30 |     pipenv install odata-query
 31 | 
 32 | 
 33 | The package defines the following optional ``extra``'s:
 34 | 
 35 | * ``django``: If you want to pin a compatible Django version.
 36 | * ``sqlalchemy``: If you want to pin a compatible SQLAlchemy version.
 37 | 
 38 | 
 39 | The following ``extra``'s relate to the development of this library:
 40 | 
 41 | - ``linting``: The linting and code style tools.
 42 | - ``testing``: Packages for running the tests.
 43 | - ``docs``: For building the project documentation.
 44 | 
 45 | 
 46 | You can install ``extra``'s by adding them between square brackets during
 47 | installation:
 48 | 
 49 | .. code-block:: bash
 50 | 
 51 |     pip install odata-query[sqlalchemy]
 52 | 
 53 | 
 54 | Quickstart
 55 | ----------
 56 | 
 57 | The most common use case is probably parsing an OData query string, and applying
 58 | it to a query your ORM understands. For this purpose there is an all-in-one function:
 59 | ``apply_odata_query``.
 60 | 
 61 | Example for Django:
 62 | 
 63 | .. code-block:: python
 64 | 
 65 |     from odata_query.django import apply_odata_query
 66 | 
 67 |     orm_query = MyModel.objects  # This can be a Manager or a QuerySet.
 68 |     odata_query = "name eq 'test'"  # This will usually come from a query string parameter.
 69 | 
 70 |     query = apply_odata_query(orm_query, odata_query)
 71 |     results = query.all()
 72 | 
 73 | 
 74 | Example for SQLAlchemy ORM:
 75 | 
 76 | .. code-block:: python
 77 | 
 78 |     from odata_query.sqlalchemy import apply_odata_query
 79 | 
 80 |     orm_query = select(MyModel)  # This is any form of Query or Selectable.
 81 |     odata_query = "name eq 'test'"  # This will usually come from a query string parameter.
 82 | 
 83 |     query = apply_odata_query(orm_query, odata_query)
 84 |     results = session.execute(query).scalars().all()
 85 | 
 86 | Example for SQLAlchemy Core:
 87 | 
 88 | .. code-block:: python
 89 | 
 90 |     from odata_query.sqlalchemy import apply_odata_core
 91 | 
 92 |     core_query = select(MyTable)  # This is any form of Query or Selectable.
 93 |     odata_query = "name eq 'test'"  # This will usually come from a query string parameter.
 94 | 
 95 |     query = apply_odata_core(core_query, odata_query)
 96 |     results = session.execute(query).scalars().all()
 97 | 
 98 | .. splitinclude-1
 99 | 
100 | Advanced Usage
101 | --------------
102 | 
103 | Not all use cases are as simple as that. Luckily, ``odata-query`` is modular
104 | and extendable. See the `documentation`_ for advanced usage or extending the
105 | library for other cases.
106 | 
107 | .. splitinclude-2
108 | 
109 | Contact
110 | -------
111 | 
112 | Got any questions or ideas? We'd love to hear from you. Check out our
113 | `contributing guidelines`_ for ways to offer feedback and
114 | contribute.
115 | 
116 | 
117 | License
118 | -------
119 | 
120 | Copyright © `Gorillini NV`_.
121 | All rights reserved.
122 | 
123 | Licensed under the MIT License.
124 | 
125 | 
126 | .. _odata v4: https://www.odata.org/
127 | .. _django queries: https://docs.djangoproject.com/en/3.2/topics/db/queries/
128 | .. _sqlalchemy queries: https://docs.sqlalchemy.org/en/14/orm/loading_objects.html
129 | .. _documentation: https://odata-query.readthedocs.io/en/latest
130 | .. _Gorillini NV: https://gorilla.co/
131 | .. _contributing guidelines: ./CONTRIBUTING.rst
132 | 
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
 1 | # Minimal makefile for Sphinx documentation
 2 | #
 3 | 
 4 | # You can set these variables from the command line, and also
 5 | # from the environment for the first two.
 6 | SPHINXOPTS    ?=
 7 | SPHINXBUILD   ?= sphinx-build
 8 | SOURCEDIR     = source
 9 | BUILDDIR      = build
10 | 
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | 	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 | 
15 | .PHONY: help Makefile
16 | 
17 | apidoc:
18 | 	sphinx-apidoc -f -o source/api -e ../odata_query
19 | 
20 | # Catch-all target: route all unknown targets to Sphinx using the new
21 | # "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
22 | %: Makefile
23 | 	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
24 | 
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
 1 | @ECHO OFF
 2 | 
 3 | pushd %~dp0
 4 | 
 5 | REM Command file for Sphinx documentation
 6 | 
 7 | if "%SPHINXBUILD%" == "" (
 8 | 	set SPHINXBUILD=sphinx-build
 9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 | 
13 | if "%1" == "" goto help
14 | 
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | 	echo.
18 | 	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | 	echo.installed, then set the SPHINXBUILD environment variable to point
20 | 	echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | 	echo.may add the Sphinx directory to PATH.
22 | 	echo.
23 | 	echo.If you don't have Sphinx installed, grab it from
24 | 	echo.http://sphinx-doc.org/
25 | 	exit /b 1
26 | )
27 | 
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 | 
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 | 
34 | :end
35 | popd
36 | 
--------------------------------------------------------------------------------
/docs/source/api/index.rst:
--------------------------------------------------------------------------------
1 | API Docs
2 | ========
3 | 
4 | .. toctree::
5 |    :maxdepth: 4
6 | 
7 |    modules
8 | 
9 | 
--------------------------------------------------------------------------------
/docs/source/api/modules.rst:
--------------------------------------------------------------------------------
1 | odata_query
2 | ===========
3 | 
4 | .. toctree::
5 |    :maxdepth: 4
6 | 
7 |    odata_query
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.ast.rst:
--------------------------------------------------------------------------------
1 | odata\_query.ast module
2 | =======================
3 | 
4 | .. automodule:: odata_query.ast
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.django.django_q.rst:
--------------------------------------------------------------------------------
1 | odata\_query.django.django\_q module
2 | ====================================
3 | 
4 | .. automodule:: odata_query.django.django_q
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.django.django_q_ext.rst:
--------------------------------------------------------------------------------
1 | odata\_query.django.django\_q\_ext module
2 | =========================================
3 | 
4 | .. automodule:: odata_query.django.django_q_ext
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.django.rst:
--------------------------------------------------------------------------------
 1 | odata\_query.django package
 2 | ===========================
 3 | 
 4 | Submodules
 5 | ----------
 6 | 
 7 | .. toctree::
 8 |    :maxdepth: 4
 9 | 
10 |    odata_query.django.django_q
11 |    odata_query.django.django_q_ext
12 |    odata_query.django.shorthand
13 |    odata_query.django.utils
14 | 
15 | Module contents
16 | ---------------
17 | 
18 | .. automodule:: odata_query.django
19 |    :members:
20 |    :undoc-members:
21 |    :show-inheritance:
22 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.django.shorthand.rst:
--------------------------------------------------------------------------------
1 | odata\_query.django.shorthand module
2 | ====================================
3 | 
4 | .. automodule:: odata_query.django.shorthand
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.django.utils.rst:
--------------------------------------------------------------------------------
1 | odata\_query.django.utils module
2 | ================================
3 | 
4 | .. automodule:: odata_query.django.utils
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.exceptions.rst:
--------------------------------------------------------------------------------
1 | odata\_query.exceptions module
2 | ==============================
3 | 
4 | .. automodule:: odata_query.exceptions
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.grammar.rst:
--------------------------------------------------------------------------------
1 | odata\_query.grammar module
2 | ===========================
3 | 
4 | .. automodule:: odata_query.grammar
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.rewrite.rst:
--------------------------------------------------------------------------------
1 | odata\_query.rewrite module
2 | ===========================
3 | 
4 | .. automodule:: odata_query.rewrite
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.roundtrip.rst:
--------------------------------------------------------------------------------
1 | odata\_query.roundtrip module
2 | =============================
3 | 
4 | .. automodule:: odata_query.roundtrip
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.rst:
--------------------------------------------------------------------------------
 1 | odata\_query package
 2 | ====================
 3 | 
 4 | Subpackages
 5 | -----------
 6 | 
 7 | .. toctree::
 8 |    :maxdepth: 4
 9 | 
10 |    odata_query.django
11 |    odata_query.sql
12 |    odata_query.sqlalchemy
13 | 
14 | Submodules
15 | ----------
16 | 
17 | .. toctree::
18 |    :maxdepth: 4
19 | 
20 |    odata_query.ast
21 |    odata_query.exceptions
22 |    odata_query.grammar
23 |    odata_query.rewrite
24 |    odata_query.roundtrip
25 |    odata_query.typing
26 |    odata_query.utils
27 |    odata_query.visitor
28 | 
29 | Module contents
30 | ---------------
31 | 
32 | .. automodule:: odata_query
33 |    :members:
34 |    :undoc-members:
35 |    :show-inheritance:
36 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sql.athena.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sql.athena module
2 | ==============================
3 | 
4 | .. automodule:: odata_query.sql.athena
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sql.base.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sql.base module
2 | ============================
3 | 
4 | .. automodule:: odata_query.sql.base
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sql.rst:
--------------------------------------------------------------------------------
 1 | odata\_query.sql package
 2 | ========================
 3 | 
 4 | Submodules
 5 | ----------
 6 | 
 7 | .. toctree::
 8 |    :maxdepth: 4
 9 | 
10 |    odata_query.sql.athena
11 |    odata_query.sql.base
12 |    odata_query.sql.sqlite
13 | 
14 | Module contents
15 | ---------------
16 | 
17 | .. automodule:: odata_query.sql
18 |    :members:
19 |    :undoc-members:
20 |    :show-inheritance:
21 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sql.sqlite.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sql.sqlite module
2 | ==============================
3 | 
4 | .. automodule:: odata_query.sql.sqlite
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sqlalchemy.common.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sqlalchemy.common module
2 | =====================================
3 | 
4 | .. automodule:: odata_query.sqlalchemy.common
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sqlalchemy.core.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sqlalchemy.core module
2 | ===================================
3 | 
4 | .. automodule:: odata_query.sqlalchemy.core
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sqlalchemy.functions_ext.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sqlalchemy.functions\_ext module
2 | =============================================
3 | 
4 | .. automodule:: odata_query.sqlalchemy.functions_ext
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sqlalchemy.orm.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sqlalchemy.orm module
2 | ==================================
3 | 
4 | .. automodule:: odata_query.sqlalchemy.orm
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sqlalchemy.rst:
--------------------------------------------------------------------------------
 1 | odata\_query.sqlalchemy package
 2 | ===============================
 3 | 
 4 | Submodules
 5 | ----------
 6 | 
 7 | .. toctree::
 8 |    :maxdepth: 4
 9 | 
10 |    odata_query.sqlalchemy.common
11 |    odata_query.sqlalchemy.core
12 |    odata_query.sqlalchemy.functions_ext
13 |    odata_query.sqlalchemy.orm
14 |    odata_query.sqlalchemy.shorthand
15 | 
16 | Module contents
17 | ---------------
18 | 
19 | .. automodule:: odata_query.sqlalchemy
20 |    :members:
21 |    :undoc-members:
22 |    :show-inheritance:
23 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.sqlalchemy.shorthand.rst:
--------------------------------------------------------------------------------
1 | odata\_query.sqlalchemy.shorthand module
2 | ========================================
3 | 
4 | .. automodule:: odata_query.sqlalchemy.shorthand
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.typing.rst:
--------------------------------------------------------------------------------
1 | odata\_query.typing module
2 | ==========================
3 | 
4 | .. automodule:: odata_query.typing
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.utils.rst:
--------------------------------------------------------------------------------
1 | odata\_query.utils module
2 | =========================
3 | 
4 | .. automodule:: odata_query.utils
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/api/odata_query.visitor.rst:
--------------------------------------------------------------------------------
1 | odata\_query.visitor module
2 | ===========================
3 | 
4 | .. automodule:: odata_query.visitor
5 |    :members:
6 |    :undoc-members:
7 |    :show-inheritance:
8 | 
--------------------------------------------------------------------------------
/docs/source/changelog.rst:
--------------------------------------------------------------------------------
1 | .. _changelog:
2 | 
3 | .. include:: ../../CHANGELOG.rst
4 | 
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
 1 | # Configuration file for the Sphinx documentation builder.
 2 | #
 3 | # This file only contains a selection of the most common options. For a full
 4 | # list see the documentation:
 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
 6 | 
 7 | # -- Path setup --------------------------------------------------------------
 8 | 
 9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 | 
17 | 
18 | # -- Project information -----------------------------------------------------
19 | 
20 | project = "OData Query"
21 | copyright = "2021, Gorillini NV"
22 | author = "Oliver Hofkens"
23 | 
24 | # The full version, including alpha/beta/rc tags
25 | release = "0.10.0"
26 | 
27 | 
28 | # -- General configuration ---------------------------------------------------
29 | 
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 |     "sphinx.ext.autodoc",
35 |     "sphinx.ext.napoleon",
36 |     "sphinx.ext.doctest",
37 |     "sphinx.ext.graphviz",
38 |     "sphinx.ext.viewcode",
39 | ]
40 | 
41 | # Add any paths that contain templates here, relative to this directory.
42 | templates_path = ["_templates"]
43 | 
44 | # List of patterns, relative to source directory, that match files and
45 | # directories to ignore when looking for source files.
46 | # This pattern also affects html_static_path and html_extra_path.
47 | exclude_patterns = []
48 | 
49 | 
50 | # -- Options for HTML output -------------------------------------------------
51 | 
52 | # The theme to use for HTML and HTML Help pages.  See the documentation for
53 | # a list of builtin themes.
54 | #
55 | html_theme = "sphinx_rtd_theme"
56 | 
57 | # Add any paths that contain custom static files (such as style sheets) here,
58 | # relative to this directory. They are copied after the builtin static files,
59 | # so a file named "default.css" will overwrite the builtin "default.css".
60 | html_static_path = ["_static"]
61 | 
--------------------------------------------------------------------------------
/docs/source/contributing.rst:
--------------------------------------------------------------------------------
1 | .. _contributing:
2 | 
3 | .. include:: ../../CONTRIBUTING.rst
4 | 
--------------------------------------------------------------------------------
/docs/source/deviations-and-extensions.rst:
--------------------------------------------------------------------------------
 1 | Deviations and Extensions to the OData Spec
 2 | ===========================================
 3 | 
 4 | There are some minor cases where this library deviates from the official `spec`_, 
 5 | e.g. when the spec is ambiguous, or to add extra non-breaking functionality.
 6 | 
 7 | 
 8 | Lists with a single item need a trailing comma
 9 | ----------------------------------------------
10 | 
11 | This library needs a trailing comma to represent a list with a single item, 
12 | whereas the spec does not describe this trailing comma (e.g.
13 | ``(item,)`` instead of ``(item)``). 
14 | 
15 | The reason is that the spec doesn't seem to differentiate between a list of a
16 | single item and any other parenthesized expression. This can lead to parsing 
17 | conflicts. Consider the following expression:
18 | 
19 | .. code-block::
20 | 	
21 | 	concat(('a'), ('b'))
22 | 
23 | 
24 | The grammar in the official spec could parse this as either:
25 | 
26 | * Concatenate 2 lists, the first one containing the single string 'a', the second 
27 |   one containing the single string 'b'. The result would be ``('a', 'b')``.
28 | * Concatenate 2 expressions in parentheses. The first one is the string literal
29 |   'a', the second one the string literal 'b'. The result would be ``'ab'``
30 | 
31 | 
32 | To make this difference explicit, this library requires a trailing comma to 
33 | signify a list. The same behavior is present in Python:
34 | 
35 | >>> ("a")
36 | 'a'
37 | >>> ("a",)
38 | ('a',)
39 | 
40 | .. >>> from odata_query.grammar import ODataLexer, ODataParser
41 | .. >>> lexer = ODataLexer()
42 | .. >>> parser = ODataParser()
43 | >>> parser.parse(lexer.tokenize("('a')"))
44 | String(val='a')
45 | >>> parser.parse(lexer.tokenize("('a',)"))
46 | List(val=[String(val='a')])
47 | 
48 | 
49 | Durations expressed in years and months
50 | ---------------------------------------
51 | 
52 | The official `spec`_ defines a duration with a number of days, hours, minutes,
53 | and seconds. E.g. ``duration'P1DT2H'`` is a duration of 1 **D**\ ay and 2 **H**\ ours.
54 | 
55 | This library adds the ability to express durations containing **Y**\ ears and 
56 | **M**\ onths. E.g. ``duration'P1Y2M3DT4H'`` would express a duration of 1 **Y**\ ear,
57 | 2 **M**\ onths, 3 **D**\ ays, and 4 **H**\ ours.
58 | 
59 | It's important to note that the final internal value is still expressed in days,
60 | based on average durations.
61 | Thus, a month simply represents 30.44 days, while a year represents 365.25 days.
62 | 
63 | 
64 | .. _spec: https://www.odata.org/documentation/
65 | 
--------------------------------------------------------------------------------
/docs/source/django.rst:
--------------------------------------------------------------------------------
 1 | Using OData with Django
 2 | =======================
 3 | 
 4 | Basic Usage
 5 | -----------
 6 | 
 7 | The easiest way to add OData filtering to a Django QuerySet is with the shorthand:
 8 | 
 9 | .. code-block:: python
10 | 
11 |     from odata_query.django import apply_odata_query
12 | 
13 |     orm_query = MyModel.objects  # This can be a Manager or a QuerySet.
14 |     odata_query = "name eq 'test'"  # This will usually come from a query string parameter.
15 | 
16 |     query = apply_odata_query(orm_query, odata_query)
17 |     results = query.all()
18 | 
19 | 
20 | Advanced Usage
21 | --------------
22 | 
23 | If you need some more flexibility or advanced features, the implementation of the
24 | shorthand is a nice starting point: :py:mod:`odata_query.django.shorthand`
25 | 
26 | Let's break it down real quick:
27 | 
28 | 
29 | Parsing the OData Query
30 | ^^^^^^^^^^^^^^^^^^^^^^^
31 | 
32 | .. include:: snippets/parsing.rst
33 | 
34 | 
35 | Optional: Modifying the AST
36 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^
37 | 
38 | .. include:: snippets/modifying.rst
39 | 
40 | 
41 | Building a Query Filter
42 | ^^^^^^^^^^^^^^^^^^^^^^^
43 | 
44 | To get from an :term:`AST` to something Django can use, you'll need to use the
45 | :py:class:`odata_query.django.django_q.AstToDjangoQVisitor`. It needs to know
46 | about the 'root model' of your query in order to build relationships if necessary.
47 | In most cases, this will be ``queryset.model``.
48 | 
49 | 
50 | .. code-block:: python
51 | 
52 |     from odata_query.django.django_q import AstToDjangoQVisitor
53 | 
54 |     visitor = AstToDjangoQVisitor(MyModel)
55 |     query_filter = visitor.visit(ast)
56 | 
57 | 
58 | Optional: QuerySet Annotations
59 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
60 | 
61 | For some queries using advanced expressions, the ``AstToDjangoQVisitor`` will
62 | generate `queryset annotations`_. For the query to work, these will need to be
63 | applied:
64 | 
65 | .. code-block:: python
66 | 
67 |     if visitor.queryset_annotations:
68 |         queryset = queryset.annotate(**visitor.queryset_annotations)
69 | 
70 | 
71 | Running the query
72 | ^^^^^^^^^^^^^^^^^
73 | 
74 | Finally, we're ready to run the query:
75 | 
76 | .. code-block:: python
77 | 
78 |     results = queryset.filter(query_filter).all()
79 | 
80 | 
81 | .. _queryset annotations: https://docs.djangoproject.com/en/3.2/ref/models/querysets/#annotate
82 | 
--------------------------------------------------------------------------------
/docs/source/glossary.rst:
--------------------------------------------------------------------------------
 1 | Glossary
 2 | ========
 3 | 
 4 | .. glossary::
 5 | 
 6 |    AST
 7 |       Abstract Syntax Tree. A tree data structure representing the syntactic
 8 |       structure of some text. For example, in OData we might write
 9 |       ``name eq 'Bobby'``, which could be represented as the tree:
10 |       ``Compare(Eq, Identifier('name'), String('Bobby'))``. For more,
11 |       see `Wikipedia `_
12 | 
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
 1 | .. include:: ../../README.rst
 2 |    :end-before: .. splitinclude-1
 3 | 
 4 | Contents
 5 | ========
 6 | 
 7 | .. toctree::
 8 |    :maxdepth: 2
 9 | 
10 |    Getting Started 
11 |    django
12 |    sqlalchemy
13 |    sql
14 |    parsing-odata
15 |    working-with-ast
16 |    deviations-and-extensions
17 |    glossary
18 | 
19 |    contributing
20 |    changelog
21 | 
22 |    api/index
23 | 
24 | .. include:: ../../README.rst
25 |    :start-after: .. splitinclude-2
26 | 
27 | 
28 | Indices and tables
29 | ==================
30 | 
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 | 
--------------------------------------------------------------------------------
/docs/source/parsing-odata.rst:
--------------------------------------------------------------------------------
 1 | .. _ref-parsing-odata:
 2 | 
 3 | Parsing OData
 4 | =============
 5 | 
 6 | ``odata-query`` includes a parser that tries to cover as much as possible of the `OData v4 filter spec`_.
 7 | This parser is built with `SLY`_ and consists of a :ref:`ref-lexer` and a :ref:`ref-parser`.
 8 | 
 9 | 
10 | .. _ref-lexer:
11 | 
12 | Lexer
13 | -----
14 | 
15 | The lexer's job is to take an OData query string and break it up into a series of
16 | meaningful tokens, where each token has a type and a value. These are like the
17 | words of a sentence. At this stage we're not looking at the structure of the
18 | entire sentence yet, just if the individual words make sense.
19 | 
20 | The OData lexer this library uses is defined in :py:class:`odata_query.grammar.ODataLexer`. Here's
21 | an example of what it does:
22 | 
23 | .. doctest::
24 | 
25 |    >>> from odata_query.grammar import ODataLexer
26 |    >>> lexer = ODataLexer()
27 |    >>> list(lexer.tokenize("name eq 'Hello World'"))
28 |    [Token(type='ODATA_IDENTIFIER', value=Identifier(name='name'), lineno=1, index=0), Token(type='EQ', value=Eq(), lineno=1, index=4), Token(type='STRING', value=String(val='Hello World'), lineno=1, index=8)]
29 | 
30 | 
31 | 
32 | .. _ref-parser:
33 | 
34 | Parser
35 | ------
36 | 
37 | The parser's job is to take the tokens as produced by the :ref:`ref-lexer`
38 | and find the language structure in them, according to the grammar rules defined
39 | by the OData standard. In our case, the parser tries to build an :term:`AST` that
40 | represents the entire query. This :term:`AST` is a tree structure that consists
41 | of the nodes found in :py:mod:`odata_query.ast`.
42 | 
43 | As an example, the following OData query::
44 | 
45 |     name eq 'Hello World'
46 | 
47 | can be represented in the following :term:`AST`:
48 | 
49 | .. graphviz::
50 | 
51 |    digraph {
52 |        "Compare()" -> "Identifier('name')" [label = "left"];
53 |        "Compare()" -> "Eq()" [label = "comparator"];
54 |        "Compare()" -> "String('Hello World')" [label = "right"];
55 |    }
56 | 
57 | 
58 | The OData parser this library uses is defined in :py:class:`odata_query.grammar.ODataParser`.
59 | Here's an example of what it does:
60 | 
61 | .. doctest::
62 | 
63 |    >>> from odata_query.grammar import ODataParser
64 |    >>> parser = ODataParser()
65 |    >>> parser.parse(lexer.tokenize("name eq 'Hello World'"))
66 |    Compare(comparator=Eq(), left=Identifier(name='name'), right=String(val='Hello World'))
67 | 
68 | 
69 | 
70 | .. _OData v4 filter spec: https://docs.oasis-open.org/odata/odata/v4.01/cs01/abnf/odata-abnf-construction-rules.txt
71 | .. _SLY: https://github.com/dabeaz/sly
72 | 
73 | 
--------------------------------------------------------------------------------
/docs/source/snippets/modifying.rst:
--------------------------------------------------------------------------------
 1 | There are cases where you'll want to modify the query before executing it. That's
 2 | what :ref:`ref-node-transformer`'s are for!
 3 | 
 4 | One example might be that certain fields are exposed to end users under a different
 5 | name than the one in the database. In this case, the
 6 | :py:class:`odata_query.rewrite.AliasRewriter` will come in handy. Just pass it a
 7 | mapping of aliases to their full name and let it do its job:
 8 | 
 9 | .. code-block:: python
10 | 
11 |     from odata_query.rewrite import AliasRewriter
12 | 
13 |     rewriter = AliasRewriter({
14 |         "name": "author/name",
15 |     })
16 |     new_ast = rewriter.visit(ast)
17 | 
18 | 
--------------------------------------------------------------------------------
/docs/source/snippets/parsing.rst:
--------------------------------------------------------------------------------
 1 | To get from a string representing an OData query to a usable representation,
 2 | we need to tokenize and parse it as follows:
 3 | 
 4 | .. code-block:: python
 5 | 
 6 |     from odata_query.grammar import ODataParser, ODataLexer
 7 | 
 8 |     lexer = ODataLexer()
 9 |     parser = ODataParser()
10 |     ast = parser.parse(lexer.tokenize(my_odata_query))
11 | 
12 | This process is described in more detail in :ref:`ref-parsing-odata`.
13 | 
--------------------------------------------------------------------------------
/docs/source/sql.rst:
--------------------------------------------------------------------------------
 1 | Using OData with raw SQL
 2 | ========================
 3 | 
 4 | Using a raw SQL interface is slightly more involved and less powerful, but
 5 | offers a lot of flexibility in return.
 6 | 
 7 | 
 8 | Parsing the OData Query
 9 | ^^^^^^^^^^^^^^^^^^^^^^^
10 | 
11 | .. include:: snippets/parsing.rst
12 | 
13 | 
14 | Optional: Modifying the AST
15 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^
16 | 
17 | .. include:: snippets/modifying.rst
18 | 
19 | 
20 | Building a Query Filter
21 | ^^^^^^^^^^^^^^^^^^^^^^^
22 | 
23 | To get from an :term:`AST` to a SQL clause, you'll need to use
24 | :py:class:`odata_query.sql.base.AstToSqlVisitor` (standard SQL) or one of its
25 | dialect-specific subclasses, such as
26 | :py:class:`odata_query.sql.sqlite.AstToSqliteSqlVisitor` (SQLite).
27 | 
28 | .. code-block:: python
29 | 
30 |     from odata_query.sql import AstToSqlVisitor
31 | 
32 |     visitor = AstToSqlVisitor()
33 |     where_clause = visitor.visit(ast)
34 | 
35 | 
36 | Running the query
37 | ^^^^^^^^^^^^^^^^^
38 | 
39 | Finally, we're ready to run the query:
40 | 
41 | .. code-block:: python
42 | 
43 |     query = "SELECT * FROM my_table WHERE " + where_clause
44 |     results = conn.execute(query).fetchall()
45 | 
46 | 
47 | Supported dialects
48 | ^^^^^^^^^^^^^^^^^^
49 | 
50 | .. autoclass:: odata_query.sql.base.AstToSqlVisitor
51 | .. autoclass:: odata_query.sql.sqlite.AstToSqliteSqlVisitor
52 | .. autoclass:: odata_query.sql.athena.AstToAthenaSqlVisitor
53 | 
--------------------------------------------------------------------------------
/docs/source/sqlalchemy.rst:
--------------------------------------------------------------------------------
  1 | Using OData with SQLAlchemy
  2 | ===========================
  3 | 
  4 | Basic Usage
  5 | -----------
  6 | 
  7 | The easiest way to add OData filtering to a SQLAlchemy query is with the shorthand:
  8 | 
  9 | 
 10 | SQLAlchemy ORM
 11 | ^^^^^^^^^^^^^^
 12 | 
 13 | .. code-block:: python
 14 | 
 15 |     from odata_query.sqlalchemy import apply_odata_query
 16 | 
 17 |     orm_query = select(MyModel)  # This is any form of Query or Selectable.
 18 |     odata_query = "name eq 'test'"  # This will usually come from a query string parameter.
 19 | 
 20 |     query = apply_odata_query(orm_query, odata_query)
 21 |     results = session.execute(query).scalars().all()
 22 | 
 23 | 
 24 | SQLAlchemy Core
 25 | ^^^^^^^^^^^^^^^
 26 | 
 27 | .. attention::
 28 | 
 29 |    Basic support for SQLAlchemy Core is new since version 0.7.0.
 30 |    It currently does not support relationship traversal or ``any``/``all``
 31 |    yet. Those operations will raise a ``NotImplementedException``.
 32 | 
 33 | 
 34 | .. code-block:: python
 35 | 
 36 |     from odata_query.sqlalchemy import apply_odata_core
 37 | 
 38 |     core_query = select(MyTable)  # This is any form of Query or Selectable.
 39 |     odata_query = "name eq 'test'"  # This will usually come from a query string parameter.
 40 | 
 41 |     query = apply_odata_query(core_query, odata_query)
 42 |     results = session.execute(query).scalars().all()
 43 | 
 44 | 
 45 | Advanced Usage
 46 | --------------
 47 | 
 48 | If you need some more flexibility or advanced features, the implementation of the
 49 | shorthand is a nice starting point: :py:mod:`odata_query.sqlalchemy.shorthand`
 50 | 
 51 | Let's break it down real quick:
 52 | 
 53 | 
 54 | Parsing the OData Query
 55 | ^^^^^^^^^^^^^^^^^^^^^^^
 56 | 
 57 | .. include:: snippets/parsing.rst
 58 | 
 59 | 
 60 | Optional: Modifying the AST
 61 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 62 | 
 63 | .. include:: snippets/modifying.rst
 64 | 
 65 | 
 66 | Building a Query Filter
 67 | ^^^^^^^^^^^^^^^^^^^^^^^
 68 | 
 69 | To get from an :term:`AST` to something SQLAlchemy can use, you'll need to use the
 70 | :py:class:`odata_query.sqlalchemy.orm.AstToSqlAlchemyOrmVisitor` (ORM mode) or
 71 | the :py:class:`odata_query.sqlalchemy.core.AstToSqlAlchemyCoreVisitor` (Core
 72 | mode).
 73 | It needs to know about the 'root model' or table of your query in order to see
 74 | which fields exists and how objects are related.
 75 | 
 76 | 
 77 | SQLAlchemy ORM
 78 | """"""""""""""
 79 | 
 80 | .. code-block:: python
 81 | 
 82 |     from odata_query.sqlalchemy.orm import AstToSqlAlchemyOrmVisitor
 83 | 
 84 |     visitor = AstToSqlAlchemyOrmVisitor(MyModel)
 85 |     query_filter = visitor.visit(ast)
 86 | 
 87 | SQLAlchemy Core
 88 | """""""""""""""
 89 | 
 90 | .. code-block:: python
 91 | 
 92 |     from odata_query.sqlalchemy.core import AstToSqlAlchemyCoreVisitor
 93 | 
 94 |     visitor = AstToSqlAlchemyCoreVisitor(MyTable)
 95 |     query_filter = visitor.visit(ast)
 96 | 
 97 | 
 98 | Optional: Joins
 99 | ^^^^^^^^^^^^^^^
100 | 
101 | .. attention::
102 | 
103 |    Relationship traversal and automatic joins are not yet supported for
104 |    SQLAlchemy Core mode.
105 | 
106 | 
107 | If your query spans relationships, the ``AstToSqlAlchemyClauseVisitor`` will
108 | generate join statements. For the query to work, these will need to be
109 | applied explicitly:
110 | 
111 | .. code-block:: python
112 | 
113 |     for j in visitor.join_relationships:
114 |         query = query.join(j)
115 | 
116 | 
117 | Running the query
118 | ^^^^^^^^^^^^^^^^^
119 | 
120 | Finally, we're ready to run the query:
121 | 
122 | .. code-block:: python
123 | 
124 |     query = query.where(query_filter)
125 |     results = s.execute(query).scalars().all()
126 | 
--------------------------------------------------------------------------------
/docs/source/working-with-ast.rst:
--------------------------------------------------------------------------------
 1 | Working with the AST
 2 | ====================
 3 | 
 4 | Now that our :ref:`OData query has been parsed ` to an :term:`AST`,
 5 | how do we work with it?  `The Visitor Pattern`_ is a popular way to walk tree
 6 | structures such as :term:`AST`'s and modify or transform them to another
 7 | representation. ``odata-query`` contains the :ref:`ref-node-visitor` and
 8 | :ref:`ref-node-transformer` base classes that implement this pattern, as well
 9 | as some concrete implementations.
10 | 
11 | 
12 | .. _ref-node-visitor:
13 | 
14 | NodeVisitor
15 | -----------
16 | 
17 | A :py:class:`odata_query.visitor.NodeVisitor` is a class that walks an :term:`AST`
18 | (depth-first by default) and calls a ``visit_{node_type}`` method on each
19 | :py:class:`odata_query.ast._Node` it encounters. These methods can return whatever
20 | they want, making this a very flexible pattern! If no ``visit_`` method is
21 | implemented for the type of the node the visitor will continue with the node's
22 | children if it has any, so you only need to implement what you explicitly need.
23 | A simple :py:class:`odata_query.visitor.NodeVisitor` that counts comparison
24 | expressions for example, might look like this:
25 | 
26 | .. code-block:: python
27 | 
28 |     class ComparisonCounter(NodeVisitor):
29 |         def visit_Comparison(self, node: ast.Comparison) -> int:
30 |             count_lhs = self.visit(node.left) or 0
31 |             count_rhs = self.visit(node.right) or 0
32 |             return 1 + count_lhs + count_rhs
33 | 
34 | 
35 |     count = ComparisonCounter().visit(my_ast)
36 | 
37 | 
38 | This isn't the most useful implementation... For some more realistic examples,
39 | take a look at the :py:class:`odata_query.django.django_q.AstToDjangoQVisitor` or
40 | the :py:class:`odata_query.sqlalchemy.orm.AstToSqlAlchemyOrmVisitor`
41 | implementations. They transform an :term:`AST` to Django and SQLAlchemy ORM queries
42 | respectively.
43 | 
44 | 
45 | .. _ref-node-transformer:
46 | 
47 | NodeTransformer
48 | ---------------
49 | 
50 | A :py:class:`odata_query.visitor.NodeTransformer` is very similar to a
51 | :ref:`ref-node-visitor`, with one difference: The ``visit_`` methods should return
52 | an :py:class:`odata_query.ast._Node`, which will replace the node that is being
53 | visited. This allows ``NodeTransformer``'s to modify the :term:`AST` while it's
54 | being traversed. For example, the following
55 | :py:class:`odata_query.visitor.NodeTransformer` would invert all 'less-than'
56 | comparisons to 'greater-than' and vice-versa:
57 | 
58 | 
59 | .. code-block:: python
60 | 
61 |     class ComparisonInverter(NodeTransformer):
62 |         def visit_Comparison(self, node: ast.Comparison) -> ast.Comparison:
63 |             if node.comparator == ast.Lt():
64 |                 new_comparator = ast.Gt()
65 |             elif node.comparator == ast.Gt():
66 |                 new_comparator = ast.Lt()
67 |             else:
68 |                 new_comparator = node.comparator
69 | 
70 |             return ast.Comparison(new_comparator, node.left, node.right)
71 | 
72 | 
73 |     inverted = ComparisonInverter().visit(my_ast)
74 | 
75 | 
76 | An interesting concrete implementation in ``odata-query`` is the
77 | :py:class:`odata_query.rewrite.AliasRewriter`. This transformer looks for
78 | aliases in identifiers and attributes, and replaces them with their full names.
79 | 
80 | 
81 | 
82 | Included Visitors
83 | -----------------
84 | 
85 | 
86 | .. autoclass:: odata_query.django.django_q.AstToDjangoQVisitor
87 | .. autoclass:: odata_query.sqlalchemy.orm.AstToSqlAlchemyOrmVisitor
88 | .. autoclass:: odata_query.sqlalchemy.core.AstToSqlAlchemyCoreVisitor
89 | .. autoclass:: odata_query.rewrite.AliasRewriter
90 | .. autoclass:: odata_query.roundtrip.AstToODataVisitor
91 | 
92 | 
93 | .. _The Visitor Pattern: https://en.wikipedia.org/wiki/Visitor_pattern
94 | 
--------------------------------------------------------------------------------
/odata_query/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.10.0"
2 | 
--------------------------------------------------------------------------------
/odata_query/ast.py:
--------------------------------------------------------------------------------
  1 | import datetime as dt
  2 | import re
  3 | from dataclasses import dataclass, field
  4 | from typing import List as ListType, Optional, Tuple
  5 | from uuid import UUID
  6 | 
  7 | from dateutil.parser import isoparse
  8 | 
  9 | DURATION_PATTERN = re.compile(
 10 |     r"([+-])?P(\d+Y)?(\d+M)?(\d+D)?(?:T(\d+H)?(\d+M)?(\d+(?:\.\d+)?S)?)?"
 11 | )
 12 | 
 13 | 
 14 | @dataclass(frozen=True)
 15 | class _Node:
 16 |     pass
 17 | 
 18 | 
 19 | @dataclass(frozen=True)
 20 | class Identifier(_Node):
 21 |     name: str
 22 |     namespace: Tuple[str, ...] = field(default_factory=tuple)
 23 | 
 24 |     def full_name(self):
 25 |         return ".".join(self.namespace + (self.name,))
 26 | 
 27 | 
 28 | @dataclass(frozen=True)
 29 | class Attribute(_Node):
 30 |     owner: _Node
 31 |     attr: str
 32 | 
 33 | 
 34 | ###############################################################################
 35 | # Literals
 36 | ###############################################################################
 37 | @dataclass(frozen=True)
 38 | class _Literal(_Node):
 39 |     pass
 40 | 
 41 |     @property
 42 |     def py_val(self):
 43 |         raise NotImplementedError()
 44 | 
 45 | 
 46 | @dataclass(frozen=True)
 47 | class Null(_Literal):
 48 |     @property
 49 |     def py_val(self) -> None:
 50 |         return None
 51 | 
 52 | 
 53 | @dataclass(frozen=True)
 54 | class Integer(_Literal):
 55 |     val: str
 56 | 
 57 |     @property
 58 |     def py_val(self) -> int:
 59 |         return int(self.val)
 60 | 
 61 | 
 62 | @dataclass(frozen=True)
 63 | class Float(_Literal):
 64 |     val: str
 65 | 
 66 |     @property
 67 |     def py_val(self) -> float:
 68 |         return float(self.val)
 69 | 
 70 | 
 71 | @dataclass(frozen=True)
 72 | class Boolean(_Literal):
 73 |     val: str
 74 | 
 75 |     @property
 76 |     def py_val(self) -> bool:
 77 |         return self.val.lower() == "true"
 78 | 
 79 | 
 80 | @dataclass(frozen=True)
 81 | class String(_Literal):
 82 |     val: str
 83 | 
 84 |     @property
 85 |     def py_val(self) -> str:
 86 |         return self.val
 87 | 
 88 | 
 89 | @dataclass(frozen=True)
 90 | class Geography(_Literal):
 91 |     val: str
 92 | 
 93 |     def wkt(self):
 94 |         return self.val
 95 | 
 96 | 
 97 | @dataclass(frozen=True)
 98 | class Date(_Literal):
 99 |     val: str
100 | 
101 |     @property
102 |     def py_val(self) -> dt.date:
103 |         return dt.date.fromisoformat(self.val)
104 | 
105 | 
106 | @dataclass(frozen=True)
107 | class Time(_Literal):
108 |     val: str
109 | 
110 |     @property
111 |     def py_val(self) -> dt.time:
112 |         return dt.time.fromisoformat(self.val)
113 | 
114 | 
115 | @dataclass(frozen=True)
116 | class DateTime(_Literal):
117 |     val: str
118 | 
119 |     @property
120 |     def py_val(self) -> dt.datetime:
121 |         return isoparse(self.val)
122 | 
123 | 
124 | @dataclass(frozen=True)
125 | class Duration(_Literal):
126 |     val: str
127 | 
128 |     @property
129 |     def py_val(self) -> dt.timedelta:
130 |         sign, years, months, days, hours, minutes, seconds = self.unpack()
131 | 
132 |         # Initialize days to 0 if None
133 |         num_days = float(days or 0)
134 | 
135 |         # Approximate conversion, adjust as necessary for more precision
136 |         num_days += float(years or 0) * 365.25  # Average including leap years
137 |         num_days += float(months or 0) * 30.44  # Average month length
138 | 
139 |         delta = dt.timedelta(
140 |             days=num_days,
141 |             hours=float(hours or 0),
142 |             minutes=float(minutes or 0),
143 |             seconds=float(seconds or 0),
144 |         )
145 |         if sign and sign == "-":
146 |             delta = -1 * delta
147 |         return delta
148 | 
149 |     def unpack(
150 |         self,
151 |     ) -> Tuple[
152 |         Optional[str],
153 |         Optional[str],
154 |         Optional[str],
155 |         Optional[str],
156 |         Optional[str],
157 |         Optional[str],
158 |         Optional[str],
159 |     ]:
160 |         """
161 |         Returns:
162 |             ``(sign, years, months, days, hours, minutes, seconds)``
163 |         """
164 | 
165 |         match = DURATION_PATTERN.fullmatch(self.val)
166 |         if not match:
167 |             raise ValueError(f"Could not unpack Duration with value {self.val}")
168 | 
169 |         sign, years, months, days, hours, minutes, seconds = match.groups()
170 | 
171 |         _years = years[:-1] if years else None
172 |         _months = months[:-1] if months else None
173 |         _days = days[:-1] if days else None
174 |         _hours = hours[:-1] if hours else None
175 |         _minutes = minutes[:-1] if minutes else None
176 |         _seconds = seconds[:-1] if seconds else None
177 | 
178 |         return sign, _years, _months, _days, _hours, _minutes, _seconds
179 | 
180 | 
181 | @dataclass(frozen=True)
182 | class GUID(_Literal):
183 |     val: str
184 | 
185 |     @property
186 |     def py_val(self) -> UUID:
187 |         return UUID(self.val)
188 | 
189 | 
190 | @dataclass(frozen=True)
191 | class List(_Literal):
192 |     val: ListType[_Literal]
193 | 
194 |     @property
195 |     def py_val(self) -> list:
196 |         return [v.py_val for v in self.val]
197 | 
198 | 
199 | ###############################################################################
200 | # Arithmetic
201 | ###############################################################################
202 | @dataclass(frozen=True)
203 | class _BinOpToken(_Node):
204 |     pass
205 | 
206 | 
207 | @dataclass(frozen=True)
208 | class Add(_BinOpToken):
209 |     pass
210 | 
211 | 
212 | @dataclass(frozen=True)
213 | class Sub(_BinOpToken):
214 |     pass
215 | 
216 | 
217 | @dataclass(frozen=True)
218 | class Mult(_BinOpToken):
219 |     pass
220 | 
221 | 
222 | @dataclass(frozen=True)
223 | class Div(_BinOpToken):
224 |     pass
225 | 
226 | 
227 | @dataclass(frozen=True)
228 | class Mod(_BinOpToken):
229 |     pass
230 | 
231 | 
232 | @dataclass(frozen=True)
233 | class BinOp(_Node):
234 |     op: _BinOpToken
235 |     left: _Node
236 |     right: _Node
237 | 
238 | 
239 | ###############################################################################
240 | # Comparison
241 | ###############################################################################
242 | @dataclass(frozen=True)
243 | class _Comparator(_Node):
244 |     pass
245 | 
246 | 
247 | @dataclass(frozen=True)
248 | class Eq(_Comparator):
249 |     pass
250 | 
251 | 
252 | @dataclass(frozen=True)
253 | class NotEq(_Comparator):
254 |     pass
255 | 
256 | 
257 | @dataclass(frozen=True)
258 | class Lt(_Comparator):
259 |     pass
260 | 
261 | 
262 | @dataclass(frozen=True)
263 | class LtE(_Comparator):
264 |     pass
265 | 
266 | 
267 | @dataclass(frozen=True)
268 | class Gt(_Comparator):
269 |     pass
270 | 
271 | 
272 | @dataclass(frozen=True)
273 | class GtE(_Comparator):
274 |     pass
275 | 
276 | 
277 | @dataclass(frozen=True)
278 | class In(_Comparator):
279 |     pass
280 | 
281 | 
282 | @dataclass(frozen=True)
283 | class Compare(_Node):
284 |     comparator: _Comparator
285 |     left: _Node
286 |     right: _Node
287 | 
288 | 
289 | ###############################################################################
290 | # Boolean ops
291 | ###############################################################################
292 | @dataclass(frozen=True)
293 | class _BoolOpToken(_Node):
294 |     pass
295 | 
296 | 
297 | @dataclass(frozen=True)
298 | class And(_BoolOpToken):
299 |     pass
300 | 
301 | 
302 | @dataclass(frozen=True)
303 | class Or(_BoolOpToken):
304 |     pass
305 | 
306 | 
307 | @dataclass(frozen=True)
308 | class BoolOp(_Node):
309 |     op: _BoolOpToken
310 |     left: _Node
311 |     right: _Node
312 | 
313 | 
314 | ###############################################################################
315 | # Unary ops
316 | ###############################################################################
317 | @dataclass(frozen=True)
318 | class _UnaryOpToken(_Node):
319 |     pass
320 | 
321 | 
322 | @dataclass(frozen=True)
323 | class Not(_UnaryOpToken):
324 |     pass
325 | 
326 | 
327 | @dataclass(frozen=True)
328 | class USub(_UnaryOpToken):
329 |     pass
330 | 
331 | 
332 | @dataclass(frozen=True)
333 | class UnaryOp(_Node):
334 |     op: _UnaryOpToken
335 |     operand: _Node
336 | 
337 | 
338 | ###############################################################################
339 | # Function calls
340 | ###############################################################################
341 | @dataclass(frozen=True)
342 | class NamedParam(_Node):
343 |     name: Identifier
344 |     param: _Node
345 | 
346 | 
347 | @dataclass(frozen=True)
348 | class Call(_Node):
349 |     func: Identifier
350 |     args: ListType[_Node]
351 | 
352 | 
353 | ###############################################################################
354 | # Collections
355 | ###############################################################################
356 | @dataclass(frozen=True)
357 | class _CollectionOperator(_Node):
358 |     pass
359 | 
360 | 
361 | @dataclass(frozen=True)
362 | class Any(_CollectionOperator):
363 |     pass
364 | 
365 | 
366 | @dataclass(frozen=True)
367 | class All(_CollectionOperator):
368 |     pass
369 | 
370 | 
371 | @dataclass(frozen=True)
372 | class Lambda(_Node):
373 |     identifier: Identifier
374 |     expression: _Node
375 | 
376 | 
377 | @dataclass(frozen=True)
378 | class CollectionLambda(_Node):
379 |     owner: _Node
380 |     operator: _CollectionOperator
381 |     lambda_: Optional[Lambda]
382 | 
--------------------------------------------------------------------------------
/odata_query/django/__init__.py:
--------------------------------------------------------------------------------
1 | from .django_q import AstToDjangoQVisitor
2 | from .django_q_ext import *
3 | from .shorthand import apply_odata_query
4 | 
--------------------------------------------------------------------------------
/odata_query/django/django_q_ext.py:
--------------------------------------------------------------------------------
 1 | from django.db.models import CharField, Lookup, Subquery, fields, functions
 2 | from django.db.models.query import QuerySet
 3 | 
 4 | 
 5 | @fields.Field.register_lookup
 6 | class NotEqual(Lookup):
 7 |     """https://docs.djangoproject.com/en/2.2/howto/custom-lookups/"""
 8 | 
 9 |     lookup_name = "ne"
10 | 
11 |     def as_sql(self, compiler, connection):  # type: ignore
12 |         lhs, lhs_params = self.process_lhs(compiler, connection)
13 |         rhs, rhs_params = self.process_rhs(compiler, connection)
14 |         params = lhs_params + rhs_params
15 |         return "%s <> %s" % (lhs, rhs), params
16 | 
17 | 
18 | CharField.register_lookup(functions.Length)
19 | CharField.register_lookup(functions.Upper)
20 | CharField.register_lookup(functions.Lower)
21 | CharField.register_lookup(functions.Trim)
22 | 
23 | 
24 | class _AnyAll(Subquery):
25 |     template = ""
26 | 
27 |     def __init__(self, queryset: QuerySet, negated: bool = False, **kwargs):
28 |         # As a performance optimization, remove ordering since ~ doesn't
29 |         # care about it, just whether or not a row matches.
30 |         queryset = queryset.order_by()
31 |         self.negated = negated
32 |         super().__init__(queryset, **kwargs)
33 | 
34 |     def __invert__(self) -> Subquery:
35 |         clone = self.copy()
36 |         clone.negated = not self.negated
37 |         return clone
38 | 
39 |     def __repr__(self) -> str:
40 |         return self.template % {"subquery": self.queryset.query}
41 | 
42 |     def as_sql(  # type: ignore
43 |         self, compiler, connection, template=None, **extra_context
44 |     ):
45 |         sql, params = super().as_sql(compiler, connection, template, **extra_context)
46 |         # if self.negated:
47 |         #     sql = "NOT {}".format(sql)
48 |         return sql, params
49 | 
50 |     def select_format(self, compiler, sql, params):  # type:ignore
51 |         return sql, params
52 | 
53 | 
54 | class Any(_AnyAll):
55 |     template = "ANY(%(subquery)s)"
56 | 
57 | 
58 | class All(_AnyAll):
59 |     template = "ALL(%(subquery)s)"
60 | 
--------------------------------------------------------------------------------
/odata_query/django/shorthand.py:
--------------------------------------------------------------------------------
 1 | from django.db.models.query import QuerySet
 2 | 
 3 | from odata_query.grammar import ODataLexer, ODataParser  # type: ignore
 4 | 
 5 | from .django_q import AstToDjangoQVisitor
 6 | 
 7 | 
 8 | def apply_odata_query(queryset: QuerySet, odata_query: str) -> QuerySet:
 9 |     """
10 |     Shorthand for applying an OData query to a Django QuerySet.
11 | 
12 |     Args:
13 |         queryset: Django QuerySet to apply the OData query to.
14 |         odata_query: OData query string.
15 |     Returns:
16 |         QuerySet: The modified QuerySet
17 |     """
18 |     lexer = ODataLexer()
19 |     parser = ODataParser()
20 |     model = queryset.model
21 | 
22 |     ast = parser.parse(lexer.tokenize(odata_query))
23 |     transformer = AstToDjangoQVisitor(model)
24 |     where_clause = transformer.visit(ast)
25 | 
26 |     if transformer.queryset_annotations:
27 |         queryset = queryset.annotate(**transformer.queryset_annotations)
28 | 
29 |     return queryset.filter(where_clause)
30 | 
--------------------------------------------------------------------------------
/odata_query/django/utils.py:
--------------------------------------------------------------------------------
 1 | from typing import Tuple, Type
 2 | 
 3 | from django.db.models import Model
 4 | 
 5 | 
 6 | def reverse_relationship(
 7 |     relationship_expr: str, root_model: Type[Model]
 8 | ) -> Tuple[str, Type[Model]]:
 9 |     """
10 |     Reverses a relationship expression relative to root_model.
11 | 
12 |     Args:
13 |         relationship_expr: The Django relationship string, with underscores to
14 |             represent relationship traversal.
15 |         root_model: The model to which relationship_expr is relative.
16 | 
17 |     Returns:
18 |         str: The django relationship string in reverse, so from the last joined
19 |             relationship back to the root model.
20 |         Type[Model]: The model to which the returned expression is relative.
21 |     """
22 |     relation_steps = relationship_expr.split("__")
23 | 
24 |     related_model = root_model
25 |     path_to_outerref_parts = []
26 |     for step in relation_steps:
27 |         related_field = related_model._meta.get_field(step)
28 |         related_model = related_field.related_model
29 |         path_to_outerref_parts.append(related_field.remote_field.name)
30 | 
31 |     path_to_outerref = "__".join(reversed(path_to_outerref_parts))
32 | 
33 |     return (path_to_outerref, related_model)
34 | 
--------------------------------------------------------------------------------
/odata_query/exceptions.py:
--------------------------------------------------------------------------------
  1 | from typing import Any, Optional
  2 | 
  3 | from sly.lex import Token
  4 | 
  5 | 
  6 | class ODataException(Exception):
  7 |     """
  8 |     Base class for all exceptions in this library.
  9 |     """
 10 | 
 11 |     pass
 12 | 
 13 | 
 14 | class ODataSyntaxError(ODataException):
 15 |     """
 16 |     Base class for syntax errors.
 17 |     """
 18 | 
 19 |     pass
 20 | 
 21 | 
 22 | class TokenizingException(ODataSyntaxError):
 23 |     """
 24 |     Thrown when the lexer cannot tokenize the query.
 25 |     """
 26 | 
 27 |     def __init__(self, token: Token):
 28 |         self.token = token
 29 |         super().__init__(f"Failed to tokenize at: {token}")
 30 | 
 31 | 
 32 | class ParsingException(ODataSyntaxError):
 33 |     """
 34 |     Thrown when the parser cannot parse the query.
 35 |     """
 36 | 
 37 |     def __init__(self, token: Optional[Token], eof: bool = False):
 38 |         self.token = token
 39 |         self.eof = eof
 40 |         super().__init__(f"Failed to parse at: {token}")
 41 | 
 42 | 
 43 | class FunctionCallException(ODataException):
 44 |     """
 45 |     Base class for errors in function calls.
 46 |     """
 47 | 
 48 |     pass
 49 | 
 50 | 
 51 | class UnknownFunctionException(FunctionCallException):
 52 |     """
 53 |     Thrown when the parser encounters an undefined function call.
 54 |     """
 55 | 
 56 |     def __init__(self, function_name: str):
 57 |         self.function_name = function_name
 58 |         super().__init__(f"Unknown function: '{function_name}'")
 59 | 
 60 | 
 61 | class ArgumentCountException(FunctionCallException):
 62 |     """
 63 |     Thrown when the parser encounters a function called with a wrong number
 64 |     of arguments.
 65 |     """
 66 | 
 67 |     def __init__(
 68 |         self, function_name: str, exp_min_args: int, exp_max_args: int, given_args: int
 69 |     ):
 70 |         self.function_name = function_name
 71 |         self.exp_min_args = exp_min_args
 72 |         self.exp_max_args = exp_max_args
 73 |         self.n_args_given = given_args
 74 |         if exp_min_args != exp_max_args:
 75 |             super().__init__(
 76 |                 f"Function '{function_name}' takes between {exp_min_args} and "
 77 |                 f"{exp_max_args} arguments. {given_args} given."
 78 |             )
 79 |         else:
 80 |             super().__init__(
 81 |                 f"Function '{function_name}' takes {exp_min_args} arguments. "
 82 |                 f"{given_args} given."
 83 |             )
 84 | 
 85 | 
 86 | class UnsupportedFunctionException(FunctionCallException):
 87 |     """
 88 |     Thrown when a function is used that is not implemented yet.
 89 |     """
 90 | 
 91 |     def __init__(self, function_name: str):
 92 |         self.function_name = function_name
 93 |         super().__init__(f"Function '{function_name}' is not implemented yet.")
 94 | 
 95 | 
 96 | class ArgumentTypeException(FunctionCallException):
 97 |     """
 98 |     Thrown when a function is called with argument of the wrong type.
 99 |     """
100 | 
101 |     def __init__(
102 |         self,
103 |         function_name: Optional[str] = None,
104 |         expected_type: Optional[str] = None,
105 |         actual_type: Optional[str] = None,
106 |     ):
107 |         self.function_name = function_name
108 |         self.expected_type = expected_type
109 |         self.actual_type = actual_type
110 | 
111 |         if function_name:
112 |             message = f"Unsupported or invalid type for function or operator '{function_name}'"
113 |         else:
114 |             message = "Invalid argument type for function or operator."
115 |         if expected_type:
116 |             message += f" Expected {expected_type}"
117 |             if actual_type:
118 |                 message += f", got {actual_type}"
119 | 
120 |         super().__init__(message)
121 | 
122 | 
123 | class TypeException(ODataException):
124 |     """
125 |     Thrown when doing an invalid operation on a value.
126 |     E.g. `10 gt null` or `~date()`
127 |     """
128 | 
129 |     def __init__(self, operation: str, value: str):
130 |         self.operation = operation
131 |         self.value = value
132 |         super().__init__(f"Cannot apply '{operation}' to '{value}'")
133 | 
134 | 
135 | class ValueException(ODataException):
136 |     """
137 |     Thrown when a value has an invalid value, such as an invalid datetime.
138 |     """
139 | 
140 |     def __init__(self, value: Any):
141 |         self.value = value
142 |         super().__init__(f"Invalid value: {value}")
143 | 
144 | 
145 | class InvalidFieldException(ODataException):
146 |     """
147 |     Thrown when a field mentioned in a query does not exist.
148 |     """
149 | 
150 |     def __init__(self, field_name: str):
151 |         self.field_name = field_name
152 |         super().__init__(f"Invalid field: {field_name}")
153 | 
--------------------------------------------------------------------------------
/odata_query/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/odata_query/py.typed
--------------------------------------------------------------------------------
/odata_query/rewrite.py:
--------------------------------------------------------------------------------
 1 | from typing import Dict, Optional
 2 | 
 3 | from . import ast
 4 | from .grammar import ODataLexer, ODataParser  # type: ignore
 5 | from .visitor import NodeTransformer
 6 | 
 7 | 
 8 | class AliasRewriter(NodeTransformer):
 9 |     """
10 |     A :class:`NodeTransformer` that replaces aliases in the :term:`AST` with their
11 |     aliased identifiers or attributes.
12 | 
13 |     Args:
14 |         field_aliases: A mapping of aliases to their full name. These can
15 |             be identifiers, attributes, and even function calls in odata
16 |             syntax.
17 |         lexer: Optional lexer instance to use. If not passed, will construct
18 |             the default one.
19 |         parser: Optional parser instance to use. If not passed, will construct
20 |             the default one.
21 |     """
22 | 
23 |     def __init__(
24 |         self,
25 |         field_aliases: Dict[str, str],
26 |         lexer: Optional[ODataLexer] = None,
27 |         parser: Optional[ODataParser] = None,
28 |     ):
29 |         self.field_aliases = field_aliases
30 | 
31 |         if not lexer:
32 |             lexer = ODataLexer()
33 |         if not parser:
34 |             parser = ODataParser()
35 | 
36 |         self.replacements = {
37 |             parser.parse(lexer.tokenize(k)): parser.parse(lexer.tokenize(v))
38 |             for k, v in self.field_aliases.items()
39 |         }
40 | 
41 |     def visit_Identifier(self, node: ast.Identifier) -> ast._Node:
42 |         """:meta private:"""
43 |         if node in self.replacements:
44 |             return self.replacements[node]
45 |         return node
46 | 
47 |     def visit_Attribute(self, node: ast.Attribute) -> ast._Node:
48 |         """:meta private:"""
49 |         if node in self.replacements:
50 |             return self.replacements[node]
51 |         else:
52 |             new_owner = self.visit(node.owner)
53 |             return ast.Attribute(new_owner, node.attr)
54 | 
55 | 
56 | class IdentifierStripper(NodeTransformer):
57 |     """
58 |     A :class:`NodeTransformer` that strips the given identifier off of
59 |     attributes. E.g. ``author/name`` -> ``name``.
60 | 
61 |     Args:
62 |         strip: The identifier to strip off of all attributes in the :term:`AST`
63 |     """
64 | 
65 |     def __init__(self, strip: ast.Identifier):
66 |         self.strip = strip
67 | 
68 |     def visit_Attribute(self, node: ast.Attribute) -> ast._Node:
69 |         """:meta private:"""
70 |         if node.owner == self.strip:
71 |             return ast.Identifier(node.attr)
72 |         elif isinstance(node.owner, ast.Attribute):
73 |             return ast.Attribute(self.visit(node.owner), node.attr)
74 | 
75 |         return node
76 | 
--------------------------------------------------------------------------------
/odata_query/roundtrip.py:
--------------------------------------------------------------------------------
  1 | import sys
  2 | from typing import Type
  3 | 
  4 | if sys.version_info >= (3, 8):
  5 |     from typing import Protocol
  6 | else:
  7 |     from typing_extensions import Protocol
  8 | 
  9 | from odata_query import ast, visitor
 10 | 
 11 | 
 12 | class LiteralValNode(Protocol):
 13 |     """:meta private:"""
 14 | 
 15 |     val: str
 16 | 
 17 | 
 18 | PRECEDENCE = {
 19 |     ast.Attribute: 10,
 20 |     ast.Call: 10,
 21 |     ast.Not: 9,
 22 |     ast.USub: 9,
 23 |     ast.Mult: 8,
 24 |     ast.Div: 8,
 25 |     ast.Mod: 8,
 26 |     ast.Add: 7,
 27 |     ast.Sub: 7,
 28 |     ast.Gt: 6,
 29 |     ast.GtE: 6,
 30 |     ast.Lt: 6,
 31 |     ast.LtE: 6,
 32 |     ast.Eq: 5,
 33 |     ast.NotEq: 5,
 34 |     ast.And: 4,
 35 |     ast.Or: 3,
 36 | }
 37 | 
 38 | 
 39 | class AstToODataVisitor(visitor.NodeVisitor):
 40 |     """
 41 |     :class:`NodeVisitor` that transforms an :term:`AST` back into an OData
 42 |     query string.
 43 |     """
 44 | 
 45 |     def visit_Identifier(self, node: ast.Identifier) -> str:
 46 |         """:meta private:"""
 47 |         if node.namespace:
 48 |             prefix = ".".join(node.namespace)
 49 |             return prefix + "." + node.name
 50 |         return node.name
 51 | 
 52 |     def visit_Attribute(self, node: ast.Attribute) -> str:
 53 |         """:meta private:"""
 54 |         return self.visit(node.owner) + "/" + node.attr
 55 | 
 56 |     def visit_Null(self, node: ast.Null) -> str:
 57 |         """:meta private:"""
 58 |         return "null"
 59 | 
 60 |     def visit_String(self, node: ast.String) -> str:
 61 |         """:meta private:"""
 62 |         return "'" + node.val + "'"
 63 | 
 64 |     def visit_Duration(self, node: ast.Duration) -> str:
 65 |         """:meta private:"""
 66 |         return "duration'" + node.val + "'"
 67 | 
 68 |     def _visit_Literal(self, node: LiteralValNode) -> str:
 69 |         """:meta private:"""
 70 |         return node.val
 71 | 
 72 |     visit_Integer = _visit_Literal
 73 |     visit_Float = _visit_Literal
 74 |     visit_Boolean = _visit_Literal
 75 |     visit_Date = _visit_Literal
 76 |     visit_Time = _visit_Literal
 77 |     visit_DateTime = _visit_Literal
 78 |     visit_GUID = _visit_Literal
 79 | 
 80 |     def visit_List(self, node: ast.List) -> str:
 81 |         """:meta private:"""
 82 |         return "(" + ", ".join(self.visit(v) for v in node.val) + ")"
 83 | 
 84 |     def visit_Add(self, node: ast.Add) -> str:
 85 |         """:meta private:"""
 86 |         return "add"
 87 | 
 88 |     def visit_Sub(self, node: ast.Sub) -> str:
 89 |         """:meta private:"""
 90 |         return "sub"
 91 | 
 92 |     def visit_Mult(self, node: ast.Mult) -> str:
 93 |         """:meta private:"""
 94 |         return "mul"
 95 | 
 96 |     def visit_Div(self, node: ast.Div) -> str:
 97 |         """:meta private:"""
 98 |         return "div"
 99 | 
100 |     def visit_Mod(self, node: ast.Mod) -> str:
101 |         """:meta private:"""
102 |         return "mod"
103 | 
104 |     def visit_BinOp(self, node: ast.BinOp) -> str:
105 |         """:meta private:"""
106 |         left = self._visit_and_paren_if_precedence_lower(node.left, type(node.op))
107 |         right = self._visit_and_paren_if_precedence_lower(node.right, type(node.op))
108 |         return left + " " + self.visit(node.op) + " " + right
109 | 
110 |     def visit_Eq(self, node: ast.Eq) -> str:
111 |         """:meta private:"""
112 |         return "eq"
113 | 
114 |     def visit_NotEq(self, node: ast.NotEq) -> str:
115 |         """:meta private:"""
116 |         return "ne"
117 | 
118 |     def visit_Lt(self, node: ast.Lt) -> str:
119 |         """:meta private:"""
120 |         return "lt"
121 | 
122 |     def visit_LtE(self, node: ast.LtE) -> str:
123 |         """:meta private:"""
124 |         return "le"
125 | 
126 |     def visit_Gt(self, node: ast.Gt) -> str:
127 |         """:meta private:"""
128 |         return "gt"
129 | 
130 |     def visit_GtE(self, node: ast.GtE) -> str:
131 |         """:meta private:"""
132 |         return "ge"
133 | 
134 |     def visit_In(self, node: ast.In) -> str:
135 |         """:meta private:"""
136 |         return "in"
137 | 
138 |     def visit_Compare(self, node: ast.Compare) -> str:
139 |         """:meta private:"""
140 |         left = self._visit_and_paren_if_precedence_lower(
141 |             node.left, type(node.comparator)
142 |         )
143 |         right = self._visit_and_paren_if_precedence_lower(
144 |             node.right, type(node.comparator)
145 |         )
146 |         return left + " " + self.visit(node.comparator) + " " + right
147 | 
148 |     def visit_And(self, node: ast.And) -> str:
149 |         """:meta private:"""
150 |         return "and"
151 | 
152 |     def visit_Or(self, node: ast.Or) -> str:
153 |         """:meta private:"""
154 |         return "or"
155 | 
156 |     def visit_BoolOp(self, node: ast.BoolOp) -> str:
157 |         """:meta private:"""
158 |         left = self._visit_and_paren_if_precedence_lower(node.left, type(node.op))
159 |         right = self._visit_and_paren_if_precedence_lower(node.right, type(node.op))
160 |         return left + " " + self.visit(node.op) + " " + right
161 | 
162 |     def visit_Not(self, node: ast.Not) -> str:
163 |         """:meta private:"""
164 |         return "not"
165 | 
166 |     def visit_USub(self, node: ast.USub) -> str:
167 |         """:meta private:"""
168 |         return "-"
169 | 
170 |     def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
171 |         """:meta private:"""
172 |         operand = self._visit_and_paren_if_precedence_lower(node.operand, type(node.op))
173 |         return self.visit(node.op) + " " + operand
174 | 
175 |     def visit_Call(self, node: ast.Call) -> str:
176 |         """:meta private:"""
177 |         return (
178 |             self.visit(node.func)
179 |             + "("
180 |             + ", ".join(self.visit(n) for n in node.args)
181 |             + ")"
182 |         )
183 | 
184 |     def visit_Any(self, node: ast.Any) -> str:
185 |         """:meta private:"""
186 |         return "any"
187 | 
188 |     def visit_All(self, node: ast.All) -> str:
189 |         """:meta private:"""
190 |         return "all"
191 | 
192 |     def visit_Lambda(self, node: ast.Lambda) -> str:
193 |         """:meta private:"""
194 |         return self.visit(node.identifier) + ": " + self.visit(node.expression)
195 | 
196 |     def visit_CollectionLambda(self, node: ast.CollectionLambda) -> str:
197 |         """:meta private:"""
198 |         return (
199 |             self.visit(node.owner)
200 |             + "/"
201 |             + self.visit(node.operator)
202 |             + "("
203 |             + (self.visit(node.lambda_) if node.lambda_ else "")
204 |             + ")"
205 |         )
206 | 
207 |     def _visit_and_paren_if_precedence_lower(
208 |         self, node: ast._Node, precedence: Type[ast._Node]
209 |     ) -> str:
210 |         """
211 |         Transform `node` by visiting it, then wrap the result in parentheses if
212 |         the expressions precedence is lower than that of `precedence`.
213 | 
214 |         :meta private:
215 |         """
216 |         res = self.visit(node)
217 | 
218 |         if hasattr(node, "op"):
219 |             node_op = type(node.op)  # type: ignore
220 |         elif hasattr(node, "comparator"):
221 |             node_op = type(node.comparator)  # type: ignore
222 |         else:
223 |             node_op = type(node)
224 | 
225 |         node_prec = PRECEDENCE.get(node_op, 100)
226 |         check_prec = PRECEDENCE.get(precedence, 100)
227 | 
228 |         if node_prec < check_prec:
229 |             res = "(" + res + ")"
230 | 
231 |         return res
232 | 
--------------------------------------------------------------------------------
/odata_query/sql/__init__.py:
--------------------------------------------------------------------------------
1 | from .athena import AstToAthenaSqlVisitor
2 | from .base import AstToSqlVisitor
3 | from .sqlite import AstToSqliteSqlVisitor
4 | 
--------------------------------------------------------------------------------
/odata_query/sql/athena.py:
--------------------------------------------------------------------------------
  1 | import re
  2 | 
  3 | from odata_query import ast, exceptions, typing
  4 | 
  5 | from .base import AstToSqlVisitor
  6 | 
  7 | UNSAFE_CHARS = re.compile(r"[^a-zA-Z0-9_]")
  8 | 
  9 | 
 10 | def clean_athena_identifier(identifier: str) -> str:
 11 |     """
 12 |     Cleans an Athena identifier so it passes the following validation rules:
 13 | 
 14 |     - Table names and table column names in Athena must be lowercase
 15 |     - Athena table, view, database, and column names allow only underscore special characters
 16 |     - Names should be quoted or backticked when starting with a number or underscore
 17 | 
 18 |     Source: https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
 19 |     """
 20 |     id_new = identifier.lower()
 21 |     id_new = UNSAFE_CHARS.sub("_", id_new)
 22 |     return id_new
 23 | 
 24 | 
 25 | class AstToAthenaSqlVisitor(AstToSqlVisitor):
 26 |     """
 27 |     :class:`NodeVisitor` that transforms an :term:`AST` into an Athena SQL
 28 |     ``WHERE`` clause.
 29 | 
 30 |     Args:
 31 |         table_alias: Optional alias for the root table.
 32 |     """
 33 | 
 34 |     def visit_Identifier(self, node: ast.Identifier) -> str:
 35 |         ":meta private:"
 36 |         # Double quotes for column names acc SQL Standard
 37 |         sql_id = f'"{clean_athena_identifier(node.name)}"'
 38 | 
 39 |         if self.table_alias:
 40 |             sql_id = f'"{self.table_alias}".' + sql_id
 41 | 
 42 |         return sql_id
 43 | 
 44 |     def visit_DateTime(self, node: ast.DateTime) -> str:
 45 |         ":meta private:"
 46 |         return f"FROM_ISO8601_TIMESTAMP('{node.val}')"
 47 | 
 48 |     def sqlfunc_length(self, arg: ast._Node) -> str:
 49 |         ":meta private:"
 50 |         arg_sql = self.visit(arg)
 51 |         inferred_type = typing.infer_type(arg)
 52 | 
 53 |         # If the input is a string or default, assume str-length:
 54 |         if inferred_type is ast.String or inferred_type is None:
 55 |             return f"LENGTH({arg_sql})"
 56 | 
 57 |         # If the input is a list, assume list-length:
 58 |         if inferred_type is ast.List:
 59 |             return f"CARDINALITY({arg_sql})"
 60 | 
 61 |         raise exceptions.ArgumentTypeException("length")
 62 | 
 63 |     def sqlfunc_substring(self, *args: ast._Node) -> str:
 64 |         ":meta private:"
 65 |         args_sql = [self.visit(arg) for arg in args]
 66 |         inferred_type = typing.infer_type(args[0])
 67 | 
 68 |         # If the first input is a string or default, assume str-substr:
 69 |         if inferred_type is ast.String or inferred_type is None:
 70 |             if len(args) == 2:
 71 |                 return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1)"
 72 |             if len(args) == 3:
 73 |                 return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1, {args_sql[2]})"
 74 | 
 75 |         # If the first input is a list, assume list-substr:
 76 |         if inferred_type is ast.List:
 77 |             if len(args) == 2:
 78 |                 return f"SLICE({args_sql[0]}, {args_sql[1]})"
 79 |             if len(args) == 3:
 80 |                 return f"SLICE({args_sql[0]}, {args_sql[1]}, {args_sql[2]})"
 81 | 
 82 |         raise exceptions.ArgumentTypeException("substring")
 83 | 
 84 |     def sqlfunc_round(self, arg: ast._Node) -> str:
 85 |         ":meta private:"
 86 |         arg_sql = self.visit(arg)
 87 |         return f"ROUND({arg_sql})"
 88 | 
 89 |     def sqlfunc_floor(self, arg: ast._Node) -> str:
 90 |         ":meta private:"
 91 |         arg_sql = self.visit(arg)
 92 |         return f"FLOOR({arg_sql})"
 93 | 
 94 |     def sqlfunc_ceiling(self, arg: ast._Node) -> str:
 95 |         ":meta private:"
 96 |         arg_sql = self.visit(arg)
 97 |         return f"CEILING({arg_sql})"
 98 | 
 99 |     def sqlfunc_hassubset(self, *args: ast._Node) -> str:
100 |         ":meta private:"
101 |         args_sql = [self.visit(arg) for arg in args]
102 |         return f"CARDINALITY(ARRAY_INTERSECT({args_sql[0]}, {args_sql[1]})) = CARDINALITY({args_sql[1]})"
103 | 
--------------------------------------------------------------------------------
/odata_query/sql/sqlite.py:
--------------------------------------------------------------------------------
  1 | from odata_query import ast, exceptions, typing
  2 | 
  3 | from .base import AstToSqlVisitor
  4 | 
  5 | 
  6 | class AstToSqliteSqlVisitor(AstToSqlVisitor):
  7 |     """
  8 |     :class:`NodeVisitor` that transforms an :term:`AST` into a SQLite SQL
  9 |     ``WHERE`` clause.
 10 | 
 11 |     Args:
 12 |         table_alias: Optional alias for the root table.
 13 |     """
 14 | 
 15 |     def visit_Boolean(self, node: ast.Boolean) -> str:
 16 |         """:meta private:"""
 17 |         if node.py_val:
 18 |             return "1"
 19 |         return "0"
 20 | 
 21 |     def visit_Date(self, node: ast.Date) -> str:
 22 |         """:meta private:"""
 23 |         return f"DATE('{node.val}')"
 24 | 
 25 |     def visit_DateTime(self, node: ast.DateTime) -> str:
 26 |         """:meta private:"""
 27 |         return f"DATETIME('{node.val}')"
 28 | 
 29 |     def sqlfunc_indexof(self, *args: ast._Node) -> str:
 30 |         """:meta private:"""
 31 |         args_sql = [self.visit(arg) for arg in args]
 32 |         inferred_type = [typing.infer_type(arg) for arg in args]
 33 | 
 34 |         # If any of the inputs is a string, assume str-indexof:
 35 |         if any(typ is ast.String for typ in inferred_type) or all(
 36 |             typ is None for typ in inferred_type
 37 |         ):
 38 |             return f"INSTR({args_sql[0]}, {args_sql[1]}) - 1"
 39 | 
 40 |         # If any of the inputs is a list, assume list-indexof
 41 |         # which isn't easily doable at the moment:
 42 |         if any(typ is ast.List for typ in inferred_type):
 43 |             raise exceptions.UnsupportedFunctionException("indexof")
 44 | 
 45 |         raise exceptions.ArgumentTypeException("indexof")
 46 | 
 47 |     def sqlfunc_length(self, arg: ast._Node) -> str:
 48 |         """:meta private:"""
 49 |         arg_sql = self.visit(arg)
 50 |         return f"LENGTH({arg_sql})"
 51 | 
 52 |     def sqlfunc_substring(self, *args: ast._Node) -> str:
 53 |         """:meta private:"""
 54 |         args_sql = [self.visit(arg) for arg in args]
 55 |         inferred_type = typing.infer_type(args[0])
 56 | 
 57 |         # If the first input is a string or default, assume str-substr:
 58 |         if inferred_type is ast.String or inferred_type is None:
 59 |             if len(args) == 2:
 60 |                 return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1)"
 61 |             if len(args) == 3:
 62 |                 return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1, {args_sql[2]})"
 63 | 
 64 |         # If the first input is a list, assume list-substr:
 65 |         if inferred_type is ast.List:
 66 |             raise exceptions.UnsupportedFunctionException("substring")
 67 | 
 68 |         raise exceptions.ArgumentTypeException("substring")
 69 | 
 70 |     def sqlfunc_year(self, arg: ast._Node) -> str:
 71 |         """:meta private:"""
 72 |         arg_sql = self.visit(arg)
 73 |         return f"CAST(STRFTIME('%Y', {arg_sql}) AS INTEGER)"
 74 | 
 75 |     def sqlfunc_month(self, arg: ast._Node) -> str:
 76 |         """:meta private:"""
 77 |         arg_sql = self.visit(arg)
 78 |         return f"CAST(STRFTIME('%m', {arg_sql}) AS INTEGER)"
 79 | 
 80 |     def sqlfunc_day(self, arg: ast._Node) -> str:
 81 |         """:meta private:"""
 82 |         arg_sql = self.visit(arg)
 83 |         return f"CAST(STRFTIME('%d', {arg_sql}) AS INTEGER)"
 84 | 
 85 |     def sqlfunc_hour(self, arg: ast._Node) -> str:
 86 |         """:meta private:"""
 87 |         arg_sql = self.visit(arg)
 88 |         return f"CAST(STRFTIME('%H', {arg_sql}) AS INTEGER)"
 89 | 
 90 |     def sqlfunc_minute(self, arg: ast._Node) -> str:
 91 |         """:meta private:"""
 92 |         arg_sql = self.visit(arg)
 93 |         return f"CAST(STRFTIME('%M', {arg_sql}) AS INTEGER)"
 94 | 
 95 |     def sqlfunc_date(self, arg: ast._Node) -> str:
 96 |         """:meta private:"""
 97 |         arg_sql = self.visit(arg)
 98 |         return f"DATE({arg_sql})"
 99 | 
100 |     def sqlfunc_now(self) -> str:
101 |         """:meta private:"""
102 |         return "DATETIME('now')"
103 | 
104 |     def sqlfunc_round(self, arg: ast._Node) -> str:
105 |         """:meta private:"""
106 |         arg_sql = self.visit(arg)
107 |         return f"TRUNC({arg_sql} + 0.5)"
108 | 
109 |     def sqlfunc_floor(self, arg: ast._Node) -> str:
110 |         """:meta private:"""
111 |         arg_sql = self.visit(arg)
112 |         return f"FLOOR({arg_sql})"
113 | 
114 |     def sqlfunc_ceiling(self, arg: ast._Node) -> str:
115 |         """:meta private:"""
116 |         arg_sql = self.visit(arg)
117 |         return f"CEILING({arg_sql})"
118 | 
--------------------------------------------------------------------------------
/odata_query/sqlalchemy/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import AstToSqlAlchemyCoreVisitor
2 | from .orm import AstToSqlAlchemyOrmVisitor
3 | from .shorthand import apply_odata_core, apply_odata_query
4 | 
--------------------------------------------------------------------------------
/odata_query/sqlalchemy/common.py:
--------------------------------------------------------------------------------
  1 | import operator
  2 | from typing import Any, Callable, Optional, Union
  3 | 
  4 | from sqlalchemy.sql import functions
  5 | from sqlalchemy.sql.expression import (
  6 |     BinaryExpression,
  7 |     BindParameter,
  8 |     BooleanClauseList,
  9 |     ClauseElement,
 10 |     False_,
 11 |     Null,
 12 |     True_,
 13 |     and_,
 14 |     cast,
 15 |     extract,
 16 |     false,
 17 |     literal,
 18 |     null,
 19 |     or_,
 20 |     true,
 21 | )
 22 | from sqlalchemy.types import Date, Time
 23 | 
 24 | from odata_query import ast, exceptions as ex, typing, visitor
 25 | 
 26 | from . import functions_ext
 27 | 
 28 | 
 29 | class _CommonVisitors(visitor.NodeVisitor):
 30 |     """
 31 |     Contains the visitor methods that are equal between SQLAlchemy Core and ORM.
 32 |     """
 33 | 
 34 |     def visit_Null(self, node: ast.Null) -> Null:
 35 |         ":meta private:"
 36 |         return null()
 37 | 
 38 |     def visit_Integer(self, node: ast.Integer) -> BindParameter:
 39 |         ":meta private:"
 40 |         return literal(node.py_val)
 41 | 
 42 |     def visit_Float(self, node: ast.Float) -> BindParameter:
 43 |         ":meta private:"
 44 |         return literal(node.py_val)
 45 | 
 46 |     def visit_Boolean(self, node: ast.Boolean) -> Union[True_, False_]:
 47 |         ":meta private:"
 48 |         if node.val == "true":
 49 |             return true()
 50 |         else:
 51 |             return false()
 52 | 
 53 |     def visit_String(self, node: ast.String) -> BindParameter:
 54 |         ":meta private:"
 55 |         return literal(node.py_val)
 56 | 
 57 |     def visit_Date(self, node: ast.Date) -> BindParameter:
 58 |         ":meta private:"
 59 |         try:
 60 |             return literal(node.py_val)
 61 |         except ValueError:
 62 |             raise ex.ValueException(node.val)
 63 | 
 64 |     def visit_DateTime(self, node: ast.DateTime) -> BindParameter:
 65 |         ":meta private:"
 66 |         try:
 67 |             return literal(node.py_val)
 68 |         except ValueError:
 69 |             raise ex.ValueException(node.val)
 70 | 
 71 |     def visit_Time(self, node: ast.Time) -> BindParameter:
 72 |         ":meta private:"
 73 |         try:
 74 |             return literal(node.py_val)
 75 |         except ValueError:
 76 |             raise ex.ValueException(node.val)
 77 | 
 78 |     def visit_Duration(self, node: ast.Duration) -> BindParameter:
 79 |         ":meta private:"
 80 |         return literal(node.py_val)
 81 | 
 82 |     def visit_GUID(self, node: ast.GUID) -> BindParameter:
 83 |         ":meta private:"
 84 |         return literal(node.val)
 85 | 
 86 |     def visit_List(self, node: ast.List) -> list:
 87 |         ":meta private:"
 88 |         return [self.visit(n) for n in node.val]
 89 | 
 90 |     def visit_Add(self, node: ast.Add) -> Callable[[Any, Any], Any]:
 91 |         ":meta private:"
 92 |         return operator.add
 93 | 
 94 |     def visit_Sub(self, node: ast.Sub) -> Callable[[Any, Any], Any]:
 95 |         ":meta private:"
 96 |         return operator.sub
 97 | 
 98 |     def visit_Mult(self, node: ast.Mult) -> Callable[[Any, Any], Any]:
 99 |         ":meta private:"
100 |         return operator.mul
101 | 
102 |     def visit_Div(self, node: ast.Div) -> Callable[[Any, Any], Any]:
103 |         ":meta private:"
104 |         return operator.truediv
105 | 
106 |     def visit_Mod(self, node: ast.Mod) -> Callable[[Any, Any], Any]:
107 |         ":meta private:"
108 |         return operator.mod
109 | 
110 |     def visit_BinOp(self, node: ast.BinOp) -> Any:
111 |         ":meta private:"
112 |         left = self.visit(node.left)
113 |         right = self.visit(node.right)
114 |         op = self.visit(node.op)
115 | 
116 |         return op(left, right)
117 | 
118 |     def visit_Eq(
119 |         self, node: ast.Eq
120 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
121 |         ":meta private:"
122 |         return operator.eq
123 | 
124 |     def visit_NotEq(
125 |         self, node: ast.NotEq
126 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
127 |         ":meta private:"
128 |         return operator.ne
129 | 
130 |     def visit_Lt(
131 |         self, node: ast.Lt
132 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
133 |         ":meta private:"
134 |         return operator.lt
135 | 
136 |     def visit_LtE(
137 |         self, node: ast.LtE
138 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
139 |         ":meta private:"
140 |         return operator.le
141 | 
142 |     def visit_Gt(
143 |         self, node: ast.Gt
144 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
145 |         ":meta private:"
146 |         return operator.gt
147 | 
148 |     def visit_GtE(
149 |         self, node: ast.GtE
150 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
151 |         ":meta private:"
152 |         return operator.ge
153 | 
154 |     def visit_In(
155 |         self, node: ast.In
156 |     ) -> Callable[[ClauseElement, ClauseElement], BinaryExpression]:
157 |         ":meta private:"
158 |         return lambda a, b: a.in_(b)
159 | 
160 |     def visit_And(
161 |         self, node: ast.And
162 |     ) -> Callable[[ClauseElement, ClauseElement], BooleanClauseList]:
163 |         ":meta private:"
164 |         return and_
165 | 
166 |     def visit_Or(
167 |         self, node: ast.Or
168 |     ) -> Callable[[ClauseElement, ClauseElement], BooleanClauseList]:
169 |         ":meta private:"
170 |         return or_
171 | 
172 |     def visit_BoolOp(self, node: ast.BoolOp) -> BooleanClauseList:
173 |         ":meta private:"
174 |         left = self.visit(node.left)
175 |         right = self.visit(node.right)
176 |         op = self.visit(node.op)
177 |         return op(left, right)
178 | 
179 |     def visit_Not(self, node: ast.Not) -> Callable[[ClauseElement], ClauseElement]:
180 |         ":meta private:"
181 |         return operator.invert
182 | 
183 |     def visit_UnaryOp(self, node: ast.UnaryOp) -> ClauseElement:
184 |         ":meta private:"
185 |         mod = self.visit(node.op)
186 |         val = self.visit(node.operand)
187 | 
188 |         try:
189 |             return mod(val)
190 |         except TypeError:
191 |             raise ex.TypeException(node.op.__class__.__name__, val)
192 | 
193 |     def visit_Call(self, node: ast.Call) -> ClauseElement:
194 |         ":meta private:"
195 |         try:
196 |             handler = getattr(self, "func_" + node.func.name.lower())
197 |         except AttributeError:
198 |             raise ex.UnsupportedFunctionException(node.func.name)
199 | 
200 |         return handler(*node.args)
201 | 
202 |     def func_contains(self, field: ast._Node, substr: ast._Node) -> ClauseElement:
203 |         ":meta private:"
204 |         return self._substr_function(field, substr, "contains")
205 | 
206 |     def func_startswith(self, field: ast._Node, substr: ast._Node) -> ClauseElement:
207 |         ":meta private:"
208 |         return self._substr_function(field, substr, "startswith")
209 | 
210 |     def func_endswith(self, field: ast._Node, substr: ast._Node) -> ClauseElement:
211 |         ":meta private:"
212 |         return self._substr_function(field, substr, "endswith")
213 | 
214 |     def func_length(self, arg: ast._Node) -> functions.Function:
215 |         ":meta private:"
216 |         return functions.char_length(self.visit(arg))
217 | 
218 |     def func_concat(self, *args: ast._Node) -> functions.Function:
219 |         ":meta private:"
220 |         return functions.concat(*[self.visit(arg) for arg in args])
221 | 
222 |     def func_indexof(self, first: ast._Node, second: ast._Node) -> functions.Function:
223 |         ":meta private:"
224 |         # TODO: Highly dialect dependent, might want to implement in GenericFunction:
225 |         # Subtract 1 because OData is 0-indexed while SQL is 1-indexed
226 |         return functions_ext.strpos(self.visit(first), self.visit(second)) - 1
227 | 
228 |     def func_substring(
229 |         self, fullstr: ast._Node, index: ast._Node, nchars: Optional[ast._Node] = None
230 |     ) -> functions.Function:
231 |         ":meta private:"
232 |         # Add 1 because OData is 0-indexed while SQL is 1-indexed
233 |         if nchars:
234 |             return functions_ext.substr(
235 |                 self.visit(fullstr),
236 |                 self.visit(index) + 1,
237 |                 self.visit(nchars),
238 |             )
239 |         else:
240 |             return functions_ext.substr(self.visit(fullstr), self.visit(index) + 1)
241 | 
242 |     def func_matchespattern(
243 |         self, field: ast._Node, pattern: ast._Node
244 |     ) -> functions.Function:
245 |         ":meta private:"
246 |         identifier = self.visit(field)
247 |         return identifier.regexp_match(self.visit(pattern))
248 | 
249 |     def func_tolower(self, field: ast._Node) -> functions.Function:
250 |         ":meta private:"
251 |         return functions_ext.lower(self.visit(field))
252 | 
253 |     def func_toupper(self, field: ast._Node) -> functions.Function:
254 |         ":meta private:"
255 |         return functions_ext.upper(self.visit(field))
256 | 
257 |     def func_trim(self, field: ast._Node) -> functions.Function:
258 |         ":meta private:"
259 |         return functions_ext.ltrim(functions_ext.rtrim(self.visit(field)))
260 | 
261 |     def func_date(self, field: ast._Node) -> ClauseElement:
262 |         ":meta private:"
263 |         return cast(self.visit(field), Date)
264 | 
265 |     def func_day(self, field: ast._Node) -> functions.Function:
266 |         ":meta private:"
267 |         return extract("day", self.visit(field))
268 | 
269 |     def func_hour(self, field: ast._Node) -> functions.Function:
270 |         ":meta private:"
271 |         return extract("hour", self.visit(field))
272 | 
273 |     def func_minute(self, field: ast._Node) -> functions.Function:
274 |         ":meta private:"
275 |         return extract("minute", self.visit(field))
276 | 
277 |     def func_month(self, field: ast._Node) -> functions.Function:
278 |         ":meta private:"
279 |         return extract("month", self.visit(field))
280 | 
281 |     def func_now(self) -> functions.Function:
282 |         ":meta private:"
283 |         return functions.now()
284 | 
285 |     def func_second(self, field: ast._Node) -> functions.Function:
286 |         ":meta private:"
287 |         return extract("second", self.visit(field))
288 | 
289 |     def func_time(self, field: ast._Node) -> functions.Function:
290 |         ":meta private:"
291 |         return cast(self.visit(field), Time)
292 | 
293 |     def func_year(self, field: ast._Node) -> functions.Function:
294 |         ":meta private:"
295 |         return extract("year", self.visit(field))
296 | 
297 |     def func_ceiling(self, field: ast._Node) -> functions.Function:
298 |         ":meta private:"
299 |         return functions_ext.ceil(self.visit(field))
300 | 
301 |     def func_floor(self, field: ast._Node) -> functions.Function:
302 |         ":meta private:"
303 |         return functions_ext.floor(self.visit(field))
304 | 
305 |     def func_round(self, field: ast._Node) -> functions.Function:
306 |         ":meta private:"
307 |         return functions_ext.round(self.visit(field))
308 | 
309 |     def _substr_function(
310 |         self, field: ast._Node, substr: ast._Node, func: str
311 |     ) -> ClauseElement:
312 |         ":meta private:"
313 |         typing.typecheck(field, (ast.Identifier, ast.String), "field")
314 |         typing.typecheck(substr, ast.String, "substring")
315 | 
316 |         identifier = self.visit(field)
317 |         substring = self.visit(substr)
318 |         op = getattr(identifier, func)
319 | 
320 |         return op(substring)
321 | 
--------------------------------------------------------------------------------
/odata_query/sqlalchemy/core.py:
--------------------------------------------------------------------------------
 1 | from typing import Type
 2 | 
 3 | import sqlalchemy as sa
 4 | from sqlalchemy.sql.expression import BinaryExpression, ClauseElement, ColumnClause
 5 | 
 6 | from odata_query import ast, exceptions as ex, visitor
 7 | 
 8 | from . import common
 9 | 
10 | 
11 | class AstToSqlAlchemyCoreVisitor(common._CommonVisitors, visitor.NodeVisitor):
12 |     """
13 |     :class:`NodeVisitor` that transforms an :term:`AST` into a SQLAlchemy where
14 |     clause using Core features.
15 | 
16 |     Args:
17 |         table: A SQLalchemy table
18 |     """
19 | 
20 |     def __init__(self, table: Type[sa.Table]):
21 |         self.table = table
22 | 
23 |     def visit_Identifier(self, node: ast.Identifier) -> ColumnClause:
24 |         """:meta private:"""
25 |         try:
26 |             return self.table.c[node.name]
27 |         except KeyError:
28 |             raise ex.InvalidFieldException(node.name)
29 | 
30 |     def visit_Attribute(self, node: ast.Attribute) -> ColumnClause:
31 |         """:meta private:"""
32 |         raise NotImplementedError(
33 |             "Relationship traversal is not yet implemented for the SQLAlchemy Core visitor."
34 |         )
35 | 
36 |     def visit_Compare(self, node: ast.Compare) -> BinaryExpression:
37 |         """:meta private:"""
38 |         left = self.visit(node.left)
39 |         right = self.visit(node.right)
40 |         op = self.visit(node.comparator)
41 |         return op(left, right)
42 | 
43 |     def visit_CollectionLambda(self, node: ast.CollectionLambda) -> ClauseElement:
44 |         """:meta private:"""
45 |         raise NotImplementedError(
46 |             "Collection lambda is not yet implemented for the SQLAlchemy Core visitor."
47 |         )
48 | 
--------------------------------------------------------------------------------
/odata_query/sqlalchemy/functions_ext.py:
--------------------------------------------------------------------------------
 1 | from sqlalchemy.sql.functions import GenericFunction
 2 | from sqlalchemy.types import Integer, String
 3 | 
 4 | 
 5 | class strpos(GenericFunction):
 6 |     type = Integer
 7 |     package = "odata"
 8 |     inherit_cache = True
 9 | 
10 | 
11 | class substr(GenericFunction):
12 |     type = String
13 |     package = "odata"
14 |     inherit_cache = True
15 | 
16 | 
17 | class lower(GenericFunction):
18 |     type = String
19 |     package = "odata"
20 |     inherit_cache = True
21 | 
22 | 
23 | class upper(GenericFunction):
24 |     type = String
25 |     package = "odata"
26 |     inherit_cache = True
27 | 
28 | 
29 | class ltrim(GenericFunction):
30 |     type = String
31 |     package = "odata"
32 |     inherit_cache = True
33 | 
34 | 
35 | class rtrim(GenericFunction):
36 |     type = String
37 |     package = "odata"
38 |     inherit_cache = True
39 | 
40 | 
41 | class ceil(GenericFunction):
42 |     type = Integer
43 |     package = "odata"
44 |     inherit_cache = True
45 | 
46 | 
47 | class floor(GenericFunction):
48 |     type = Integer
49 |     package = "odata"
50 |     inherit_cache = True
51 | 
52 | 
53 | class round(GenericFunction):
54 |     type = Integer
55 |     package = "odata"
56 |     inherit_cache = True
57 | 
--------------------------------------------------------------------------------
/odata_query/sqlalchemy/orm.py:
--------------------------------------------------------------------------------
  1 | from typing import List, Type
  2 | 
  3 | from sqlalchemy.inspection import inspect
  4 | from sqlalchemy.orm.attributes import InstrumentedAttribute
  5 | from sqlalchemy.orm.decl_api import DeclarativeMeta
  6 | from sqlalchemy.orm.relationships import RelationshipProperty
  7 | from sqlalchemy.sql.expression import BinaryExpression, ClauseElement, ColumnClause
  8 | 
  9 | from odata_query import ast, exceptions as ex, utils, visitor
 10 | 
 11 | from . import common
 12 | 
 13 | 
 14 | class AstToSqlAlchemyOrmVisitor(common._CommonVisitors, visitor.NodeVisitor):
 15 |     """
 16 |     :class:`NodeVisitor` that transforms an :term:`AST` into a SQLAlchemy query
 17 |     filter clause using ORM features.
 18 | 
 19 |     Args:
 20 |         root_model: The root model of the query.
 21 |     """
 22 | 
 23 |     def __init__(self, root_model: Type[DeclarativeMeta]):
 24 |         self.root_model = root_model
 25 |         self.join_relationships: List[InstrumentedAttribute] = []
 26 | 
 27 |     def visit_Identifier(self, node: ast.Identifier) -> ColumnClause:
 28 |         ":meta private:"
 29 |         try:
 30 |             return getattr(self.root_model, node.name)
 31 |         except AttributeError:
 32 |             raise ex.InvalidFieldException(node.name)
 33 | 
 34 |     def visit_Attribute(self, node: ast.Attribute) -> ColumnClause:
 35 |         ":meta private:"
 36 |         rel_attr = self.visit(node.owner)
 37 |         # Owner is an InstrumentedAttribute, hopefully of a relationship.
 38 |         # But we need the model pointed to by the relationship.
 39 |         prop_inspect = inspect(rel_attr).property
 40 |         if not isinstance(prop_inspect, RelationshipProperty):
 41 |             # TODO: new exception:
 42 |             raise ValueError(f"Not a relationship: {node.owner}")
 43 |         self.join_relationships.append(rel_attr)
 44 | 
 45 |         # We'd like to reference the column on the related class:
 46 |         owner_cls = prop_inspect.entity.class_
 47 |         try:
 48 |             return getattr(owner_cls, node.attr)
 49 |         except AttributeError:
 50 |             raise ex.InvalidFieldException(node.attr)
 51 | 
 52 |     def visit_Compare(self, node: ast.Compare) -> BinaryExpression:
 53 |         ":meta private:"
 54 |         left = self.visit(node.left)
 55 |         right = self.visit(node.right)
 56 |         op = self.visit(node.comparator)
 57 | 
 58 |         # If a node is a `relationship` representing a single foreign key,
 59 |         # the client meant to compare the foreign key, not the related object.
 60 |         # E.g. In "blogpost/author eq 1", left should be "blogpost/author_id"
 61 |         left = self._maybe_sub_relationship_with_foreign_key(left)
 62 |         right = self._maybe_sub_relationship_with_foreign_key(right)
 63 | 
 64 |         return op(left, right)
 65 | 
 66 |     def visit_CollectionLambda(self, node: ast.CollectionLambda) -> ClauseElement:
 67 |         ":meta private:"
 68 |         owner_prop = self.visit(node.owner)
 69 |         collection_model = inspect(owner_prop).property.entity.class_
 70 | 
 71 |         if node.lambda_:
 72 |             # For the lambda, we want to strip the identifier off, because
 73 |             # we will execute this as a subquery in the wanted model's context.
 74 |             subq_ast = utils.expression_relative_to_identifier(
 75 |                 node.lambda_.identifier, node.lambda_.expression
 76 |             )
 77 |             subq_transformer = self.__class__(collection_model)
 78 |             subquery_filter = subq_transformer.visit(subq_ast)
 79 |         else:
 80 |             subquery_filter = None
 81 | 
 82 |         if isinstance(node.operator, ast.Any):
 83 |             return owner_prop.any(subquery_filter)
 84 |         else:
 85 |             # For an ALL query, invert both the filter and the EXISTS:
 86 |             if node.lambda_:
 87 |                 subquery_filter = ~subquery_filter
 88 |             return ~owner_prop.any(subquery_filter)
 89 | 
 90 |     def _maybe_sub_relationship_with_foreign_key(
 91 |         self, elem: ClauseElement
 92 |     ) -> ClauseElement:
 93 |         """
 94 |         If the given ClauseElement is a `relationship` with a single ForeignKey,
 95 |         replace it with the `ForeignKey` itself.
 96 | 
 97 |         :meta private:
 98 |         """
 99 |         try:
100 |             prop_inspect = inspect(elem).property
101 |             if isinstance(prop_inspect, RelationshipProperty):
102 |                 foreign_key = prop_inspect._calculated_foreign_keys
103 |                 if len(foreign_key) == 1:
104 |                     return next(iter(foreign_key))
105 |         except Exception:
106 |             pass
107 | 
108 |         return elem
109 | 
--------------------------------------------------------------------------------
/odata_query/sqlalchemy/shorthand.py:
--------------------------------------------------------------------------------
 1 | from typing import List
 2 | 
 3 | from sqlalchemy.orm.query import Query
 4 | from sqlalchemy.sql.expression import ClauseElement, Select
 5 | 
 6 | from odata_query.grammar import ODataLexer, ODataParser  # type: ignore
 7 | 
 8 | from .core import AstToSqlAlchemyCoreVisitor
 9 | from .orm import AstToSqlAlchemyOrmVisitor
10 | 
11 | 
12 | def _get_joined_attrs(query: Select) -> List[str]:
13 |     # use _legacy_setup_joins for legacy Query objects
14 |     setup_joins = (
15 |         getattr(query, "_legacy_setup_joins", query._setup_joins) or query._setup_joins
16 |     )
17 |     return [str(join[0]) for join in setup_joins]
18 | 
19 | 
20 | def apply_odata_query(query: ClauseElement, odata_query: str) -> ClauseElement:
21 |     """
22 |     Shorthand for applying an OData query to a SQLAlchemy query.
23 | 
24 |     Args:
25 |         query: SQLAlchemy query to apply the OData query to.
26 |         odata_query: OData query string.
27 |     Returns:
28 |         ClauseElement: The modified query
29 |     """
30 |     lexer = ODataLexer()
31 |     parser = ODataParser()
32 | 
33 |     clause_elem: ClauseElement
34 |     if isinstance(query, Query):
35 |         # For now, we keep supporting the 1.x style of queries unofficially.
36 |         # GITHUB-34
37 |         clause_elem = query.__clause_element__()
38 |     else:
39 |         clause_elem = query
40 | 
41 |     model = clause_elem.columns_clause_froms[0].entity_namespace
42 | 
43 |     ast = parser.parse(lexer.tokenize(odata_query))
44 |     transformer = AstToSqlAlchemyOrmVisitor(model)
45 |     where_clause = transformer.visit(ast)
46 | 
47 |     existing_joins = _get_joined_attrs(query)
48 |     for required_join in transformer.join_relationships:
49 |         if (
50 |             str(required_join) not in existing_joins
51 |             and str(required_join.key) not in existing_joins
52 |         ):
53 |             query = query.join(required_join)
54 | 
55 |     return query.filter(where_clause)
56 | 
57 | 
58 | def apply_odata_core(query: ClauseElement, odata_query: str) -> ClauseElement:
59 |     """
60 |     Shorthand for applying an OData query to a SQLAlchemy core.
61 | 
62 |     Args:
63 |         query: SQLAlchemy query to apply the OData query to.
64 |         odata_query: OData query string.
65 |     Returns:
66 |         ClauseElement: The modified query
67 |     """
68 |     lexer = ODataLexer()
69 |     parser = ODataParser()
70 |     table = query.columns_clause_froms[0]
71 | 
72 |     ast = parser.parse(lexer.tokenize(odata_query))
73 |     transformer = AstToSqlAlchemyCoreVisitor(table)
74 |     where_clause = transformer.visit(ast)
75 |     return query.filter(where_clause)
76 | 
--------------------------------------------------------------------------------
/odata_query/typing.py:
--------------------------------------------------------------------------------
  1 | import logging
  2 | import operator
  3 | from typing import Optional, Tuple, Type, Union
  4 | 
  5 | from . import ast, exceptions as ex
  6 | 
  7 | log = logging.getLogger(__name__)
  8 | 
  9 | 
 10 | def typecheck(
 11 |     node: ast._Node, expected_type: Union[Type, Tuple[Type, ...]], field_name: str
 12 | ) -> None:
 13 |     """
 14 |     Checks that the inferred type of ``node`` is (one) of ``expected_type``, and
 15 |     raises :class:`ArgumentTypeException` if not.
 16 | 
 17 |     Args:
 18 |         node: The node to type check.
 19 |         expected_type: The allowed type(s) the node can have.
 20 |         field_name: The name of the field you're typechecking. Only used in the
 21 |             exception.
 22 |     Raises:
 23 |         ArgumentTypeException
 24 |     """
 25 |     actual_type = infer_type(node)
 26 |     compare = operator.contains if isinstance(expected_type, tuple) else operator.eq
 27 |     if actual_type and not compare(expected_type, actual_type):
 28 |         allowed = (
 29 |             [t.__name__ for t in expected_type]
 30 |             if isinstance(expected_type, tuple)
 31 |             else expected_type.__name__
 32 |         )
 33 |         raise ex.ArgumentTypeException(field_name, str(allowed), actual_type.__name__)
 34 | 
 35 | 
 36 | def infer_type(node: ast._Node) -> Optional[Type[ast._Node]]:
 37 |     """
 38 |     Tries to infer the type of ``node``.
 39 | 
 40 |     Args:
 41 |         node: The node to infer the type for.
 42 |     Returns:
 43 |         The inferred type or ``None`` if unable to infer.
 44 |     """
 45 |     if isinstance(node, (ast._Literal)):
 46 |         return type(node)
 47 | 
 48 |     if isinstance(node, (ast.Compare, ast.BoolOp)):
 49 |         return ast.Boolean
 50 | 
 51 |     if isinstance(node, ast.Call):
 52 |         return infer_return_type(node)
 53 | 
 54 |     log.debug("Failed to infer type for %s", node)
 55 |     return None
 56 | 
 57 | 
 58 | def infer_return_type(node: ast.Call) -> Optional[Type[ast._Node]]:
 59 |     """
 60 |     Tries to infer the type of a function call ``node``.
 61 | 
 62 |     Args:
 63 |         node: The node to infer the type for.
 64 |     Returns:
 65 |         The inferred type or ``None`` if unable to infer.
 66 |     """
 67 |     func = node.func.full_name()
 68 | 
 69 |     if func in (
 70 |         "contains",
 71 |         "endswith",
 72 |         "startswith",
 73 |         "hassubset",
 74 |         "hassubsequence",
 75 |         "geo.intersects",
 76 |     ):
 77 |         return ast.Boolean
 78 | 
 79 |     if func in (
 80 |         "indexof",
 81 |         "length",
 82 |         "year",
 83 |         "month",
 84 |         "day",
 85 |         "hour",
 86 |         "minute",
 87 |         "second",
 88 |         "totaloffsetminutes",
 89 |     ):
 90 |         return ast.Integer
 91 | 
 92 |     if func in (
 93 |         "fractionalseconds",
 94 |         "totalseconds",
 95 |         "ceiling",
 96 |         "floor",
 97 |         "round",
 98 |         "geo.distance",
 99 |         "geo.length",
100 |     ):
101 |         return ast.Float
102 | 
103 |     if func in ("tolower", "toupper", "trim"):
104 |         return ast.String
105 | 
106 |     if func == "date":
107 |         return ast.Date
108 | 
109 |     if func in ("maxdatetime", "mindatetime", "now"):
110 |         return ast.DateTime
111 | 
112 |     if func == "concat":
113 |         return infer_type(node.args[0]) or infer_type(node.args[1])
114 | 
115 |     if func == "substring":
116 |         return infer_type(node.args[0])
117 | 
118 |     return None
119 | 
--------------------------------------------------------------------------------
/odata_query/utils.py:
--------------------------------------------------------------------------------
 1 | from . import ast
 2 | from .rewrite import IdentifierStripper
 3 | 
 4 | 
 5 | def expression_relative_to_identifier(
 6 |     identifier: ast.Identifier, expression: ast._Node
 7 | ) -> ast._Node:
 8 |     """
 9 |     Shorthand for the :class:`IdentifierStripper`.
10 | 
11 |     Args:
12 |         identifier: Identifier to strip from ``expression``.
13 |         expression: Expression to strip the ``identifier`` from.
14 | 
15 |     Returns:
16 |         The ``expression`` relative to the ``identifier``.
17 |     """
18 |     stripper = IdentifierStripper(identifier)
19 |     result = stripper.visit(expression)
20 |     return result
21 | 
--------------------------------------------------------------------------------
/odata_query/visitor.py:
--------------------------------------------------------------------------------
 1 | from dataclasses import fields
 2 | from typing import Any, Iterator, Tuple
 3 | 
 4 | from . import ast
 5 | 
 6 | 
 7 | def iter_dataclass_fields(node: ast._Node) -> Iterator[Tuple[str, Any]]:
 8 |     """
 9 |     Loops over all fields of the given node, yielding the field's name and
10 |     the current value.
11 | 
12 |     Yields:
13 |         Tuples of ``(fieldname, value)`` for each field in ``node._fields``.
14 |     """
15 |     for field in fields(node):
16 |         yield field.name, getattr(node, field.name)
17 | 
18 | 
19 | class NodeVisitor:
20 |     """
21 |     Base class for visitors that walk the :term:`AST` and calls a visitor
22 |     method for every node found. This method may return a value
23 |     which is forwarded by the :func:`visit` method.
24 | 
25 |     This class is meant to be subclassed, with the subclass adding visitor
26 |     methods.
27 |     By default the visitor methods for the nodes are named ``'visit_'`` +
28 |     class name of the node (e.g. ``visit_Identifier(self, identifier)``).
29 |     If no visitor method exists for a node, the :func:`generic_visit` visitor is
30 |     used instead.
31 |     """
32 | 
33 |     def visit(self, node: ast._Node) -> Any:
34 |         """
35 |         Looks for an explicit node visiting method on ``self``,
36 |         otherwise calls :func:`generic_visit`.
37 | 
38 |         Returns:
39 |             Whatever the called method returned. The user is free to choose what
40 |             the :class:`NodeVisitor` should return.
41 |         """
42 |         method = "visit_" + node.__class__.__name__
43 |         visitor = getattr(self, method, self.generic_visit)
44 |         return visitor(node)
45 | 
46 |     def generic_visit(self, node: ast._Node):
47 |         """
48 |         Visits all fields on ``node`` recursively.
49 |         Called if no explicit visitor method exists for a node.
50 |         """
51 |         for field, value in iter_dataclass_fields(node):
52 |             if isinstance(value, list):
53 |                 for item in value:
54 |                     if isinstance(item, ast._Node):
55 |                         self.visit(item)
56 |             elif isinstance(value, ast._Node):
57 |                 self.visit(value)
58 | 
59 | 
60 | class NodeTransformer(NodeVisitor):
61 |     """
62 |     A subclass of :class:`NodeVisitor` that allows replacing of nodes in the
63 |     :term:`AST` as it passes over it. The visitor methods should return instances
64 |     of :class:`_Node` that replace the passed node.
65 |     """
66 | 
67 |     def generic_visit(self, node: ast._Node) -> ast._Node:
68 |         new_kwargs = {}
69 | 
70 |         for field, value in iter_dataclass_fields(node):
71 |             if isinstance(value, list):
72 |                 new_val = []
73 |                 for item in value:
74 |                     if isinstance(item, ast._Node):
75 |                         new_val.append(self.visit(item))
76 |                     else:
77 |                         new_val.append(item)
78 |             elif isinstance(value, ast._Node):
79 |                 new_val = self.visit(value)
80 |             else:
81 |                 new_val = value
82 | 
83 |             new_kwargs[field] = new_val
84 | 
85 |         return type(node)(**new_kwargs)
86 | 
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [tool.poetry]
 2 | name = "odata-query"
 3 | version = "0.10.0"
 4 | description = "An OData query parser and transpiler."
 5 | authors = ["Oliver Hofkens "]
 6 | readme = "README.rst"
 7 | license = "MIT"
 8 | keywords = ["OData", "Query", "Parser"]
 9 | classifiers = [
10 |     "Development Status :: 5 - Production/Stable",
11 |     "Environment :: Web Environment",
12 |     "Framework :: Django",
13 |     "Intended Audience :: Developers",
14 |     "License :: OSI Approved :: MIT License",
15 |     "Operating System :: OS Independent",
16 |     "Programming Language :: Python :: 3 :: Only",
17 |     "Programming Language :: SQL",
18 |     "Topic :: Database",
19 |     "Topic :: Internet :: WWW/HTTP",
20 |     "Topic :: Internet :: WWW/HTTP :: Indexing/Search",
21 |     "Topic :: Software Development :: Compilers"
22 | ]
23 | include = ["odata_query/py.typed"]
24 | 
25 | [tool.poetry.dependencies]
26 | python = "^3.7"
27 | 
28 | python-dateutil = "^2.8.1"
29 | sly = "^0.4"
30 | 
31 | django = { version = ">=2.2", optional = true }
32 | sqlalchemy = { version = "^1.4", optional = true }
33 | 
34 | black = { version = "^22.1", optional = true }
35 | bump2version = { version = "^1.0", optional = true }
36 | flake8 = { version = "^3.8", optional = true }
37 | isort = { version = "^5.7", optional = true }
38 | mypy = { version = "^0.931", optional = true }
39 | types-python-dateutil = { version = "^2.8.1", optional = true }
40 | pytest = { version = "^6.2 || ^7.0", optional = true }
41 | pytest-cov = { version = "*", optional = true }
42 | sphinx = { version = "^5.3", optional = true }
43 | sphinx-rtd-theme = { version = "^2.0", optional = true }
44 | vulture = { version = "^2.3", optional = true }
45 | 
46 | [tool.poetry.extras]
47 | dev = ["bump2version"]
48 | django = ["django"]
49 | docs = ["sphinx", "sphinx-rtd-theme"]
50 | linting = ["flake8", "black", "isort", "mypy", "types-python-dateutil", "yamllint", "vulture"]
51 | sqlalchemy = ["sqlalchemy"]
52 | testing = ["pytest", "pytest-cov", "pytest-xdist"]
53 | 
54 | [build-system]
55 | requires = ["poetry-core>=1.0.0"]
56 | build-backend = "poetry.core.masonry.api"
57 | 
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
 1 | [bumpversion]
 2 | current_version = 0.10.0
 3 | commit = True
 4 | tag = True
 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(?:(?Pa|b|rc)(?P\d+))?
 6 | serialize = 
 7 | 	{major}.{minor}.{patch}{release}{build}
 8 | 	{major}.{minor}.{patch}
 9 | 
10 | [bumpversion:part:release]
11 | values = 
12 | 	a
13 | 	b
14 | 	rc
15 | 
16 | [bumpversion:part:build]
17 | first_value = 0
18 | 
19 | [flake8]
20 | ignore = E203, E266, E501, W503
21 | max-line-length = 80
22 | max-complexity = 18
23 | select = B,C,E,F,W,T4,B9
24 | show_source = True
25 | per_file_ignores = 
26 | 	__init__.py: F401,F403
27 | 	odata_query/grammar.py:F821,F811
28 | 
29 | [tool:isort]
30 | multi_line_output = 3
31 | include_trailing_comma = True
32 | force_grid_wrap = 0
33 | combine_as_imports = True
34 | line_length = 88
35 | default_section = THIRDPARTY
36 | known_first_party = odata_query, tests
37 | 
38 | [tool:pytest]
39 | addopts = 
40 | 	--cov .
41 | 	--cov-branch
42 | 	--cov-config setup.cfg
43 | 	--cov-report term-missing
44 | 	--cov-report xml:coverage.xml
45 | testpaths = tests
46 | markers = 
47 | 	slow: marks tests as slow (deselect with '-m "not slow"')
48 | 
49 | [coverage:run]
50 | omit = ./tests/*, ./.tox/*
51 | 
52 | [bumpversion:file:pyproject.toml]
53 | search = version = "{current_version}"
54 | replace = version = "{new_version}"
55 | 
56 | [bumpversion:file:sonar-project.properties]
57 | search = sonar.projectVersion={current_version}
58 | replace = sonar.projectVersion={new_version}
59 | 
60 | [bumpversion:file:odata_query/__init__.py]
61 | search = __version__ = "{current_version}"
62 | replace = __version__ = "{new_version}"
63 | 
64 | [bumpversion:file:docs/source/conf.py]
65 | search = release = "{current_version}"
66 | replace = release = "{new_version}"
67 | 
--------------------------------------------------------------------------------
/sonar-project.properties:
--------------------------------------------------------------------------------
 1 | sonar.projectKey=gorillaco_odata-query
 2 | sonar.projectName=OData Query
 3 | sonar.projectVersion=0.10.0
 4 | 
 5 | sonar.sources=./odata_query
 6 | sonar.tests=./tests
 7 | 
 8 | sonar.python.coverage.reportPaths=coverage.xml
 9 | 
10 | 
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
 1 | from pathlib import Path
 2 | 
 3 | import pytest
 4 | 
 5 | from odata_query.grammar import ODataLexer, ODataParser
 6 | 
 7 | 
 8 | @pytest.fixture(scope="session")
 9 | def lexer():
10 |     return ODataLexer()
11 | 
12 | 
13 | @pytest.fixture
14 | def parser():
15 |     return ODataParser()
16 | 
17 | 
18 | @pytest.fixture(scope="session")
19 | def data_dir():
20 |     data_dir_path = Path(__file__).parent / "data"
21 |     data_dir_path.mkdir(exist_ok=True)
22 | 
23 |     return data_dir_path
24 | 
--------------------------------------------------------------------------------
/tests/data/world_borders.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/data/world_borders.zip
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/__init__.py
--------------------------------------------------------------------------------
/tests/integration/django/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/django/__init__.py
--------------------------------------------------------------------------------
/tests/integration/django/apps.py:
--------------------------------------------------------------------------------
 1 | from django.apps import AppConfig
 2 | 
 3 | 
 4 | class ODataQueryConfig(AppConfig):
 5 |     name = "tests.integration.django"
 6 |     verbose_name = "OData Query Django test app"
 7 |     default = True
 8 | 
 9 | 
10 | class DbRouter:
11 |     """
12 |     Ensure that GeoDjango models go to the SpatiaLite database, while other
13 |     models use the default SQLite database.
14 |     """
15 | 
16 |     GEO_APP = "django_geo"
17 | 
18 |     def db_for_read(self, model, **hints):
19 |         if model._meta.app_label == self.GEO_APP:
20 |             return "geo"
21 |         return None
22 | 
23 |     def db_for_write(self, model, **hints):
24 |         if model._meta.app_label == self.GEO_APP:
25 |             return "geo"
26 |         return None
27 | 
28 |     def allow_relation(self, obj1, obj2, **hints):
29 |         return obj1._meta.app_label == obj2._meta.app_label
30 | 
31 |     def allow_migrate(self, db: str, app_label: str, model_name=None, **hints):
32 |         if app_label != self.GEO_APP and db == "default":
33 |             return True
34 |         if app_label == self.GEO_APP and db == "geo":
35 |             return True
36 |         return False
37 | 
--------------------------------------------------------------------------------
/tests/integration/django/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from django.core import management
3 | 
4 | 
5 | @pytest.fixture(scope="session")
6 | def django_db():
7 |     management.call_command("migrate", "--run-syncdb")
8 | 
--------------------------------------------------------------------------------
/tests/integration/django/models.py:
--------------------------------------------------------------------------------
 1 | import uuid
 2 | 
 3 | import django
 4 | from django.db import models
 5 | 
 6 | django.setup()
 7 | 
 8 | 
 9 | class Author(models.Model):
10 |     id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
11 |     name = models.CharField(max_length=128)
12 | 
13 | 
14 | class BlogPost(models.Model):
15 |     published_at = models.DateTimeField()
16 |     title = models.CharField(max_length=128)
17 |     content = models.TextField()
18 | 
19 |     authors = models.ManyToManyField(Author, related_name="blogposts")
20 | 
21 | 
22 | class Comment(models.Model):
23 |     content = models.TextField()
24 | 
25 |     author = models.ForeignKey(
26 |         Author, on_delete=models.CASCADE, related_name="comments"
27 |     )
28 |     blogpost = models.ForeignKey(
29 |         BlogPost, on_delete=models.CASCADE, related_name="comments"
30 |     )
31 | 
--------------------------------------------------------------------------------
/tests/integration/django/settings.py:
--------------------------------------------------------------------------------
 1 | from pathlib import Path
 2 | 
 3 | DB_DIR = Path(__file__).parent / "db"
 4 | DB_DIR.mkdir(exist_ok=True)
 5 | 
 6 | DATABASES = {
 7 |     "default": {
 8 |         "ENGINE": "django.db.backends.sqlite3",
 9 |         "NAME": str(DB_DIR / "odata-query"),
10 |     },
11 |     "geo": {
12 |         "ENGINE": "django.contrib.gis.db.backends.spatialite",
13 |         "NAME": str(DB_DIR / "odata-query-geo"),
14 |     },
15 | }
16 | DATABASE_ROUTERS = ["tests.integration.django.apps.DbRouter"]
17 | DEBUG = True
18 | INSTALLED_APPS = [
19 |     "tests.integration.django.apps.ODataQueryConfig",
20 |     # GEO:
21 |     "django.contrib.gis",
22 |     "tests.integration.django_geo.apps.ODataQueryConfig",
23 | ]
24 | 
--------------------------------------------------------------------------------
/tests/integration/django/test_querying.py:
--------------------------------------------------------------------------------
  1 | import datetime as dt
  2 | from typing import Type
  3 | 
  4 | import pytest
  5 | from django.db import models, transaction
  6 | 
  7 | from odata_query.django import apply_odata_query
  8 | from odata_query.django.django_q import DJANGO_LT_4
  9 | 
 10 | from .models import Author, BlogPost, Comment
 11 | 
 12 | 
 13 | @pytest.fixture
 14 | def sample_data_sess(django_db):
 15 |     # https://docs.djangoproject.com/en/3.1/topics/db/transactions/#managing-autocommit
 16 |     transaction.set_autocommit(False)
 17 |     a1 = Author(name="Gorilla")
 18 |     a1.save()
 19 |     a2 = Author(name="Baboon")
 20 |     a2.save()
 21 |     a3 = Author(name="Saki")
 22 |     a3.save()
 23 | 
 24 |     bp1 = BlogPost(
 25 |         title="Querying Data",
 26 |         published_at=dt.datetime(2020, 1, 1),
 27 |         content="How 2 query data...",
 28 |     )
 29 |     bp1.save()
 30 |     bp2 = BlogPost(
 31 |         title="Automating Monkey Jobs",
 32 |         published_at=dt.datetime(2019, 1, 1),
 33 |         content="How 2 automate monkey jobs...",
 34 |     )
 35 |     bp2.save()
 36 |     bp1.authors.set([a1])
 37 |     bp2.authors.set([a2, a3])
 38 | 
 39 |     c1 = Comment(content="Dope!", author=a2, blogpost=bp1)
 40 |     c1.save()
 41 |     c2 = Comment(content="Cool!", author=a1, blogpost=bp2)
 42 |     c2.save()
 43 |     yield
 44 |     transaction.rollback()
 45 | 
 46 | 
 47 | @pytest.mark.parametrize(
 48 |     "model, query, exp_results",
 49 |     [
 50 |         (Author, "name eq 'Baboon'", 1),
 51 |         (Author, "startswith(name, 'Gori')", 1),
 52 |         (BlogPost, "contains(content, 'How')", 2),
 53 |         (BlogPost, "published_at gt 2019-06-01", 1),
 54 |         (Author, "contains(blogposts/title, 'Monkey')", 2),
 55 |         (Author, "startswith(blogposts/comments/content, 'Cool')", 2),
 56 |         (Author, "comments/any()", 2),
 57 |         (BlogPost, "authors/any(a: contains(a/name, 'o'))", 2),
 58 |         (BlogPost, "authors/all(a: contains(a/name, 'o'))", 1),
 59 |         (Author, "blogposts/comments/any(c: contains(c/content, 'Cool'))", 2),
 60 |         (Author, "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 0),
 61 |         pytest.param(
 62 |             Author,
 63 |             "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)",
 64 |             0,
 65 |             marks=pytest.mark.xfail(
 66 |                 not DJANGO_LT_4,
 67 |                 reason="Bug related to https://code.djangoproject.com/ticket/33705",
 68 |             ),
 69 |         ),
 70 |         (BlogPost, "comments/author eq 0", 0),
 71 |         (BlogPost, "substring(content, 0) eq 'test'", 0),
 72 |         (BlogPost, "year(published_at) eq 2019", 1),
 73 |         # GITHUB-19
 74 |         (BlogPost, "contains(title, 'Query') eq true", 1),
 75 |         (BlogPost, "contains(title, 'Query') eq false", 1),
 76 |     ],
 77 | )
 78 | def test_query_with_odata(
 79 |     model: Type[models.Model],
 80 |     query: str,
 81 |     exp_results: int,
 82 |     sample_data_sess,
 83 | ):
 84 |     q = apply_odata_query(model.objects, query)
 85 |     results = q.all()
 86 |     assert len(results) == exp_results
 87 | 
 88 | 
 89 | @pytest.mark.xfail(
 90 |     not DJANGO_LT_4,
 91 |     reason="Bug fixed in unreleased version: https://code.djangoproject.com/ticket/33705",
 92 | )
 93 | @pytest.mark.parametrize(
 94 |     "odata_query, expected_sql",
 95 |     [
 96 |         (
 97 |             "author eq null",
 98 |             (
 99 |                 'SELECT DISTINCT "django_comment"."id", "django_comment"."content", "django_comment"."author_id", "django_comment"."blogpost_id" '
100 |                 'FROM "django_comment" '
101 |                 'WHERE "django_comment"."author_id" IS NULL'
102 |             ),
103 |         ),
104 |         (
105 |             "author ne null",
106 |             (
107 |                 'SELECT DISTINCT "django_comment"."id", "django_comment"."content", "django_comment"."author_id", "django_comment"."blogpost_id" '
108 |                 'FROM "django_comment" '
109 |                 'WHERE "django_comment"."author_id" IS NOT NULL'
110 |             ),
111 |         ),
112 |     ],
113 | )
114 | def test_odata_filter_to_sql_query(odata_query: str, expected_sql: str):
115 |     q = apply_odata_query(Comment.objects, odata_query)
116 |     queryset = q.distinct()
117 |     sql = str(queryset.query)
118 |     assert sql == expected_sql
119 | 
--------------------------------------------------------------------------------
/tests/integration/django/test_utils.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from odata_query.django import utils
 4 | 
 5 | from .models import Author, BlogPost
 6 | 
 7 | 
 8 | @pytest.mark.parametrize(
 9 |     "root_model, rel, exp_model, exp_rel",
10 |     [
11 |         (Author, "blogposts", BlogPost, "authors"),
12 |         (BlogPost, "comments__author", Author, "comments__blogpost"),
13 |     ],
14 | )
15 | def test_reverse_relationship(root_model, rel, exp_model, exp_rel):
16 |     res_rel, res_model = utils.reverse_relationship(rel, root_model)
17 | 
18 |     assert res_rel == exp_rel
19 |     assert res_model is exp_model
20 | 
--------------------------------------------------------------------------------
/tests/integration/django_geo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/django_geo/__init__.py
--------------------------------------------------------------------------------
/tests/integration/django_geo/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 | 
3 | 
4 | class ODataQueryConfig(AppConfig):
5 |     name = "tests.integration.django_geo"
6 |     verbose_name = "OData Query GeoDjango test app"
7 |     default = True
8 | 
--------------------------------------------------------------------------------
/tests/integration/django_geo/conftest.py:
--------------------------------------------------------------------------------
 1 | from pathlib import Path
 2 | from zipfile import ZipFile
 3 | 
 4 | import pytest
 5 | from django.core import management
 6 | 
 7 | 
 8 | @pytest.fixture(scope="session")
 9 | def django_db():
10 |     management.call_command("migrate", "--run-syncdb", "--database", "geo")
11 | 
12 | 
13 | @pytest.fixture(scope="session")
14 | def world_borders_dataset(data_dir: Path):
15 |     target_dir = data_dir / "world_borders"
16 | 
17 |     if target_dir.exists():
18 |         return target_dir
19 | 
20 |     filename_zip = target_dir.with_suffix(".zip")
21 |     with ZipFile(filename_zip, "r") as z:
22 |         z.extractall(target_dir)
23 | 
24 |     return target_dir
25 | 
--------------------------------------------------------------------------------
/tests/integration/django_geo/models.py:
--------------------------------------------------------------------------------
 1 | # https://docs.djangoproject.com/en/4.2/ref/contrib/gis/tutorial/#geographic-models
 2 | 
 3 | from django.core.exceptions import ImproperlyConfigured
 4 | from django.db import models
 5 | 
 6 | # This file needs to be importable even without Geo system libraries installed.
 7 | # Tests using these libraries will be skipped using pytest.skip
 8 | try:
 9 |     from django.contrib.gis.db.models import MultiPolygonField
10 | except (ImportError, ImproperlyConfigured):
11 |     MultiPolygonField = models.CharField
12 | 
13 | 
14 | class WorldBorder(models.Model):
15 |     # Regular Django fields corresponding to the attributes in the
16 |     # world borders shapefile.
17 |     name = models.CharField(max_length=50)
18 |     area = models.IntegerField()
19 |     pop2005 = models.IntegerField("Population 2005")
20 |     fips = models.CharField("FIPS Code", max_length=2, null=True)
21 |     iso2 = models.CharField("2 Digit ISO", max_length=2)
22 |     iso3 = models.CharField("3 Digit ISO", max_length=3)
23 |     un = models.IntegerField("United Nations Code")
24 |     region = models.IntegerField("Region Code")
25 |     subregion = models.IntegerField("Sub-Region Code")
26 |     lon = models.FloatField()
27 |     lat = models.FloatField()
28 | 
29 |     # GeoDjango-specific: a geometry field (MultiPolygonField)
30 |     mpoly = MultiPolygonField()
31 | 
32 |     def __str__(self):
33 |         return self.name
34 | 
--------------------------------------------------------------------------------
/tests/integration/django_geo/test_querying.py:
--------------------------------------------------------------------------------
 1 | from pathlib import Path
 2 | from typing import Type
 3 | 
 4 | import pytest
 5 | from django.core.exceptions import ImproperlyConfigured
 6 | 
 7 | from odata_query.django import apply_odata_query
 8 | 
 9 | from .models import WorldBorder
10 | 
11 | try:
12 |     from django.contrib.gis.db import models
13 |     from django.contrib.gis.utils import LayerMapping
14 | except (ImportError, ImproperlyConfigured):
15 |     pytest.skip(allow_module_level=True, reason="Could not load GIS libraries")
16 | 
17 | # The default spatial reference system for geometry fields is WGS84
18 | # (meaning the SRID is 4326)
19 | SRID = "SRID=4326"
20 | 
21 | 
22 | @pytest.fixture(scope="session")
23 | def sample_data_sess(django_db, world_borders_dataset: Path):
24 |     world_mapping = {
25 |         "fips": "FIPS",
26 |         "iso2": "ISO2",
27 |         "iso3": "ISO3",
28 |         "un": "UN",
29 |         "name": "NAME",
30 |         "area": "AREA",
31 |         "pop2005": "POP2005",
32 |         "region": "REGION",
33 |         "subregion": "SUBREGION",
34 |         "lon": "LON",
35 |         "lat": "LAT",
36 |         "mpoly": "MULTIPOLYGON",
37 |     }
38 | 
39 |     world_shp = world_borders_dataset / "TM_WORLD_BORDERS-0.3.shp"
40 |     lm = LayerMapping(WorldBorder, world_shp, world_mapping, transform=False)
41 |     lm.save(strict=True, verbose=True)
42 |     yield
43 |     WorldBorder.objects.all().delete()
44 | 
45 | 
46 | @pytest.mark.parametrize(
47 |     "model, query, exp_results",
48 |     [
49 |         (
50 |             WorldBorder,
51 |             "geo.length(mpoly) gt 1000000",
52 |             154,
53 |         ),
54 |         (
55 |             WorldBorder,
56 |             f"geo.intersects(mpoly, geography'{SRID};Point(-95.3385 29.7245)')",
57 |             1,
58 |         ),
59 |     ],
60 | )
61 | def test_query_with_odata(
62 |     model: Type[models.Model],
63 |     query: str,
64 |     exp_results: int,
65 |     sample_data_sess,
66 | ):
67 |     q = apply_odata_query(model.objects, query)
68 |     results = q.all()
69 |     assert len(results) == exp_results
70 | 
--------------------------------------------------------------------------------
/tests/integration/sql/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/sql/__init__.py
--------------------------------------------------------------------------------
/tests/integration/sql/conftest.py:
--------------------------------------------------------------------------------
 1 | import sqlite3
 2 | 
 3 | import pytest
 4 | 
 5 | 
 6 | @pytest.fixture(scope="session")
 7 | def db_conn():
 8 |     conn = sqlite3.connect(":memory:", isolation_level=None)
 9 |     yield conn
10 |     conn.close()
11 | 
12 | 
13 | @pytest.fixture(scope="session")
14 | def db_schema(db_conn):
15 |     cur = db_conn.cursor()
16 | 
17 |     cur.execute("CREATE TABLE author (id INTEGER PRIMARY KEY, name TEXT);")
18 |     cur.execute(
19 |         "CREATE TABLE blogpost ("
20 |         "id INTEGER PRIMARY KEY, "
21 |         "published_at TEXT, "
22 |         "title TEXT, "
23 |         "content TEXT);"
24 |     )
25 |     cur.execute(
26 |         "CREATE TABLE comment ("
27 |         "id INTEGER PRIMARY KEY, "
28 |         "content TEXT, "
29 |         "author_id INTEGER, "
30 |         "blogpost_id INTEGER, "
31 |         "FOREIGN KEY (author_id) REFERENCES author(id), "
32 |         "FOREIGN KEY (blogpost_id) REFERENCES blogpost(id)"
33 |         ");"
34 |     )
35 |     cur.execute(
36 |         "CREATE TABLE author_blogpost ("
37 |         "author_id INTEGER, "
38 |         "blogpost_id INTEGER, "
39 |         "FOREIGN KEY (author_id) REFERENCES author(id), "
40 |         "FOREIGN KEY (blogpost_id) REFERENCES blogpost(id)"
41 |         ");"
42 |     )
43 |     db_conn.commit()
44 | 
--------------------------------------------------------------------------------
/tests/integration/sql/test_odata_to_athena_sql.py:
--------------------------------------------------------------------------------
  1 | import pytest
  2 | 
  3 | from odata_query import sql
  4 | 
  5 | 
  6 | @pytest.mark.parametrize(
  7 |     "odata_query, expected",
  8 |     [
  9 |         ("meter_id eq '1'", "\"meter_id\" = '1'"),
 10 |         ("meter_id ne '1'", "\"meter_id\" != '1'"),
 11 |         ("meter_id eq 'o''reilly'''", "\"meter_id\" = 'o''reilly'''"),
 12 |         (
 13 |             "meter_id eq 6c0e37e3-e856-45ee-bd58-484b11882c67",
 14 |             "\"meter_id\" = '6c0e37e3-e856-45ee-bd58-484b11882c67'",
 15 |         ),
 16 |         ("meter_id in ('1',)", "\"meter_id\" IN ('1')"),
 17 |         ("meter_id in ('1', '2')", "\"meter_id\" IN ('1', '2')"),
 18 |         ("not (meter_id in ('1', '2'))", "NOT \"meter_id\" IN ('1', '2')"),
 19 |         ("meter_id eq null", '"meter_id" IS NULL'),
 20 |         ("meter_id ne null", '"meter_id" IS NOT NULL'),
 21 |         ("eac gt 10", '"eac" > 10'),
 22 |         ("eac ge 10", '"eac" >= 10'),
 23 |         ("eac lt 10", '"eac" < 10'),
 24 |         ("eac le 10", '"eac" <= 10'),
 25 |         ("eac gt 1.0 and eac lt 10.0", '"eac" > 1.0 AND "eac" < 10.0'),
 26 |         ("eac ge 1.0 and eac le 10.0", '"eac" >= 1.0 AND "eac" <= 10.0'),
 27 |         (
 28 |             "eac gt 1 and eac lt 1 and meter_id eq '1'",
 29 |             '"eac" > 1 AND "eac" < 1 AND "meter_id" = \'1\'',
 30 |         ),
 31 |         # OData spec defines AND with higher precedence than OR:
 32 |         (
 33 |             "eac gt 1 and eac lt 10 or eac eq 5 and eac ne 10",
 34 |             '("eac" > 1 AND "eac" < 10) OR ("eac" = 5 AND "eac" != 10)',
 35 |         ),
 36 |         # Unless overridden by parentheses:
 37 |         (
 38 |             "eac gt 1 and (eac lt 10 or eac eq 5) and eac ne 10",
 39 |             '"eac" > 1 AND ("eac" < 10 OR "eac" = 5) AND "eac" != 10',
 40 |         ),
 41 |         ("not (eac gt 10 and eac lt 20)", 'NOT ("eac" > 10 AND "eac" < 20)'),
 42 |         ("eac gt 1 eq true", '("eac" > 1) = TRUE'),
 43 |         ("true eq eac gt 1", 'TRUE = ("eac" > 1)'),
 44 |         ("eac add 10 gt 1000", '"eac" + 10 > 1000'),
 45 |         ("eac add 10 gt eac sub 10", '"eac" + 10 > "eac" - 10'),
 46 |         ("eac mul 10 div 10 eq eac", '"eac" * 10 / 10 = "eac"'),
 47 |         ("eac mod 10 add -1 le eac", '"eac" % 10 + -1 <= "eac"'),
 48 |         (
 49 |             "period_start gt 2020-01-01T00:00:00",
 50 |             "\"period_start\" > FROM_ISO8601_TIMESTAMP('2020-01-01T00:00:00')",
 51 |         ),
 52 |         (
 53 |             "period_start add duration'P365D' ge period_end",
 54 |             '"period_start" + INTERVAL \'365\' DAY >= "period_end"',
 55 |         ),
 56 |         (
 57 |             "period_start add duration'P1Y' ge period_end",
 58 |             '"period_start" + INTERVAL \'1\' YEAR >= "period_end"',
 59 |         ),
 60 |         (
 61 |             "period_start add duration'P2M' ge period_end",
 62 |             '"period_start" + INTERVAL \'2\' MONTH >= "period_end"',
 63 |         ),
 64 |         (
 65 |             "period_start add duration'P365DT12H1M1.1S' ge period_end",
 66 |             "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"",
 67 |         ),
 68 |         (
 69 |             "period_start add duration'PT1S' ge period_end",
 70 |             '"period_start" + INTERVAL \'1\' SECOND >= "period_end"',
 71 |         ),
 72 |         ("year(period_start) eq 2019", 'EXTRACT (YEAR FROM "period_start") = 2019'),
 73 |         (
 74 |             "period_end lt now() sub duration'P365D'",
 75 |             "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '365' DAY",
 76 |         ),
 77 |         (
 78 |             "startswith(trim(meter_id), '999')",
 79 |             "TRIM(\"meter_id\") LIKE '999%'",
 80 |         ),
 81 |         (
 82 |             "year(date(now())) eq 2020",
 83 |             "EXTRACT (YEAR FROM CAST (CURRENT_TIMESTAMP AS DATE)) = 2020",
 84 |         ),
 85 |         ("length(concat('abc', 'def')) lt 10", "LENGTH('abc' || 'def') < 10"),
 86 |         (
 87 |             "length(concat(('1', '2'), ('3', '4'))) eq 4",
 88 |             "CARDINALITY(('1', '2') || ('3', '4')) = 4",
 89 |         ),
 90 |         (
 91 |             "indexof(substring('abcdefghi', 3), 'hi') gt 1",
 92 |             "POSITION('hi' IN SUBSTR('abcdefghi', 3 + 1)) - 1 > 1",
 93 |         ),
 94 |         (
 95 |             "substring('hello', 1, 3) eq 'ell'",
 96 |             "SUBSTR('hello', 1 + 1, 3) = 'ell'",
 97 |         ),
 98 |         ("substring((1, 2, 3), 1)", "SLICE((1, 2, 3), 1)"),
 99 |         ("substring((1, 2, 3), 1, 1)", "SLICE((1, 2, 3), 1, 1)"),
100 |         (
101 |             "contains(meter_id, sub_meter_id)",
102 |             "\"meter_id\" LIKE '%' || \"sub_meter_id\" || '%'",
103 |         ),
104 |         (
105 |             "year(supply_start_date) eq (year(now()) sub 1)",
106 |             'EXTRACT (YEAR FROM "supply_start_date") = EXTRACT (YEAR FROM CURRENT_TIMESTAMP) - 1',
107 |         ),
108 |         (
109 |             "measurement_class eq 'C' and endswith(data_collector, 'rie')",
110 |             "\"measurement_class\" = 'C' AND \"data_collector\" LIKE '%rie'",
111 |         ),
112 |         # GITHUB-47
113 |         (
114 |             "contains(tolower(name), tolower('A'))",
115 |             "LOWER(\"name\") LIKE '%' || LOWER('A') || '%'",
116 |         ),
117 |     ],
118 | )
119 | def test_odata_filter_to_sql(odata_query: str, expected: str, lexer, parser):
120 |     ast = parser.parse(lexer.tokenize(odata_query))
121 |     visitor = sql.AstToAthenaSqlVisitor()
122 | 
123 |     if isinstance(expected, str):
124 |         res = visitor.visit(ast)
125 |         assert res == expected
126 |     else:
127 |         with pytest.raises(expected):
128 |             res = visitor.visit(ast)
129 | 
--------------------------------------------------------------------------------
/tests/integration/sql/test_odata_to_sql.py:
--------------------------------------------------------------------------------
  1 | from typing import Union
  2 | 
  3 | import pytest
  4 | 
  5 | from odata_query import exceptions as ex, sql
  6 | 
  7 | 
  8 | @pytest.mark.parametrize(
  9 |     "odata_query, expected",
 10 |     [
 11 |         ("meter_id eq '1'", "\"meter_id\" = '1'"),
 12 |         ("meter_id ne '1'", "\"meter_id\" != '1'"),
 13 |         ("meter_id eq 'o''reilly'''", "\"meter_id\" = 'o''reilly'''"),
 14 |         (
 15 |             "meter_id eq 6c0e37e3-e856-45ee-bd58-484b11882c67",
 16 |             "\"meter_id\" = '6c0e37e3-e856-45ee-bd58-484b11882c67'",
 17 |         ),
 18 |         ("meter_id in ('1',)", "\"meter_id\" IN ('1')"),
 19 |         ("meter_id in ('1', '2')", "\"meter_id\" IN ('1', '2')"),
 20 |         ("not (meter_id in ('1', '2'))", "NOT \"meter_id\" IN ('1', '2')"),
 21 |         ("meter_id eq null", '"meter_id" IS NULL'),
 22 |         ("meter_id ne null", '"meter_id" IS NOT NULL'),
 23 |         ("eac gt 10", '"eac" > 10'),
 24 |         ("eac ge 10", '"eac" >= 10'),
 25 |         ("eac lt 10", '"eac" < 10'),
 26 |         ("eac le 10", '"eac" <= 10'),
 27 |         ("eac gt 1.0 and eac lt 10.0", '"eac" > 1.0 AND "eac" < 10.0'),
 28 |         ("eac ge 1.0 and eac le 10.0", '"eac" >= 1.0 AND "eac" <= 10.0'),
 29 |         (
 30 |             "eac gt 1 and eac lt 1 and meter_id eq '1'",
 31 |             '"eac" > 1 AND "eac" < 1 AND "meter_id" = \'1\'',
 32 |         ),
 33 |         # OData spec defines AND with higher precedence than OR:
 34 |         (
 35 |             "eac gt 1 and eac lt 10 or eac eq 5 and eac ne 10",
 36 |             '("eac" > 1 AND "eac" < 10) OR ("eac" = 5 AND "eac" != 10)',
 37 |         ),
 38 |         # Unless overridden by parentheses:
 39 |         (
 40 |             "eac gt 1 and (eac lt 10 or eac eq 5) and eac ne 10",
 41 |             '"eac" > 1 AND ("eac" < 10 OR "eac" = 5) AND "eac" != 10',
 42 |         ),
 43 |         ("not (eac gt 10 and eac lt 20)", 'NOT ("eac" > 10 AND "eac" < 20)'),
 44 |         ("eac gt 1 eq true", '("eac" > 1) = TRUE'),
 45 |         ("true eq eac gt 1", 'TRUE = ("eac" > 1)'),
 46 |         ("eac add 10 gt 1000", '"eac" + 10 > 1000'),
 47 |         ("eac add 10 gt eac sub 10", '"eac" + 10 > "eac" - 10'),
 48 |         ("eac mul 10 div 10 eq eac", '"eac" * 10 / 10 = "eac"'),
 49 |         ("eac mod 10 add -1 le eac", '"eac" % 10 + -1 <= "eac"'),
 50 |         (
 51 |             "period_start gt 2020-01-01T00:00:00",
 52 |             "\"period_start\" > TIMESTAMP '2020-01-01 00:00:00'",
 53 |         ),
 54 |         (
 55 |             "period_start add duration'P365D' ge period_end",
 56 |             '"period_start" + INTERVAL \'365\' DAY >= "period_end"',
 57 |         ),
 58 |         (
 59 |             "period_start add duration'P365DT12H1M1.1S' ge period_end",
 60 |             "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"",
 61 |         ),
 62 |         (
 63 |             "period_start add duration'PT1S' ge period_end",
 64 |             '"period_start" + INTERVAL \'1\' SECOND >= "period_end"',
 65 |         ),
 66 |         (
 67 |             "period_start add duration'P1Y' ge period_end",
 68 |             '"period_start" + INTERVAL \'1\' YEAR >= "period_end"',
 69 |         ),
 70 |         (
 71 |             "period_start add duration'P2M' ge period_end",
 72 |             '"period_start" + INTERVAL \'2\' MONTH >= "period_end"',
 73 |         ),
 74 |         ("year(period_start) eq 2019", 'EXTRACT (YEAR FROM "period_start") = 2019'),
 75 |         (
 76 |             "period_end lt now() sub duration'P365D'",
 77 |             "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '365' DAY",
 78 |         ),
 79 |         (
 80 |             "period_end lt now() sub duration'P1Y'",
 81 |             "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '1' YEAR",
 82 |         ),
 83 |         (
 84 |             "period_end lt now() sub duration'P2M'",
 85 |             "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '2' MONTH",
 86 |         ),
 87 |         (
 88 |             "startswith(trim(meter_id), '999')",
 89 |             "TRIM(\"meter_id\") LIKE '999%'",
 90 |         ),
 91 |         (
 92 |             "year(date(now())) eq 2020",
 93 |             "EXTRACT (YEAR FROM CAST (CURRENT_TIMESTAMP AS DATE)) = 2020",
 94 |         ),
 95 |         ("length(concat('abc', 'def')) lt 10", "CHAR_LENGTH('abc' || 'def') < 10"),
 96 |         (
 97 |             "length(concat(('1', '2'), ('3', '4'))) eq 4",
 98 |             "CARDINALITY(('1', '2') || ('3', '4')) = 4",
 99 |         ),
100 |         (
101 |             "indexof(substring('abcdefghi', 3), 'hi') gt 1",
102 |             "POSITION('hi' IN SUBSTRING('abcdefghi' FROM 3 + 1)) - 1 > 1",
103 |         ),
104 |         (
105 |             "substring('hello', 1, 3) eq 'ell'",
106 |             "SUBSTRING('hello' FROM 1 + 1 FOR 3) = 'ell'",
107 |         ),
108 |         ("substring((1, 2, 3), 1)", ex.UnsupportedFunctionException),
109 |         ("substring((1, 2, 3), 1, 1)", ex.UnsupportedFunctionException),
110 |         (
111 |             "contains(meter_id, sub_meter_id)",
112 |             "\"meter_id\" LIKE '%' || \"sub_meter_id\" || '%'",
113 |         ),
114 |         (
115 |             "year(supply_start_date) eq (year(now()) sub 1)",
116 |             'EXTRACT (YEAR FROM "supply_start_date") = EXTRACT (YEAR FROM CURRENT_TIMESTAMP) - 1',
117 |         ),
118 |         (
119 |             "measurement_class eq 'C' and endswith(data_collector, 'rie')",
120 |             "\"measurement_class\" = 'C' AND \"data_collector\" LIKE '%rie'",
121 |         ),
122 |         # GITHUB-47
123 |         (
124 |             "contains(tolower(name), tolower('A'))",
125 |             "LOWER(\"name\") LIKE '%' || LOWER('A') || '%'",
126 |         ),
127 |     ],
128 | )
129 | def test_odata_filter_to_sql(
130 |     odata_query: str, expected: Union[str, type], lexer, parser
131 | ):
132 |     ast = parser.parse(lexer.tokenize(odata_query))
133 |     visitor = sql.AstToSqlVisitor()
134 | 
135 |     if isinstance(expected, str):
136 |         res = visitor.visit(ast)
137 |         assert res == expected
138 |     else:
139 |         with pytest.raises(expected):
140 |             res = visitor.visit(ast)
141 | 
--------------------------------------------------------------------------------
/tests/integration/sql/test_odata_to_sqlite.py:
--------------------------------------------------------------------------------
  1 | import pytest
  2 | 
  3 | from odata_query import exceptions as ex, sql
  4 | 
  5 | 
  6 | @pytest.mark.parametrize(
  7 |     "odata_query, expected",
  8 |     [
  9 |         ("meter_id eq '1'", "\"meter_id\" = '1'"),
 10 |         ("meter_id ne '1'", "\"meter_id\" != '1'"),
 11 |         ("meter_id eq 'o''reilly'''", "\"meter_id\" = 'o''reilly'''"),
 12 |         (
 13 |             "meter_id eq 6c0e37e3-e856-45ee-bd58-484b11882c67",
 14 |             "\"meter_id\" = '6c0e37e3-e856-45ee-bd58-484b11882c67'",
 15 |         ),
 16 |         ("meter_id in ('1',)", "\"meter_id\" IN ('1')"),
 17 |         ("meter_id in ('1', '2')", "\"meter_id\" IN ('1', '2')"),
 18 |         ("not (meter_id in ('1', '2'))", "NOT \"meter_id\" IN ('1', '2')"),
 19 |         ("meter_id eq null", '"meter_id" IS NULL'),
 20 |         ("meter_id ne null", '"meter_id" IS NOT NULL'),
 21 |         ("eac gt 10", '"eac" > 10'),
 22 |         ("eac ge 10", '"eac" >= 10'),
 23 |         ("eac lt 10", '"eac" < 10'),
 24 |         ("eac le 10", '"eac" <= 10'),
 25 |         ("eac gt 1.0 and eac lt 10.0", '"eac" > 1.0 AND "eac" < 10.0'),
 26 |         ("eac ge 1.0 and eac le 10.0", '"eac" >= 1.0 AND "eac" <= 10.0'),
 27 |         (
 28 |             "eac gt 1 and eac lt 1 and meter_id eq '1'",
 29 |             '"eac" > 1 AND "eac" < 1 AND "meter_id" = \'1\'',
 30 |         ),
 31 |         # OData spec defines AND with higher precedence than OR:
 32 |         (
 33 |             "eac gt 1 and eac lt 10 or eac eq 5 and eac ne 10",
 34 |             '("eac" > 1 AND "eac" < 10) OR ("eac" = 5 AND "eac" != 10)',
 35 |         ),
 36 |         # Unless overridden by parentheses:
 37 |         (
 38 |             "eac gt 1 and (eac lt 10 or eac eq 5) and eac ne 10",
 39 |             '"eac" > 1 AND ("eac" < 10 OR "eac" = 5) AND "eac" != 10',
 40 |         ),
 41 |         ("not (eac gt 10 and eac lt 20)", 'NOT ("eac" > 10 AND "eac" < 20)'),
 42 |         ("eac gt 1 eq true", '("eac" > 1) = 1'),
 43 |         ("true eq eac gt 1", '1 = ("eac" > 1)'),
 44 |         ("eac add 10 gt 1000", '"eac" + 10 > 1000'),
 45 |         ("eac add 10 gt eac sub 10", '"eac" + 10 > "eac" - 10'),
 46 |         ("eac mul 10 div 10 eq eac", '"eac" * 10 / 10 = "eac"'),
 47 |         ("eac mod 10 add -1 le eac", '"eac" % 10 + -1 <= "eac"'),
 48 |         (
 49 |             "period_start gt 2020-01-01T00:00:00",
 50 |             "\"period_start\" > DATETIME('2020-01-01T00:00:00')",
 51 |         ),
 52 |         (
 53 |             "period_start add duration'P365D' ge period_end",
 54 |             '"period_start" + INTERVAL \'365\' DAY >= "period_end"',
 55 |         ),
 56 |         (
 57 |             "period_start add duration'P365DT12H1M1.1S' ge period_end",
 58 |             "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"",
 59 |         ),
 60 |         (
 61 |             "period_start add duration'PT1S' ge period_end",
 62 |             '"period_start" + INTERVAL \'1\' SECOND >= "period_end"',
 63 |         ),
 64 |         (
 65 |             "period_start add duration'P1Y' ge period_end",
 66 |             '"period_start" + INTERVAL \'1\' YEAR >= "period_end"',
 67 |         ),
 68 |         (
 69 |             "period_start add duration'P2M' ge period_end",
 70 |             '"period_start" + INTERVAL \'2\' MONTH >= "period_end"',
 71 |         ),
 72 |         (
 73 |             "year(period_start) eq 2019",
 74 |             "CAST(STRFTIME('%Y', \"period_start\") AS INTEGER) = 2019",
 75 |         ),
 76 |         (
 77 |             "period_end lt now() sub duration'P365D'",
 78 |             "\"period_end\" < DATETIME('now') - INTERVAL '365' DAY",
 79 |         ),
 80 |         (
 81 |             "startswith(trim(meter_id), '999')",
 82 |             "TRIM(\"meter_id\") LIKE '999%'",
 83 |         ),
 84 |         (
 85 |             "year(date(now())) eq 2020",
 86 |             "CAST(STRFTIME('%Y', DATE(DATETIME('now'))) AS INTEGER) = 2020",
 87 |         ),
 88 |         ("length(concat('abc', 'def')) lt 10", "LENGTH('abc' || 'def') < 10"),
 89 |         (
 90 |             "length(concat(('1', '2'), ('3', '4'))) eq 4",
 91 |             "LENGTH(('1', '2') || ('3', '4')) = 4",
 92 |         ),
 93 |         (
 94 |             "indexof(substring('abcdefghi', 3), 'hi') gt 1",
 95 |             "INSTR(SUBSTR('abcdefghi', 3 + 1), 'hi') - 1 > 1",
 96 |         ),
 97 |         (
 98 |             "substring('hello', 1, 3) eq 'ell'",
 99 |             "SUBSTR('hello', 1 + 1, 3) = 'ell'",
100 |         ),
101 |         ("substring((1, 2, 3), 1)", ex.UnsupportedFunctionException),
102 |         ("substring((1, 2, 3), 1, 1)", ex.UnsupportedFunctionException),
103 |         (
104 |             "contains(meter_id, sub_meter_id)",
105 |             "\"meter_id\" LIKE '%' || \"sub_meter_id\" || '%'",
106 |         ),
107 |         (
108 |             "year(supply_start_date) eq (year(now()) sub 1)",
109 |             "CAST(STRFTIME('%Y', \"supply_start_date\") AS INTEGER) = CAST(STRFTIME('%Y', DATETIME('now')) AS INTEGER) - 1",
110 |         ),
111 |         (
112 |             "measurement_class eq 'C' and endswith(data_collector, 'rie')",
113 |             "\"measurement_class\" = 'C' AND \"data_collector\" LIKE '%rie'",
114 |         ),
115 |         # GITHUB-47
116 |         (
117 |             "contains(tolower(name), tolower('A'))",
118 |             "LOWER(\"name\") LIKE '%' || LOWER('A') || '%'",
119 |         ),
120 |     ],
121 | )
122 | def test_odata_filter_to_sql(odata_query: str, expected: str, lexer, parser):
123 |     ast = parser.parse(lexer.tokenize(odata_query))
124 |     visitor = sql.AstToSqliteSqlVisitor()
125 | 
126 |     if isinstance(expected, str):
127 |         res = visitor.visit(ast)
128 |         assert res == expected
129 |     else:
130 |         with pytest.raises(expected):
131 |             res = visitor.visit(ast)
132 | 
--------------------------------------------------------------------------------
/tests/integration/sql/test_querying.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from odata_query import sql
 4 | 
 5 | 
 6 | @pytest.fixture
 7 | def sample_data_sess(db_conn, db_schema):
 8 |     cur = db_conn.cursor()
 9 | 
10 |     cur.execute("BEGIN")
11 |     cur.execute("INSERT INTO author (name) VALUES ('Gorilla'), ('Baboon'), ('Saki')")
12 |     cur.execute(
13 |         "INSERT INTO blogpost (title, published_at, content) VALUES "
14 |         "('Querying Data', '2020-01-01T00:00:00', 'How 2 query data...'), "
15 |         "('Automating Monkey Jobs', '2019-01-01T00:00:00', 'How 2 automate monkey jobs...')"
16 |     )
17 |     cur.execute(
18 |         "INSERT INTO author_blogpost (author_id, blogpost_id) VALUES "
19 |         "(1, 1), "
20 |         "(2, 2), "
21 |         "(3, 2)"
22 |     )
23 |     cur.execute(
24 |         "INSERT INTO comment (content, author_id, blogpost_id) VALUES "
25 |         "('Dope!', 2, 1), "
26 |         "('Cool!', 1, 2)"
27 |     )
28 | 
29 |     yield cur
30 | 
31 |     cur.execute("ROLLBACK")
32 | 
33 | 
34 | @pytest.mark.parametrize(
35 |     "table, query, exp_results",
36 |     [
37 |         ("author", "name eq 'Baboon'", 1),
38 |         ("author", "startswith(name, 'Gori')", 1),
39 |         ("blogpost", "contains(content, 'How')", 2),
40 |         ("blogpost", "published_at gt 2019-06-01", 1),
41 |         # (Author, "contains(blogposts/title, 'Monkey')", 2),
42 |         # (Author, "startswith(blogposts/comments/content, 'Cool')", 2),
43 |         # (Author, "comments/any()", 2),
44 |         # (BlogPost, "authors/any(a: contains(a/name, 'o'))", 2),
45 |         # (BlogPost, "authors/all(a: contains(a/name, 'o'))", 1),
46 |         # (Author, "blogposts/comments/any(c: contains(c/content, 'Cool'))", 2),
47 |         ("author", "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 0),
48 |         (
49 |             "author",
50 |             "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)",
51 |             0,
52 |         ),
53 |         # (BlogPost, "comments/author eq 0", 0),
54 |         ("blogpost", "substring(content, 0) eq 'test'", 0),
55 |         ("blogpost", "year(published_at) eq 2019", 1),
56 |         # GITHUB-19
57 |         ("blogpost", "contains(title, 'Query') eq true", 1),
58 |         ("blogpost", "contains(tolower(title), tolower('Query'))", 1),
59 |     ],
60 | )
61 | def test_query_with_odata(
62 |     table: str, query: str, exp_results: int, lexer, parser, sample_data_sess
63 | ):
64 |     ast = parser.parse(lexer.tokenize(query))
65 |     visitor = sql.AstToSqliteSqlVisitor()
66 |     where_clause = visitor.visit(ast)
67 | 
68 |     results = sample_data_sess.execute(
69 |         f'SELECT * FROM "{table}" WHERE {where_clause}'
70 |     ).fetchall()
71 |     assert len(results) == exp_results
72 | 
--------------------------------------------------------------------------------
/tests/integration/sqlalchemy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/integration/sqlalchemy/__init__.py
--------------------------------------------------------------------------------
/tests/integration/sqlalchemy/conftest.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from sqlalchemy import create_engine
 3 | from sqlalchemy.orm import sessionmaker
 4 | 
 5 | from .models import Base
 6 | 
 7 | 
 8 | @pytest.fixture(scope="session")
 9 | def db_engine():
10 |     engine = create_engine("sqlite://", future=True)
11 |     Base.metadata.create_all(engine)
12 |     yield engine
13 |     engine.dispose()
14 | 
15 | 
16 | @pytest.fixture(scope="session")
17 | def db_session(db_engine):
18 |     session = sessionmaker(bind=db_engine)
19 |     return session
20 | 
--------------------------------------------------------------------------------
/tests/integration/sqlalchemy/models.py:
--------------------------------------------------------------------------------
 1 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
 2 | from sqlalchemy.orm import declarative_base, relationship
 3 | 
 4 | Base = declarative_base()
 5 | 
 6 | 
 7 | author_blogpost = Table(
 8 |     "author_blogpost",
 9 |     Base.metadata,
10 |     Column("author_id", Integer, ForeignKey("author.id")),
11 |     Column("blogpost_id", Integer, ForeignKey("blogpost.id")),
12 | )
13 | 
14 | 
15 | class Author(Base):
16 |     __tablename__ = "author"
17 | 
18 |     id = Column(Integer, primary_key=True)
19 |     name = Column(String, nullable=False)
20 | 
21 |     blogposts = relationship(
22 |         "BlogPost", back_populates="authors", secondary=author_blogpost
23 |     )
24 |     comments = relationship("Comment", back_populates="author")
25 | 
26 | 
27 | class BlogPost(Base):
28 |     __tablename__ = "blogpost"
29 | 
30 |     id = Column(Integer, primary_key=True)
31 |     published_at = Column(DateTime, nullable=False)
32 |     title = Column(String, nullable=False)
33 |     content = Column(Text)
34 | 
35 |     authors = relationship(
36 |         "Author", back_populates="blogposts", secondary=author_blogpost
37 |     )
38 |     comments = relationship("Comment", back_populates="blogpost")
39 | 
40 | 
41 | class Comment(Base):
42 |     __tablename__ = "comment"
43 | 
44 |     id = Column(Integer, primary_key=True)
45 |     content = Column(Text)
46 | 
47 |     author_id = Column(Integer, ForeignKey("author.id"))
48 |     author = relationship("Author", back_populates="comments")
49 |     blogpost_id = Column(Integer, ForeignKey("blogpost.id"))
50 |     blogpost = relationship("BlogPost", back_populates="comments")
51 | 
--------------------------------------------------------------------------------
/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_core.py:
--------------------------------------------------------------------------------
  1 | import datetime as dt
  2 | 
  3 | import pytest
  4 | from sqlalchemy.sql import functions
  5 | from sqlalchemy.sql.expression import cast, extract, literal
  6 | from sqlalchemy.types import Date, Time
  7 | 
  8 | from odata_query.sqlalchemy import AstToSqlAlchemyCoreVisitor, functions_ext
  9 | 
 10 | from . import models
 11 | 
 12 | BlogPost = models.Base.metadata.tables["blogpost"]
 13 | Author = models.Base.metadata.tables["author"]
 14 | Comment = models.Base.metadata.tables["comment"]
 15 | 
 16 | 
 17 | def tz(offset: int) -> dt.tzinfo:
 18 |     return dt.timezone(dt.timedelta(hours=offset))
 19 | 
 20 | 
 21 | @pytest.mark.parametrize(
 22 |     "odata_query, expected_q",
 23 |     [
 24 |         (
 25 |             "id eq a7af27e6-f5a0-11e9-9649-0a252986adba",
 26 |             BlogPost.c.id == "a7af27e6-f5a0-11e9-9649-0a252986adba",
 27 |         ),
 28 |         ("my_app.c.id eq 1", BlogPost.c.id == 1),
 29 |         (
 30 |             "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)",
 31 |             BlogPost.c.id.in_(
 32 |                 [
 33 |                     literal("a7af27e6-f5a0-11e9-9649-0a252986adba"),
 34 |                     literal("800c56e4-354d-11eb-be38-3af9d323e83c"),
 35 |                 ]
 36 |             ),
 37 |         ),
 38 |         ("id eq 4", BlogPost.c.id == 4),
 39 |         ("id ne 4", BlogPost.c.id != 4),
 40 |         ("4 eq id", 4 == BlogPost.c.id),
 41 |         ("4 ne id", 4 != BlogPost.c.id),
 42 |         (
 43 |             "published_at gt 2018-01-01",
 44 |             BlogPost.c.published_at > dt.date(2018, 1, 1),
 45 |         ),
 46 |         (
 47 |             "published_at ge 2018-01-01",
 48 |             BlogPost.c.published_at >= dt.date(2018, 1, 1),
 49 |         ),
 50 |         (
 51 |             "published_at lt 2018-01-01",
 52 |             BlogPost.c.published_at < dt.date(2018, 1, 1),
 53 |         ),
 54 |         (
 55 |             "published_at le 2018-01-01",
 56 |             BlogPost.c.published_at <= dt.date(2018, 1, 1),
 57 |         ),
 58 |         (
 59 |             "2018-01-01 gt published_at",
 60 |             literal(dt.date(2018, 1, 1)) > BlogPost.c.published_at,
 61 |         ),
 62 |         (
 63 |             "2018-01-01 ge published_at",
 64 |             literal(dt.date(2018, 1, 1)) >= BlogPost.c.published_at,
 65 |         ),
 66 |         (
 67 |             "published_at gt 2018-01-01T01:02",
 68 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2),
 69 |         ),
 70 |         (
 71 |             "published_at gt 2018-01-01T01:02:03",
 72 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3),
 73 |         ),
 74 |         (
 75 |             "published_at gt 2018-01-01T01:02:03.123",
 76 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000),
 77 |         ),
 78 |         (
 79 |             "published_at gt 2018-01-01T01:02:03.123456",
 80 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_456),
 81 |         ),
 82 |         (
 83 |             "published_at gt 2018-01-01T01:02Z",
 84 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(0)),
 85 |         ),
 86 |         (
 87 |             "published_at gt 2018-01-01T01:02:03Z",
 88 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, tzinfo=tz(0)),
 89 |         ),
 90 |         (
 91 |             "published_at gt 2018-01-01T01:02:03.123Z",
 92 |             BlogPost.c.published_at
 93 |             > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000, tzinfo=tz(0)),
 94 |         ),
 95 |         (
 96 |             "published_at gt 2018-01-01T01:02+02:00",
 97 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(+2)),
 98 |         ),
 99 |         (
100 |             "published_at gt 2018-01-01T01:02-02:00",
101 |             BlogPost.c.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(-2)),
102 |         ),
103 |         (
104 |             "id in (1, 2, 3)",
105 |             BlogPost.c.id.in_([literal(1), literal(2), literal(3)]),
106 |         ),
107 |         ("id eq null", BlogPost.c.id == None),  # noqa:E711
108 |         ("id ne null", BlogPost.c.id != None),  # noqa:E711
109 |         ("not (id eq 1)", ~(BlogPost.c.id == 1)),
110 |         ("id eq 1 or id eq 2", (BlogPost.c.id == 1) | (BlogPost.c.id == 2)),
111 |         (
112 |             "id eq 1 and content eq 'executing'",
113 |             (BlogPost.c.id == 1) & (BlogPost.c.content == "executing"),
114 |         ),
115 |         (
116 |             "id eq 1 and (content eq 'executing' or content eq 'failed')",
117 |             (BlogPost.c.id == 1)
118 |             & ((BlogPost.c.content == "executing") | (BlogPost.c.content == "failed")),
119 |         ),
120 |         ("id eq 1 add 1", BlogPost.c.id == literal(1) + literal(1)),
121 |         ("id eq 2 sub 1", BlogPost.c.id == literal(2) - literal(1)),
122 |         ("id eq 2 mul 2", BlogPost.c.id == literal(2) * literal(2)),
123 |         ("id eq 2 div 2", BlogPost.c.id == literal(2) / literal(2)),
124 |         ("id eq 5 mod 4", BlogPost.c.id == literal(5) % literal(4)),
125 |         ("id eq 2 add -1", BlogPost.c.id == literal(2) + literal(-1)),
126 |         ("id eq id sub 1", BlogPost.c.id == BlogPost.c.id - literal(1)),
127 |         (
128 |             "title eq 'donut' add 'tello'",
129 |             BlogPost.c.title == literal("donut") + literal("tello"),
130 |         ),
131 |         (
132 |             "title eq content add content",
133 |             BlogPost.c.title == BlogPost.c.content + BlogPost.c.content,
134 |         ),
135 |         (
136 |             "published_at eq 2019-01-01T00:00:00 add duration'P1DT1H1M1S'",
137 |             BlogPost.c.published_at
138 |             == literal(dt.datetime(2019, 1, 1, 0, 0, 0))
139 |             + dt.timedelta(days=1, hours=1, minutes=1, seconds=1),
140 |         ),
141 |         (
142 |             "published_at eq 2019-01-01T00:00:00 add duration'P1Y'",
143 |             BlogPost.c.published_at
144 |             == literal(dt.datetime(2019, 1, 1, 0, 0, 0))
145 |             + dt.timedelta(days=365.25),  # 1 times 365.25 (average year in days)
146 |         ),
147 |         (
148 |             "published_at eq 2019-01-01T00:00:00 add duration'P2M'",
149 |             BlogPost.c.published_at
150 |             == literal(dt.datetime(2019, 1, 1, 0, 0, 0))
151 |             + dt.timedelta(days=60.88),  # 2 times 30.44 (average month in days)
152 |         ),
153 |         ("contains(title, 'copy')", BlogPost.c.title.contains("copy")),
154 |         ("startswith(title, 'copy')", BlogPost.c.title.startswith("copy")),
155 |         ("endswith(title, 'bla')", BlogPost.c.title.endswith("bla")),
156 |         (
157 |             "id eq length(title)",
158 |             BlogPost.c.id == functions.char_length(BlogPost.c.title),
159 |         ),
160 |         ("length(title) eq 10", functions.char_length(BlogPost.c.title) == 10),
161 |         ("10 eq length(title)", 10 == functions.char_length(BlogPost.c.title)),
162 |         (
163 |             "length(title) eq length('flippot')",
164 |             functions.char_length(BlogPost.c.title) == functions.char_length("flippot"),
165 |         ),
166 |         (
167 |             "title eq concat('a', 'b')",
168 |             BlogPost.c.title == functions.concat("a", "b"),
169 |         ),
170 |         (
171 |             "title eq concat('test', id)",
172 |             BlogPost.c.title == functions.concat("test", BlogPost.c.id),
173 |         ),
174 |         (
175 |             "title eq concat(concat('a', 'b'), 'c')",
176 |             BlogPost.c.title == functions.concat(functions.concat("a", "b"), "c"),
177 |         ),
178 |         (
179 |             "concat(title, 'a') eq 'testa'",
180 |             functions.concat(BlogPost.c.title, "a") == "testa",
181 |         ),
182 |         (
183 |             "indexof(title, 'Copy') eq 6",
184 |             functions_ext.strpos(BlogPost.c.title, "Copy") - 1 == 6,
185 |         ),
186 |         (
187 |             "substring(title, 0) eq 'Copy'",
188 |             functions_ext.substr(BlogPost.c.title, literal(0) + 1) == "Copy",
189 |         ),
190 |         (
191 |             "substring(title, 0, 4) eq 'Copy'",
192 |             functions_ext.substr(BlogPost.c.title, literal(0) + 1, 4) == "Copy",
193 |         ),
194 |         (
195 |             "matchesPattern(title, 'C.py')",
196 |             BlogPost.c.title.regexp_match("C.py"),
197 |         ),
198 |         (
199 |             "tolower(title) eq 'copy'",
200 |             functions_ext.lower(BlogPost.c.title) == "copy",
201 |         ),
202 |         (
203 |             "toupper(title) eq 'COPY'",
204 |             functions_ext.upper(BlogPost.c.title) == "COPY",
205 |         ),
206 |         (
207 |             "trim(title) eq 'copy'",
208 |             functions_ext.ltrim(functions_ext.rtrim(BlogPost.c.title)) == "copy",
209 |         ),
210 |         (
211 |             "date(published_at) eq 2019-01-01",
212 |             cast(BlogPost.c.published_at, Date) == dt.date(2019, 1, 1),
213 |         ),
214 |         (
215 |             "day(published_at) eq 1",
216 |             extract("day", BlogPost.c.published_at) == 1,
217 |         ),
218 |         (
219 |             "hour(published_at) eq 1",
220 |             extract("hour", BlogPost.c.published_at) == 1,
221 |         ),
222 |         (
223 |             "minute(published_at) eq 1",
224 |             extract("minute", BlogPost.c.published_at) == 1,
225 |         ),
226 |         (
227 |             "month(published_at) eq 1",
228 |             extract("month", BlogPost.c.published_at) == 1,
229 |         ),
230 |         ("published_at eq now()", BlogPost.c.published_at == functions.now()),
231 |         (
232 |             "second(published_at) eq 1",
233 |             extract("second", BlogPost.c.published_at) == 1,
234 |         ),
235 |         (
236 |             "time(published_at) eq 14:00:00",
237 |             cast(BlogPost.c.published_at, Time) == dt.time(14, 0, 0),
238 |         ),
239 |         (
240 |             "year(published_at) eq 2019",
241 |             extract("year", BlogPost.c.published_at) == 2019,
242 |         ),
243 |         ("ceiling(id) eq 1", functions_ext.ceil(BlogPost.c.id) == 1),
244 |         ("floor(id) eq 1", functions_ext.floor(BlogPost.c.id) == 1),
245 |         ("round(id) eq 1", functions_ext.round(BlogPost.c.id) == 1),
246 |         (
247 |             "date(published_at) eq 2019-01-01 add duration'P1D'",
248 |             cast(BlogPost.c.published_at, Date)
249 |             == literal(dt.date(2019, 1, 1)) + dt.timedelta(days=1),
250 |         ),
251 |         (
252 |             "date(published_at) eq 2019-01-01 add duration'-P1D'",
253 |             cast(BlogPost.c.published_at, Date)
254 |             == literal(dt.date(2019, 1, 1)) + -1 * dt.timedelta(days=1),
255 |         ),
256 |         pytest.param(
257 |             "authors/name eq 'Ruben'",
258 |             Author.c.name == "Ruben",
259 |             marks=pytest.mark.xfail(reason="Not implemented yet."),
260 |         ),
261 |         pytest.param(
262 |             "authors/comments/content eq 'Cool!'",
263 |             Comment.c.content == "Cool!",
264 |             marks=pytest.mark.xfail(reason="Not implemented yet."),
265 |         ),
266 |         pytest.param(
267 |             "contains(comments/content, 'Cool')",
268 |             Comment.c.content.contains("Cool"),
269 |             marks=pytest.mark.xfail(reason="Not implemented yet."),
270 |         ),
271 |         # GITHUB-19
272 |         (
273 |             "contains(title, 'TEST') eq true",
274 |             BlogPost.c.title.contains("TEST") == True,  # noqa:E712
275 |         ),
276 |         # GITHUB-47
277 |         (
278 |             "contains(tolower(title), tolower('A'))",
279 |             functions_ext.lower(BlogPost.c.title).contains(functions_ext.lower("A")),
280 |         ),
281 |     ],
282 | )
283 | def test_odata_filter_to_sqlalchemy_query(
284 |     odata_query: str, expected_q: str, lexer, parser
285 | ):
286 |     ast = parser.parse(lexer.tokenize(odata_query))
287 |     transformer = AstToSqlAlchemyCoreVisitor(BlogPost)
288 |     res_q = transformer.visit(ast)
289 | 
290 |     assert res_q.compare(expected_q), (str(res_q), str(expected_q))
291 | 
--------------------------------------------------------------------------------
/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_orm.py:
--------------------------------------------------------------------------------
  1 | import datetime as dt
  2 | 
  3 | import pytest
  4 | from sqlalchemy.sql import functions
  5 | from sqlalchemy.sql.expression import cast, extract, literal
  6 | from sqlalchemy.types import Date, Time
  7 | 
  8 | from odata_query.sqlalchemy import AstToSqlAlchemyOrmVisitor, functions_ext
  9 | 
 10 | from .models import Author, BlogPost, Comment
 11 | 
 12 | 
 13 | def tz(offset: int) -> dt.tzinfo:
 14 |     return dt.timezone(dt.timedelta(hours=offset))
 15 | 
 16 | 
 17 | @pytest.mark.parametrize(
 18 |     "odata_query, expected_q",
 19 |     [
 20 |         (
 21 |             "id eq a7af27e6-f5a0-11e9-9649-0a252986adba",
 22 |             BlogPost.id == "a7af27e6-f5a0-11e9-9649-0a252986adba",
 23 |         ),
 24 |         ("my_app.id eq 1", BlogPost.id == 1),
 25 |         (
 26 |             "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)",
 27 |             BlogPost.id.in_(
 28 |                 [
 29 |                     literal("a7af27e6-f5a0-11e9-9649-0a252986adba"),
 30 |                     literal("800c56e4-354d-11eb-be38-3af9d323e83c"),
 31 |                 ]
 32 |             ),
 33 |         ),
 34 |         ("id eq 4", BlogPost.id == 4),
 35 |         ("id ne 4", BlogPost.id != 4),
 36 |         ("4 eq id", 4 == BlogPost.id),
 37 |         ("4 ne id", 4 != BlogPost.id),
 38 |         ("published_at gt 2018-01-01", BlogPost.published_at > dt.date(2018, 1, 1)),
 39 |         ("published_at ge 2018-01-01", BlogPost.published_at >= dt.date(2018, 1, 1)),
 40 |         ("published_at lt 2018-01-01", BlogPost.published_at < dt.date(2018, 1, 1)),
 41 |         ("published_at le 2018-01-01", BlogPost.published_at <= dt.date(2018, 1, 1)),
 42 |         (
 43 |             "2018-01-01 gt published_at",
 44 |             literal(dt.date(2018, 1, 1)) > BlogPost.published_at,
 45 |         ),
 46 |         (
 47 |             "2018-01-01 ge published_at",
 48 |             literal(dt.date(2018, 1, 1)) >= BlogPost.published_at,
 49 |         ),
 50 |         (
 51 |             "published_at gt 2018-01-01T01:02",
 52 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2),
 53 |         ),
 54 |         (
 55 |             "published_at gt 2018-01-01T01:02:03",
 56 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3),
 57 |         ),
 58 |         (
 59 |             "published_at gt 2018-01-01T01:02:03.123",
 60 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000),
 61 |         ),
 62 |         (
 63 |             "published_at gt 2018-01-01T01:02:03.123456",
 64 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, 123_456),
 65 |         ),
 66 |         (
 67 |             "published_at gt 2018-01-01T01:02Z",
 68 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(0)),
 69 |         ),
 70 |         (
 71 |             "published_at gt 2018-01-01T01:02:03Z",
 72 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, 3, tzinfo=tz(0)),
 73 |         ),
 74 |         (
 75 |             "published_at gt 2018-01-01T01:02:03.123Z",
 76 |             BlogPost.published_at
 77 |             > dt.datetime(2018, 1, 1, 1, 2, 3, 123_000, tzinfo=tz(0)),
 78 |         ),
 79 |         (
 80 |             "published_at gt 2018-01-01T01:02+02:00",
 81 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(+2)),
 82 |         ),
 83 |         (
 84 |             "published_at gt 2018-01-01T01:02-02:00",
 85 |             BlogPost.published_at > dt.datetime(2018, 1, 1, 1, 2, tzinfo=tz(-2)),
 86 |         ),
 87 |         ("id in (1, 2, 3)", BlogPost.id.in_([literal(1), literal(2), literal(3)])),
 88 |         ("id eq null", BlogPost.id == None),  # noqa:E711
 89 |         ("id ne null", BlogPost.id != None),  # noqa:E711
 90 |         ("not (id eq 1)", ~(BlogPost.id == 1)),
 91 |         ("id eq 1 or id eq 2", (BlogPost.id == 1) | (BlogPost.id == 2)),
 92 |         (
 93 |             "id eq 1 and content eq 'executing'",
 94 |             (BlogPost.id == 1) & (BlogPost.content == "executing"),
 95 |         ),
 96 |         (
 97 |             "id eq 1 and (content eq 'executing' or content eq 'failed')",
 98 |             (BlogPost.id == 1)
 99 |             & ((BlogPost.content == "executing") | (BlogPost.content == "failed")),
100 |         ),
101 |         ("id eq 1 add 1", BlogPost.id == literal(1) + literal(1)),
102 |         ("id eq 2 sub 1", BlogPost.id == literal(2) - literal(1)),
103 |         ("id eq 2 mul 2", BlogPost.id == literal(2) * literal(2)),
104 |         ("id eq 2 div 2", BlogPost.id == literal(2) / literal(2)),
105 |         ("id eq 5 mod 4", BlogPost.id == literal(5) % literal(4)),
106 |         ("id eq 2 add -1", BlogPost.id == literal(2) + literal(-1)),
107 |         ("id eq id sub 1", BlogPost.id == BlogPost.id - literal(1)),
108 |         (
109 |             "title eq 'donut' add 'tello'",
110 |             BlogPost.title == literal("donut") + literal("tello"),
111 |         ),
112 |         (
113 |             "title eq content add content",
114 |             BlogPost.title == BlogPost.content + BlogPost.content,
115 |         ),
116 |         (
117 |             "published_at eq 2019-01-01T00:00:00 add duration'P1DT1H1M1S'",
118 |             BlogPost.published_at
119 |             == literal(dt.datetime(2019, 1, 1, 0, 0, 0))
120 |             + dt.timedelta(days=1, hours=1, minutes=1, seconds=1),
121 |         ),
122 |         (
123 |             "published_at eq 2019-01-01T00:00:00 add duration'P1Y'",
124 |             BlogPost.published_at
125 |             == literal(dt.datetime(2019, 1, 1, 0, 0, 0))
126 |             + dt.timedelta(days=365.25),  # 1 times 365.25 (average year in days)
127 |         ),
128 |         (
129 |             "published_at eq 2019-01-01T00:00:00 add duration'P2M'",
130 |             BlogPost.published_at
131 |             == literal(dt.datetime(2019, 1, 1, 0, 0, 0))
132 |             + dt.timedelta(days=60.88),  # 2 times 30.44 (average month in days)
133 |         ),
134 |         ("contains(title, 'copy')", BlogPost.title.contains("copy")),
135 |         ("startswith(title, 'copy')", BlogPost.title.startswith("copy")),
136 |         ("endswith(title, 'bla')", BlogPost.title.endswith("bla")),
137 |         ("id eq length(title)", BlogPost.id == functions.char_length(BlogPost.title)),
138 |         ("length(title) eq 10", functions.char_length(BlogPost.title) == 10),
139 |         ("10 eq length(title)", 10 == functions.char_length(BlogPost.title)),
140 |         (
141 |             "length(title) eq length('flippot')",
142 |             functions.char_length(BlogPost.title) == functions.char_length("flippot"),
143 |         ),
144 |         ("title eq concat('a', 'b')", BlogPost.title == functions.concat("a", "b")),
145 |         (
146 |             "title eq concat('test', id)",
147 |             BlogPost.title == functions.concat("test", BlogPost.id),
148 |         ),
149 |         (
150 |             "title eq concat(concat('a', 'b'), 'c')",
151 |             BlogPost.title == functions.concat(functions.concat("a", "b"), "c"),
152 |         ),
153 |         (
154 |             "concat(title, 'a') eq 'testa'",
155 |             functions.concat(BlogPost.title, "a") == "testa",
156 |         ),
157 |         (
158 |             "indexof(title, 'Copy') eq 6",
159 |             functions_ext.strpos(BlogPost.title, "Copy") - 1 == 6,
160 |         ),
161 |         (
162 |             "substring(title, 0) eq 'Copy'",
163 |             functions_ext.substr(BlogPost.title, literal(0) + 1) == "Copy",
164 |         ),
165 |         (
166 |             "substring(title, 0, 4) eq 'Copy'",
167 |             functions_ext.substr(BlogPost.title, literal(0) + 1, 4) == "Copy",
168 |         ),
169 |         ("matchesPattern(title, 'C.py')", BlogPost.title.regexp_match("C.py")),
170 |         ("tolower(title) eq 'copy'", functions_ext.lower(BlogPost.title) == "copy"),
171 |         ("toupper(title) eq 'COPY'", functions_ext.upper(BlogPost.title) == "COPY"),
172 |         (
173 |             "trim(title) eq 'copy'",
174 |             functions_ext.ltrim(functions_ext.rtrim(BlogPost.title)) == "copy",
175 |         ),
176 |         (
177 |             "date(published_at) eq 2019-01-01",
178 |             cast(BlogPost.published_at, Date) == dt.date(2019, 1, 1),
179 |         ),
180 |         ("day(published_at) eq 1", extract("day", BlogPost.published_at) == 1),
181 |         ("hour(published_at) eq 1", extract("hour", BlogPost.published_at) == 1),
182 |         ("minute(published_at) eq 1", extract("minute", BlogPost.published_at) == 1),
183 |         ("month(published_at) eq 1", extract("month", BlogPost.published_at) == 1),
184 |         ("published_at eq now()", BlogPost.published_at == functions.now()),
185 |         ("second(published_at) eq 1", extract("second", BlogPost.published_at) == 1),
186 |         (
187 |             "time(published_at) eq 14:00:00",
188 |             cast(BlogPost.published_at, Time) == dt.time(14, 0, 0),
189 |         ),
190 |         ("year(published_at) eq 2019", extract("year", BlogPost.published_at) == 2019),
191 |         ("ceiling(id) eq 1", functions_ext.ceil(BlogPost.id) == 1),
192 |         ("floor(id) eq 1", functions_ext.floor(BlogPost.id) == 1),
193 |         ("round(id) eq 1", functions_ext.round(BlogPost.id) == 1),
194 |         (
195 |             "date(published_at) eq 2019-01-01 add duration'P1D'",
196 |             cast(BlogPost.published_at, Date)
197 |             == literal(dt.date(2019, 1, 1)) + dt.timedelta(days=1),
198 |         ),
199 |         (
200 |             "date(published_at) eq 2019-01-01 add duration'-P1D'",
201 |             cast(BlogPost.published_at, Date)
202 |             == literal(dt.date(2019, 1, 1)) + -1 * dt.timedelta(days=1),
203 |         ),
204 |         ("authors/name eq 'Ruben'", Author.name == "Ruben"),
205 |         ("authors/comments/content eq 'Cool!'", Comment.content == "Cool!"),
206 |         ("contains(comments/content, 'Cool')", Comment.content.contains("Cool")),
207 |         # GITHUB-19
208 |         (
209 |             "contains(title, 'TEST') eq true",
210 |             BlogPost.title.contains("TEST") == True,  # noqa:E712
211 |         ),
212 |         # GITHUB-47
213 |         (
214 |             "contains(tolower(title), tolower('A'))",
215 |             functions_ext.lower(BlogPost.title).contains(functions_ext.lower("A")),
216 |         ),
217 |     ],
218 | )
219 | def test_odata_filter_to_sqlalchemy_query(
220 |     odata_query: str, expected_q: str, lexer, parser
221 | ):
222 |     ast = parser.parse(lexer.tokenize(odata_query))
223 |     transformer = AstToSqlAlchemyOrmVisitor(BlogPost)
224 |     res_q = transformer.visit(ast)
225 | 
226 |     assert res_q.compare(expected_q), (str(res_q), str(expected_q))
227 | 
--------------------------------------------------------------------------------
/tests/integration/sqlalchemy/test_querying.py:
--------------------------------------------------------------------------------
  1 | import datetime as dt
  2 | from typing import Callable, Type
  3 | 
  4 | import pytest
  5 | from sqlalchemy import select
  6 | 
  7 | from odata_query.sqlalchemy import apply_odata_core, apply_odata_query
  8 | 
  9 | from .models import Author, Base, BlogPost, Comment
 10 | 
 11 | 
 12 | @pytest.fixture
 13 | def sample_data_sess(db_session):
 14 |     s = db_session()
 15 |     a1 = Author(name="Gorilla")
 16 |     a2 = Author(name="Baboon")
 17 |     a3 = Author(name="Saki")
 18 |     bp1 = BlogPost(
 19 |         title="Querying Data",
 20 |         published_at=dt.datetime(2020, 1, 1),
 21 |         content="How 2 query data...",
 22 |         authors=[a1],
 23 |     )
 24 |     bp2 = BlogPost(
 25 |         title="Automating Monkey Jobs",
 26 |         published_at=dt.datetime(2019, 1, 1),
 27 |         content="How 2 automate monkey jobs...",
 28 |         authors=[a2, a3],
 29 |     )
 30 |     c1 = Comment(content="Dope!", author=a2, blogpost=bp1)
 31 |     c2 = Comment(content="Cool!", author=a1, blogpost=bp2)
 32 |     s.add_all([a1, a2, a3, bp1, bp2, c1, c2])
 33 |     s.flush()
 34 |     yield s
 35 |     s.rollback()
 36 | 
 37 | 
 38 | def apply_odata_query_bc_sqla1(*args, **kwargs):
 39 |     """
 40 |     Check for backwards compatibility with the "old style" of ORM querying.
 41 |     See: GITHUB-34
 42 |     """
 43 |     return apply_odata_query(*args, **kwargs)
 44 | 
 45 | 
 46 | @pytest.mark.parametrize(
 47 |     "model, query, exp_results",
 48 |     [
 49 |         (Author, "name eq 'Baboon'", 1),
 50 |         (Author, "startswith(name, 'Gori')", 1),
 51 |         (BlogPost, "contains(content, 'How')", 2),
 52 |         (BlogPost, "published_at gt 2019-06-01", 1),
 53 |         (Author, "contains(blogposts/title, 'Monkey')", 2),
 54 |         (Author, "startswith(blogposts/comments/content, 'Cool')", 2),
 55 |         (Author, "comments/any()", 2),
 56 |         (BlogPost, "authors/any(a: contains(a/name, 'o'))", 2),
 57 |         (BlogPost, "authors/all(a: contains(a/name, 'o'))", 1),
 58 |         (Author, "blogposts/comments/any(c: contains(c/content, 'Cool'))", 2),
 59 |         (Author, "id eq a7af27e6-f5a0-11e9-9649-0a252986adba", 0),
 60 |         (
 61 |             Author,
 62 |             "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)",
 63 |             0,
 64 |         ),
 65 |         (BlogPost, "comments/author eq 0", 0),
 66 |         (BlogPost, "substring(content, 0) eq 'test'", 0),
 67 |         (BlogPost, "year(published_at) eq 2019", 1),
 68 |         # GITHUB-19
 69 |         (BlogPost, "contains(title, 'Query') eq true", 1),
 70 |     ],
 71 | )
 72 | @pytest.mark.parametrize(
 73 |     "apply_func",
 74 |     [
 75 |         pytest.param(apply_odata_query, id="ORM"),
 76 |         pytest.param(apply_odata_query_bc_sqla1, id="ORM 1.x"),
 77 |         pytest.param(apply_odata_core, id="Core"),
 78 |     ],
 79 | )
 80 | def test_query_with_odata(
 81 |     model: Type[Base],
 82 |     query: str,
 83 |     exp_results: int,
 84 |     apply_func: Callable,
 85 |     sample_data_sess,
 86 | ):
 87 |     # ORM mode 1.x:
 88 |     if apply_func is apply_odata_query_bc_sqla1:
 89 |         base_q = sample_data_sess.query(model)
 90 |         q = apply_func(base_q, query)
 91 |         results = q.all()
 92 |         assert len(results) == exp_results
 93 |         return
 94 | 
 95 |     # ORM mode:
 96 |     elif apply_func is apply_odata_query:
 97 |         base_q = select(model)
 98 | 
 99 |     # Core mode:
100 |     elif apply_func is apply_odata_core:
101 |         base_q = select(model.__table__)
102 | 
103 |     else:
104 |         raise ValueError(apply_func)
105 | 
106 |     try:
107 |         q = apply_func(base_q, query)
108 |     except NotImplementedError:
109 |         pytest.xfail("Not implemented yet.")
110 | 
111 |     results = sample_data_sess.execute(q).scalars().all()
112 |     assert len(results) == exp_results
113 | 
114 | 
115 | @pytest.mark.parametrize(
116 |     "apply_func",
117 |     [
118 |         pytest.param(apply_odata_query, id="ORM"),
119 |         pytest.param(apply_odata_query_bc_sqla1, id="ORM 1.x"),
120 |     ],
121 | )
122 | def test_query_with_existing_join(apply_func, sample_data_sess):
123 |     """
124 |     GITHUB-37
125 |     """
126 |     odata_query = "author/name eq 'Gorilla'"
127 |     exp_results = 1
128 | 
129 |     # ORM mode 1.x:
130 |     if apply_func is apply_odata_query_bc_sqla1:
131 |         base_q = sample_data_sess.query(Comment).join(
132 |             Author, Comment.author_id == Author.id
133 |         )
134 |         q = apply_func(base_q, odata_query)
135 |         results = q.all()
136 |         assert len(results) == exp_results
137 |         return
138 | 
139 |     # ORM mode:
140 |     elif apply_func is apply_odata_query:
141 |         base_q = select(Comment).join(Comment.author)
142 | 
143 |     else:
144 |         raise ValueError(apply_func)
145 | 
146 |     q = apply_func(base_q, odata_query)
147 | 
148 |     results = sample_data_sess.execute(q).scalars().all()
149 |     assert len(results) == exp_results
150 | 
--------------------------------------------------------------------------------
/tests/integration/test_roundtrip.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from odata_query.roundtrip import AstToODataVisitor
 4 | 
 5 | 
 6 | @pytest.mark.parametrize(
 7 |     "odata_query",
 8 |     [
 9 |         "id eq a7af27e6-f5a0-11e9-9649-0a252986adba",
10 |         "my_app.id eq 1",
11 |         "id in (a7af27e6-f5a0-11e9-9649-0a252986adba, 800c56e4-354d-11eb-be38-3af9d323e83c)",
12 |         "id eq 4",
13 |         "id ne 4",
14 |         "4 eq id",
15 |         "4 ne id",
16 |         "published_at gt 2018-01-01",
17 |         "published_at ge 2018-01-01",
18 |         "published_at lt 2018-01-01",
19 |         "published_at le 2018-01-01",
20 |         "2018-01-01 gt published_at",
21 |         "2018-01-01 ge published_at",
22 |         "published_at gt 2018-01-01T01:02",
23 |         "published_at gt 2018-01-01T01:02:03",
24 |         "published_at gt 2018-01-01T01:02:03.123",
25 |         "published_at gt 2018-01-01T01:02:03.123456",
26 |         "published_at gt 2018-01-01T01:02Z",
27 |         "published_at gt 2018-01-01T01:02:03Z",
28 |         "published_at gt 2018-01-01T01:02:03.123Z",
29 |         "published_at gt 2018-01-01T01:02+02:00",
30 |         "published_at gt 2018-01-01T01:02-02:00",
31 |         "id in (1, 2, 3)",
32 |         "id eq null",
33 |         "id ne null",
34 |         "not (id eq 1)",
35 |         "id eq 1 or id eq 2",
36 |         "id eq 1 and content eq 'executing'",
37 |         "id eq 1 and (content eq 'executing' or content eq 'failed')",
38 |         "id eq 1 add 1",
39 |         "id eq 2 sub 1",
40 |         "id eq 2 mul 2",
41 |         "id eq 2 div 2",
42 |         "id eq 5 mod 4",
43 |         "id eq 2 add -1",
44 |         "id eq id sub 1",
45 |         "title eq 'donut' add 'tello'",
46 |         "title eq content add content",
47 |         "published_at eq 2019-01-01T00:00:00 add duration'P1DT1H1M1S'",
48 |         "contains(title, 'copy')",
49 |         "startswith(title, 'copy')",
50 |         "endswith(title, 'bla')",
51 |         "id eq length(title)",
52 |         "length(title) eq 10",
53 |         "10 eq length(title)",
54 |         "length(title) eq length('flippot')",
55 |         "title eq concat('a', 'b')",
56 |         "title eq concat('test', id)",
57 |         "title eq concat(concat('a', 'b'), 'c')",
58 |         "concat(title, 'a') eq 'testa'",
59 |         "indexof(title, 'Copy') eq 6",
60 |         "substring(title, 0) eq 'Copy'",
61 |         "substring(title, 0, 4) eq 'Copy'",
62 |         "matchesPattern(title, 'C.py')",
63 |         "tolower(title) eq 'copy'",
64 |         "toupper(title) eq 'COPY'",
65 |         "trim(title) eq 'copy'",
66 |         "date(published_at) eq 2019-01-01",
67 |         "day(published_at) eq 1",
68 |         "hour(published_at) eq 1",
69 |         "minute(published_at) eq 1",
70 |         "month(published_at) eq 1",
71 |         "published_at eq now()",
72 |         "second(published_at) eq 1",
73 |         "time(published_at) eq 14:00:00",
74 |         "year(published_at) eq 2019",
75 |         "ceiling(id) eq 1",
76 |         "floor(id) eq 1",
77 |         "round(id) eq 1",
78 |         "date(published_at) eq 2019-01-01 add duration'P1D'",
79 |         "date(published_at) eq 2019-01-01 add duration'-P1D'",
80 |         "authors/name eq 'Ruben'",
81 |         "authors/comments/content eq 'Cool!'",
82 |         "contains(comments/content, 'Cool')",
83 |         "contains(title, 'TEST') eq true",
84 |         # Precedence checks:
85 |         "1 mul (2 add -3 sub 4) div 5",
86 |     ],
87 | )
88 | def test_odata_filter_roundtrip(odata_query: str, lexer, parser):
89 |     ast = parser.parse(lexer.tokenize(odata_query))
90 |     transformer = AstToODataVisitor()
91 |     res = transformer.visit(ast)
92 | 
93 |     assert res == odata_query
94 | 
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests/unit/django/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/unit/django/__init__.py
--------------------------------------------------------------------------------
/tests/unit/sql/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gorilla-co/odata-query/2d30c4a90d6b0142b7d0fdcc27c8954e5ab95898/tests/unit/sql/__init__.py
--------------------------------------------------------------------------------
/tests/unit/sql/test_ast_to_sql.py:
--------------------------------------------------------------------------------
  1 | from typing import List
  2 | 
  3 | import pytest
  4 | 
  5 | from odata_query import ast, sql
  6 | 
  7 | 
  8 | @pytest.mark.parametrize(
  9 |     "ast_input, sql_expected",
 10 |     [
 11 |         (ast.Compare(ast.Eq(), ast.Integer("1"), ast.Integer("1")), "1 = 1"),
 12 |         (
 13 |             ast.Compare(ast.NotEq(), ast.Boolean("true"), ast.Boolean("false")),
 14 |             "TRUE != FALSE",
 15 |         ),
 16 |         (
 17 |             ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("123.12")),
 18 |             '"eac" <= 123.12',
 19 |         ),
 20 |         (
 21 |             ast.Compare(
 22 |                 ast.Lt(), ast.Identifier("period_start"), ast.Date("2019-01-01")
 23 |             ),
 24 |             "\"period_start\" < DATE '2019-01-01'",
 25 |         ),
 26 |         (
 27 |             ast.BoolOp(
 28 |                 ast.And(),
 29 |                 ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("123.12")),
 30 |                 ast.Compare(
 31 |                     ast.In(),
 32 |                     ast.Identifier("meter_id"),
 33 |                     ast.List([ast.String("1"), ast.String("2"), ast.String("3")]),
 34 |                 ),
 35 |             ),
 36 |             "\"eac\" >= 123.12 AND \"meter_id\" IN ('1', '2', '3')",
 37 |         ),
 38 |         (
 39 |             ast.BoolOp(
 40 |                 ast.And(),
 41 |                 ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("1")),
 42 |                 ast.BoolOp(
 43 |                     ast.Or(),
 44 |                     ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("10.0")),
 45 |                     ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("1.0")),
 46 |                 ),
 47 |             ),
 48 |             '"a" = \'1\' AND ("eac" <= 10.0 OR "eac" >= 1.0)',
 49 |         ),
 50 |     ],
 51 | )
 52 | def test_ast_to_sql(ast_input: ast._Node, sql_expected: str):
 53 |     visitor = sql.AstToSqlVisitor()
 54 |     res = visitor.visit(ast_input)
 55 | 
 56 |     assert res == sql_expected
 57 | 
 58 | 
 59 | @pytest.mark.parametrize(
 60 |     "func_name, args, sql_expected",
 61 |     [
 62 |         ("concat", [ast.String("ab"), ast.String("cd")], "'ab' || 'cd'"),
 63 |         (
 64 |             "contains",
 65 |             [ast.String("abc"), ast.String("b")],
 66 |             "'abc' LIKE '%b%'",
 67 |         ),
 68 |         (
 69 |             "endswith",
 70 |             [ast.String("abc"), ast.String("bc")],
 71 |             "'abc' LIKE '%bc'",
 72 |         ),
 73 |         (
 74 |             "indexof",
 75 |             [ast.String("abc"), ast.String("bc")],
 76 |             "POSITION('bc' IN 'abc') - 1",
 77 |         ),
 78 |         ("length", [ast.String("abc")], "CHAR_LENGTH('abc')"),
 79 |         (
 80 |             "length",
 81 |             [ast.List([ast.String("a"), ast.String("b")])],
 82 |             "CARDINALITY(('a', 'b'))",
 83 |         ),
 84 |         (
 85 |             "startswith",
 86 |             [ast.String("abc"), ast.String("ab")],
 87 |             "'abc' LIKE 'ab%'",
 88 |         ),
 89 |         (
 90 |             "substring",
 91 |             [ast.String("abc"), ast.Integer("1")],
 92 |             "SUBSTRING('abc' FROM 1 + 1)",
 93 |         ),
 94 |         (
 95 |             "substring",
 96 |             [ast.String("abcdef"), ast.Integer("1"), ast.Integer("2")],
 97 |             "SUBSTRING('abcdef' FROM 1 + 1 FOR 2)",
 98 |         ),
 99 |         ("tolower", [ast.String("ABC")], "LOWER('ABC')"),
100 |         ("toupper", [ast.String("abc")], "UPPER('abc')"),
101 |         ("trim", [ast.String(" abc ")], "TRIM(' abc ')"),
102 |         (
103 |             "year",
104 |             [ast.DateTime("2018-01-01T10:00:00")],
105 |             "EXTRACT (YEAR FROM TIMESTAMP '2018-01-01 10:00:00')",
106 |         ),
107 |         (
108 |             "month",
109 |             [ast.DateTime("2018-01-01T10:00:00")],
110 |             "EXTRACT (MONTH FROM TIMESTAMP '2018-01-01 10:00:00')",
111 |         ),
112 |         (
113 |             "day",
114 |             [ast.DateTime("2018-01-01T10:00:00")],
115 |             "EXTRACT (DAY FROM TIMESTAMP '2018-01-01 10:00:00')",
116 |         ),
117 |         (
118 |             "hour",
119 |             [ast.DateTime("2018-01-01T10:00:00")],
120 |             "EXTRACT (HOUR FROM TIMESTAMP '2018-01-01 10:00:00')",
121 |         ),
122 |         (
123 |             "minute",
124 |             [ast.DateTime("2018-01-01T10:00:00")],
125 |             "EXTRACT (MINUTE FROM TIMESTAMP '2018-01-01 10:00:00')",
126 |         ),
127 |         (
128 |             "date",
129 |             [ast.DateTime("2018-01-01T10:00:00")],
130 |             "CAST (TIMESTAMP '2018-01-01 10:00:00' AS DATE)",
131 |         ),
132 |         ("now", [], "CURRENT_TIMESTAMP"),
133 |         ("round", [ast.Float("123.12")], "CAST (123.12 + 0.5 AS INTEGER)"),
134 |         (
135 |             "floor",
136 |             [ast.Float("123.12")],
137 |             """CASE 123.12
138 |     WHEN > 0 CAST (123.12 AS INTEGER)
139 |     WHEN < 0 CAST (0 - (ABS(123.12) + 0.5) AS INTEGER))
140 |     ELSE 123.12
141 | END""",
142 |         ),
143 |         (
144 |             "ceiling",
145 |             [ast.Float("123.12")],
146 |             """CASE 123.12 - CAST (123.12 AS INTEGER)
147 |     WHEN > 0 123.12+1
148 |     WHEN < 0 123.12-1
149 |     ELSE 123.12
150 | END""",
151 |         ),
152 |     ],
153 | )
154 | def test_ast_to_sql_functions(func_name: str, args: List[ast._Node], sql_expected: str):
155 |     inp_ast = ast.Call(ast.Identifier(func_name), args)
156 |     visitor = sql.AstToSqlVisitor()
157 |     res = visitor.visit(inp_ast)
158 | 
159 |     assert res == sql_expected
160 | 
--------------------------------------------------------------------------------
/tests/unit/sql/test_ast_to_sqlite.py:
--------------------------------------------------------------------------------
  1 | from typing import List
  2 | 
  3 | import pytest
  4 | 
  5 | from odata_query import ast, sql
  6 | 
  7 | 
  8 | @pytest.mark.parametrize(
  9 |     "ast_input, sql_expected",
 10 |     [
 11 |         (ast.Compare(ast.Eq(), ast.Integer("1"), ast.Integer("1")), "1 = 1"),
 12 |         (
 13 |             ast.Compare(ast.NotEq(), ast.Boolean("true"), ast.Boolean("false")),
 14 |             "1 != 0",
 15 |         ),
 16 |         (
 17 |             ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("123.12")),
 18 |             '"eac" <= 123.12',
 19 |         ),
 20 |         (
 21 |             ast.Compare(
 22 |                 ast.Lt(),
 23 |                 ast.Identifier("period_start"),
 24 |                 ast.Date("2019-01-01"),
 25 |             ),
 26 |             "\"period_start\" < DATE('2019-01-01')",
 27 |         ),
 28 |         (
 29 |             ast.BoolOp(
 30 |                 ast.And(),
 31 |                 ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("123.12")),
 32 |                 ast.Compare(
 33 |                     ast.In(),
 34 |                     ast.Identifier("meter_id"),
 35 |                     ast.List([ast.String("1"), ast.String("2"), ast.String("3")]),
 36 |                 ),
 37 |             ),
 38 |             "\"eac\" >= 123.12 AND \"meter_id\" IN ('1', '2', '3')",
 39 |         ),
 40 |         (
 41 |             ast.BoolOp(
 42 |                 ast.And(),
 43 |                 ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("1")),
 44 |                 ast.BoolOp(
 45 |                     ast.Or(),
 46 |                     ast.Compare(ast.LtE(), ast.Identifier("eac"), ast.Float("10.0")),
 47 |                     ast.Compare(ast.GtE(), ast.Identifier("eac"), ast.Float("1.0")),
 48 |                 ),
 49 |             ),
 50 |             '"a" = \'1\' AND ("eac" <= 10.0 OR "eac" >= 1.0)',
 51 |         ),
 52 |     ],
 53 | )
 54 | def test_ast_to_sql(ast_input: ast._Node, sql_expected: str):
 55 |     visitor = sql.AstToSqliteSqlVisitor()
 56 |     res = visitor.visit(ast_input)
 57 | 
 58 |     assert res == sql_expected
 59 | 
 60 | 
 61 | @pytest.mark.parametrize(
 62 |     "func_name, args, sql_expected",
 63 |     [
 64 |         ("concat", [ast.String("ab"), ast.String("cd")], "'ab' || 'cd'"),
 65 |         (
 66 |             "contains",
 67 |             [ast.String("abc"), ast.String("b")],
 68 |             "'abc' LIKE '%b%'",
 69 |         ),
 70 |         (
 71 |             "endswith",
 72 |             [ast.String("abc"), ast.String("bc")],
 73 |             "'abc' LIKE '%bc'",
 74 |         ),
 75 |         (
 76 |             "indexof",
 77 |             [ast.String("abc"), ast.String("bc")],
 78 |             "INSTR('abc', 'bc') - 1",
 79 |         ),
 80 |         (
 81 |             "length",
 82 |             [ast.String("a")],
 83 |             "LENGTH('a')",
 84 |         ),
 85 |         (
 86 |             "length",
 87 |             [ast.Identifier("a")],
 88 |             'LENGTH("a")',
 89 |         ),
 90 |         (
 91 |             "startswith",
 92 |             [ast.String("abc"), ast.String("ab")],
 93 |             "'abc' LIKE 'ab%'",
 94 |         ),
 95 |         (
 96 |             "substring",
 97 |             [ast.String("abc"), ast.Integer("1")],
 98 |             "SUBSTR('abc', 1 + 1)",
 99 |         ),
100 |         (
101 |             "substring",
102 |             [ast.String("abcdef"), ast.Integer("1"), ast.Integer("2")],
103 |             "SUBSTR('abcdef', 1 + 1, 2)",
104 |         ),
105 |         ("tolower", [ast.String("ABC")], "LOWER('ABC')"),
106 |         ("toupper", [ast.String("abc")], "UPPER('abc')"),
107 |         ("trim", [ast.String(" abc ")], "TRIM(' abc ')"),
108 |         (
109 |             "year",
110 |             [ast.DateTime("2018-01-01T10:00:00")],
111 |             "CAST(STRFTIME('%Y', DATETIME('2018-01-01T10:00:00')) AS INTEGER)",
112 |         ),
113 |         (
114 |             "month",
115 |             [ast.DateTime("2018-01-01T10:00:00")],
116 |             "CAST(STRFTIME('%m', DATETIME('2018-01-01T10:00:00')) AS INTEGER)",
117 |         ),
118 |         (
119 |             "day",
120 |             [ast.DateTime("2018-01-01T10:00:00")],
121 |             "CAST(STRFTIME('%d', DATETIME('2018-01-01T10:00:00')) AS INTEGER)",
122 |         ),
123 |         (
124 |             "hour",
125 |             [ast.DateTime("2018-01-01T10:00:00")],
126 |             "CAST(STRFTIME('%H', DATETIME('2018-01-01T10:00:00')) AS INTEGER)",
127 |         ),
128 |         (
129 |             "minute",
130 |             [ast.DateTime("2018-01-01T10:00:00")],
131 |             "CAST(STRFTIME('%M', DATETIME('2018-01-01T10:00:00')) AS INTEGER)",
132 |         ),
133 |         (
134 |             "date",
135 |             [ast.DateTime("2018-01-01T10:00:00")],
136 |             "DATE(DATETIME('2018-01-01T10:00:00'))",
137 |         ),
138 |         ("now", [], "DATETIME('now')"),
139 |         ("round", [ast.Float("123.12")], "TRUNC(123.12 + 0.5)"),
140 |         ("floor", [ast.Float("123.12")], "FLOOR(123.12)"),
141 |         ("ceiling", [ast.Float("123.12")], "CEILING(123.12)"),
142 |     ],
143 | )
144 | def test_ast_to_sql_functions(func_name: str, args: List[ast._Node], sql_expected: str):
145 |     inp_ast = ast.Call(ast.Identifier(func_name), args)
146 |     visitor = sql.AstToSqliteSqlVisitor()
147 |     res = visitor.visit(inp_ast)
148 | 
149 |     assert res == sql_expected
150 | 
--------------------------------------------------------------------------------
/tests/unit/test_rewrite.py:
--------------------------------------------------------------------------------
 1 | from odata_query import ast
 2 | from odata_query.rewrite import AliasRewriter
 3 | 
 4 | 
 5 | def test_identifier_rewrite():
 6 |     rewriter = AliasRewriter({"a": "author"})
 7 |     _ast = ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("Bobby"))
 8 |     exp = ast.Compare(ast.Eq(), ast.Identifier("author"), ast.String("Bobby"))
 9 | 
10 |     res = rewriter.visit(_ast)
11 |     assert res == exp
12 | 
13 | 
14 | def test_identifier_to_attribute_rewrite():
15 |     rewriter = AliasRewriter({"a": "author/name"})
16 |     _ast = ast.Compare(ast.Eq(), ast.Identifier("a"), ast.String("Bobby"))
17 |     exp = ast.Compare(
18 |         ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby")
19 |     )
20 | 
21 |     res = rewriter.visit(_ast)
22 |     assert res == exp
23 | 
24 | 
25 | def test_attribute_to_identifier_rewrite():
26 |     rewriter = AliasRewriter({"author/name": "author_name"})
27 |     _ast = ast.Compare(
28 |         ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby")
29 |     )
30 |     exp = ast.Compare(ast.Eq(), ast.Identifier("author_name"), ast.String("Bobby"))
31 | 
32 |     res = rewriter.visit(_ast)
33 |     assert res == exp
34 | 
35 | 
36 | def test_attribute_to_attribute_rewrite():
37 |     rewriter = AliasRewriter({"author/name": "author/info/name"})
38 |     _ast = ast.Compare(
39 |         ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby")
40 |     )
41 |     exp = ast.Compare(
42 |         ast.Eq(),
43 |         ast.Attribute(ast.Attribute(ast.Identifier("author"), "info"), "name"),
44 |         ast.String("Bobby"),
45 |     )
46 | 
47 |     res = rewriter.visit(_ast)
48 |     assert res == exp
49 | 
50 | 
51 | def test_identifier_to_function_rewrite():
52 |     rewriter = AliasRewriter({"author_length": "length(author)"})
53 |     _ast = ast.Compare(ast.Eq(), ast.Identifier("author_length"), ast.Integer(10))
54 |     exp = ast.Compare(
55 |         ast.Eq(),
56 |         ast.Call(ast.Identifier("length"), [ast.Identifier("author")]),
57 |         ast.Integer(10),
58 |     )
59 | 
60 |     res = rewriter.visit(_ast)
61 |     assert res == exp
62 | 
63 | 
64 | def test_rewrite_attribute_owner():
65 |     rewriter = AliasRewriter({"a": "author"})
66 |     _ast = ast.Compare(
67 |         ast.Eq(), ast.Attribute(ast.Identifier("a"), "name"), ast.String("Bobby")
68 |     )
69 |     exp = ast.Compare(
70 |         ast.Eq(), ast.Attribute(ast.Identifier("author"), "name"), ast.String("Bobby")
71 |     )
72 | 
73 |     res = rewriter.visit(_ast)
74 |     assert res == exp
75 | 
--------------------------------------------------------------------------------
/tests/unit/test_typing.py:
--------------------------------------------------------------------------------
 1 | from typing import Type
 2 | 
 3 | import pytest
 4 | 
 5 | from odata_query import ast, typing
 6 | 
 7 | 
 8 | @pytest.mark.parametrize(
 9 |     "input_node, expected_type",
10 |     [
11 |         (ast.String("abc"), ast.String),
12 |         (ast.Compare(ast.Eq(), ast.String("abc"), ast.String("def")), ast.Boolean),
13 |         (
14 |             ast.Call(ast.Identifier("contains"), [ast.String("abc"), ast.String("b")]),
15 |             ast.Boolean,
16 |         ),
17 |         (ast.Identifier("a"), None),
18 |     ],
19 | )
20 | def test_infer_type_of_node(input_node: ast._Node, expected_type: Type[ast._Node]):
21 |     res = typing.infer_type(input_node)
22 | 
23 |     assert res is expected_type
24 | 
25 | 
26 | @pytest.mark.parametrize(
27 |     "input_node, expected_type",
28 |     [
29 |         (
30 |             ast.Call(ast.Identifier("contains"), [ast.String("abc"), ast.String("b")]),
31 |             ast.Boolean,
32 |         ),
33 |         (
34 |             ast.Call(ast.Identifier("indexof"), [ast.String("abc"), ast.String("b")]),
35 |             ast.Integer,
36 |         ),
37 |         (
38 |             ast.Call(ast.Identifier("floor"), [ast.Float("10.32")]),
39 |             ast.Float,
40 |         ),
41 |         (
42 |             ast.Call(ast.Identifier("date"), [ast.DateTime("2020-01-01T10:10:10")]),
43 |             ast.Date,
44 |         ),
45 |         (
46 |             ast.Call(ast.Identifier("maxdatetime"), []),
47 |             ast.DateTime,
48 |         ),
49 |         (
50 |             ast.Call(ast.Identifier("unknown_function"), []),
51 |             None,
52 |         ),
53 |     ],
54 | )
55 | def test_infer_return_type_of_call(
56 |     input_node: ast.Call, expected_type: Type[ast._Node]
57 | ):
58 |     res = typing.infer_type(input_node)
59 | 
60 |     assert res is expected_type
61 | 
--------------------------------------------------------------------------------
/tests/unit/test_utils.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from odata_query import ast, utils
 4 | 
 5 | 
 6 | @pytest.mark.parametrize(
 7 |     "expression, expected",
 8 |     [
 9 |         (ast.Attribute(ast.Identifier("id"), "name"), ast.Identifier("name")),
10 |         (
11 |             ast.Attribute(ast.Attribute(ast.Identifier("id"), "user"), "name"),
12 |             ast.Attribute(ast.Identifier("user"), "name"),
13 |         ),
14 |         (
15 |             ast.Attribute(ast.Identifier("something else"), "whatever"),
16 |             ast.Attribute(ast.Identifier("something else"), "whatever"),
17 |         ),
18 |         (
19 |             ast.Compare(
20 |                 ast.Eq(),
21 |                 ast.Attribute(ast.Identifier("id"), "name"),
22 |                 ast.String("Jozef"),
23 |             ),
24 |             ast.Compare(ast.Eq(), ast.Identifier("name"), ast.String("Jozef")),
25 |         ),
26 |         (
27 |             ast.Compare(
28 |                 ast.In(),
29 |                 ast.Attribute(ast.Identifier("id"), "name"),
30 |                 ast.List([ast.String("Jozef")]),
31 |             ),
32 |             ast.Compare(
33 |                 ast.In(), ast.Identifier("name"), ast.List([ast.String("Jozef")])
34 |             ),
35 |         ),
36 |     ],
37 | )
38 | def test_expression_relative_to_identifier(expression, expected):
39 |     res = utils.expression_relative_to_identifier(ast.Identifier("id"), expression)
40 | 
41 |     assert res == expected
42 | 
--------------------------------------------------------------------------------
/tests/unit/test_visitor.py:
--------------------------------------------------------------------------------
 1 | from copy import deepcopy
 2 | from unittest import mock
 3 | 
 4 | import pytest
 5 | 
 6 | from odata_query import ast, visitor
 7 | 
 8 | 
 9 | @pytest.fixture()
10 | def simple_ast():
11 |     return ast.Compare(
12 |         ast.Eq(),
13 |         ast.BoolOp(
14 |             ast.And(),
15 |             ast.Compare(ast.Eq(), ast.Identifier("meter_id"), ast.String("123")),
16 |             ast.Compare(
17 |                 ast.In(), ast.Identifier("category"), ast.List([ast.String("Gas")])
18 |             ),
19 |         ),
20 |         ast.Boolean("true"),
21 |     )
22 | 
23 | 
24 | def test_generic_visitor_calls_methods_based_on_node_class(simple_ast):
25 |     _visitor = visitor.NodeVisitor()
26 |     _visitor.visit_Identifier = mock.MagicMock()
27 |     _visitor.visit(simple_ast)
28 |     assert _visitor.visit_Identifier.call_count == 2
29 | 
30 | 
31 | def test_node_transformer_without_methods_doesnt_modify_tree(simple_ast):
32 |     transformer = visitor.NodeTransformer()
33 |     new_tree = transformer.visit(deepcopy(simple_ast))
34 |     assert simple_ast == new_tree
35 | 
36 | 
37 | def test_node_transformer_simple(simple_ast):
38 |     class Inverter(visitor.NodeTransformer):
39 |         """Simple transformer that inverts equality checks"""
40 | 
41 |         def visit_Compare(self, node: ast.Compare) -> ast.Compare:
42 |             left = self.visit(node.left)
43 |             right = self.visit(node.right)
44 |             if isinstance(node.comparator, ast.Eq):
45 |                 return ast.Compare(ast.NotEq(), left, right)
46 |             elif isinstance(node.comparator, ast.NotEq):
47 |                 return ast.Compare(ast.Eq(), left, right)
48 |             return ast.Compare(node.comparator, left, right)
49 | 
50 |     transformer = Inverter()
51 |     new_tree = transformer.visit(deepcopy(simple_ast))
52 | 
53 |     assert simple_ast != new_tree
54 |     assert isinstance(new_tree.comparator, ast.NotEq)
55 |     assert isinstance(new_tree.left.left.comparator, ast.NotEq)
56 | 
57 |     roundtrip_tree = transformer.visit(deepcopy(new_tree))
58 |     assert simple_ast == roundtrip_tree
59 | 
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
 1 | [tox]
 2 | envlist = py37-django3, py{38,39,310,311}-django{3,4}, linting, docs
 3 | skip_missing_interpreters = True
 4 | isolated_build = True
 5 | 
 6 | [gh-actions]
 7 | python =
 8 |     3.7: py37
 9 |     3.8: py38, linting, docs
10 |     3.9: py39
11 |     3.10: py310
12 |     3.11: py311
13 | 
14 | [testenv:linting]
15 | basepython = python3.8
16 | extras =
17 |     linting
18 | commands =
19 |     flake8 --show-source odata_query tests
20 |     black --check --diff odata_query tests
21 |     isort --check-only odata_query tests
22 |     mypy --ignore-missing-imports -p odata_query
23 |     vulture odata_query/ --min-confidence 80
24 | 
25 | [testenv:docs]
26 | basepython = python3.8
27 | extras =
28 |     docs
29 |     django
30 |     sqlalchemy
31 | changedir =
32 |     docs/
33 | commands =
34 |     sphinx-build source build
35 | 
36 | [testenv]
37 | deps =
38 |     django3: Django>=3.2,<4
39 |     django4: Django>=4,<5
40 | extras =
41 |     testing
42 |     django
43 |     sqlalchemy
44 | setenv =
45 |     DJANGO_SETTINGS_MODULE = tests.integration.django.settings
46 | passenv =
47 |     PYTHONBREAKPOINT
48 | commands =
49 |     pytest {posargs:tests/unit/ tests/integration/} -r fEs
50 | 
--------------------------------------------------------------------------------