├── docs ├── _config.yml └── index.md ├── django_cte ├── __init__.py ├── jitmixin.py ├── raw.py ├── join.py ├── meta.py ├── _deprecated.py ├── query.py └── cte.py ├── tests ├── test_v1 │ ├── __init__.py │ ├── test_raw.py │ ├── models.py │ ├── test_django.py │ ├── test_manager.py │ ├── test_combinators.py │ ├── test_recursive.py │ └── test_cte.py ├── settings.py ├── __init__.py ├── models.py ├── test_django.py ├── test_raw.py ├── django_setup.py ├── test_manager.py ├── test_combinators.py ├── test_recursive.py └── test_cte.py ├── LICENSE ├── README.md ├── .gitignore ├── pyproject.toml ├── .github └── workflows │ ├── pypi.yml │ └── tests.yml └── CHANGELOG.md /docs/_config.yml: -------------------------------------------------------------------------------- 1 | title: django-cte 2 | author: Dimagi 3 | markdown: kramdown 4 | kramdown: 5 | toc_levels: 2..3 -------------------------------------------------------------------------------- /django_cte/__init__.py: -------------------------------------------------------------------------------- 1 | from .cte import CTE, with_cte, CTEManager, CTEQuerySet, With # noqa 2 | 3 | __version__ = "2.0.1" 4 | __all__ = ["CTE", "with_cte"] 5 | -------------------------------------------------------------------------------- /tests/test_v1/__init__.py: -------------------------------------------------------------------------------- 1 | from unmagic import fixture 2 | 3 | from .. import ignore_v1_warnings 4 | 5 | 6 | @fixture(autouse=__file__) 7 | def ignore_v1_deprecations(): 8 | with ignore_v1_warnings(): 9 | yield 10 | 11 | 12 | @fixture(autouse=__file__, scope="class") 13 | def ignore_v1_deprecations_in_class_setup(): 14 | with ignore_v1_warnings(): 15 | yield 16 | 17 | 18 | with ignore_v1_warnings(): 19 | from . import models # noqa: F401 20 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | if "DB_SETTINGS" in os.environ: 5 | _db_settings = json.loads(os.environ["DB_SETTINGS"]) 6 | else: 7 | # sqlite3 by default 8 | # must be sqlite3 >= 3.8.3 supporting WITH clause 9 | # must be sqlite3 >= 3.35.0 supporting MATERIALIZED option 10 | _db_settings = { 11 | "ENGINE": "django.db.backends.sqlite3", 12 | "NAME": ":memory:", 13 | } 14 | 15 | DATABASES = {'default': _db_settings} 16 | 17 | INSTALLED_APPS = ["tests"] 18 | 19 | SECRET_KEY = "test" 20 | USE_TZ = False 21 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from contextlib import contextmanager 4 | 5 | import django 6 | from unmagic import fixture 7 | 8 | # django setup must occur before importing models 9 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") 10 | django.setup() 11 | 12 | from .django_setup import init_db, destroy_db # noqa 13 | 14 | 15 | @fixture(autouse=__file__, scope="package") 16 | def test_db(): 17 | with ignore_v1_warnings(): 18 | init_db() 19 | yield 20 | destroy_db() 21 | 22 | 23 | @contextmanager 24 | def ignore_v1_warnings(): 25 | msg = ( 26 | r"CTE(Manager|QuerySet) is deprecated.*" 27 | r"|" 28 | r"Use `django_cte\.with_cte\(.*\)` instead\." 29 | r"|" 30 | r"Use `django_cte\.CTE(\.recursive)?` instead\." 31 | ) 32 | with warnings.catch_warnings(): 33 | warnings.filterwarnings("ignore", message=msg, category=DeprecationWarning) 34 | yield 35 | -------------------------------------------------------------------------------- /django_cte/jitmixin.py: -------------------------------------------------------------------------------- 1 | def jit_mixin(obj, mixin): 2 | """Apply mixin to object and return the object""" 3 | if not isinstance(obj, mixin): 4 | obj.__class__ = jit_mixin_type(obj.__class__, mixin) 5 | return obj 6 | 7 | 8 | def jit_mixin_type(base, *mixins): 9 | assert not issubclass(base, mixins), (base, mixins) 10 | mixed = _mixin_cache.get((base, mixins)) 11 | if mixed is None: 12 | prefix = "".join(m._jit_mixin_prefix for m in mixins) 13 | name = f"{prefix}{base.__name__}" 14 | mixed = _mixin_cache[(base, mixins)] = type(name, (*mixins, base), { 15 | "_jit_mixin_base": getattr(base, "_jit_mixin_base", base), 16 | "_jit_mixins": mixins + getattr(base, "_jit_mixins", ()), 17 | }) 18 | return mixed 19 | 20 | 21 | _mixin_cache = {} 22 | 23 | 24 | class JITMixin: 25 | 26 | def __reduce__(self): 27 | # make JITMixin classes pickleable 28 | return (jit_mixin_type, (self._jit_mixin_base, *self._jit_mixins)) 29 | -------------------------------------------------------------------------------- /django_cte/raw.py: -------------------------------------------------------------------------------- 1 | def raw_cte_sql(sql, params, refs): 2 | """Raw CTE SQL 3 | 4 | :param sql: SQL query (string). 5 | :param params: List of bind parameters. 6 | :param refs: Dict of output fields: `{"name": }`. 7 | :returns: Object that can be passed to `With`. 8 | """ 9 | 10 | class raw_cte_ref: 11 | def __init__(self, output_field): 12 | self.output_field = output_field 13 | 14 | def get_source_expressions(self): 15 | return [] 16 | 17 | class raw_cte_compiler: 18 | 19 | def __init__(self, connection): 20 | self.connection = connection 21 | 22 | def as_sql(self): 23 | return sql, params 24 | 25 | def quote_name_unless_alias(self, name): 26 | return self.connection.ops.quote_name(name) 27 | 28 | class raw_cte_queryset: 29 | class query: 30 | @staticmethod 31 | def get_compiler(connection, *, elide_empty=None): 32 | return raw_cte_compiler(connection) 33 | 34 | @staticmethod 35 | def resolve_ref(name): 36 | return raw_cte_ref(refs[name]) 37 | 38 | return raw_cte_queryset 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Dimagi Inc., and individual contributors. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name Dimagi, nor the names of its contributors, may be used 12 | to endorse or promote products derived from this software without 13 | specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL DIMAGI INC. BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Common Table Expressions with Django 2 | 3 | [![Build Status](https://github.com/dimagi/django-cte/actions/workflows/tests.yml/badge.svg)](https://github.com/dimagi/django-cte/actions/workflows/tests.yml) 4 | [![PyPI version](https://badge.fury.io/py/django-cte.svg)](https://badge.fury.io/py/django-cte) 5 | 6 | ## Installation 7 | ``` 8 | pip install django-cte 9 | ``` 10 | 11 | 12 | ## Documentation 13 | 14 | The [django-cte documentation](https://dimagi.github.io/django-cte/) shows how 15 | to use Common Table Expressions with the Django ORM. 16 | 17 | 18 | ## Running tests 19 | 20 | ``` 21 | cd django-cte 22 | uv sync 23 | 24 | pytest 25 | ruff check 26 | 27 | # To run tests against postgres 28 | psql -U username -h localhost -p 5432 -c 'create database django_cte;' 29 | export PG_DB_SETTINGS='{ 30 | "ENGINE":"django.db.backends.postgresql_psycopg2", 31 | "NAME":"django_cte", 32 | "USER":"username", 33 | "PASSWORD":"password", 34 | "HOST":"localhost", 35 | "PORT":"5432"}' 36 | 37 | # WARNING pytest will delete the test_django_cte database if it exists! 38 | DB_SETTINGS="$PG_DB_SETTINGS" pytest 39 | ``` 40 | 41 | All feature and bug contributions are expected to be covered by tests. 42 | 43 | 44 | ## Publishing a new verison to PyPI 45 | 46 | Push a new tag to Github using the format vX.Y.Z where X.Y.Z matches the version 47 | in [`__init__.py`](django_cte/__init__.py). 48 | 49 | A new version is published to https://test.pypi.org/p/django-cte on every 50 | push to the *main* branch. 51 | 52 | Publishing is automated with [Github Actions](.github/workflows/pypi.yml). 53 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | from django.db.models import ( 2 | CASCADE, 3 | Manager, 4 | Model, 5 | QuerySet, 6 | AutoField, 7 | CharField, 8 | ForeignKey, 9 | IntegerField, 10 | TextField, 11 | ) 12 | 13 | 14 | class LT40QuerySet(QuerySet): 15 | 16 | def lt40(self): 17 | return self.filter(amount__lt=40) 18 | 19 | 20 | class LT25QuerySet(QuerySet): 21 | 22 | def lt25(self): 23 | return self.filter(amount__lt=25) 24 | 25 | 26 | class Region(Model): 27 | name = TextField(primary_key=True) 28 | parent = ForeignKey("self", null=True, on_delete=CASCADE) 29 | 30 | class Meta: 31 | db_table = "region" 32 | 33 | 34 | class User(Model): 35 | id = AutoField(primary_key=True) 36 | name = TextField() 37 | 38 | class Meta: 39 | db_table = "user" 40 | 41 | 42 | class Order(Model): 43 | id = AutoField(primary_key=True) 44 | region = ForeignKey(Region, on_delete=CASCADE) 45 | amount = IntegerField(default=0) 46 | user = ForeignKey(User, null=True, on_delete=CASCADE) 47 | 48 | class Meta: 49 | db_table = "orders" 50 | 51 | 52 | class OrderFromLT40(Order): 53 | class Meta: 54 | proxy = True 55 | objects = Manager.from_queryset(LT40QuerySet)() 56 | 57 | 58 | class OrderCustomManagerNQuery(Order): 59 | class Meta: 60 | proxy = True 61 | objects = Manager.from_queryset(LT25QuerySet)() 62 | 63 | 64 | class KeyPair(Model): 65 | key = CharField(max_length=32) 66 | value = IntegerField(default=0) 67 | parent = ForeignKey("self", null=True, on_delete=CASCADE) 68 | 69 | class Meta: 70 | db_table = "keypair" 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "django-cte" 3 | description = "Common Table Expressions (CTE) for Django" 4 | authors = [{name = "Daniel Miller", email = "millerdev@gmail.com"}] 5 | license = "BSD-3-Clause" 6 | license-files = ["LICENSE"] 7 | readme = {file = "README.md", content-type = "text/markdown"} 8 | dynamic = ["version"] 9 | requires-python = ">= 3.9" 10 | # Python and Django versions are read from this file by GitHub Actions. 11 | # Precise formatting is important. 12 | classifiers = [ 13 | "Development Status :: 5 - Production/Stable", 14 | 'Environment :: Web Environment', 15 | 'Intended Audience :: Developers', 16 | 'Operating System :: OS Independent', 17 | 'Programming Language :: Python', 18 | 'Programming Language :: Python :: 3', 19 | 'Programming Language :: Python :: 3.9', 20 | 'Programming Language :: Python :: 3.10', 21 | 'Programming Language :: Python :: 3.11', 22 | 'Programming Language :: Python :: 3.12', 23 | 'Programming Language :: Python :: 3.13', 24 | 'Programming Language :: Python :: 3.14', 25 | 'Framework :: Django', 26 | 'Framework :: Django :: 4', 27 | 'Framework :: Django :: 4.2', 28 | 'Framework :: Django :: 5', 29 | 'Framework :: Django :: 5.1', 30 | 'Framework :: Django :: 5.2', 31 | 'Topic :: Software Development :: Libraries :: Python Modules', 32 | ] 33 | dependencies = ["django"] 34 | 35 | [dependency-groups] 36 | dev = [ 37 | "psycopg2-binary", 38 | "pytest-unmagic", 39 | "ruff", 40 | ] 41 | 42 | [project.urls] 43 | Home = "https://github.com/dimagi/django-cte" 44 | 45 | [build-system] 46 | requires = ["flit_core >=3.2,<4"] 47 | build-backend = "flit_core.buildapi" 48 | 49 | [tool.flit.module] 50 | name = "django_cte" 51 | 52 | [tool.distutils.bdist_wheel] 53 | universal = true 54 | -------------------------------------------------------------------------------- /tests/test_django.py: -------------------------------------------------------------------------------- 1 | from unittest import SkipTest 2 | 3 | import django 4 | from django.db import OperationalError, ProgrammingError 5 | from django.db.models import Window 6 | from django.db.models.functions import Rank 7 | from django.test import TestCase 8 | 9 | from django_cte import CTE, with_cte 10 | 11 | from .models import Order, Region 12 | 13 | 14 | class WindowFunctions(TestCase): 15 | 16 | def test_heterogeneous_filter_in_cte(self): 17 | if django.VERSION < (4, 2): 18 | raise SkipTest("feature added in Django 4.2") 19 | cte = CTE( 20 | Order.objects.annotate( 21 | region_amount_rank=Window( 22 | Rank(), partition_by="region_id", order_by="-amount" 23 | ), 24 | ) 25 | .order_by("region_id") 26 | .values("region_id", "region_amount_rank") 27 | .filter(region_amount_rank=1, region_id__in=["sun", "moon"]) 28 | ) 29 | qs = with_cte(cte, select=cte.join(Region, name=cte.col.region_id)) 30 | print(qs.query) 31 | # ProgrammingError: column cte.region_id does not exist 32 | # WITH RECURSIVE "cte" AS (SELECT * FROM ( 33 | # SELECT "orders"."region_id" AS "col1", ... 34 | # "region" INNER JOIN "cte" ON "region"."name" = ("cte"."region_id") 35 | try: 36 | self.assertEqual({r.name for r in qs}, {"moon", "sun"}) 37 | except (OperationalError, ProgrammingError) as err: 38 | if "cte.region_id" in str(err): 39 | raise SkipTest( 40 | "window function auto-aliasing breaks CTE " 41 | "column references" 42 | ) 43 | raise 44 | if django.VERSION < (5, 2): 45 | assert 0, "unexpected pass" 46 | -------------------------------------------------------------------------------- /tests/test_raw.py: -------------------------------------------------------------------------------- 1 | from django.db.models import IntegerField, TextField 2 | from django.test import TestCase 3 | 4 | from django_cte import CTE, with_cte 5 | from django_cte.raw import raw_cte_sql 6 | 7 | from .models import Region 8 | 9 | int_field = IntegerField() 10 | text_field = TextField() 11 | 12 | 13 | class TestRawCTE(TestCase): 14 | 15 | def test_raw_cte_sql(self): 16 | cte = CTE(raw_cte_sql( 17 | """ 18 | SELECT region_id, AVG(amount) AS avg_order 19 | FROM orders 20 | WHERE region_id = %s 21 | GROUP BY region_id 22 | """, 23 | ["moon"], 24 | {"region_id": text_field, "avg_order": int_field}, 25 | )) 26 | moon_avg = with_cte( 27 | cte, select=cte.join(Region, name=cte.col.region_id) 28 | ).annotate(avg_order=cte.col.avg_order) 29 | print(moon_avg.query) 30 | 31 | data = [(r.name, r.parent.name, r.avg_order) for r in moon_avg] 32 | self.assertEqual(data, [('moon', 'earth', 2)]) 33 | 34 | def test_raw_cte_sql_name_escape(self): 35 | cte = CTE( 36 | raw_cte_sql( 37 | """ 38 | SELECT region_id, AVG(amount) AS avg_order 39 | FROM orders 40 | WHERE region_id = %s 41 | GROUP BY region_id 42 | """, 43 | ["moon"], 44 | {"region_id": text_field, "avg_order": int_field}, 45 | ), 46 | name="mixedCaseCTEName" 47 | ) 48 | moon_avg = with_cte( 49 | cte, select=cte.join(Region, name=cte.col.region_id) 50 | ).annotate(avg_order=cte.col.avg_order) 51 | self.assertTrue( 52 | str(moon_avg.query).startswith( 53 | 'WITH RECURSIVE "mixedCaseCTEName"') 54 | ) 55 | -------------------------------------------------------------------------------- /tests/test_v1/test_raw.py: -------------------------------------------------------------------------------- 1 | from django.db.models import IntegerField, TextField 2 | from django.test import TestCase 3 | 4 | from django_cte import With 5 | from django_cte.raw import raw_cte_sql 6 | 7 | from .models import Region 8 | 9 | int_field = IntegerField() 10 | text_field = TextField() 11 | 12 | 13 | class TestRawCTE(TestCase): 14 | 15 | def test_raw_cte_sql(self): 16 | cte = With(raw_cte_sql( 17 | """ 18 | SELECT region_id, AVG(amount) AS avg_order 19 | FROM orders 20 | WHERE region_id = %s 21 | GROUP BY region_id 22 | """, 23 | ["moon"], 24 | {"region_id": text_field, "avg_order": int_field}, 25 | )) 26 | moon_avg = ( 27 | cte 28 | .join(Region, name=cte.col.region_id) 29 | .annotate(avg_order=cte.col.avg_order) 30 | .with_cte(cte) 31 | ) 32 | print(moon_avg.query) 33 | 34 | data = [(r.name, r.parent.name, r.avg_order) for r in moon_avg] 35 | self.assertEqual(data, [('moon', 'earth', 2)]) 36 | 37 | def test_raw_cte_sql_name_escape(self): 38 | cte = With( 39 | raw_cte_sql( 40 | """ 41 | SELECT region_id, AVG(amount) AS avg_order 42 | FROM orders 43 | WHERE region_id = %s 44 | GROUP BY region_id 45 | """, 46 | ["moon"], 47 | {"region_id": text_field, "avg_order": int_field}, 48 | ), 49 | name="mixedCaseCTEName" 50 | ) 51 | moon_avg = ( 52 | cte 53 | .join(Region, name=cte.col.region_id) 54 | .annotate(avg_order=cte.col.avg_order) 55 | .with_cte(cte) 56 | ) 57 | self.assertTrue( 58 | str(moon_avg.query).startswith( 59 | 'WITH RECURSIVE "mixedCaseCTEName"') 60 | ) 61 | -------------------------------------------------------------------------------- /tests/test_v1/models.py: -------------------------------------------------------------------------------- 1 | from django.db.models import Manager 2 | 3 | from django_cte import CTEManager, CTEQuerySet 4 | 5 | from ..models import ( 6 | KeyPair as V2KeyPair, 7 | Order as V2Order, 8 | Region as V2Region, 9 | User, # noqa: F401 10 | ) 11 | 12 | 13 | class LT40QuerySet(CTEQuerySet): 14 | 15 | def lt40(self): 16 | return self.filter(amount__lt=40) 17 | 18 | 19 | class LT30QuerySet(CTEQuerySet): 20 | 21 | def lt30(self): 22 | return self.filter(amount__lt=30) 23 | 24 | 25 | class LT25QuerySet(CTEQuerySet): 26 | 27 | def lt25(self): 28 | return self.filter(amount__lt=25) 29 | 30 | 31 | class LTManager(CTEManager): 32 | pass 33 | 34 | 35 | class V1Region(V2Region): 36 | objects = CTEManager() 37 | 38 | class Meta: 39 | proxy = True 40 | 41 | 42 | Region = V1Region 43 | 44 | 45 | class V1Order(V2Order): 46 | objects = CTEManager() 47 | 48 | class Meta: 49 | proxy = True 50 | 51 | 52 | Order = V1Order 53 | 54 | 55 | class V1OrderFromLT40(Order): 56 | class Meta: 57 | proxy = True 58 | objects = CTEManager.from_queryset(LT40QuerySet)() 59 | 60 | 61 | class V1OrderLT40AsManager(Order): 62 | class Meta: 63 | proxy = True 64 | objects = LT40QuerySet.as_manager() 65 | 66 | 67 | class V1OrderCustomManagerNQuery(Order): 68 | class Meta: 69 | proxy = True 70 | objects = LTManager.from_queryset(LT25QuerySet)() 71 | 72 | 73 | class V1OrderCustomManager(Order): 74 | class Meta: 75 | proxy = True 76 | objects = LTManager() 77 | 78 | 79 | class V1OrderPlainManager(Order): 80 | class Meta: 81 | proxy = True 82 | objects = Manager() 83 | 84 | 85 | class V1KeyPair(V2KeyPair): 86 | objects = CTEManager() 87 | 88 | class Meta: 89 | proxy = True 90 | 91 | 92 | KeyPair = V1KeyPair 93 | OrderCustomManager = V1OrderCustomManager 94 | OrderCustomManagerNQuery = V1OrderCustomManagerNQuery 95 | OrderFromLT40 = V1OrderFromLT40 96 | OrderLT40AsManager = V1OrderLT40AsManager 97 | OrderPlainManager = V1OrderPlainManager 98 | -------------------------------------------------------------------------------- /tests/django_setup.py: -------------------------------------------------------------------------------- 1 | from django.db import connection 2 | 3 | from .models import KeyPair, Region, Order, User 4 | 5 | is_initialized = False 6 | 7 | 8 | def init_db(): 9 | global is_initialized 10 | if is_initialized: 11 | return 12 | is_initialized = True 13 | 14 | connection.creation.create_test_db(verbosity=0, autoclobber=True) 15 | 16 | setup_data() 17 | 18 | 19 | def destroy_db(): 20 | connection.creation.destroy_test_db(verbosity=0) 21 | 22 | 23 | def setup_data(): 24 | admin = User.objects.create(name="admin") 25 | 26 | regions = {None: None} 27 | for name, parent in [ 28 | ("sun", None), 29 | ("mercury", "sun"), 30 | ("venus", "sun"), 31 | ("earth", "sun"), 32 | ("moon", "earth"), 33 | ("mars", "sun"), 34 | ("deimos", "mars"), 35 | ("phobos", "mars"), 36 | ("proxima centauri", None), 37 | ("proxima centauri b", "proxima centauri"), 38 | ("bernard's star", None), 39 | ]: 40 | region = Region(name=name, parent=regions[parent]) 41 | region.save() 42 | regions[name] = region 43 | 44 | for region, amount in [ 45 | ("sun", 1000), 46 | ("mercury", 10), 47 | ("mercury", 11), 48 | ("mercury", 12), 49 | ("venus", 20), 50 | ("venus", 21), 51 | ("venus", 22), 52 | ("venus", 23), 53 | ("earth", 30), 54 | ("earth", 31), 55 | ("earth", 32), 56 | ("earth", 33), 57 | ("moon", 1), 58 | ("moon", 2), 59 | ("moon", 3), 60 | ("mars", 40), 61 | ("mars", 41), 62 | ("mars", 42), 63 | ("proxima centauri", 2000), 64 | ("proxima centauri b", 10), 65 | ("proxima centauri b", 11), 66 | ("proxima centauri b", 12), 67 | ]: 68 | order = Order(amount=amount, region=regions[region], user=admin) 69 | order.save() 70 | 71 | for key, value, parent in [ 72 | ("level 1", 1, None), 73 | ("level 2", 1, "level 1"), 74 | ("level 2", 2, "level 1"), 75 | ("level 3", 1, "level 2"), 76 | ]: 77 | parent = parent and KeyPair.objects.filter(key=parent).first() 78 | KeyPair.objects.create(key=key, value=value, parent=parent) 79 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python distribution to PyPI and TestPyPI 2 | # Source: 3 | # https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ 4 | on: 5 | push: 6 | branches: 7 | - main 8 | tags: 9 | - 'v*' 10 | workflow_dispatch: 11 | jobs: 12 | build: 13 | name: Build distribution package 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: astral-sh/setup-uv@v6 18 | with: 19 | version: '>=0.7' 20 | - name: Check for version match in git tag and django_cte.__version__ 21 | if: startsWith(github.ref, 'refs/tags/v') 22 | run: uvx pyverno check django_cte/__init__.py "${{ github.ref }}" 23 | - name: Add untagged version suffix 24 | if: ${{ ! startsWith(github.ref, 'refs/tags/v') }} 25 | run: uvx pyverno update django_cte/__init__.py 26 | - name: Build a binary wheel and a source tarball 27 | run: uv build 28 | - name: Store the distribution packages 29 | uses: actions/upload-artifact@v4 30 | with: 31 | name: python-package-distributions 32 | path: dist/ 33 | pypi-publish: 34 | name: Upload release to PyPI 35 | needs: [build] 36 | runs-on: ubuntu-latest 37 | environment: 38 | name: pypi 39 | url: https://pypi.org/p/django-cte 40 | permissions: 41 | id-token: write 42 | steps: 43 | - name: Download all the dists 44 | uses: actions/download-artifact@v4 45 | with: 46 | name: python-package-distributions 47 | path: dist/ 48 | - name: Publish package distributions to PyPI 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | pypi-test-publish: 51 | name: Upload release to test PyPI 52 | needs: [build] 53 | runs-on: ubuntu-latest 54 | environment: 55 | name: testpypi 56 | url: https://test.pypi.org/p/django-cte 57 | permissions: 58 | id-token: write 59 | steps: 60 | - name: Download all the dists 61 | uses: actions/download-artifact@v4 62 | with: 63 | name: python-package-distributions 64 | path: dist/ 65 | - name: Publish package distributions to PyPI 66 | uses: pypa/gh-action-pypi-publish@release/v1 67 | with: 68 | repository-url: https://test.pypi.org/legacy/ 69 | -------------------------------------------------------------------------------- /tests/test_manager.py: -------------------------------------------------------------------------------- 1 | from django.db.models.expressions import F 2 | from django.test import TestCase 3 | 4 | from django_cte import CTE, with_cte 5 | 6 | from .models import ( 7 | OrderFromLT40, 8 | OrderCustomManagerNQuery, 9 | LT40QuerySet, 10 | ) 11 | 12 | 13 | class TestCTE(TestCase): 14 | 15 | def test_cte_queryset_with_from_queryset(self): 16 | self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) 17 | 18 | cte = CTE( 19 | OrderFromLT40.objects 20 | .annotate(region_parent=F("region__parent_id")) 21 | .filter(region__parent_id="sun") 22 | ) 23 | orders = with_cte( 24 | cte, 25 | select=cte.queryset() 26 | .lt40() # custom queryset method 27 | .order_by("region_id", "amount") 28 | ) 29 | print(orders.query) 30 | 31 | data = [(x.region_id, x.amount, x.region_parent) for x in orders] 32 | self.assertEqual(data, [ 33 | ("earth", 30, "sun"), 34 | ("earth", 31, "sun"), 35 | ("earth", 32, "sun"), 36 | ("earth", 33, "sun"), 37 | ('mercury', 10, 'sun'), 38 | ('mercury', 11, 'sun'), 39 | ('mercury', 12, 'sun'), 40 | ('venus', 20, 'sun'), 41 | ('venus', 21, 'sun'), 42 | ('venus', 22, 'sun'), 43 | ('venus', 23, 'sun'), 44 | ]) 45 | 46 | def test_cte_queryset_with_custom_queryset(self): 47 | cte = CTE( 48 | OrderCustomManagerNQuery.objects 49 | .annotate(region_parent=F("region__parent_id")) 50 | .filter(region__parent_id="sun") 51 | ) 52 | orders = with_cte( 53 | cte, 54 | select=cte.queryset() 55 | .lt25() # custom queryset method 56 | .order_by("region_id", "amount") 57 | ) 58 | print(orders.query) 59 | 60 | data = [(x.region_id, x.amount, x.region_parent) for x in orders] 61 | self.assertEqual(data, [ 62 | ('mercury', 10, 'sun'), 63 | ('mercury', 11, 'sun'), 64 | ('mercury', 12, 'sun'), 65 | ('venus', 20, 'sun'), 66 | ('venus', 21, 'sun'), 67 | ('venus', 22, 'sun'), 68 | ('venus', 23, 'sun'), 69 | ]) 70 | 71 | def test_cte_queryset_with_deferred_loading(self): 72 | cte = CTE( 73 | OrderCustomManagerNQuery.objects.order_by("id").only("id")[:1] 74 | ) 75 | orders = with_cte(cte, select=cte) 76 | print(orders.query) 77 | 78 | self.assertEqual([x.id for x in orders], [1]) 79 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: django-cte tests 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | configure: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Read Python versions from pyproject.toml 15 | id: read-python-versions 16 | # produces output like: python_versions=[ "3.9", "3.10", "3.11", "3.12" ] 17 | run: >- 18 | echo "python_versions=$( 19 | grep -oP '(?<=Language :: Python :: )\d\.\d+' pyproject.toml 20 | | jq --raw-input . 21 | | jq --slurp . 22 | | tr '\n' ' ' 23 | )" >> $GITHUB_OUTPUT 24 | - name: Read Django versions from pyproject.toml 25 | id: read-django-versions 26 | # django_versions=[ "Django~=4.2.0", "Django~=5.1.0", "Django~=5.2.0" ] 27 | run: >- 28 | echo "django_versions=$( 29 | grep -oP '(?<=Framework :: Django :: )\d+\.\d+' pyproject.toml 30 | | sed -E 's/(.+)/Django~=\1.0/' 31 | | jq --raw-input . 32 | | jq --slurp . 33 | | tr '\n' ' ' 34 | )" >> $GITHUB_OUTPUT 35 | outputs: 36 | python_versions: ${{ steps.read-python-versions.outputs.python_versions }} 37 | django_versions: ${{ steps.read-django-versions.outputs.django_versions }} 38 | 39 | tests: 40 | needs: [configure] 41 | runs-on: ubuntu-latest 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | python: ${{ fromJSON(needs.configure.outputs.python_versions) }} 46 | django: ${{ fromJSON(needs.configure.outputs.django_versions) }} 47 | exclude: 48 | - {python: '3.9', django: 'Django~=5.1.0'} 49 | - {python: '3.9', django: 'Django~=5.2.0'} 50 | env: 51 | allowed_python_failure: '3.14' 52 | services: 53 | postgres: 54 | image: postgres:latest 55 | env: 56 | POSTGRES_DB: postgres 57 | POSTGRES_PASSWORD: postgres 58 | POSTGRES_USER: postgres 59 | ports: 60 | - 5432:5432 61 | options: >- 62 | --health-cmd pg_isready 63 | --health-interval 10s 64 | --health-timeout 5s 65 | --health-retries 5 66 | steps: 67 | - uses: actions/checkout@v4 68 | - uses: astral-sh/setup-uv@v6 69 | with: 70 | version: '>=0.7' 71 | python-version: ${{ matrix.python }} 72 | - name: Setup 73 | run: | 74 | uv sync --locked --no-install-package=django 75 | uv pip install "${{ matrix.django }}" 76 | - name: Run tests on PostgreSQL 77 | env: 78 | DB_SETTINGS: >- 79 | { 80 | "ENGINE":"django.db.backends.postgresql_psycopg2", 81 | "NAME":"django_cte", 82 | "USER":"postgres", 83 | "PASSWORD":"postgres", 84 | "HOST":"localhost", 85 | "PORT":"5432" 86 | } 87 | run: .venv/bin/pytest -v 88 | continue-on-error: ${{ matrix.python == env.allowed_python_failure }} 89 | - name: Run tests on SQLite 90 | run: .venv/bin/pytest -v 91 | continue-on-error: ${{ matrix.python == env.allowed_python_failure }} 92 | - name: Check style 93 | run: .venv/bin/ruff check 94 | -------------------------------------------------------------------------------- /django_cte/join.py: -------------------------------------------------------------------------------- 1 | from django.db.models.sql.constants import INNER 2 | 3 | 4 | class QJoin: 5 | """Join clause with join condition from Q object clause 6 | 7 | :param parent_alias: Alias of parent table. 8 | :param table_name: Name of joined table. 9 | :param table_alias: Alias of joined table. 10 | :param on_clause: Query `where_class` instance represenging the ON clause. 11 | :param join_type: Join type (INNER or LOUTER). 12 | """ 13 | 14 | filtered_relation = None 15 | 16 | def __init__(self, parent_alias, table_name, table_alias, 17 | on_clause, join_type=INNER, nullable=None): 18 | self.parent_alias = parent_alias 19 | self.table_name = table_name 20 | self.table_alias = table_alias 21 | self.on_clause = on_clause 22 | self.join_type = join_type # LOUTER or INNER 23 | self.nullable = join_type != INNER if nullable is None else nullable 24 | 25 | @property 26 | def identity(self): 27 | return ( 28 | self.__class__, 29 | self.table_name, 30 | self.parent_alias, 31 | self.join_type, 32 | self.on_clause, 33 | ) 34 | 35 | def __hash__(self): 36 | return hash(self.identity) 37 | 38 | def __eq__(self, other): 39 | if not isinstance(other, QJoin): 40 | return NotImplemented 41 | return self.identity == other.identity 42 | 43 | def equals(self, other): 44 | return self.identity == other.identity 45 | 46 | def as_sql(self, compiler, connection): 47 | """Generate join clause SQL""" 48 | on_clause_sql, params = self.on_clause.as_sql(compiler, connection) 49 | if self.table_alias == self.table_name: 50 | alias = '' 51 | else: 52 | alias = ' %s' % self.table_alias 53 | qn = compiler.quote_name_unless_alias 54 | sql = '%s %s%s ON %s' % ( 55 | self.join_type, 56 | qn(self.table_name), 57 | alias, 58 | on_clause_sql 59 | ) 60 | return sql, params 61 | 62 | def relabeled_clone(self, change_map): 63 | return self.__class__( 64 | parent_alias=change_map.get(self.parent_alias, self.parent_alias), 65 | table_name=self.table_name, 66 | table_alias=change_map.get(self.table_alias, self.table_alias), 67 | on_clause=self.on_clause.relabeled_clone(change_map), 68 | join_type=self.join_type, 69 | nullable=self.nullable, 70 | ) 71 | 72 | class join_field: 73 | # `Join.join_field` is used internally by `Join` as well as in 74 | # `QuerySet.resolve_expression()`: 75 | # 76 | # isinstance(table, Join) 77 | # and table.join_field.related_model._meta.db_table != alias 78 | # 79 | # Currently that does not apply here since `QJoin` is not an 80 | # instance of `Join`, although maybe it should? Maybe this 81 | # should have `related_model._meta.db_table` return 82 | # `.table_name` or `.table_alias`? 83 | # 84 | # `PathInfo.join_field` is another similarly named attribute in 85 | # Django that has a much more complicated interface, but luckily 86 | # seems unrelated to `Join.join_field`. 87 | 88 | class related_model: 89 | class _meta: 90 | # for QuerySet.set_group_by(allow_aliases=True) 91 | local_concrete_fields = () 92 | -------------------------------------------------------------------------------- /tests/test_v1/test_django.py: -------------------------------------------------------------------------------- 1 | from unittest import SkipTest 2 | 3 | import django 4 | from django.db import OperationalError, ProgrammingError 5 | from django.db.models import Window 6 | from django.db.models.functions import Rank 7 | from django.test import TestCase, skipUnlessDBFeature 8 | 9 | from .models import Order, Region, User 10 | 11 | 12 | @skipUnlessDBFeature("supports_select_union") 13 | class NonCteQueries(TestCase): 14 | """Test non-CTE queries 15 | 16 | These tests were adapted from the Django test suite. The models used 17 | here use CTEManager and CTEQuerySet to verify feature parity with 18 | their base classes Manager and QuerySet. 19 | """ 20 | 21 | @classmethod 22 | def setUpTestData(cls): 23 | Order.objects.all().delete() 24 | 25 | def test_union_with_select_related_and_order(self): 26 | e1 = User.objects.create(name="e1") 27 | a1 = Order.objects.create(region_id="earth", user=e1) 28 | a2 = Order.objects.create(region_id="moon", user=e1) 29 | Order.objects.create(region_id="sun", user=e1) 30 | base_qs = Order.objects.select_related("user").order_by() 31 | qs1 = base_qs.filter(region_id="earth") 32 | qs2 = base_qs.filter(region_id="moon") 33 | print(qs1.union(qs2).order_by("pk").query) 34 | self.assertSequenceEqual(qs1.union(qs2).order_by("pk"), [a1, a2]) 35 | 36 | @skipUnlessDBFeature("supports_slicing_ordering_in_compound") 37 | def test_union_with_select_related_and_first(self): 38 | e1 = User.objects.create(name="e1") 39 | a1 = Order.objects.create(region_id="earth", user=e1) 40 | Order.objects.create(region_id="moon", user=e1) 41 | base_qs = Order.objects.select_related("user") 42 | qs1 = base_qs.filter(region_id="earth") 43 | qs2 = base_qs.filter(region_id="moon") 44 | self.assertEqual(qs1.union(qs2).first(), a1) 45 | 46 | def test_union_with_first(self): 47 | e1 = User.objects.create(name="e1") 48 | a1 = Order.objects.create(region_id="earth", user=e1) 49 | base_qs = Order.objects.order_by() 50 | qs1 = base_qs.filter(region_id="earth") 51 | qs2 = base_qs.filter(region_id="moon") 52 | self.assertEqual(qs1.union(qs2).first(), a1) 53 | 54 | 55 | class WindowFunctions(TestCase): 56 | 57 | def test_heterogeneous_filter_in_cte(self): 58 | if django.VERSION < (4, 2): 59 | raise SkipTest("feature added in Django 4.2") 60 | from django_cte import With 61 | cte = With( 62 | Order.objects.annotate( 63 | region_amount_rank=Window( 64 | Rank(), partition_by="region_id", order_by="-amount" 65 | ), 66 | ) 67 | .order_by("region_id") 68 | .values("region_id", "region_amount_rank") 69 | .filter(region_amount_rank=1, region_id__in=["sun", "moon"]) 70 | ) 71 | qs = cte.join(Region, name=cte.col.region_id).with_cte(cte) 72 | print(qs.query) 73 | # ProgrammingError: column cte.region_id does not exist 74 | # WITH RECURSIVE "cte" AS (SELECT * FROM ( 75 | # SELECT "orders"."region_id" AS "col1", ... 76 | # "region" INNER JOIN "cte" ON "region"."name" = ("cte"."region_id") 77 | try: 78 | self.assertEqual({r.name for r in qs}, {"moon", "sun"}) 79 | except (OperationalError, ProgrammingError) as err: 80 | if "cte.region_id" in str(err): 81 | raise SkipTest( 82 | "window function auto-aliasing breaks CTE " 83 | "column references" 84 | ) 85 | raise 86 | if django.VERSION < (5, 2): 87 | assert 0, "unexpected pass" 88 | -------------------------------------------------------------------------------- /django_cte/meta.py: -------------------------------------------------------------------------------- 1 | import weakref 2 | 3 | from django.db.models.expressions import Col, Expression 4 | 5 | 6 | class CTEColumns: 7 | 8 | def __init__(self, cte): 9 | self._cte = weakref.ref(cte) 10 | 11 | def __getattr__(self, name): 12 | return CTEColumn(self._cte(), name) 13 | 14 | 15 | class CTEColumn(Expression): 16 | 17 | def __init__(self, cte, name, output_field=None): 18 | self._cte = cte 19 | self.table_alias = cte.name 20 | self.name = self.alias = name 21 | self._output_field = output_field 22 | 23 | def __repr__(self): 24 | return "<{} {}.{}>".format( 25 | self.__class__.__name__, 26 | self._cte.name, 27 | self.name, 28 | ) 29 | 30 | @property 31 | def _ref(self): 32 | if self._cte.query is None: 33 | raise ValueError( 34 | "cannot resolve '{cte}.{name}' in recursive CTE setup. " 35 | "Hint: use ExpressionWrapper({cte}.col.{name}, " 36 | "output_field=...)".format(cte=self._cte.name, name=self.name) 37 | ) 38 | ref = self._cte._resolve_ref(self.name) 39 | if ref is self or self in ref.get_source_expressions(): 40 | raise ValueError("Circular reference: {} = {}".format(self, ref)) 41 | return ref 42 | 43 | @property 44 | def target(self): 45 | return self._ref.target 46 | 47 | @property 48 | def output_field(self): 49 | # required to fix error caused by django commit 50 | # 9d519d3dc4e5bd1d9ff3806b44624c3e487d61c1 51 | if self._cte.query is None: 52 | raise AttributeError 53 | 54 | if self._output_field is not None: 55 | return self._output_field 56 | return self._ref.output_field 57 | 58 | def as_sql(self, compiler, connection): 59 | qn = compiler.quote_name_unless_alias 60 | ref = self._ref 61 | if isinstance(ref, Col) and self.name == "pk": 62 | column = ref.target.column 63 | else: 64 | column = self.name 65 | return "%s.%s" % (qn(self.table_alias), qn(column)), [] 66 | 67 | def relabeled_clone(self, relabels): 68 | if self.table_alias is not None and self.table_alias in relabels: 69 | clone = self.copy() 70 | clone.table_alias = relabels[self.table_alias] 71 | return clone 72 | return self 73 | 74 | 75 | class CTEColumnRef(Expression): 76 | 77 | def __init__(self, name, cte_name, output_field): 78 | self.name = name 79 | self.cte_name = cte_name 80 | self.output_field = output_field 81 | self._alias = None 82 | 83 | def resolve_expression(self, query=None, allow_joins=True, reuse=None, 84 | summarize=False, for_save=False): 85 | if query: 86 | clone = self.copy() 87 | clone._alias = self._alias or query.table_map.get( 88 | self.cte_name, [self.cte_name])[0] 89 | return clone 90 | return super().resolve_expression( 91 | query, allow_joins, reuse, summarize, for_save) 92 | 93 | def relabeled_clone(self, change_map): 94 | if ( 95 | self.cte_name not in change_map 96 | and self._alias not in change_map 97 | ): 98 | return super().relabeled_clone(change_map) 99 | 100 | clone = self.copy() 101 | if self.cte_name in change_map: 102 | clone._alias = change_map[self.cte_name] 103 | 104 | if self._alias in change_map: 105 | clone._alias = change_map[self._alias] 106 | return clone 107 | 108 | def as_sql(self, compiler, connection): 109 | qn = compiler.quote_name_unless_alias 110 | table = self._alias or compiler.query.table_map.get( 111 | self.cte_name, [self.cte_name])[0] 112 | return "%s.%s" % (qn(table), qn(self.name)), [] 113 | -------------------------------------------------------------------------------- /tests/test_v1/test_manager.py: -------------------------------------------------------------------------------- 1 | from django.db.models.expressions import F 2 | from django.db.models.query import QuerySet 3 | from django.test import TestCase 4 | 5 | from django_cte import With, CTEQuerySet, CTEManager 6 | 7 | from .models import ( 8 | Order, 9 | OrderFromLT40, 10 | OrderLT40AsManager, 11 | OrderCustomManagerNQuery, 12 | OrderCustomManager, 13 | LT40QuerySet, 14 | LTManager, 15 | LT25QuerySet, 16 | ) 17 | 18 | 19 | class TestCTE(TestCase): 20 | def test_cte_queryset_correct_defaultmanager(self): 21 | self.assertEqual(type(Order._default_manager), CTEManager) 22 | self.assertEqual(type(Order.objects.all()), CTEQuerySet) 23 | 24 | def test_cte_queryset_correct_from_queryset(self): 25 | self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) 26 | 27 | def test_cte_queryset_correct_queryset_as_manager(self): 28 | self.assertEqual(type(OrderLT40AsManager.objects.all()), LT40QuerySet) 29 | 30 | def test_cte_queryset_correct_manager_n_from_queryset(self): 31 | self.assertIsInstance( 32 | OrderCustomManagerNQuery._default_manager, LTManager) 33 | self.assertEqual(type( 34 | OrderCustomManagerNQuery.objects.all()), LT25QuerySet) 35 | 36 | def test_cte_create_manager_from_non_cteQuery(self): 37 | class BrokenQuerySet(QuerySet): 38 | "This should be a CTEQuerySet if we want this to work" 39 | 40 | with self.assertRaises(TypeError): 41 | CTEManager.from_queryset(BrokenQuerySet)() 42 | 43 | def test_cte_queryset_correct_limitedmanager(self): 44 | self.assertEqual(type(OrderCustomManager._default_manager), LTManager) 45 | # Check the expected even if not ideal behavior occurs 46 | self.assertIsInstance(OrderCustomManager.objects.all(), CTEQuerySet) 47 | 48 | def test_cte_queryset_with_from_queryset(self): 49 | self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) 50 | 51 | cte = With( 52 | OrderFromLT40.objects 53 | .annotate(region_parent=F("region__parent_id")) 54 | .filter(region__parent_id="sun") 55 | ) 56 | orders = ( 57 | cte.queryset() 58 | .with_cte(cte) 59 | .lt40() # custom queryset method 60 | .order_by("region_id", "amount") 61 | ) 62 | print(orders.query) 63 | 64 | data = [(x.region_id, x.amount, x.region_parent) for x in orders] 65 | self.assertEqual(data, [ 66 | ("earth", 30, "sun"), 67 | ("earth", 31, "sun"), 68 | ("earth", 32, "sun"), 69 | ("earth", 33, "sun"), 70 | ('mercury', 10, 'sun'), 71 | ('mercury', 11, 'sun'), 72 | ('mercury', 12, 'sun'), 73 | ('venus', 20, 'sun'), 74 | ('venus', 21, 'sun'), 75 | ('venus', 22, 'sun'), 76 | ('venus', 23, 'sun'), 77 | ]) 78 | 79 | def test_cte_queryset_with_custom_queryset(self): 80 | cte = With( 81 | OrderCustomManagerNQuery.objects 82 | .annotate(region_parent=F("region__parent_id")) 83 | .filter(region__parent_id="sun") 84 | ) 85 | orders = ( 86 | cte.queryset() 87 | .with_cte(cte) 88 | .lt25() # custom queryset method 89 | .order_by("region_id", "amount") 90 | ) 91 | print(orders.query) 92 | 93 | data = [(x.region_id, x.amount, x.region_parent) for x in orders] 94 | self.assertEqual(data, [ 95 | ('mercury', 10, 'sun'), 96 | ('mercury', 11, 'sun'), 97 | ('mercury', 12, 'sun'), 98 | ('venus', 20, 'sun'), 99 | ('venus', 21, 'sun'), 100 | ('venus', 22, 'sun'), 101 | ('venus', 23, 'sun'), 102 | ]) 103 | 104 | def test_cte_queryset_with_deferred_loading(self): 105 | cte = With( 106 | OrderCustomManagerNQuery.objects.order_by("id").only("id")[:1] 107 | ) 108 | orders = cte.queryset().with_cte(cte) 109 | print(orders.query) 110 | 111 | self.assertEqual([x.id for x in orders], [1]) 112 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Django CTE change log 2 | 3 | ## 2.0.0 - 2025-06-16 4 | 5 | - **API overhaul** 6 | - `With` has been renamed to `CTE`. `With` is deprecated and will be removed 7 | in a future version of django-cte. 8 | - `with_cte` was moved from a `CTEQuerySet` method to a stand-alone function. 9 | - `CTEManager` and `CTEQuerySet` are deprecated and should be removed from 10 | code that uses them, as they are no longer necessary. They will be removed 11 | in a future version of django-cte. 12 | - Reference the [documentation](https://dimagi.github.io/django-cte/) for new 13 | usage patterns. 14 | - **BREAKING:** On Django 5.2 and later, the name specified in 15 | `.values('fk_name')` must match the name of the same column referenced by 16 | `cte.col.fk_name`—for example, in a join condition. It may end with `_id` or 17 | not, but the references must be consistent. This change may require previously 18 | working CTE queries to be adjusted when migrating to Django 5.2 19 | ([example](https://github.com/dimagi/django-cte/commit/321d92cd8d1edd515c1f5000a3b12c35265aa4f8)). 20 | - Django 5.0 is EOL and no longer supported. 21 | - Fixed broken `UNION` and other "combined" queries. 22 | - Internally, the library has been updated to simplify the code and remove 23 | workarounds for old and unsupported versions of Django. 24 | - Modernized development tooling 25 | - Replaced _nosetests_ with _pytest_. 26 | - Replaced _setup.py_ with _pyproject.toml_ 27 | - Replaced _flake8_ with _ruff_. 28 | - Replaced _venv/pip_ with _uv_. 29 | - Improved Github Actions automation, including automated releases. 30 | - Dev versions of django-cte are now published on PyPI, making them easier to 31 | test and use before an official release is cut. 32 | 33 | ## 1.3.3 - 2024-06-07 34 | 35 | - Handle empty result sets in CTEs ([#92](https://github.com/dimagi/django-cte/pull/92)). 36 | - Fix `.explain()` in Django >= 4.0 ([#91](https://github.com/dimagi/django-cte/pull/91)). 37 | - Fixed bug in deferred loading ([#90](https://github.com/dimagi/django-cte/pull/90)). 38 | 39 | ## 1.3.2 - 2023-11-20 40 | 41 | - Work around changes in Django 4.2 that broke CTE queries due to internally 42 | generated column aliases in the query compiler. The workaround is not always 43 | effective. Some queries will produce mal-formed SQL. For example, CTE queries 44 | with window functions. 45 | 46 | ## 1.3.1 - 2023-06-13 47 | 48 | - Fix: `.update()` did not work when using CTE manager or when accessing nested 49 | tables. 50 | 51 | ## 1.3.0 - 2023-05-24 52 | 53 | - Add support for Materialized CTEs. 54 | - Fix: add EXPLAIN clause in correct position when using `.explain()` method. 55 | 56 | ## v1.2.1 - 2022-07-07 57 | 58 | - Fix compatibility with non-CTE models. 59 | 60 | ## v1.2.0 - 2022-03-30 61 | 62 | - Add support for Django 3.1, 3.2 and 4.0. 63 | - Quote the CTE table name if needed. 64 | - Resolve `OuterRef` in CTE `Subquery`. 65 | - Fix default `CTEManager` so it can use `from_queryset` corectly. 66 | - Fix for Django 3.0.5+. 67 | 68 | ## v1.1.5 - 2020-02-07 69 | 70 | - Django 3 compatibility. Thank you @tim-schilling and @ryanhiebert! 71 | 72 | ## v1.1.4 - 2018-07-30 73 | 74 | - Python 3 compatibility. 75 | 76 | ## v1.1.3 - 2018-06-19 77 | 78 | - Fix CTE alias bug. 79 | 80 | ## v1.1.2 - 2018-05-22 81 | 82 | - Use `_default_manager` instead of `objects`. 83 | 84 | ## v1.1.1 - 2018-04-13 85 | 86 | - Fix recursive CTE pickling. Note: this is currently [broken on Django 87 | master](https://github.com/django/django/pull/9134#pullrequestreview-112057277). 88 | 89 | ## v1.1.0 - 2018-04-09 90 | 91 | - `With.queryset()` now uses the CTE model's manager to create a new `QuerySet`, 92 | which makes it easier to work with custom `QuerySet` classes. 93 | 94 | ## v1.0.0 - 2018-04-04 95 | 96 | - BACKWARD INCOMPATIBLE CHANGE: `With.queryset()` no longer accepts a `model` 97 | argument. 98 | - Improve `With.queryset()` to select directly from the CTE rather than 99 | joining to anoter QuerySet. 100 | - Refactor `With.join()` to use real JOIN clause. 101 | 102 | ## v0.1.4 - 2018-03-21 103 | 104 | - Fix related field attname masking CTE column. 105 | 106 | ## v0.1.3 - 2018-03-15 107 | 108 | - Add `django_cte.raw.raw_cte_sql` for constructing CTEs with raw SQL. 109 | 110 | ## v0.1.2 - 2018-02-21 111 | 112 | - Improve error on bad recursive reference. 113 | - Add more tests. 114 | - Add change log. 115 | - Improve README. 116 | - PEP-8 style fixes. 117 | 118 | ## v0.1.1 - 2018-02-21 119 | 120 | - Fix readme formatting on PyPI. 121 | 122 | ## v0.1 - 2018-02-21 123 | 124 | - Initial implementation. 125 | -------------------------------------------------------------------------------- /django_cte/_deprecated.py: -------------------------------------------------------------------------------- 1 | try: 2 | from warnings import deprecated 3 | except ImportError: 4 | from warnings import warn 5 | 6 | # Copied from Python 3.13, lightly modified for Python 3.9 compatibility. 7 | # Can be removed when the oldest supported Python version is 3.13. 8 | class deprecated: 9 | """Indicate that a class, function or overload is deprecated. 10 | 11 | When this decorator is applied to an object, the type checker 12 | will generate a diagnostic on usage of the deprecated object. 13 | 14 | Usage: 15 | 16 | @deprecated("Use B instead") 17 | class A: 18 | pass 19 | 20 | @deprecated("Use g instead") 21 | def f(): 22 | pass 23 | 24 | @overload 25 | @deprecated("int support is deprecated") 26 | def g(x: int) -> int: ... 27 | @overload 28 | def g(x: str) -> int: ... 29 | 30 | The warning specified by *category* will be emitted at runtime 31 | on use of deprecated objects. For functions, that happens on calls; 32 | for classes, on instantiation and on creation of subclasses. 33 | If the *category* is ``None``, no warning is emitted at runtime. 34 | The *stacklevel* determines where the 35 | warning is emitted. If it is ``1`` (the default), the warning 36 | is emitted at the direct caller of the deprecated object; if it 37 | is higher, it is emitted further up the stack. 38 | Static type checker behavior is not affected by the *category* 39 | and *stacklevel* arguments. 40 | 41 | The deprecation message passed to the decorator is saved in the 42 | ``__deprecated__`` attribute on the decorated object. 43 | If applied to an overload, the decorator 44 | must be after the ``@overload`` decorator for the attribute to 45 | exist on the overload as returned by ``get_overloads()``. 46 | 47 | See PEP 702 for details. 48 | 49 | """ 50 | def __init__( 51 | self, 52 | message: str, 53 | /, 54 | *, 55 | category=DeprecationWarning, 56 | stacklevel=1, 57 | ): 58 | if not isinstance(message, str): 59 | raise TypeError( 60 | f"Expected an object of type str for 'message', not {type(message).__name__!r}" 61 | ) 62 | self.message = message 63 | self.category = category 64 | self.stacklevel = stacklevel 65 | 66 | def __call__(self, arg, /): 67 | # Make sure the inner functions created below don't 68 | # retain a reference to self. 69 | msg = self.message 70 | category = self.category 71 | stacklevel = self.stacklevel 72 | if category is None: 73 | arg.__deprecated__ = msg 74 | return arg 75 | elif isinstance(arg, type): 76 | import functools 77 | from types import MethodType 78 | 79 | original_new = arg.__new__ 80 | 81 | @functools.wraps(original_new) 82 | def __new__(cls, /, *args, **kwargs): 83 | if cls is arg: 84 | warn(msg, category=category, stacklevel=stacklevel + 1) 85 | if original_new is not object.__new__: 86 | return original_new(cls, *args, **kwargs) 87 | # Mirrors a similar check in object.__new__. 88 | elif cls.__init__ is object.__init__ and (args or kwargs): 89 | raise TypeError(f"{cls.__name__}() takes no arguments") 90 | else: 91 | return original_new(cls) 92 | 93 | arg.__new__ = staticmethod(__new__) 94 | 95 | original_init_subclass = arg.__init_subclass__ 96 | # We need slightly different behavior if __init_subclass__ 97 | # is a bound method (likely if it was implemented in Python) 98 | if isinstance(original_init_subclass, MethodType): 99 | original_init_subclass = original_init_subclass.__func__ 100 | 101 | @functools.wraps(original_init_subclass) 102 | def __init_subclass__(*args, **kwargs): 103 | warn(msg, category=category, stacklevel=stacklevel + 1) 104 | return original_init_subclass(*args, **kwargs) 105 | 106 | arg.__init_subclass__ = classmethod(__init_subclass__) 107 | # Or otherwise, which likely means it's a builtin such as 108 | # object's implementation of __init_subclass__. 109 | else: 110 | @functools.wraps(original_init_subclass) 111 | def __init_subclass__(*args, **kwargs): 112 | warn(msg, category=category, stacklevel=stacklevel + 1) 113 | return original_init_subclass(*args, **kwargs) 114 | 115 | arg.__init_subclass__ = __init_subclass__ 116 | 117 | arg.__deprecated__ = __new__.__deprecated__ = msg 118 | __init_subclass__.__deprecated__ = msg 119 | return arg 120 | elif callable(arg): 121 | import functools 122 | import inspect 123 | 124 | @functools.wraps(arg) 125 | def wrapper(*args, **kwargs): 126 | warn(msg, category=category, stacklevel=stacklevel + 1) 127 | return arg(*args, **kwargs) 128 | 129 | if inspect.iscoroutinefunction(arg): 130 | wrapper = inspect.markcoroutinefunction(wrapper) 131 | 132 | arg.__deprecated__ = wrapper.__deprecated__ = msg 133 | return wrapper 134 | else: 135 | raise TypeError( 136 | "@deprecated decorator with non-None category must be applied to " 137 | f"a class or callable, not {arg!r}" 138 | ) 139 | -------------------------------------------------------------------------------- /django_cte/query.py: -------------------------------------------------------------------------------- 1 | import django 2 | from django.core.exceptions import EmptyResultSet 3 | from django.db.models.sql.constants import LOUTER 4 | 5 | from .jitmixin import JITMixin, jit_mixin 6 | from .join import QJoin 7 | 8 | # NOTE: it is currently not possible to execute delete queries that 9 | # reference CTEs without patching `QuerySet.delete` (Django method) 10 | # to call `self.query.chain(sql.DeleteQuery)` instead of 11 | # `sql.DeleteQuery(self.model)` 12 | 13 | 14 | class CTEQuery(JITMixin): 15 | """A Query mixin that processes SQL compilation through a CTE compiler""" 16 | _jit_mixin_prefix = "CTE" 17 | _with_ctes = () 18 | 19 | @property 20 | def combined_queries(self): 21 | return self.__dict__.get("combined_queries", ()) 22 | 23 | @combined_queries.setter 24 | def combined_queries(self, queries): 25 | ctes = [] 26 | seen = {cte.name: cte for cte in self._with_ctes} 27 | for query in queries: 28 | for cte in getattr(query, "_with_ctes", ()): 29 | if seen.get(cte.name) is cte: 30 | continue 31 | if cte.name in seen: 32 | raise ValueError( 33 | f"Found two or more CTEs named '{cte.name}'. " 34 | "Hint: assign a unique name to each CTE." 35 | ) 36 | ctes.append(cte) 37 | seen[cte.name] = cte 38 | 39 | if seen: 40 | def without_ctes(query): 41 | if getattr(query, "_with_ctes", None): 42 | query = query.clone() 43 | del query._with_ctes 44 | return query 45 | 46 | self._with_ctes += tuple(ctes) 47 | queries = tuple(without_ctes(q) for q in queries) 48 | self.__dict__["combined_queries"] = queries 49 | 50 | def resolve_expression(self, *args, **kwargs): 51 | clone = super().resolve_expression(*args, **kwargs) 52 | clone._with_ctes = tuple( 53 | cte.resolve_expression(*args, **kwargs) 54 | for cte in clone._with_ctes 55 | ) 56 | return clone 57 | 58 | def get_compiler(self, *args, **kwargs): 59 | return jit_mixin(super().get_compiler(*args, **kwargs), CTECompiler) 60 | 61 | def chain(self, klass=None): 62 | clone = jit_mixin(super().chain(klass), CTEQuery) 63 | clone._with_ctes = self._with_ctes 64 | return clone 65 | 66 | 67 | def generate_cte_sql(connection, query, as_sql): 68 | if not query._with_ctes: 69 | return as_sql() 70 | 71 | ctes = [] 72 | params = [] 73 | for cte in query._with_ctes: 74 | if django.VERSION > (4, 2): 75 | _ignore_with_col_aliases(cte.query) 76 | 77 | alias = query.alias_map.get(cte.name) 78 | should_elide_empty = ( 79 | not isinstance(alias, QJoin) or alias.join_type != LOUTER 80 | ) 81 | 82 | compiler = cte.query.get_compiler( 83 | connection=connection, elide_empty=should_elide_empty 84 | ) 85 | 86 | qn = compiler.quote_name_unless_alias 87 | try: 88 | cte_sql, cte_params = compiler.as_sql() 89 | except EmptyResultSet: 90 | # If the CTE raises an EmptyResultSet the SqlCompiler still 91 | # needs to know the information about this base compiler 92 | # like, col_count and klass_info. 93 | as_sql() 94 | raise 95 | template = get_cte_query_template(cte) 96 | ctes.append(template.format(name=qn(cte.name), query=cte_sql)) 97 | params.extend(cte_params) 98 | 99 | explain_attribute = "explain_info" 100 | explain_info = getattr(query, explain_attribute, None) 101 | explain_format = getattr(explain_info, "format", None) 102 | explain_options = getattr(explain_info, "options", {}) 103 | 104 | explain_query_or_info = getattr(query, explain_attribute, None) 105 | sql = [] 106 | if explain_query_or_info: 107 | sql.append( 108 | connection.ops.explain_query_prefix( 109 | explain_format, 110 | **explain_options 111 | ) 112 | ) 113 | # this needs to get set to None so that the base as_sql() doesn't 114 | # insert the EXPLAIN statement where it would end up between the 115 | # WITH ... clause and the final SELECT 116 | setattr(query, explain_attribute, None) 117 | 118 | if ctes: 119 | # Always use WITH RECURSIVE 120 | # https://www.postgresql.org/message-id/13122.1339829536%40sss.pgh.pa.us 121 | sql.extend(["WITH RECURSIVE", ", ".join(ctes)]) 122 | base_sql, base_params = as_sql() 123 | 124 | if explain_query_or_info: 125 | setattr(query, explain_attribute, explain_query_or_info) 126 | 127 | sql.append(base_sql) 128 | params.extend(base_params) 129 | return " ".join(sql), tuple(params) 130 | 131 | 132 | def get_cte_query_template(cte): 133 | if cte.materialized: 134 | return "{name} AS MATERIALIZED ({query})" 135 | return "{name} AS ({query})" 136 | 137 | 138 | def _ignore_with_col_aliases(cte_query): 139 | if getattr(cte_query, "combined_queries", None): 140 | cte_query.combined_queries = tuple( 141 | jit_mixin(q, NoAliasQuery) for q in cte_query.combined_queries 142 | ) 143 | 144 | 145 | class CTECompiler(JITMixin): 146 | """Mixin for django.db.models.sql.compiler.SQLCompiler""" 147 | _jit_mixin_prefix = "CTE" 148 | 149 | def as_sql(self, *args, **kwargs): 150 | def _as_sql(): 151 | return super(CTECompiler, self).as_sql(*args, **kwargs) 152 | return generate_cte_sql(self.connection, self.query, _as_sql) 153 | 154 | 155 | class NoAliasQuery(JITMixin): 156 | """Mixin for django.db.models.sql.compiler.Query""" 157 | _jit_mixin_prefix = "NoAlias" 158 | 159 | def get_compiler(self, *args, **kwargs): 160 | return jit_mixin(super().get_compiler(*args, **kwargs), NoAliasCompiler) 161 | 162 | 163 | class NoAliasCompiler(JITMixin): 164 | """Mixin for django.db.models.sql.compiler.SQLCompiler""" 165 | _jit_mixin_prefix = "NoAlias" 166 | 167 | def get_select(self, *, with_col_aliases=False, **kw): 168 | return super().get_select(**kw) 169 | -------------------------------------------------------------------------------- /tests/test_combinators.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from django.db.models import Value 3 | from django.db.models.aggregates import Sum 4 | from django.test import TestCase 5 | 6 | from django_cte import CTE, with_cte 7 | 8 | from .models import Order 9 | 10 | 11 | class TestCTECombinators(TestCase): 12 | 13 | def test_cte_union_query(self): 14 | one = CTE( 15 | Order.objects 16 | .values("region_id") 17 | .annotate(total=Sum("amount")), 18 | name="one" 19 | ) 20 | two = CTE( 21 | Order.objects 22 | .values("region_id") 23 | .annotate(total=Sum("amount") * 2), 24 | name="two" 25 | ) 26 | 27 | earths = with_cte( 28 | one, 29 | select=one.join( 30 | Order.objects.filter(region_id="earth"), 31 | region=one.col.region_id 32 | ) 33 | .annotate(region_total=one.col.total) 34 | .values_list("amount", "region_id", "region_total") 35 | ) 36 | mars = with_cte( 37 | two, 38 | select=two.join( 39 | Order.objects.filter(region_id="mars"), 40 | region=two.col.region_id 41 | ) 42 | .annotate(region_total=two.col.total) 43 | .values_list("amount", "region_id", "region_total") 44 | ) 45 | combined = earths.union(mars, all=True) 46 | print(combined.query) 47 | 48 | self.assertEqual(sorted(combined), [ 49 | (30, 'earth', 126), 50 | (31, 'earth', 126), 51 | (32, 'earth', 126), 52 | (33, 'earth', 126), 53 | (40, 'mars', 246), 54 | (41, 'mars', 246), 55 | (42, 'mars', 246), 56 | ]) 57 | 58 | # queries used in union should still work on their own 59 | print(earths.query) 60 | self.assertEqual(sorted(earths),[ 61 | (30, 'earth', 126), 62 | (31, 'earth', 126), 63 | (32, 'earth', 126), 64 | (33, 'earth', 126), 65 | ]) 66 | print(mars.query) 67 | self.assertEqual(sorted(mars),[ 68 | (40, 'mars', 246), 69 | (41, 'mars', 246), 70 | (42, 'mars', 246), 71 | ]) 72 | 73 | def test_cte_union_with_non_cte_query(self): 74 | one = CTE( 75 | Order.objects 76 | .values("region_id") 77 | .annotate(total=Sum("amount")), 78 | ) 79 | 80 | earths = with_cte( 81 | one, 82 | select=one.join( 83 | Order.objects.filter(region_id="earth"), 84 | region=one.col.region_id 85 | ).annotate(region_total=one.col.total) 86 | ) 87 | plain_mars = ( 88 | Order.objects.filter(region_id="mars") 89 | .annotate(region_total=Value(0)) 90 | ) 91 | # Note: this does not work in the opposite order. A CTE query 92 | # must come first to invoke custom CTE combinator logic. 93 | combined = earths.union(plain_mars, all=True) \ 94 | .values_list("amount", "region_id", "region_total") 95 | print(combined.query) 96 | 97 | self.assertEqual(sorted(combined), [ 98 | (30, 'earth', 126), 99 | (31, 'earth', 126), 100 | (32, 'earth', 126), 101 | (33, 'earth', 126), 102 | (40, 'mars', 0), 103 | (41, 'mars', 0), 104 | (42, 'mars', 0), 105 | ]) 106 | 107 | def test_cte_union_with_duplicate_names(self): 108 | cte_sun = CTE( 109 | Order.objects 110 | .filter(region__parent="sun") 111 | .values("region_id") 112 | .annotate(total=Sum("amount")), 113 | ) 114 | cte_proxima = CTE( 115 | Order.objects 116 | .filter(region__parent="proxima centauri") 117 | .values("region_id") 118 | .annotate(total=2 * Sum("amount")), 119 | ) 120 | 121 | orders_sun = with_cte( 122 | cte_sun, 123 | select=cte_sun.join(Order, region=cte_sun.col.region_id) 124 | .annotate(region_total=cte_sun.col.total) 125 | ) 126 | orders_proxima = with_cte( 127 | cte_proxima, 128 | select=cte_proxima.join(Order, region=cte_proxima.col.region_id) 129 | .annotate(region_total=cte_proxima.col.total) 130 | ) 131 | 132 | msg = "Found two or more CTEs named 'cte'" 133 | with pytest.raises(ValueError, match=msg): 134 | orders_sun.union(orders_proxima) 135 | 136 | def test_cte_union_of_same_cte(self): 137 | cte = CTE( 138 | Order.objects 139 | .filter(region__parent="sun") 140 | .values("region_id") 141 | .annotate(total=Sum("amount")), 142 | ) 143 | 144 | orders_big = with_cte( 145 | cte, 146 | select=cte.join(Order, region=cte.col.region_id) 147 | .annotate(region_total=3 * cte.col.total) 148 | ) 149 | orders_small = with_cte( 150 | cte, 151 | select=cte.join(Order, region=cte.col.region_id) 152 | .annotate(region_total=cte.col.total) 153 | ) 154 | 155 | orders = orders_big.union(orders_small) \ 156 | .values_list("amount", "region_id", "region_total") 157 | print(orders.query) 158 | 159 | self.assertEqual(sorted(orders), [ 160 | (10, 'mercury', 33), 161 | (10, 'mercury', 99), 162 | (11, 'mercury', 33), 163 | (11, 'mercury', 99), 164 | (12, 'mercury', 33), 165 | (12, 'mercury', 99), 166 | (20, 'venus', 86), 167 | (20, 'venus', 258), 168 | (21, 'venus', 86), 169 | (21, 'venus', 258), 170 | (22, 'venus', 86), 171 | (22, 'venus', 258), 172 | (23, 'venus', 86), 173 | (23, 'venus', 258), 174 | (30, 'earth', 126), 175 | (30, 'earth', 378), 176 | (31, 'earth', 126), 177 | (31, 'earth', 378), 178 | (32, 'earth', 126), 179 | (32, 'earth', 378), 180 | (33, 'earth', 126), 181 | (33, 'earth', 378), 182 | (40, 'mars', 123), 183 | (40, 'mars', 369), 184 | (41, 'mars', 123), 185 | (41, 'mars', 369), 186 | (42, 'mars', 123), 187 | (42, 'mars', 369) 188 | ]) 189 | 190 | def test_cte_intersection(self): 191 | cte_big = CTE( 192 | Order.objects 193 | .values("region_id") 194 | .annotate(total=Sum("amount")), 195 | name='big' 196 | ) 197 | cte_small = CTE( 198 | Order.objects 199 | .values("region_id") 200 | .annotate(total=Sum("amount")), 201 | name='small' 202 | ) 203 | orders_big = with_cte( 204 | cte_big, 205 | select=cte_big.join(Order, region=cte_big.col.region_id) 206 | .annotate(region_total=cte_big.col.total) 207 | .filter(region_total__gte=86) 208 | ) 209 | orders_small = with_cte( 210 | cte_small, 211 | select=cte_small.join(Order, region=cte_small.col.region_id) 212 | .annotate(region_total=cte_small.col.total) 213 | .filter(region_total__lte=123) 214 | ) 215 | 216 | orders = orders_small.intersection(orders_big) \ 217 | .values_list("amount", "region_id", "region_total") 218 | print(orders.query) 219 | 220 | self.assertEqual(sorted(orders), [ 221 | (20, 'venus', 86), 222 | (21, 'venus', 86), 223 | (22, 'venus', 86), 224 | (23, 'venus', 86), 225 | (40, 'mars', 123), 226 | (41, 'mars', 123), 227 | (42, 'mars', 123), 228 | ]) 229 | 230 | def test_cte_difference(self): 231 | cte_big = CTE( 232 | Order.objects 233 | .values("region_id") 234 | .annotate(total=Sum("amount")), 235 | name='big' 236 | ) 237 | cte_small = CTE( 238 | Order.objects 239 | .values("region_id") 240 | .annotate(total=Sum("amount")), 241 | name='small' 242 | ) 243 | orders_big = with_cte( 244 | cte_big, 245 | select=cte_big.join(Order, region=cte_big.col.region_id) 246 | .annotate(region_total=cte_big.col.total) 247 | .filter(region_total__gte=86) 248 | ) 249 | orders_small = with_cte( 250 | cte_small, 251 | select=cte_small.join(Order, region=cte_small.col.region_id) 252 | .annotate(region_total=cte_small.col.total) 253 | .filter(region_total__lte=123) 254 | ) 255 | 256 | orders = orders_small.difference(orders_big) \ 257 | .values_list("amount", "region_id", "region_total") 258 | print(orders.query) 259 | 260 | self.assertEqual(sorted(orders), [ 261 | (1, 'moon', 6), 262 | (2, 'moon', 6), 263 | (3, 'moon', 6), 264 | (10, 'mercury', 33), 265 | (10, 'proxima centauri b', 33), 266 | (11, 'mercury', 33), 267 | (11, 'proxima centauri b', 33), 268 | (12, 'mercury', 33), 269 | (12, 'proxima centauri b', 33), 270 | ]) 271 | -------------------------------------------------------------------------------- /tests/test_v1/test_combinators.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from django.db.models import Value 3 | from django.db.models.aggregates import Sum 4 | from django.test import TestCase 5 | 6 | from django_cte import With 7 | 8 | from .models import Order, OrderPlainManager 9 | 10 | 11 | class TestCTECombinators(TestCase): 12 | 13 | def test_cte_union_query(self): 14 | one = With( 15 | Order.objects 16 | .values("region_id") 17 | .annotate(total=Sum("amount")), 18 | name="one" 19 | ) 20 | two = With( 21 | Order.objects 22 | .values("region_id") 23 | .annotate(total=Sum("amount") * 2), 24 | name="two" 25 | ) 26 | 27 | earths = ( 28 | one.join( 29 | Order.objects.filter(region_id="earth"), 30 | region=one.col.region_id 31 | ) 32 | .with_cte(one) 33 | .annotate(region_total=one.col.total) 34 | .values_list("amount", "region_id", "region_total") 35 | ) 36 | mars = ( 37 | two.join( 38 | Order.objects.filter(region_id="mars"), 39 | region=two.col.region_id 40 | ) 41 | .with_cte(two) 42 | .annotate(region_total=two.col.total) 43 | .values_list("amount", "region_id", "region_total") 44 | ) 45 | combined = earths.union(mars, all=True) 46 | print(combined.query) 47 | 48 | self.assertEqual(sorted(combined), [ 49 | (30, 'earth', 126), 50 | (31, 'earth', 126), 51 | (32, 'earth', 126), 52 | (33, 'earth', 126), 53 | (40, 'mars', 246), 54 | (41, 'mars', 246), 55 | (42, 'mars', 246), 56 | ]) 57 | 58 | # queries used in union should still work on their own 59 | print(earths.query) 60 | self.assertEqual(sorted(earths),[ 61 | (30, 'earth', 126), 62 | (31, 'earth', 126), 63 | (32, 'earth', 126), 64 | (33, 'earth', 126), 65 | ]) 66 | print(mars.query) 67 | self.assertEqual(sorted(mars),[ 68 | (40, 'mars', 246), 69 | (41, 'mars', 246), 70 | (42, 'mars', 246), 71 | ]) 72 | 73 | def test_cte_union_with_non_cte_query(self): 74 | one = With( 75 | Order.objects 76 | .values("region_id") 77 | .annotate(total=Sum("amount")), 78 | ) 79 | 80 | earths = ( 81 | one.join( 82 | Order.objects.filter(region_id="earth"), 83 | region=one.col.region_id 84 | ) 85 | .with_cte(one) 86 | .annotate(region_total=one.col.total) 87 | ) 88 | plain_mars = ( 89 | OrderPlainManager.objects.filter(region_id="mars") 90 | .annotate(region_total=Value(0)) 91 | ) 92 | # Note: this does not work in the opposite order. A CTE query 93 | # must come first to invoke custom CTE combinator logic. 94 | combined = earths.union(plain_mars, all=True) \ 95 | .values_list("amount", "region_id", "region_total") 96 | print(combined.query) 97 | 98 | self.assertEqual(sorted(combined), [ 99 | (30, 'earth', 126), 100 | (31, 'earth', 126), 101 | (32, 'earth', 126), 102 | (33, 'earth', 126), 103 | (40, 'mars', 0), 104 | (41, 'mars', 0), 105 | (42, 'mars', 0), 106 | ]) 107 | 108 | def test_cte_union_with_duplicate_names(self): 109 | cte_sun = With( 110 | Order.objects 111 | .filter(region__parent="sun") 112 | .values("region_id") 113 | .annotate(total=Sum("amount")), 114 | ) 115 | cte_proxima = With( 116 | Order.objects 117 | .filter(region__parent="proxima centauri") 118 | .values("region_id") 119 | .annotate(total=2 * Sum("amount")), 120 | ) 121 | 122 | orders_sun = ( 123 | cte_sun.join(Order, region=cte_sun.col.region_id) 124 | .with_cte(cte_sun) 125 | .annotate(region_total=cte_sun.col.total) 126 | ) 127 | orders_proxima = ( 128 | cte_proxima.join(Order, region=cte_proxima.col.region_id) 129 | .with_cte(cte_proxima) 130 | .annotate(region_total=cte_proxima.col.total) 131 | ) 132 | 133 | msg = "Found two or more CTEs named 'cte'" 134 | with pytest.raises(ValueError, match=msg): 135 | orders_sun.union(orders_proxima) 136 | 137 | def test_cte_union_of_same_cte(self): 138 | cte = With( 139 | Order.objects 140 | .filter(region__parent="sun") 141 | .values("region_id") 142 | .annotate(total=Sum("amount")), 143 | ) 144 | 145 | orders_big = ( 146 | cte.join(Order, region=cte.col.region_id) 147 | .with_cte(cte) 148 | .annotate(region_total=3 * cte.col.total) 149 | ) 150 | orders_small = ( 151 | cte.join(Order, region=cte.col.region_id) 152 | .with_cte(cte) 153 | .annotate(region_total=cte.col.total) 154 | ) 155 | 156 | orders = orders_big.union(orders_small) \ 157 | .values_list("amount", "region_id", "region_total") 158 | print(orders.query) 159 | 160 | self.assertEqual(sorted(orders), [ 161 | (10, 'mercury', 33), 162 | (10, 'mercury', 99), 163 | (11, 'mercury', 33), 164 | (11, 'mercury', 99), 165 | (12, 'mercury', 33), 166 | (12, 'mercury', 99), 167 | (20, 'venus', 86), 168 | (20, 'venus', 258), 169 | (21, 'venus', 86), 170 | (21, 'venus', 258), 171 | (22, 'venus', 86), 172 | (22, 'venus', 258), 173 | (23, 'venus', 86), 174 | (23, 'venus', 258), 175 | (30, 'earth', 126), 176 | (30, 'earth', 378), 177 | (31, 'earth', 126), 178 | (31, 'earth', 378), 179 | (32, 'earth', 126), 180 | (32, 'earth', 378), 181 | (33, 'earth', 126), 182 | (33, 'earth', 378), 183 | (40, 'mars', 123), 184 | (40, 'mars', 369), 185 | (41, 'mars', 123), 186 | (41, 'mars', 369), 187 | (42, 'mars', 123), 188 | (42, 'mars', 369) 189 | ]) 190 | 191 | def test_cte_intersection(self): 192 | cte_big = With( 193 | Order.objects 194 | .values("region_id") 195 | .annotate(total=Sum("amount")), 196 | name='big' 197 | ) 198 | cte_small = With( 199 | Order.objects 200 | .values("region_id") 201 | .annotate(total=Sum("amount")), 202 | name='small' 203 | ) 204 | orders_big = ( 205 | cte_big.join(Order, region=cte_big.col.region_id) 206 | .with_cte(cte_big) 207 | .annotate(region_total=cte_big.col.total) 208 | .filter(region_total__gte=86) 209 | ) 210 | orders_small = ( 211 | cte_small.join(Order, region=cte_small.col.region_id) 212 | .with_cte(cte_small) 213 | .annotate(region_total=cte_small.col.total) 214 | .filter(region_total__lte=123) 215 | ) 216 | 217 | orders = orders_small.intersection(orders_big) \ 218 | .values_list("amount", "region_id", "region_total") 219 | print(orders.query) 220 | 221 | self.assertEqual(sorted(orders), [ 222 | (20, 'venus', 86), 223 | (21, 'venus', 86), 224 | (22, 'venus', 86), 225 | (23, 'venus', 86), 226 | (40, 'mars', 123), 227 | (41, 'mars', 123), 228 | (42, 'mars', 123), 229 | ]) 230 | 231 | def test_cte_difference(self): 232 | cte_big = With( 233 | Order.objects 234 | .values("region_id") 235 | .annotate(total=Sum("amount")), 236 | name='big' 237 | ) 238 | cte_small = With( 239 | Order.objects 240 | .values("region_id") 241 | .annotate(total=Sum("amount")), 242 | name='small' 243 | ) 244 | orders_big = ( 245 | cte_big.join(Order, region=cte_big.col.region_id) 246 | .with_cte(cte_big) 247 | .annotate(region_total=cte_big.col.total) 248 | .filter(region_total__gte=86) 249 | ) 250 | orders_small = ( 251 | cte_small.join(Order, region=cte_small.col.region_id) 252 | .with_cte(cte_small) 253 | .annotate(region_total=cte_small.col.total) 254 | .filter(region_total__lte=123) 255 | ) 256 | 257 | orders = orders_small.difference(orders_big) \ 258 | .values_list("amount", "region_id", "region_total") 259 | print(orders.query) 260 | 261 | self.assertEqual(sorted(orders), [ 262 | (1, 'moon', 6), 263 | (2, 'moon', 6), 264 | (3, 'moon', 6), 265 | (10, 'mercury', 33), 266 | (10, 'proxima centauri b', 33), 267 | (11, 'mercury', 33), 268 | (11, 'proxima centauri b', 33), 269 | (12, 'mercury', 33), 270 | (12, 'proxima centauri b', 33), 271 | ]) 272 | -------------------------------------------------------------------------------- /django_cte/cte.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | import django 4 | from django.db.models import Manager, sql 5 | from django.db.models.expressions import Ref 6 | from django.db.models.query import Q, QuerySet, ValuesIterable 7 | from django.db.models.sql.datastructures import BaseTable 8 | 9 | from .jitmixin import jit_mixin 10 | from .join import QJoin, INNER 11 | from .meta import CTEColumnRef, CTEColumns 12 | from .query import CTEQuery 13 | from ._deprecated import deprecated 14 | 15 | __all__ = ["CTE", "with_cte"] 16 | 17 | 18 | def with_cte(*ctes, select): 19 | """Add Common Table Expression(s) (CTEs) to a model or queryset 20 | 21 | :param *ctes: One or more CTE objects. 22 | :param select: A model class, queryset, or CTE to use as the base 23 | query to which CTEs are attached. 24 | :returns: A queryset with the given CTE added to it. 25 | """ 26 | if isinstance(select, CTE): 27 | select = select.queryset() 28 | elif not isinstance(select, QuerySet): 29 | select = select._default_manager.all() 30 | jit_mixin(select.query, CTEQuery) 31 | select.query._with_ctes += ctes 32 | return select 33 | 34 | 35 | class CTE: 36 | """Common Table Expression 37 | 38 | :param queryset: A queryset to use as the body of the CTE. 39 | :param name: Optional name parameter for the CTE (default: "cte"). 40 | This must be a unique name that does not conflict with other 41 | entities (tables, views, functions, other CTE(s), etc.) referenced 42 | in the given query as well any query to which this CTE will 43 | eventually be added. 44 | :param materialized: Optional parameter (default: False) which enforce 45 | using of MATERIALIZED statement for supporting databases. 46 | """ 47 | 48 | def __init__(self, queryset, name="cte", materialized=False): 49 | self._set_queryset(queryset) 50 | self.name = name 51 | self.col = CTEColumns(self) 52 | self.materialized = materialized 53 | 54 | def __getstate__(self): 55 | return (self.query, self.name, self.materialized, self._iterable_class) 56 | 57 | def __setstate__(self, state): 58 | if len(state) == 3: 59 | # Keep compatibility with the previous serialization method 60 | self.query, self.name, self.materialized = state 61 | self._iterable_class = ValuesIterable 62 | else: 63 | self.query, self.name, self.materialized, self._iterable_class = state 64 | self.col = CTEColumns(self) 65 | 66 | def __repr__(self): 67 | return f"<{type(self).__name__} {self.name}>" 68 | 69 | def _set_queryset(self, queryset): 70 | self.query = None if queryset is None else queryset.query 71 | self._iterable_class = getattr(queryset, "_iterable_class", ValuesIterable) 72 | 73 | @classmethod 74 | def recursive(cls, make_cte_queryset, name="cte", materialized=False): 75 | """Recursive Common Table Expression 76 | 77 | :param make_cte_queryset: Function taking a single argument (a 78 | not-yet-fully-constructed cte object) and returning a `QuerySet` 79 | object. The returned `QuerySet` normally consists of an initial 80 | statement unioned with a recursive statement. 81 | :param name: See `name` parameter of `__init__`. 82 | :param materialized: See `materialized` parameter of `__init__`. 83 | :returns: The fully constructed recursive cte object. 84 | """ 85 | cte = cls(None, name, materialized) 86 | cte._set_queryset(make_cte_queryset(cte)) 87 | return cte 88 | 89 | def join(self, model_or_queryset, *filter_q, **filter_kw): 90 | """Join this CTE to the given model or queryset 91 | 92 | This CTE will be referenced by the returned queryset, but the 93 | corresponding `WITH ...` statement will not be prepended to the 94 | queryset's SQL output; use `with_cte(cte, select=cte.join(...))` 95 | to achieve that outcome. 96 | 97 | :param model_or_queryset: Model class or queryset to which the 98 | CTE should be joined. 99 | :param *filter_q: Join condition Q expressions (optional). 100 | :param **filter_kw: Join conditions. All LHS fields (kwarg keys) 101 | are assumed to reference `model_or_queryset` fields. Use 102 | `cte.col.name` on the RHS to recursively reference CTE query 103 | columns. For example: `cte.join(Book, id=cte.col.id)` 104 | :returns: A queryset with the given model or queryset joined to 105 | this CTE. 106 | """ 107 | if isinstance(model_or_queryset, QuerySet): 108 | queryset = model_or_queryset.all() 109 | else: 110 | queryset = model_or_queryset._default_manager.all() 111 | join_type = filter_kw.pop("_join_type", INNER) 112 | query = queryset.query 113 | 114 | # based on Query.add_q: add necessary joins to query, but no filter 115 | q_object = Q(*filter_q, **filter_kw) 116 | map = query.alias_map 117 | existing_inner = set(a for a in map if map[a].join_type == INNER) 118 | if django.VERSION >= (5, 2): 119 | on_clause, _ = query._add_q( 120 | q_object, query.used_aliases, update_join_types=(join_type == INNER) 121 | ) 122 | else: 123 | on_clause, _ = query._add_q(q_object, query.used_aliases) 124 | query.demote_joins(existing_inner) 125 | 126 | parent = query.get_initial_alias() 127 | query.join(QJoin(parent, self.name, self.name, on_clause, join_type)) 128 | return queryset 129 | 130 | def queryset(self): 131 | """Get a queryset selecting from this CTE 132 | 133 | This CTE will be referenced by the returned queryset, but the 134 | corresponding `WITH ...` statement will not be prepended to the 135 | queryset's SQL output; use `with_cte(cte, select=cte)` to do 136 | that. 137 | 138 | :returns: A queryset. 139 | """ 140 | cte_query = self.query 141 | qs = cte_query.model._default_manager.get_queryset() 142 | qs._iterable_class = self._iterable_class 143 | qs._fields = () # Allow any field names to be used in further annotations 144 | 145 | query = jit_mixin(sql.Query(cte_query.model), CTEQuery) 146 | query.join(BaseTable(self.name, None)) 147 | query.default_cols = cte_query.default_cols 148 | query.deferred_loading = cte_query.deferred_loading 149 | 150 | if django.VERSION < (5, 2) and cte_query.values_select: 151 | query.set_values(cte_query.values_select) 152 | 153 | if cte_query.annotations: 154 | for alias, value in cte_query.annotations.items(): 155 | col = CTEColumnRef(alias, self.name, value.output_field) 156 | query.add_annotation(col, alias) 157 | query.annotation_select_mask = cte_query.annotation_select_mask 158 | 159 | if selected := getattr(cte_query, "selected", None): 160 | for alias in selected: 161 | if alias not in cte_query.annotations: 162 | output_field = cte_query.resolve_ref(alias).output_field 163 | col = CTEColumnRef(alias, self.name, output_field) 164 | query.add_annotation(col, alias) 165 | query.selected = {alias: alias for alias in selected} 166 | 167 | qs.query = query 168 | return qs 169 | 170 | def _resolve_ref(self, name): 171 | selected = getattr(self.query, "selected", None) 172 | if selected and name in selected and name not in self.query.annotations: 173 | return Ref(name, self.query.resolve_ref(name)) 174 | return self.query.resolve_ref(name) 175 | 176 | def resolve_expression(self, *args, **kw): 177 | if self.query is None: 178 | raise ValueError("Cannot resolve recursive CTE without a query.") 179 | clone = copy(self) 180 | clone.query = clone.query.resolve_expression(*args, **kw) 181 | return clone 182 | 183 | 184 | @deprecated("Use `django_cte.CTE` instead.") 185 | class With(CTE): 186 | 187 | @staticmethod 188 | @deprecated("Use `django_cte.CTE.recursive` instead.") 189 | def recursive(*args, **kw): 190 | return CTE.recursive(*args, **kw) 191 | 192 | 193 | @deprecated("CTEQuerySet is deprecated. " 194 | "CTEs can now be applied to any queryset using `with_cte()`") 195 | class CTEQuerySet(QuerySet): 196 | """QuerySet with support for Common Table Expressions""" 197 | 198 | def __init__(self, model=None, query=None, using=None, hints=None): 199 | # Only create an instance of a Query if this is the first invocation in 200 | # a query chain. 201 | super(CTEQuerySet, self).__init__(model, query, using, hints) 202 | jit_mixin(self.query, CTEQuery) 203 | 204 | @deprecated("Use `django_cte.with_cte(cte, select=...)` instead.") 205 | def with_cte(self, cte): 206 | qs = self._clone() 207 | qs.query._with_ctes += cte, 208 | return qs 209 | 210 | def as_manager(cls): 211 | # Address the circular dependency between 212 | # `CTEQuerySet` and `CTEManager`. 213 | manager = CTEManager.from_queryset(cls)() 214 | manager._built_with_as_manager = True 215 | return manager 216 | as_manager.queryset_only = True 217 | as_manager = classmethod(as_manager) 218 | 219 | 220 | @deprecated("CTEMAnager is deprecated. " 221 | "CTEs can now be applied to any queryset using `with_cte()`") 222 | class CTEManager(Manager.from_queryset(CTEQuerySet)): 223 | """Manager for models that perform CTE queries""" 224 | 225 | @classmethod 226 | def from_queryset(cls, queryset_class, class_name=None): 227 | if not issubclass(queryset_class, CTEQuerySet): 228 | raise TypeError( 229 | "models with CTE support need to use a CTEQuerySet") 230 | return super(CTEManager, cls).from_queryset( 231 | queryset_class, class_name=class_name) 232 | -------------------------------------------------------------------------------- /tests/test_recursive.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from unittest import SkipTest 3 | 4 | from django.db.models import IntegerField, TextField 5 | from django.db.models.expressions import ( 6 | Case, 7 | Exists, 8 | ExpressionWrapper, 9 | F, 10 | OuterRef, 11 | Q, 12 | Value, 13 | When, 14 | ) 15 | from django.db.models.functions import Concat 16 | from django.db.utils import DatabaseError 17 | from django.test import TestCase 18 | 19 | from django_cte import CTE, with_cte 20 | 21 | from .models import KeyPair, Region 22 | 23 | int_field = IntegerField() 24 | text_field = TextField() 25 | 26 | 27 | class TestRecursiveCTE(TestCase): 28 | 29 | def test_recursive_cte_query(self): 30 | def make_regions_cte(cte): 31 | return Region.objects.filter( 32 | # non-recursive: get root nodes 33 | parent__isnull=True 34 | ).values( 35 | "name", 36 | path=F("name"), 37 | depth=Value(0, output_field=int_field), 38 | ).union( 39 | # recursive union: get descendants 40 | cte.join(Region, parent=cte.col.name).values( 41 | "name", 42 | path=Concat( 43 | cte.col.path, Value(" / "), F("name"), 44 | output_field=text_field, 45 | ), 46 | depth=cte.col.depth + Value(1, output_field=int_field), 47 | ), 48 | all=True, 49 | ) 50 | 51 | cte = CTE.recursive(make_regions_cte) 52 | 53 | regions = with_cte( 54 | cte, 55 | select=cte.join(Region, name=cte.col.name) 56 | .annotate( 57 | path=cte.col.path, 58 | depth=cte.col.depth, 59 | ) 60 | .filter(depth=2) 61 | .order_by("path") 62 | ) 63 | print(regions.query) 64 | 65 | data = [(r.name, r.path, r.depth) for r in regions] 66 | self.assertEqual(data, [ 67 | ('moon', 'sun / earth / moon', 2), 68 | ('deimos', 'sun / mars / deimos', 2), 69 | ('phobos', 'sun / mars / phobos', 2), 70 | ]) 71 | 72 | def test_recursive_cte_reference_in_condition(self): 73 | def make_regions_cte(cte): 74 | return Region.objects.filter( 75 | parent__isnull=True 76 | ).values( 77 | "name", 78 | path=F("name"), 79 | depth=Value(0, output_field=int_field), 80 | is_planet=Value(0, output_field=int_field), 81 | ).union( 82 | cte.join( 83 | Region, parent=cte.col.name 84 | ).annotate( 85 | # annotations for filter and CASE/WHEN conditions 86 | parent_name=ExpressionWrapper( 87 | cte.col.name, 88 | output_field=text_field, 89 | ), 90 | parent_depth=ExpressionWrapper( 91 | cte.col.depth, 92 | output_field=int_field, 93 | ), 94 | ).filter( 95 | ~Q(parent_name="mars"), 96 | ).values( 97 | "name", 98 | path=Concat( 99 | cte.col.path, Value("\x01"), F("name"), 100 | output_field=text_field, 101 | ), 102 | depth=cte.col.depth + Value(1, output_field=int_field), 103 | is_planet=Case( 104 | When(parent_depth=0, then=Value(1)), 105 | default=Value(0), 106 | output_field=int_field, 107 | ), 108 | ), 109 | all=True, 110 | ) 111 | cte = CTE.recursive(make_regions_cte) 112 | regions = with_cte( 113 | cte, select=cte.join(Region, name=cte.col.name) 114 | ).annotate( 115 | path=cte.col.path, 116 | depth=cte.col.depth, 117 | is_planet=cte.col.is_planet, 118 | ).order_by("path") 119 | 120 | data = [(r.path.split("\x01"), r.is_planet) for r in regions] 121 | print(data) 122 | self.assertEqual(data, [ 123 | (["bernard's star"], 0), 124 | (['proxima centauri'], 0), 125 | (['proxima centauri', 'proxima centauri b'], 1), 126 | (['sun'], 0), 127 | (['sun', 'earth'], 1), 128 | (['sun', 'earth', 'moon'], 0), 129 | (['sun', 'mars'], 1), # mars moons excluded: parent_name != 'mars' 130 | (['sun', 'mercury'], 1), 131 | (['sun', 'venus'], 1), 132 | ]) 133 | 134 | def test_recursive_cte_with_empty_union_part(self): 135 | def make_regions_cte(cte): 136 | return Region.objects.none().union( 137 | cte.join(Region, parent=cte.col.name), 138 | all=True, 139 | ) 140 | cte = CTE.recursive(make_regions_cte) 141 | regions = with_cte(cte, select=cte.join(Region, name=cte.col.name)) 142 | 143 | print(regions.query) 144 | try: 145 | self.assertEqual(regions.count(), 0) 146 | except DatabaseError: 147 | raise SkipTest( 148 | "Expected failure: QuerySet omits `EmptyQuerySet` from " 149 | "UNION queries resulting in invalid CTE SQL" 150 | ) 151 | 152 | # -- recursive query "cte" does not have the form 153 | # -- non-recursive-term UNION [ALL] recursive-term 154 | # WITH RECURSIVE cte AS ( 155 | # SELECT "tests_region"."name", "tests_region"."parent_id" 156 | # FROM "tests_region", "cte" 157 | # WHERE "tests_region"."parent_id" = ("cte"."name") 158 | # ) 159 | # SELECT COUNT(*) 160 | # FROM "tests_region", "cte" 161 | # WHERE "tests_region"."name" = ("cte"."name") 162 | 163 | def test_circular_ref_error(self): 164 | def make_bad_cte(cte): 165 | # NOTE: not a valid recursive CTE query 166 | return cte.join(Region, parent=cte.col.name).values( 167 | depth=cte.col.depth + 1, 168 | ) 169 | cte = CTE.recursive(make_bad_cte) 170 | regions = with_cte(cte, select=cte.join(Region, name=cte.col.name)) 171 | with self.assertRaises(ValueError) as context: 172 | print(regions.query) 173 | self.assertIn("Circular reference:", str(context.exception)) 174 | 175 | def test_attname_should_not_mask_col_name(self): 176 | def make_regions_cte(cte): 177 | return Region.objects.filter( 178 | name="moon" 179 | ).values( 180 | "name", 181 | "parent_id", 182 | ).union( 183 | cte.join(Region, name=cte.col.parent_id).values( 184 | "name", 185 | "parent_id", 186 | ), 187 | all=True, 188 | ) 189 | cte = CTE.recursive(make_regions_cte) 190 | regions = with_cte( 191 | cte, 192 | select=Region.objects.annotate(_ex=Exists( 193 | cte.queryset() 194 | .values(value=Value("1", output_field=int_field)) 195 | .filter(name=OuterRef("name")) 196 | )) 197 | .filter(_ex=True) 198 | .order_by("name") 199 | ) 200 | print(regions.query) 201 | 202 | data = [r.name for r in regions] 203 | self.assertEqual(data, ['earth', 'moon', 'sun']) 204 | 205 | def test_pickle_recursive_cte_queryset(self): 206 | def make_regions_cte(cte): 207 | return Region.objects.filter( 208 | parent__isnull=True 209 | ).annotate( 210 | depth=Value(0, output_field=int_field), 211 | ).union( 212 | cte.join(Region, parent=cte.col.name).annotate( 213 | depth=cte.col.depth + Value(1, output_field=int_field), 214 | ), 215 | all=True, 216 | ) 217 | cte = CTE.recursive(make_regions_cte) 218 | regions = with_cte(cte, select=cte).filter(depth=2).order_by("name") 219 | 220 | pickled_qs = pickle.loads(pickle.dumps(regions)) 221 | 222 | data = [(r.name, r.depth) for r in pickled_qs] 223 | self.assertEqual(data, [(r.name, r.depth) for r in regions]) 224 | self.assertEqual(data, [('deimos', 2), ('moon', 2), ('phobos', 2)]) 225 | 226 | def test_alias_change_in_annotation(self): 227 | def make_regions_cte(cte): 228 | return Region.objects.filter( 229 | parent__name="sun", 230 | ).annotate( 231 | value=F('name'), 232 | ).union( 233 | cte.join( 234 | Region.objects.annotate(value=F('name')), 235 | parent_id=cte.col.name, 236 | ), 237 | all=True, 238 | ) 239 | cte = CTE.recursive(make_regions_cte) 240 | query = with_cte(cte, select=cte) 241 | 242 | exclude_leaves = CTE(cte.queryset().filter( 243 | parent__name='sun', 244 | ).annotate( 245 | value=Concat(F('name'), F('name')) 246 | ), name='value_cte') 247 | 248 | query = with_cte(exclude_leaves, select=query.annotate( 249 | _exclude_leaves=Exists( 250 | exclude_leaves.queryset().filter( 251 | name=OuterRef("name"), 252 | value=OuterRef("value"), 253 | ) 254 | ) 255 | ).filter(_exclude_leaves=True)) 256 | print(query.query) 257 | 258 | # Nothing should be returned. 259 | self.assertFalse(query) 260 | 261 | def test_alias_as_subquery(self): 262 | # This test covers CTEColumnRef.relabeled_clone 263 | def make_regions_cte(cte): 264 | return KeyPair.objects.filter( 265 | parent__key="level 1", 266 | ).annotate( 267 | rank=F('value'), 268 | ).union( 269 | cte.join( 270 | KeyPair.objects.order_by(), 271 | parent_id=cte.col.id, 272 | ).annotate( 273 | rank=F('value'), 274 | ), 275 | all=True, 276 | ) 277 | cte = CTE.recursive(make_regions_cte) 278 | children = with_cte(cte, select=cte) 279 | 280 | xdups = CTE(cte.queryset().filter( 281 | parent__key="level 1", 282 | ).annotate( 283 | rank=F('value') 284 | ).values('id', 'rank'), name='xdups') 285 | 286 | children = with_cte(xdups, select=children.annotate( 287 | _exclude=Exists( 288 | ( 289 | xdups.queryset().filter( 290 | id=OuterRef("id"), 291 | rank=OuterRef("rank"), 292 | ) 293 | ) 294 | ) 295 | ).filter(_exclude=True)) 296 | 297 | print(children.query) 298 | query = KeyPair.objects.filter(parent__in=children) 299 | print(query.query) 300 | print(children.query) 301 | self.assertEqual(query.get().key, 'level 3') 302 | # Tests the case in which children's query was modified since it was 303 | # used in a subquery to define `query` above. 304 | self.assertEqual( 305 | list(c.key for c in children), 306 | ['level 2', 'level 2'] 307 | ) 308 | 309 | def test_materialized(self): 310 | # This test covers MATERIALIZED option in SQL query 311 | def make_regions_cte(cte): 312 | return KeyPair.objects.all() 313 | cte = CTE.recursive(make_regions_cte, materialized=True) 314 | 315 | query = with_cte(cte, select=KeyPair) 316 | print(query.query) 317 | self.assertTrue( 318 | str(query.query).startswith('WITH RECURSIVE "cte" AS MATERIALIZED') 319 | ) 320 | 321 | def test_recursive_self_queryset(self): 322 | def make_regions_cte(cte): 323 | return Region.objects.filter( 324 | pk="earth" 325 | ).values("pk").union( 326 | cte.join(Region, parent=cte.col.pk).values("pk") 327 | ) 328 | cte = CTE.recursive(make_regions_cte) 329 | queryset = with_cte(cte, select=cte).order_by("pk") 330 | print(queryset.query) 331 | self.assertEqual(list(queryset), [ 332 | {'pk': 'earth'}, 333 | {'pk': 'moon'}, 334 | ]) 335 | -------------------------------------------------------------------------------- /tests/test_v1/test_recursive.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from unittest import SkipTest 3 | 4 | from django.db.models import IntegerField, TextField 5 | from django.db.models.expressions import ( 6 | Case, 7 | Exists, 8 | ExpressionWrapper, 9 | F, 10 | OuterRef, 11 | Q, 12 | Value, 13 | When, 14 | ) 15 | from django.db.models.functions import Concat 16 | from django.db.utils import DatabaseError 17 | from django.test import TestCase 18 | 19 | from django_cte import With 20 | 21 | from .models import KeyPair, Region 22 | 23 | int_field = IntegerField() 24 | text_field = TextField() 25 | 26 | 27 | class TestRecursiveCTE(TestCase): 28 | 29 | def test_recursive_cte_query(self): 30 | def make_regions_cte(cte): 31 | return Region.objects.filter( 32 | # non-recursive: get root nodes 33 | parent__isnull=True 34 | ).values( 35 | "name", 36 | path=F("name"), 37 | depth=Value(0, output_field=int_field), 38 | ).union( 39 | # recursive union: get descendants 40 | cte.join(Region, parent=cte.col.name).values( 41 | "name", 42 | path=Concat( 43 | cte.col.path, Value(" / "), F("name"), 44 | output_field=text_field, 45 | ), 46 | depth=cte.col.depth + Value(1, output_field=int_field), 47 | ), 48 | all=True, 49 | ) 50 | 51 | cte = With.recursive(make_regions_cte) 52 | 53 | regions = ( 54 | cte.join(Region, name=cte.col.name) 55 | .with_cte(cte) 56 | .annotate( 57 | path=cte.col.path, 58 | depth=cte.col.depth, 59 | ) 60 | .filter(depth=2) 61 | .order_by("path") 62 | ) 63 | print(regions.query) 64 | 65 | data = [(r.name, r.path, r.depth) for r in regions] 66 | self.assertEqual(data, [ 67 | ('moon', 'sun / earth / moon', 2), 68 | ('deimos', 'sun / mars / deimos', 2), 69 | ('phobos', 'sun / mars / phobos', 2), 70 | ]) 71 | 72 | def test_recursive_cte_reference_in_condition(self): 73 | def make_regions_cte(cte): 74 | return Region.objects.filter( 75 | parent__isnull=True 76 | ).values( 77 | "name", 78 | path=F("name"), 79 | depth=Value(0, output_field=int_field), 80 | is_planet=Value(0, output_field=int_field), 81 | ).union( 82 | cte.join( 83 | Region, parent=cte.col.name 84 | ).annotate( 85 | # annotations for filter and CASE/WHEN conditions 86 | parent_name=ExpressionWrapper( 87 | cte.col.name, 88 | output_field=text_field, 89 | ), 90 | parent_depth=ExpressionWrapper( 91 | cte.col.depth, 92 | output_field=int_field, 93 | ), 94 | ).filter( 95 | ~Q(parent_name="mars"), 96 | ).values( 97 | "name", 98 | path=Concat( 99 | cte.col.path, Value("\x01"), F("name"), 100 | output_field=text_field, 101 | ), 102 | depth=cte.col.depth + Value(1, output_field=int_field), 103 | is_planet=Case( 104 | When(parent_depth=0, then=Value(1)), 105 | default=Value(0), 106 | output_field=int_field, 107 | ), 108 | ), 109 | all=True, 110 | ) 111 | cte = With.recursive(make_regions_cte) 112 | regions = cte.join(Region, name=cte.col.name).with_cte(cte).annotate( 113 | path=cte.col.path, 114 | depth=cte.col.depth, 115 | is_planet=cte.col.is_planet, 116 | ).order_by("path") 117 | 118 | data = [(r.path.split("\x01"), r.is_planet) for r in regions] 119 | print(data) 120 | self.assertEqual(data, [ 121 | (["bernard's star"], 0), 122 | (['proxima centauri'], 0), 123 | (['proxima centauri', 'proxima centauri b'], 1), 124 | (['sun'], 0), 125 | (['sun', 'earth'], 1), 126 | (['sun', 'earth', 'moon'], 0), 127 | (['sun', 'mars'], 1), # mars moons excluded: parent_name != 'mars' 128 | (['sun', 'mercury'], 1), 129 | (['sun', 'venus'], 1), 130 | ]) 131 | 132 | def test_recursive_cte_with_empty_union_part(self): 133 | def make_regions_cte(cte): 134 | return Region.objects.none().union( 135 | cte.join(Region, parent=cte.col.name), 136 | all=True, 137 | ) 138 | cte = With.recursive(make_regions_cte) 139 | regions = cte.join(Region, name=cte.col.name).with_cte(cte) 140 | 141 | print(regions.query) 142 | try: 143 | self.assertEqual(regions.count(), 0) 144 | except DatabaseError: 145 | raise SkipTest( 146 | "Expected failure: QuerySet omits `EmptyQuerySet` from " 147 | "UNION queries resulting in invalid CTE SQL" 148 | ) 149 | 150 | # -- recursive query "cte" does not have the form 151 | # -- non-recursive-term UNION [ALL] recursive-term 152 | # WITH RECURSIVE cte AS ( 153 | # SELECT "tests_region"."name", "tests_region"."parent_id" 154 | # FROM "tests_region", "cte" 155 | # WHERE "tests_region"."parent_id" = ("cte"."name") 156 | # ) 157 | # SELECT COUNT(*) 158 | # FROM "tests_region", "cte" 159 | # WHERE "tests_region"."name" = ("cte"."name") 160 | 161 | def test_circular_ref_error(self): 162 | def make_bad_cte(cte): 163 | # NOTE: not a valid recursive CTE query 164 | return cte.join(Region, parent=cte.col.name).values( 165 | depth=cte.col.depth + 1, 166 | ) 167 | cte = With.recursive(make_bad_cte) 168 | regions = cte.join(Region, name=cte.col.name).with_cte(cte) 169 | with self.assertRaises(ValueError) as context: 170 | print(regions.query) 171 | self.assertIn("Circular reference:", str(context.exception)) 172 | 173 | def test_attname_should_not_mask_col_name(self): 174 | def make_regions_cte(cte): 175 | return Region.objects.filter( 176 | name="moon" 177 | ).values( 178 | "name", 179 | "parent_id", 180 | ).union( 181 | cte.join(Region, name=cte.col.parent_id).values( 182 | "name", 183 | "parent_id", 184 | ), 185 | all=True, 186 | ) 187 | cte = With.recursive(make_regions_cte) 188 | regions = ( 189 | Region.objects.all() 190 | .with_cte(cte) 191 | .annotate(_ex=Exists( 192 | cte.queryset() 193 | .values(value=Value("1", output_field=int_field)) 194 | .filter(name=OuterRef("name")) 195 | )) 196 | .filter(_ex=True) 197 | .order_by("name") 198 | ) 199 | print(regions.query) 200 | 201 | data = [r.name for r in regions] 202 | self.assertEqual(data, ['earth', 'moon', 'sun']) 203 | 204 | def test_pickle_recursive_cte_queryset(self): 205 | def make_regions_cte(cte): 206 | return Region.objects.filter( 207 | parent__isnull=True 208 | ).annotate( 209 | depth=Value(0, output_field=int_field), 210 | ).union( 211 | cte.join(Region, parent=cte.col.name).annotate( 212 | depth=cte.col.depth + Value(1, output_field=int_field), 213 | ), 214 | all=True, 215 | ) 216 | cte = With.recursive(make_regions_cte) 217 | regions = cte.queryset().with_cte(cte).filter(depth=2).order_by("name") 218 | 219 | pickled_qs = pickle.loads(pickle.dumps(regions)) 220 | 221 | data = [(r.name, r.depth) for r in pickled_qs] 222 | self.assertEqual(data, [(r.name, r.depth) for r in regions]) 223 | self.assertEqual(data, [('deimos', 2), ('moon', 2), ('phobos', 2)]) 224 | 225 | def test_alias_change_in_annotation(self): 226 | def make_regions_cte(cte): 227 | return Region.objects.filter( 228 | parent__name="sun", 229 | ).annotate( 230 | value=F('name'), 231 | ).union( 232 | cte.join( 233 | Region.objects.all().annotate( 234 | value=F('name'), 235 | ), 236 | parent_id=cte.col.name, 237 | ), 238 | all=True, 239 | ) 240 | cte = With.recursive(make_regions_cte) 241 | query = cte.queryset().with_cte(cte) 242 | 243 | exclude_leaves = With(cte.queryset().filter( 244 | parent__name='sun', 245 | ).annotate( 246 | value=Concat(F('name'), F('name')) 247 | ), name='value_cte') 248 | 249 | query = query.annotate( 250 | _exclude_leaves=Exists( 251 | exclude_leaves.queryset().filter( 252 | name=OuterRef("name"), 253 | value=OuterRef("value"), 254 | ) 255 | ) 256 | ).filter(_exclude_leaves=True).with_cte(exclude_leaves) 257 | print(query.query) 258 | 259 | # Nothing should be returned. 260 | self.assertFalse(query) 261 | 262 | def test_alias_as_subquery(self): 263 | # This test covers CTEColumnRef.relabeled_clone 264 | def make_regions_cte(cte): 265 | return KeyPair.objects.filter( 266 | parent__key="level 1", 267 | ).annotate( 268 | rank=F('value'), 269 | ).union( 270 | cte.join( 271 | KeyPair.objects.all().order_by(), 272 | parent_id=cte.col.id, 273 | ).annotate( 274 | rank=F('value'), 275 | ), 276 | all=True, 277 | ) 278 | cte = With.recursive(make_regions_cte) 279 | children = cte.queryset().with_cte(cte) 280 | 281 | xdups = With(cte.queryset().filter( 282 | parent__key="level 1", 283 | ).annotate( 284 | rank=F('value') 285 | ).values('id', 'rank'), name='xdups') 286 | 287 | children = children.annotate( 288 | _exclude=Exists( 289 | ( 290 | xdups.queryset().filter( 291 | id=OuterRef("id"), 292 | rank=OuterRef("rank"), 293 | ) 294 | ) 295 | ) 296 | ).filter(_exclude=True).with_cte(xdups) 297 | 298 | print(children.query) 299 | query = KeyPair.objects.filter(parent__in=children) 300 | print(query.query) 301 | print(children.query) 302 | self.assertEqual(query.get().key, 'level 3') 303 | # Tests the case in which children's query was modified since it was 304 | # used in a subquery to define `query` above. 305 | self.assertEqual( 306 | list(c.key for c in children), 307 | ['level 2', 'level 2'] 308 | ) 309 | 310 | def test_materialized(self): 311 | # This test covers MATERIALIZED option in SQL query 312 | def make_regions_cte(cte): 313 | return KeyPair.objects.all() 314 | cte = With.recursive(make_regions_cte, materialized=True) 315 | 316 | query = KeyPair.objects.with_cte(cte) 317 | print(query.query) 318 | self.assertTrue( 319 | str(query.query).startswith('WITH RECURSIVE "cte" AS MATERIALIZED') 320 | ) 321 | 322 | def test_recursive_self_queryset(self): 323 | def make_regions_cte(cte): 324 | return Region.objects.filter( 325 | pk="earth" 326 | ).values("pk").union( 327 | cte.join(Region, parent=cte.col.pk).values("pk") 328 | ) 329 | cte = With.recursive(make_regions_cte) 330 | queryset = cte.queryset().with_cte(cte).order_by("pk") 331 | print(queryset.query) 332 | self.assertEqual(list(queryset), [ 333 | {'pk': 'earth'}, 334 | {'pk': 'moon'}, 335 | ]) 336 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Common Table Expressions with Django 2 | 3 | * Table of contents (this line will not be displayed). 4 | {:toc} 5 | 6 | A Common Table Expression acts like a temporary table or view that exists only 7 | for the duration of the query it is attached to. django-cte allows common table 8 | expressions to be attached to normal Django ORM queries. 9 | 10 | 11 | ## Simple Common Table Expressions 12 | 13 | See [Appendix A](#appendix-a-model-definitions-used-in-sample-code) for model 14 | definitions used in sample code. 15 | 16 | Simple CTEs are constructed using `CTE(...)`. A CTE is added to a queryset using 17 | `with_cte(cte, select=queryset)`, which adds the `WITH` expression before the 18 | main `SELECT` query. A CTE can be joined to a model or other `QuerySet` using 19 | its `.join(...)` method, which creates a new queryset with a `JOIN` and 20 | `ON` condition. 21 | 22 | ```py 23 | from django_cte import CTE, with_cte 24 | 25 | cte = CTE( 26 | Order.objects 27 | .values("region_id") 28 | .annotate(total=Sum("amount")) 29 | ) 30 | 31 | orders = with_cte( 32 | # WITH cte ... 33 | cte, 34 | 35 | # SELECT ... FROM orders INNER JOIN cte ON orders.region_id = cte.region_id 36 | select=cte.join(Order, region=cte.col.region_id) 37 | 38 | # Annotate each Order with a "region_total" 39 | .annotate(region_total=cte.col.total) 40 | ) 41 | 42 | print(orders.query) # print SQL 43 | ``` 44 | 45 | The `orders` SQL, after formatting for readability, would look something like 46 | this: 47 | 48 | ```sql 49 | WITH RECURSIVE "cte" AS ( 50 | SELECT 51 | "orders"."region_id", 52 | SUM("orders"."amount") AS "total" 53 | FROM "orders" 54 | GROUP BY "orders"."region_id" 55 | ) 56 | SELECT 57 | "orders"."id", 58 | "orders"."region_id", 59 | "orders"."amount", 60 | "cte"."total" AS "region_total" 61 | FROM "orders" 62 | INNER JOIN "cte" ON "orders"."region_id" = "cte"."region_id" 63 | ``` 64 | 65 | The `orders` query is a queryset containing annotated `Order` objects, just as 66 | you would get from a query like `Order.objects.annotate(region_total=...)`. Each 67 | `Order` object will be annotated with a `region_total` attribute, which is 68 | populated with the value of the corresponding total from the joined CTE query. 69 | 70 | You may have noticed the CTE in this query uses `WITH RECURSIVE` even though 71 | this is not a [Recursive Common Table Expression](#recursive-common-table-expressions). 72 | The `RECURSIVE` keyword is always used, even for non-recursive CTEs. On 73 | databases such as PostgreSQL and SQLite this has no effect other than allowing 74 | recursive CTEs to be included in the WITH block. 75 | 76 | 77 | ## Recursive Common Table Expressions 78 | 79 | Recursive CTE queries allow fundamentally new types of queries that are 80 | not otherwise possible. 81 | 82 | Recursive CTEs are constructed using `CTE.recursive()`, which takes as its 83 | first argument a function that constructs and returns a recursive query. 84 | Recursive queries have two elements: first a non-recursive query element, and 85 | second a recursive query element. The second is typically attached to the first 86 | using `QuerySet.union()`. 87 | 88 | ```py 89 | def make_regions_cte(cte): 90 | # non-recursive: get root nodes 91 | return Region.objects.filter( 92 | parent__isnull=True 93 | ).values( 94 | "name", 95 | path=F("name"), 96 | depth=Value(0, output_field=IntegerField()), 97 | ).union( 98 | # recursive union: get descendants 99 | cte.join(Region, parent=cte.col.name).values( 100 | "name", 101 | path=Concat( 102 | cte.col.path, Value(" / "), F("name"), 103 | output_field=TextField(), 104 | ), 105 | depth=cte.col.depth + Value(1, output_field=IntegerField()), 106 | ), 107 | all=True, 108 | ) 109 | 110 | cte = CTE.recursive(make_regions_cte) 111 | 112 | regions = with_cte( 113 | cte, 114 | select=cte.join(Region, name=cte.col.name) 115 | .annotate( 116 | path=cte.col.path, 117 | depth=cte.col.depth, 118 | ) 119 | .filter(depth=2) 120 | .order_by("path") 121 | ) 122 | ``` 123 | 124 | `Region` objects returned by this query will have `path` and `depth` attributes. 125 | The results will be ordered by `path` (hierarchically by region name). The SQL 126 | produced by this query looks something like this: 127 | 128 | ```sql 129 | WITH RECURSIVE "cte" AS ( 130 | SELECT 131 | "region"."name", 132 | "region"."name" AS "path", 133 | 0 AS "depth" 134 | FROM "region" 135 | WHERE "region"."parent_id" IS NULL 136 | 137 | UNION ALL 138 | 139 | SELECT 140 | "region"."name", 141 | "cte"."path" || ' / ' || "region"."name" AS "path", 142 | "cte"."depth" + 1 AS "depth" 143 | FROM "region" 144 | INNER JOIN "cte" ON "region"."parent_id" = "cte"."name" 145 | ) 146 | SELECT 147 | "region"."name", 148 | "region"."parent_id", 149 | "cte"."path" AS "path", 150 | "cte"."depth" AS "depth" 151 | FROM "region" 152 | INNER JOIN "cte" ON "region"."name" = "cte"."name" 153 | WHERE "cte"."depth" = 2 154 | ORDER BY "path" ASC 155 | ``` 156 | 157 | 158 | ## Named Common Table Expressions 159 | 160 | It is possible to add more than one CTE to a query. To do this, each CTE must 161 | have a unique name. `CTE(queryset)` returns a CTE with the name `'cte'` by 162 | default, but that can be overridden: `CTE(queryset, name='custom')` or 163 | `CTE.recursive(make_queryset, name='custom')`. This allows each CTE to be 164 | referenced uniquely within a single query. 165 | 166 | Also note that a CTE may reference other CTEs in the same query. 167 | 168 | Example query with two CTEs, and the second (`totals`) CTE references the first 169 | (`rootmap`): 170 | 171 | ```py 172 | def make_root_mapping(rootmap): 173 | return Region.objects.filter( 174 | parent__isnull=True 175 | ).values( 176 | "name", 177 | root=F("name"), 178 | ).union( 179 | rootmap.join(Region, parent=rootmap.col.name).values( 180 | "name", 181 | root=rootmap.col.root, 182 | ), 183 | all=True, 184 | ) 185 | rootmap = CTE.recursive(make_root_mapping, name="rootmap") 186 | 187 | totals = CTE( 188 | rootmap.join(Order, region_id=rootmap.col.name) 189 | .values( 190 | root=rootmap.col.root, 191 | ).annotate( 192 | orders_count=Count("id"), 193 | region_total=Sum("amount"), 194 | ), 195 | name="totals", 196 | ) 197 | 198 | root_regions = with_cte( 199 | # Important: add both CTEs to the query 200 | rootmap, 201 | totals, 202 | 203 | select=totals.join(Region, name=totals.col.root) 204 | .annotate( 205 | # count of orders in this region and all subregions 206 | orders_count=totals.col.orders_count, 207 | # sum of order amounts in this region and all subregions 208 | region_total=totals.col.region_total, 209 | ) 210 | ) 211 | ``` 212 | 213 | And the resulting SQL. 214 | 215 | ```sql 216 | WITH RECURSIVE "rootmap" AS ( 217 | SELECT 218 | "region"."name", 219 | "region"."name" AS "root" 220 | FROM "region" 221 | WHERE "region"."parent_id" IS NULL 222 | 223 | UNION ALL 224 | 225 | SELECT 226 | "region"."name", 227 | "rootmap"."root" AS "root" 228 | FROM "region" 229 | INNER JOIN "rootmap" ON "region"."parent_id" = "rootmap"."name" 230 | ), 231 | "totals" AS ( 232 | SELECT 233 | "rootmap"."root" AS "root", 234 | COUNT("orders"."id") AS "orders_count", 235 | SUM("orders"."amount") AS "region_total" 236 | FROM "orders" 237 | INNER JOIN "rootmap" ON "orders"."region_id" = "rootmap"."name" 238 | GROUP BY "rootmap"."root" 239 | ) 240 | SELECT 241 | "region"."name", 242 | "region"."parent_id", 243 | "totals"."orders_count" AS "orders_count", 244 | "totals"."region_total" AS "region_total" 245 | FROM "region" 246 | INNER JOIN "totals" ON "region"."name" = "totals"."root" 247 | ``` 248 | 249 | 250 | ## Selecting FROM a Common Table Expression 251 | 252 | Sometimes it is useful to construct queries where the final `FROM` clause 253 | contains only common table expression(s). This is possible with 254 | `CTE(...).queryset()`. 255 | 256 | Each returned row may be a model object: 257 | 258 | ```py 259 | cte = CTE( 260 | Order.objects 261 | .annotate(region_parent=F("region__parent_id")), 262 | ) 263 | orders = with_cte(cte, select=cte.queryset()) 264 | ``` 265 | 266 | And the resulting SQL: 267 | 268 | ```sql 269 | WITH RECURSIVE "cte" AS ( 270 | SELECT 271 | "orders"."id", 272 | "orders"."region_id", 273 | "orders"."amount", 274 | "region"."parent_id" AS "region_parent" 275 | FROM "orders" 276 | INNER JOIN "region" ON "orders"."region_id" = "region"."name" 277 | ) 278 | SELECT 279 | "cte"."id", 280 | "cte"."region_id", 281 | "cte"."amount", 282 | "cte"."region_parent" AS "region_parent" 283 | FROM "cte" 284 | ``` 285 | 286 | It is also possible to do the same with `values(...)` queries: 287 | 288 | ```py 289 | cte = CTE( 290 | Order.objects 291 | .values( 292 | "region_id", 293 | region_parent=F("region__parent_id"), 294 | ) 295 | .distinct() 296 | ) 297 | values = with_cte(cte, select=cte).filter(region_parent__isnull=False) 298 | ``` 299 | 300 | Which produces this SQL: 301 | 302 | ```sql 303 | WITH RECURSIVE "cte" AS ( 304 | SELECT DISTINCT 305 | "orders"."region_id", 306 | "region"."parent_id" AS "region_parent" 307 | FROM "orders" 308 | INNER JOIN "region" ON "orders"."region_id" = "region"."name" 309 | ) 310 | SELECT 311 | "cte"."region_id", 312 | "cte"."region_parent" AS "region_parent" 313 | FROM "cte" 314 | WHERE "cte"."region_parent" IS NOT NULL 315 | ``` 316 | 317 | You may have noticed that when a CTE is passed to the `select=...` argument as 318 | in `with_cte(cte, select=cte)`, the `.queryset()` call is optional and may be 319 | omitted. 320 | 321 | 322 | ## Experimental: Left Outer Join 323 | 324 | Django does not provide precise control over joins, but there is an experimental 325 | way to perform a `LEFT OUTER JOIN` with a CTE query using the `_join_type` 326 | keyword argument of `CTE.join(...)`. 327 | 328 | ```py 329 | from django.db.models.sql.constants import LOUTER 330 | 331 | totals = CTE( 332 | Order.objects 333 | .values("region_id") 334 | .annotate(total=Sum("amount")) 335 | .filter(total__gt=100) 336 | ) 337 | orders = with_cte( 338 | totals, 339 | select=totals 340 | .join(Order, region=totals.col.region_id, _join_type=LOUTER) 341 | .annotate(region_total=totals.col.total) 342 | ) 343 | ``` 344 | 345 | Which produces the following SQL 346 | 347 | ```sql 348 | WITH RECURSIVE "cte" AS ( 349 | SELECT 350 | "orders"."region_id", 351 | SUM("orders"."amount") AS "total" 352 | FROM "orders" 353 | GROUP BY "orders"."region_id" 354 | HAVING SUM("orders"."amount") > 100 355 | ) 356 | SELECT 357 | "orders"."id", 358 | "orders"."region_id", 359 | "orders"."amount", 360 | "cte"."total" AS "region_total" 361 | FROM "orders" 362 | LEFT OUTER JOIN "cte" ON "orders"."region_id" = "cte"."region_id" 363 | ``` 364 | 365 | WARNING: as noted, this feature is experimental. There may be scenarios where 366 | Django automatically converts a `LEFT OUTER JOIN` to an `INNER JOIN` in the 367 | process of building the query. Be sure to test your queries to ensure they 368 | produce the desired SQL. 369 | 370 | 371 | ## Materialized CTE 372 | 373 | Both PostgreSQL 12+ and sqlite 3.35+ supports `MATERIALIZED` keyword for CTE 374 | queries. To enforce usage of this keyword add `materialized` as a parameter of 375 | `CTE(..., materialized=True)`. 376 | 377 | 378 | ```py 379 | cte = CTE( 380 | Order.objects.values('id'), 381 | materialized=True 382 | ) 383 | ``` 384 | 385 | Which produces this SQL: 386 | 387 | ```sql 388 | WITH RECURSIVE "cte" AS MATERIALIZED ( 389 | SELECT 390 | "orders"."id" 391 | FROM "orders" 392 | ) 393 | ... 394 | ``` 395 | 396 | 397 | ## Raw CTE SQL 398 | 399 | Some queries are easier to construct with raw SQL than with the Django ORM. 400 | `raw_cte_sql()` is one solution for situations like that. The down-side is that 401 | each result field in the raw query must be explicitly mapped to a field type. 402 | The up-side is that there is no need to compromise result-set expressiveness 403 | with the likes of `Manager.raw()`. 404 | 405 | A short example: 406 | 407 | ```py 408 | from django.db.models import IntegerField, TextField 409 | from django_cte.raw import raw_cte_sql 410 | 411 | cte = CTE(raw_cte_sql( 412 | """ 413 | SELECT region_id, AVG(amount) AS avg_order 414 | FROM orders 415 | WHERE region_id = %s 416 | GROUP BY region_id 417 | """, 418 | ["moon"], 419 | { 420 | "region_id": TextField(), 421 | "avg_order": IntegerField(), 422 | }, 423 | )) 424 | moon_avg = with_cte( 425 | cte, 426 | select=cte 427 | .join(Region, name=cte.col.region_id) 428 | .annotate(avg_order=cte.col.avg_order) 429 | ) 430 | ``` 431 | 432 | Which produces this SQL: 433 | 434 | ```sql 435 | WITH RECURSIVE "cte" AS ( 436 | SELECT region_id, AVG(amount) AS avg_order 437 | FROM orders 438 | WHERE region_id = 'moon' 439 | GROUP BY region_id 440 | ) 441 | SELECT 442 | "region"."name", 443 | "region"."parent_id", 444 | "cte"."avg_order" AS "avg_order" 445 | FROM "region" 446 | INNER JOIN "cte" ON "region"."name" = "cte"."region_id" 447 | ``` 448 | 449 | **WARNING**: Be very careful when writing raw SQL. Use bind parameters to 450 | prevent SQL injection attacks. 451 | 452 | 453 | ## More Advanced Use Cases 454 | 455 | A few more advanced techniques as well as example query results can be found 456 | in the tests: 457 | 458 | - [`test_cte.py`](https://github.com/dimagi/django-cte/blob/main/tests/test_cte.py) 459 | - [`test_recursive.py`](https://github.com/dimagi/django-cte/blob/main/tests/test_recursive.py) 460 | - [`test_raw.py`](https://github.com/dimagi/django-cte/blob/main/tests/test_raw.py) 461 | 462 | 463 | ## Appendix A: Model definitions used in sample code 464 | 465 | ```py 466 | class Order(Model): 467 | id = AutoField(primary_key=True) 468 | region = ForeignKey("Region", on_delete=CASCADE) 469 | amount = IntegerField(default=0) 470 | 471 | class Meta: 472 | db_table = "orders" 473 | 474 | 475 | class Region(Model): 476 | name = TextField(primary_key=True) 477 | parent = ForeignKey("self", null=True, on_delete=CASCADE) 478 | 479 | class Meta: 480 | db_table = "region" 481 | ``` 482 | 483 | 484 | ## Appendix B: django-cte v1 documentation (DEPRECATED) 485 | 486 | The syntax for constructing CTE queries changed slightly in django-cte 2.0. The 487 | most important change is that a custom model manager is no longer required on 488 | models used to construct CTE queries. The documentation has been updated to use 489 | v2 syntax, but the [documentation for v1](https://github.com/dimagi/django-cte/blob/v1.3.3/docs/index.md) 490 | can be found on Github if needed. 491 | -------------------------------------------------------------------------------- /tests/test_v1/test_cte.py: -------------------------------------------------------------------------------- 1 | from unittest import SkipTest 2 | 3 | from django.db.models import IntegerField, TextField 4 | from django.db.models.aggregates import Count, Max, Min, Sum 5 | from django.db.models.expressions import ( 6 | Exists, ExpressionWrapper, F, OuterRef, Subquery, 7 | ) 8 | from django.db.models.sql.constants import LOUTER 9 | from django.test import TestCase 10 | 11 | from django_cte import With 12 | from django_cte import CTEManager 13 | 14 | from .models import Order, Region, User 15 | 16 | int_field = IntegerField() 17 | text_field = TextField() 18 | 19 | 20 | class TestCTE(TestCase): 21 | 22 | def test_simple_cte_query(self): 23 | cte = With( 24 | Order.objects 25 | .values("region_id") 26 | .annotate(total=Sum("amount")) 27 | ) 28 | 29 | orders = ( 30 | # FROM orders INNER JOIN cte ON orders.region_id = cte.region_id 31 | cte.join(Order, region=cte.col.region_id) 32 | 33 | # Add `WITH ...` before `SELECT ... FROM orders ...` 34 | .with_cte(cte) 35 | 36 | # Annotate each Order with a "region_total" 37 | .annotate(region_total=cte.col.total) 38 | ) 39 | print(orders.query) 40 | 41 | data = sorted((o.amount, o.region_id, o.region_total) for o in orders) 42 | self.assertEqual(data, [ 43 | (1, 'moon', 6), 44 | (2, 'moon', 6), 45 | (3, 'moon', 6), 46 | (10, 'mercury', 33), 47 | (10, 'proxima centauri b', 33), 48 | (11, 'mercury', 33), 49 | (11, 'proxima centauri b', 33), 50 | (12, 'mercury', 33), 51 | (12, 'proxima centauri b', 33), 52 | (20, 'venus', 86), 53 | (21, 'venus', 86), 54 | (22, 'venus', 86), 55 | (23, 'venus', 86), 56 | (30, 'earth', 126), 57 | (31, 'earth', 126), 58 | (32, 'earth', 126), 59 | (33, 'earth', 126), 60 | (40, 'mars', 123), 61 | (41, 'mars', 123), 62 | (42, 'mars', 123), 63 | (1000, 'sun', 1000), 64 | (2000, 'proxima centauri', 2000), 65 | ]) 66 | 67 | def test_cte_name_escape(self): 68 | totals = With( 69 | Order.objects 70 | .filter(region__parent="sun") 71 | .values("region_id") 72 | .annotate(total=Sum("amount")), 73 | name="mixedCaseCTEName" 74 | ) 75 | orders = ( 76 | totals 77 | .join(Order, region=totals.col.region_id) 78 | .with_cte(totals) 79 | .annotate(region_total=totals.col.total) 80 | .order_by("amount") 81 | ) 82 | self.assertTrue( 83 | str(orders.query).startswith('WITH RECURSIVE "mixedCaseCTEName"')) 84 | 85 | def test_cte_queryset(self): 86 | sub_totals = With( 87 | Order.objects 88 | .values(region_parent=F("region__parent_id")) 89 | .annotate(total=Sum("amount")), 90 | ) 91 | regions = ( 92 | Region.objects.all() 93 | .with_cte(sub_totals) 94 | .annotate( 95 | child_regions_total=Subquery( 96 | sub_totals.queryset() 97 | .filter(region_parent=OuterRef("name")) 98 | .values("total"), 99 | ), 100 | ) 101 | .order_by("name") 102 | ) 103 | print(regions.query) 104 | 105 | data = [(r.name, r.child_regions_total) for r in regions] 106 | self.assertEqual(data, [ 107 | ("bernard's star", None), 108 | ('deimos', None), 109 | ('earth', 6), 110 | ('mars', None), 111 | ('mercury', None), 112 | ('moon', None), 113 | ('phobos', None), 114 | ('proxima centauri', 33), 115 | ('proxima centauri b', None), 116 | ('sun', 368), 117 | ('venus', None) 118 | ]) 119 | 120 | def test_cte_queryset_with_model_result(self): 121 | cte = With( 122 | Order.objects 123 | .annotate(region_parent=F("region__parent_id")), 124 | ) 125 | orders = cte.queryset().with_cte(cte) 126 | print(orders.query) 127 | 128 | data = sorted( 129 | (x.region_id, x.amount, x.region_parent) for x in orders)[:5] 130 | self.assertEqual(data, [ 131 | ("earth", 30, "sun"), 132 | ("earth", 31, "sun"), 133 | ("earth", 32, "sun"), 134 | ("earth", 33, "sun"), 135 | ("mars", 40, "sun"), 136 | ]) 137 | self.assertTrue( 138 | all(isinstance(x, Order) for x in orders), 139 | repr([x for x in orders]), 140 | ) 141 | 142 | def test_cte_queryset_with_join(self): 143 | cte = With( 144 | Order.objects 145 | .annotate(region_parent=F("region__parent_id")), 146 | ) 147 | orders = ( 148 | cte.queryset() 149 | .with_cte(cte) 150 | .annotate(parent=F("region__parent_id")) 151 | .order_by("region_id", "amount") 152 | ) 153 | print(orders.query) 154 | 155 | data = [(x.region_id, x.region_parent, x.parent) for x in orders][:5] 156 | self.assertEqual(data, [ 157 | ("earth", "sun", "sun"), 158 | ("earth", "sun", "sun"), 159 | ("earth", "sun", "sun"), 160 | ("earth", "sun", "sun"), 161 | ("mars", "sun", "sun"), 162 | ]) 163 | 164 | def test_cte_queryset_with_values_result(self): 165 | cte = With( 166 | Order.objects 167 | .values( 168 | "region_id", 169 | region_parent=F("region__parent_id"), 170 | ) 171 | .distinct() 172 | ) 173 | values = ( 174 | cte.queryset() 175 | .with_cte(cte) 176 | .filter(region_parent__isnull=False) 177 | ) 178 | print(values.query) 179 | 180 | def key(item): 181 | return item["region_parent"], item["region_id"] 182 | 183 | data = sorted(values, key=key)[:5] 184 | self.assertEqual(data, [ 185 | {'region_id': 'moon', 'region_parent': 'earth'}, 186 | { 187 | 'region_id': 'proxima centauri b', 188 | 'region_parent': 'proxima centauri', 189 | }, 190 | {'region_id': 'earth', 'region_parent': 'sun'}, 191 | {'region_id': 'mars', 'region_parent': 'sun'}, 192 | {'region_id': 'mercury', 'region_parent': 'sun'}, 193 | ]) 194 | 195 | def test_named_simple_ctes(self): 196 | totals = With( 197 | Order.objects 198 | .filter(region__parent="sun") 199 | .values("region_id") 200 | .annotate(total=Sum("amount")), 201 | name="totals", 202 | ) 203 | region_count = With( 204 | Region.objects 205 | .filter(parent="sun") 206 | .values("parent_id") 207 | .annotate(num=Count("name")), 208 | name="region_count", 209 | ) 210 | orders = ( 211 | region_count.join( 212 | totals.join(Order, region=totals.col.region_id), 213 | region__parent=region_count.col.parent_id 214 | ) 215 | .with_cte(totals) 216 | .with_cte(region_count) 217 | .annotate(region_total=totals.col.total) 218 | .annotate(region_count=region_count.col.num) 219 | .order_by("amount") 220 | ) 221 | print(orders.query) 222 | 223 | data = [( 224 | o.amount, 225 | o.region_id, 226 | o.region_count, 227 | o.region_total, 228 | ) for o in orders] 229 | self.assertEqual(data, [ 230 | (10, 'mercury', 4, 33), 231 | (11, 'mercury', 4, 33), 232 | (12, 'mercury', 4, 33), 233 | (20, 'venus', 4, 86), 234 | (21, 'venus', 4, 86), 235 | (22, 'venus', 4, 86), 236 | (23, 'venus', 4, 86), 237 | (30, 'earth', 4, 126), 238 | (31, 'earth', 4, 126), 239 | (32, 'earth', 4, 126), 240 | (33, 'earth', 4, 126), 241 | (40, 'mars', 4, 123), 242 | (41, 'mars', 4, 123), 243 | (42, 'mars', 4, 123), 244 | ]) 245 | 246 | def test_named_ctes(self): 247 | def make_root_mapping(rootmap): 248 | return Region.objects.filter( 249 | parent__isnull=True 250 | ).values( 251 | "name", 252 | root=F("name"), 253 | ).union( 254 | rootmap.join(Region, parent=rootmap.col.name).values( 255 | "name", 256 | root=rootmap.col.root, 257 | ), 258 | all=True, 259 | ) 260 | rootmap = With.recursive(make_root_mapping, name="rootmap") 261 | 262 | totals = With( 263 | rootmap.join(Order, region_id=rootmap.col.name) 264 | .values( 265 | root=rootmap.col.root, 266 | ).annotate( 267 | orders_count=Count("id"), 268 | region_total=Sum("amount"), 269 | ), 270 | name="totals", 271 | ) 272 | 273 | root_regions = ( 274 | totals.join(Region, name=totals.col.root) 275 | .with_cte(rootmap) 276 | .with_cte(totals) 277 | .annotate( 278 | # count of orders in this region and all subregions 279 | orders_count=totals.col.orders_count, 280 | # sum of order amounts in this region and all subregions 281 | region_total=totals.col.region_total, 282 | ) 283 | ) 284 | print(root_regions.query) 285 | 286 | data = sorted( 287 | (r.name, r.orders_count, r.region_total) for r in root_regions 288 | ) 289 | self.assertEqual(data, [ 290 | ('proxima centauri', 4, 2033), 291 | ('sun', 18, 1374), 292 | ]) 293 | 294 | def test_materialized_option(self): 295 | totals = With( 296 | Order.objects 297 | .filter(region__parent="sun") 298 | .values("region_id") 299 | .annotate(total=Sum("amount")), 300 | materialized=True 301 | ) 302 | orders = ( 303 | totals 304 | .join(Order, region=totals.col.region_id) 305 | .with_cte(totals) 306 | .annotate(region_total=totals.col.total) 307 | .order_by("amount") 308 | ) 309 | self.assertTrue( 310 | str(orders.query).startswith( 311 | 'WITH RECURSIVE "cte" AS MATERIALIZED' 312 | ) 313 | ) 314 | 315 | def test_update_cte_query(self): 316 | cte = With( 317 | Order.objects 318 | .values(region_parent=F("region__parent_id")) 319 | .annotate(total=Sum("amount")) 320 | .filter(total__isnull=False) 321 | ) 322 | # not the most efficient query, but it exercises CTEUpdateQuery 323 | Order.objects.all().with_cte(cte).filter(region_id__in=Subquery( 324 | cte.queryset() 325 | .filter(region_parent=OuterRef("region_id")) 326 | .values("region_parent") 327 | )).update(amount=Subquery( 328 | cte.queryset() 329 | .filter(region_parent=OuterRef("region_id")) 330 | .values("total") 331 | )) 332 | 333 | data = set((o.region_id, o.amount) for o in Order.objects.filter( 334 | region_id__in=["earth", "sun", "proxima centauri", "mars"] 335 | )) 336 | self.assertEqual(data, { 337 | ('earth', 6), 338 | ('mars', 40), 339 | ('mars', 41), 340 | ('mars', 42), 341 | ('proxima centauri', 33), 342 | ('sun', 368), 343 | }) 344 | 345 | def test_update_with_subquery(self): 346 | # Test for issue: https://github.com/dimagi/django-cte/issues/9 347 | # Issue is not reproduces on sqlite3 use postgres to run. 348 | # To reproduce the problem it's required to have some join 349 | # in the select-query so the compiler will turn it into a subquery. 350 | # To add a join use a filter over field of related model 351 | orders = Order.objects.filter(region__parent_id='sun') 352 | orders.update(amount=0) 353 | data = {(order.region_id, order.amount) for order in orders} 354 | self.assertEqual(data, { 355 | ('mercury', 0), 356 | ('venus', 0), 357 | ('earth', 0), 358 | ('mars', 0), 359 | }) 360 | 361 | def test_delete_cte_query(self): 362 | raise SkipTest( 363 | "this test will not work until `QuerySet.delete` (Django method) " 364 | "calls `self.query.chain(sql.DeleteQuery)` instead of " 365 | "`sql.DeleteQuery(self.model)`" 366 | ) 367 | cte = With( 368 | Order.objects 369 | .values(region_parent=F("region__parent_id")) 370 | .annotate(total=Sum("amount")) 371 | .filter(total__isnull=False) 372 | ) 373 | Order.objects.all().with_cte(cte).annotate( 374 | cte_has_order=Exists( 375 | cte.queryset() 376 | .values("total") 377 | .filter(region_parent=OuterRef("region_id")) 378 | ) 379 | ).filter(cte_has_order=False).delete() 380 | 381 | data = [(o.region_id, o.amount) for o in Order.objects.all()] 382 | self.assertEqual(data, [ 383 | ('sun', 1000), 384 | ('earth', 30), 385 | ('earth', 31), 386 | ('earth', 32), 387 | ('earth', 33), 388 | ('proxima centauri', 2000), 389 | ]) 390 | 391 | def test_outerref_in_cte_query(self): 392 | # This query is meant to return the difference between min and max 393 | # order of each region, through a subquery 394 | min_and_max = With( 395 | Order.objects 396 | .filter(region=OuterRef("pk")) 397 | .values('region') # This is to force group by region_id 398 | .annotate( 399 | amount_min=Min("amount"), 400 | amount_max=Max("amount"), 401 | ) 402 | .values('amount_min', 'amount_max') 403 | ) 404 | regions = ( 405 | Region.objects 406 | .annotate( 407 | difference=Subquery( 408 | min_and_max.queryset().with_cte(min_and_max).annotate( 409 | difference=ExpressionWrapper( 410 | F('amount_max') - F('amount_min'), 411 | output_field=int_field, 412 | ), 413 | ).values('difference')[:1], 414 | output_field=IntegerField() 415 | ) 416 | ) 417 | .order_by("name") 418 | ) 419 | print(regions.query) 420 | 421 | data = [(r.name, r.difference) for r in regions] 422 | self.assertEqual(data, [ 423 | ("bernard's star", None), 424 | ('deimos', None), 425 | ('earth', 3), 426 | ('mars', 2), 427 | ('mercury', 2), 428 | ('moon', 2), 429 | ('phobos', None), 430 | ('proxima centauri', 0), 431 | ('proxima centauri b', 2), 432 | ('sun', 0), 433 | ('venus', 3) 434 | ]) 435 | 436 | def test_experimental_left_outer_join(self): 437 | totals = With( 438 | Order.objects 439 | .values("region_id") 440 | .annotate(total=Sum("amount")) 441 | .filter(total__gt=100) 442 | ) 443 | orders = ( 444 | totals 445 | .join(Order, region=totals.col.region_id, _join_type=LOUTER) 446 | .with_cte(totals) 447 | .annotate(region_total=totals.col.total) 448 | ) 449 | print(orders.query) 450 | self.assertIn("LEFT OUTER JOIN", str(orders.query)) 451 | self.assertNotIn("INNER JOIN", str(orders.query)) 452 | 453 | data = sorted((o.region_id, o.amount, o.region_total) for o in orders) 454 | self.assertEqual(data, [ 455 | ('earth', 30, 126), 456 | ('earth', 31, 126), 457 | ('earth', 32, 126), 458 | ('earth', 33, 126), 459 | ('mars', 40, 123), 460 | ('mars', 41, 123), 461 | ('mars', 42, 123), 462 | ('mercury', 10, None), 463 | ('mercury', 11, None), 464 | ('mercury', 12, None), 465 | ('moon', 1, None), 466 | ('moon', 2, None), 467 | ('moon', 3, None), 468 | ('proxima centauri', 2000, 2000), 469 | ('proxima centauri b', 10, None), 470 | ('proxima centauri b', 11, None), 471 | ('proxima centauri b', 12, None), 472 | ('sun', 1000, 1000), 473 | ('venus', 20, None), 474 | ('venus', 21, None), 475 | ('venus', 22, None), 476 | ('venus', 23, None), 477 | ]) 478 | 479 | def test_non_cte_subquery(self): 480 | """ 481 | Verifies that subquery annotations are handled correctly when the 482 | subquery model doesn't use the CTE manager, and the query results 483 | match expected behavior 484 | """ 485 | self.assertNotIsInstance(User.objects, CTEManager) 486 | 487 | sub_totals = With( 488 | Order.objects 489 | .values(region_parent=F("region__parent_id")) 490 | .annotate( 491 | total=Sum("amount"), 492 | # trivial subquery example testing existence of 493 | # a user for the order 494 | non_cte_subquery=Exists( 495 | User.objects.filter(pk=OuterRef("user_id")) 496 | ), 497 | ), 498 | ) 499 | regions = ( 500 | Region.objects.all() 501 | .with_cte(sub_totals) 502 | .annotate( 503 | child_regions_total=Subquery( 504 | sub_totals.queryset() 505 | .filter(region_parent=OuterRef("name")) 506 | .values("total"), 507 | ), 508 | ) 509 | .order_by("name") 510 | ) 511 | print(regions.query) 512 | 513 | data = [(r.name, r.child_regions_total) for r in regions] 514 | self.assertEqual(data, [ 515 | ("bernard's star", None), 516 | ('deimos', None), 517 | ('earth', 6), 518 | ('mars', None), 519 | ('mercury', None), 520 | ('moon', None), 521 | ('phobos', None), 522 | ('proxima centauri', 33), 523 | ('proxima centauri b', None), 524 | ('sun', 368), 525 | ('venus', None) 526 | ]) 527 | 528 | def test_explain(self): 529 | """ 530 | Verifies that using .explain() prepends the EXPLAIN clause in the 531 | correct position 532 | """ 533 | 534 | totals = With( 535 | Order.objects 536 | .filter(region__parent="sun") 537 | .values("region_id") 538 | .annotate(total=Sum("amount")), 539 | name="totals", 540 | ) 541 | region_count = With( 542 | Region.objects 543 | .filter(parent="sun") 544 | .values("parent_id") 545 | .annotate(num=Count("name")), 546 | name="region_count", 547 | ) 548 | orders = ( 549 | region_count.join( 550 | totals.join(Order, region=totals.col.region_id), 551 | region__parent=region_count.col.parent_id 552 | ) 553 | .with_cte(totals) 554 | .with_cte(region_count) 555 | .annotate(region_total=totals.col.total) 556 | .annotate(region_count=region_count.col.num) 557 | .order_by("amount") 558 | ) 559 | print(orders.query) 560 | 561 | self.assertIsInstance(orders.explain(), str) 562 | 563 | def test_empty_result_set_cte(self): 564 | """ 565 | Verifies that the CTEQueryCompiler can handle empty result sets in the 566 | related CTEs 567 | """ 568 | totals = With( 569 | Order.objects 570 | .filter(id__in=[]) 571 | .values("region_id") 572 | .annotate(total=Sum("amount")), 573 | name="totals", 574 | ) 575 | orders = ( 576 | totals.join(Order, region=totals.col.region_id) 577 | .with_cte(totals) 578 | .annotate(region_total=totals.col.total) 579 | .order_by("amount") 580 | ) 581 | 582 | self.assertEqual(len(orders), 0) 583 | 584 | def test_left_outer_join_on_empty_result_set_cte(self): 585 | totals = With( 586 | Order.objects 587 | .filter(id__in=[]) 588 | .values("region_id") 589 | .annotate(total=Sum("amount")), 590 | name="totals", 591 | ) 592 | orders = ( 593 | totals.join(Order, region=totals.col.region_id, _join_type=LOUTER) 594 | .with_cte(totals) 595 | .annotate(region_total=totals.col.total) 596 | .order_by("amount") 597 | ) 598 | 599 | self.assertEqual(len(orders), 22) 600 | 601 | def test_union_query_with_cte(self): 602 | orders = ( 603 | Order.objects 604 | .filter(region__parent="sun") 605 | .only("region", "amount") 606 | ) 607 | orders_cte = With(orders, name="orders_cte") 608 | orders_cte_queryset = orders_cte.queryset() 609 | 610 | earth_orders = orders_cte_queryset.filter(region="earth") 611 | mars_orders = orders_cte_queryset.filter(region="mars") 612 | 613 | earth_mars = earth_orders.union(mars_orders, all=True) 614 | earth_mars_cte = ( 615 | earth_mars 616 | .with_cte(orders_cte) 617 | .order_by("region", "amount") 618 | .values_list("region", "amount") 619 | ) 620 | print(earth_mars_cte.query) 621 | 622 | self.assertEqual(list(earth_mars_cte), [ 623 | ('earth', 30), 624 | ('earth', 31), 625 | ('earth', 32), 626 | ('earth', 33), 627 | ('mars', 40), 628 | ('mars', 41), 629 | ('mars', 42), 630 | ]) 631 | 632 | def test_cte_select_pk(self): 633 | orders = Order.objects.filter(region="earth").values("pk") 634 | cte = With(orders) 635 | queryset = cte.join(orders, pk=cte.col.pk).with_cte(cte).order_by("pk") 636 | print(queryset.query) 637 | self.assertEqual(list(queryset), [ 638 | {'pk': 9}, 639 | {'pk': 10}, 640 | {'pk': 11}, 641 | {'pk': 12}, 642 | ]) 643 | -------------------------------------------------------------------------------- /tests/test_cte.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import django 3 | from django.db.models import IntegerField, TextField 4 | from django.db.models.aggregates import Count, Max, Min, Sum 5 | from django.db.models.expressions import ( 6 | Exists, ExpressionWrapper, F, OuterRef, Subquery, 7 | ) 8 | from django.db.models.sql.constants import LOUTER 9 | from django.db.utils import OperationalError, ProgrammingError 10 | from django.test import TestCase 11 | 12 | from django_cte import CTE, with_cte 13 | 14 | from .models import Order, Region, User 15 | 16 | int_field = IntegerField() 17 | text_field = TextField() 18 | 19 | 20 | class TestCTE(TestCase): 21 | 22 | def test_simple_cte_query(self): 23 | cte = CTE( 24 | Order.objects 25 | .values("region_id") 26 | .annotate(total=Sum("amount")) 27 | ) 28 | 29 | orders = with_cte( 30 | # WITH cte ... 31 | cte, 32 | 33 | # SELECT ... FROM orders 34 | # INNER JOIN cte ON orders.region_id = cte.region_id 35 | select=cte.join(Order, region=cte.col.region_id), 36 | ).annotate(region_total=cte.col.total) 37 | print(orders.query) 38 | 39 | data = sorted((o.amount, o.region_id, o.region_total) for o in orders) 40 | self.assertEqual(data, [ 41 | (1, 'moon', 6), 42 | (2, 'moon', 6), 43 | (3, 'moon', 6), 44 | (10, 'mercury', 33), 45 | (10, 'proxima centauri b', 33), 46 | (11, 'mercury', 33), 47 | (11, 'proxima centauri b', 33), 48 | (12, 'mercury', 33), 49 | (12, 'proxima centauri b', 33), 50 | (20, 'venus', 86), 51 | (21, 'venus', 86), 52 | (22, 'venus', 86), 53 | (23, 'venus', 86), 54 | (30, 'earth', 126), 55 | (31, 'earth', 126), 56 | (32, 'earth', 126), 57 | (33, 'earth', 126), 58 | (40, 'mars', 123), 59 | (41, 'mars', 123), 60 | (42, 'mars', 123), 61 | (1000, 'sun', 1000), 62 | (2000, 'proxima centauri', 2000), 63 | ]) 64 | 65 | def test_cte_name_escape(self): 66 | totals = CTE( 67 | Order.objects 68 | .filter(region__parent="sun") 69 | .values("region_id") 70 | .annotate(total=Sum("amount")), 71 | name="mixedCaseCTEName" 72 | ) 73 | orders = with_cte( 74 | totals, 75 | select=totals.join(Order, region=totals.col.region_id) 76 | .annotate(region_total=totals.col.total) 77 | .order_by("amount") 78 | ) 79 | self.assertTrue( 80 | str(orders.query).startswith('WITH RECURSIVE "mixedCaseCTEName"')) 81 | 82 | def test_cte_queryset(self): 83 | sub_totals = CTE( 84 | Order.objects 85 | .values(region_parent=F("region__parent_id")) 86 | .annotate(total=Sum("amount")), 87 | ) 88 | regions = with_cte( 89 | sub_totals, 90 | select=Region.objects.annotate( 91 | child_regions_total=Subquery( 92 | sub_totals.queryset() 93 | .filter(region_parent=OuterRef("name")) 94 | .values("total"), 95 | ), 96 | ) 97 | .order_by("name") 98 | ) 99 | print(regions.query) 100 | 101 | data = [(r.name, r.child_regions_total) for r in regions] 102 | self.assertEqual(data, [ 103 | ("bernard's star", None), 104 | ('deimos', None), 105 | ('earth', 6), 106 | ('mars', None), 107 | ('mercury', None), 108 | ('moon', None), 109 | ('phobos', None), 110 | ('proxima centauri', 33), 111 | ('proxima centauri b', None), 112 | ('sun', 368), 113 | ('venus', None) 114 | ]) 115 | 116 | def test_cte_queryset_with_model_result(self): 117 | cte = CTE( 118 | Order.objects 119 | .annotate(region_parent=F("region__parent_id")), 120 | ) 121 | orders = with_cte( 122 | cte, # WITH cte AS (...) 123 | select=cte, # SELECT ... FROM cte 124 | ) 125 | print(orders.query) 126 | 127 | data = sorted( 128 | (x.region_id, x.amount, x.region_parent) for x in orders)[:5] 129 | self.assertEqual(data, [ 130 | ("earth", 30, "sun"), 131 | ("earth", 31, "sun"), 132 | ("earth", 32, "sun"), 133 | ("earth", 33, "sun"), 134 | ("mars", 40, "sun"), 135 | ]) 136 | self.assertTrue( 137 | all(isinstance(x, Order) for x in orders), 138 | repr([x for x in orders]), 139 | ) 140 | 141 | def test_cte_queryset_with_join(self): 142 | cte = CTE( 143 | Order.objects 144 | .annotate(region_parent=F("region__parent_id")), 145 | ) 146 | orders = with_cte( 147 | cte, 148 | select=cte.queryset() 149 | .annotate(parent=F("region__parent_id")) 150 | .order_by("region_id", "amount") 151 | ) 152 | print(orders.query) 153 | 154 | data = [(x.region_id, x.region_parent, x.parent) for x in orders][:5] 155 | self.assertEqual(data, [ 156 | ("earth", "sun", "sun"), 157 | ("earth", "sun", "sun"), 158 | ("earth", "sun", "sun"), 159 | ("earth", "sun", "sun"), 160 | ("mars", "sun", "sun"), 161 | ]) 162 | 163 | def test_cte_queryset_with_values_result(self): 164 | cte = CTE( 165 | Order.objects 166 | .values( 167 | "region_id", 168 | region_parent=F("region__parent_id"), 169 | ) 170 | .distinct() 171 | ) 172 | values = with_cte(cte, select=cte).filter(region_parent__isnull=False) 173 | print(values.query) 174 | 175 | def key(item): 176 | return item["region_parent"], item["region_id"] 177 | 178 | data = sorted(values, key=key)[:5] 179 | self.assertEqual(data, [ 180 | {'region_id': 'moon', 'region_parent': 'earth'}, 181 | { 182 | 'region_id': 'proxima centauri b', 183 | 'region_parent': 'proxima centauri', 184 | }, 185 | {'region_id': 'earth', 'region_parent': 'sun'}, 186 | {'region_id': 'mars', 'region_parent': 'sun'}, 187 | {'region_id': 'mercury', 'region_parent': 'sun'}, 188 | ]) 189 | 190 | def test_named_simple_ctes(self): 191 | totals = CTE( 192 | Order.objects 193 | .filter(region__parent="sun") 194 | .values("region_id") 195 | .annotate(total=Sum("amount")), 196 | name="totals", 197 | ) 198 | region_count = CTE( 199 | Region.objects 200 | .filter(parent="sun") 201 | .values("parent_id") 202 | .annotate(num=Count("name")), 203 | name="region_count", 204 | ) 205 | orders = with_cte( 206 | totals, 207 | region_count, 208 | select=region_count.join( 209 | totals.join(Order, region=totals.col.region_id), 210 | region__parent=region_count.col.parent_id 211 | ) 212 | .annotate(region_total=totals.col.total) 213 | .annotate(region_count=region_count.col.num) 214 | .order_by("amount") 215 | ) 216 | print(orders.query) 217 | 218 | data = [( 219 | o.amount, 220 | o.region_id, 221 | o.region_count, 222 | o.region_total, 223 | ) for o in orders] 224 | self.assertEqual(data, [ 225 | (10, 'mercury', 4, 33), 226 | (11, 'mercury', 4, 33), 227 | (12, 'mercury', 4, 33), 228 | (20, 'venus', 4, 86), 229 | (21, 'venus', 4, 86), 230 | (22, 'venus', 4, 86), 231 | (23, 'venus', 4, 86), 232 | (30, 'earth', 4, 126), 233 | (31, 'earth', 4, 126), 234 | (32, 'earth', 4, 126), 235 | (33, 'earth', 4, 126), 236 | (40, 'mars', 4, 123), 237 | (41, 'mars', 4, 123), 238 | (42, 'mars', 4, 123), 239 | ]) 240 | 241 | def test_named_ctes(self): 242 | def make_root_mapping(rootmap): 243 | return Region.objects.filter( 244 | parent__isnull=True 245 | ).values( 246 | "name", 247 | root=F("name"), 248 | ).union( 249 | rootmap.join(Region, parent=rootmap.col.name).values( 250 | "name", 251 | root=rootmap.col.root, 252 | ), 253 | all=True, 254 | ) 255 | rootmap = CTE.recursive(make_root_mapping, name="rootmap") 256 | 257 | totals = CTE( 258 | rootmap.join(Order, region_id=rootmap.col.name) 259 | .values( 260 | root=rootmap.col.root, 261 | ).annotate( 262 | orders_count=Count("id"), 263 | region_total=Sum("amount"), 264 | ), 265 | name="totals", 266 | ) 267 | 268 | root_regions = with_cte( 269 | rootmap, 270 | totals, 271 | select=totals.join(Region, name=totals.col.root).annotate( 272 | # count of orders in this region and all subregions 273 | orders_count=totals.col.orders_count, 274 | # sum of order amounts in this region and all subregions 275 | region_total=totals.col.region_total, 276 | ) 277 | ) 278 | print(root_regions.query) 279 | 280 | data = sorted( 281 | (r.name, r.orders_count, r.region_total) for r in root_regions 282 | ) 283 | self.assertEqual(data, [ 284 | ('proxima centauri', 4, 2033), 285 | ('sun', 18, 1374), 286 | ]) 287 | 288 | def test_materialized_option(self): 289 | totals = CTE( 290 | Order.objects 291 | .filter(region__parent="sun") 292 | .values("region_id") 293 | .annotate(total=Sum("amount")), 294 | materialized=True 295 | ) 296 | orders = with_cte( 297 | totals, 298 | select=totals.join(Order, region=totals.col.region_id) 299 | .annotate(region_total=totals.col.total) 300 | .order_by("amount") 301 | ) 302 | self.assertTrue( 303 | str(orders.query).startswith( 304 | 'WITH RECURSIVE "cte" AS MATERIALIZED' 305 | ) 306 | ) 307 | 308 | def test_update_cte_query(self): 309 | cte = CTE( 310 | Order.objects 311 | .values(region_parent=F("region__parent_id")) 312 | .annotate(total=Sum("amount")) 313 | .filter(total__isnull=False) 314 | ) 315 | # not the most efficient query, but it exercises CTEUpdateQuery 316 | with_cte(cte, select=Order).filter(region_id__in=Subquery( 317 | cte.queryset() 318 | .filter(region_parent=OuterRef("region_id")) 319 | .values("region_parent") 320 | )).update(amount=Subquery( 321 | cte.queryset() 322 | .filter(region_parent=OuterRef("region_id")) 323 | .values("total") 324 | )) 325 | 326 | data = set((o.region_id, o.amount) for o in Order.objects.filter( 327 | region_id__in=["earth", "sun", "proxima centauri", "mars"] 328 | )) 329 | self.assertEqual(data, { 330 | ('earth', 6), 331 | ('mars', 40), 332 | ('mars', 41), 333 | ('mars', 42), 334 | ('proxima centauri', 33), 335 | ('sun', 368), 336 | }) 337 | 338 | def test_update_with_subquery(self): 339 | # Test for issue: https://github.com/dimagi/django-cte/issues/9 340 | # Issue is not reproduced on sqlite3, use postgres to run. 341 | # To reproduce the problem it's required to have some join 342 | # in the select-query so the compiler will turn it into a subquery. 343 | # To add a join use a filter over field of related model 344 | orders = Order.objects.filter(region__parent_id='sun') 345 | orders.update(amount=0) 346 | data = {(order.region_id, order.amount) for order in orders} 347 | self.assertEqual(data, { 348 | ('mercury', 0), 349 | ('venus', 0), 350 | ('earth', 0), 351 | ('mars', 0), 352 | }) 353 | 354 | @pytest.mark.xfail( 355 | reason="this test will not work until `QuerySet.delete` " 356 | "(Django method) calls `self.query.chain(sql.DeleteQuery)` " 357 | "instead of `sql.DeleteQuery(self.model)`", 358 | raises=(OperationalError, ProgrammingError), 359 | strict=True, 360 | ) 361 | def test_delete_cte_query(self): 362 | cte = CTE( 363 | Order.objects 364 | .values(region_parent=F("region__parent_id")) 365 | .annotate(total=Sum("amount")) 366 | .filter(total__isnull=False) 367 | ) 368 | with_cte(cte, select=Order).annotate( 369 | cte_has_order=Exists( 370 | cte.queryset() 371 | .values("total") 372 | .filter(region_parent=OuterRef("region_id")) 373 | ) 374 | ).filter(cte_has_order=False).delete() 375 | 376 | data = [(o.region_id, o.amount) for o in Order.objects.all()] 377 | self.assertEqual(data, [ 378 | ('sun', 1000), 379 | ('earth', 30), 380 | ('earth', 31), 381 | ('earth', 32), 382 | ('earth', 33), 383 | ('proxima centauri', 2000), 384 | ]) 385 | 386 | def test_outerref_in_cte_query(self): 387 | # This query is meant to return the difference between min and max 388 | # order of each region, through a subquery 389 | min_and_max = CTE( 390 | Order.objects 391 | .filter(region=OuterRef("pk")) 392 | .values('region') # This is to force group by region_id 393 | .annotate( 394 | amount_min=Min("amount"), 395 | amount_max=Max("amount"), 396 | ) 397 | .values('amount_min', 'amount_max') 398 | ) 399 | regions = ( 400 | Region.objects 401 | .annotate( 402 | difference=Subquery( 403 | with_cte(min_and_max, select=min_and_max) 404 | .annotate( 405 | difference=ExpressionWrapper( 406 | F('amount_max') - F('amount_min'), 407 | output_field=int_field, 408 | ), 409 | ).values('difference')[:1], 410 | output_field=IntegerField() 411 | ) 412 | ) 413 | .order_by("name") 414 | ) 415 | print(regions.query) 416 | 417 | data = [(r.name, r.difference) for r in regions] 418 | self.assertEqual(data, [ 419 | ("bernard's star", None), 420 | ('deimos', None), 421 | ('earth', 3), 422 | ('mars', 2), 423 | ('mercury', 2), 424 | ('moon', 2), 425 | ('phobos', None), 426 | ('proxima centauri', 0), 427 | ('proxima centauri b', 2), 428 | ('sun', 0), 429 | ('venus', 3) 430 | ]) 431 | 432 | def test_experimental_left_outer_join(self): 433 | totals = CTE( 434 | Order.objects 435 | .values("region_id") 436 | .annotate(total=Sum("amount")) 437 | .filter(total__gt=100) 438 | ) 439 | orders = with_cte( 440 | totals, 441 | select=totals 442 | .join(Order, region=totals.col.region_id, _join_type=LOUTER) 443 | .annotate(region_total=totals.col.total) 444 | ) 445 | print(orders.query) 446 | self.assertIn("LEFT OUTER JOIN", str(orders.query)) 447 | self.assertNotIn("INNER JOIN", str(orders.query)) 448 | 449 | data = sorted((o.region_id, o.amount, o.region_total) for o in orders) 450 | self.assertEqual(data, [ 451 | ('earth', 30, 126), 452 | ('earth', 31, 126), 453 | ('earth', 32, 126), 454 | ('earth', 33, 126), 455 | ('mars', 40, 123), 456 | ('mars', 41, 123), 457 | ('mars', 42, 123), 458 | ('mercury', 10, None), 459 | ('mercury', 11, None), 460 | ('mercury', 12, None), 461 | ('moon', 1, None), 462 | ('moon', 2, None), 463 | ('moon', 3, None), 464 | ('proxima centauri', 2000, 2000), 465 | ('proxima centauri b', 10, None), 466 | ('proxima centauri b', 11, None), 467 | ('proxima centauri b', 12, None), 468 | ('sun', 1000, 1000), 469 | ('venus', 20, None), 470 | ('venus', 21, None), 471 | ('venus', 22, None), 472 | ('venus', 23, None), 473 | ]) 474 | 475 | def test_non_cte_subquery(self): 476 | """ 477 | Verifies that subquery annotations are handled correctly when the 478 | subquery model doesn't use the CTE manager, and the query results 479 | match expected behavior 480 | """ 481 | sub_totals = CTE( 482 | Order.objects 483 | .values(region_parent=F("region__parent_id")) 484 | .annotate( 485 | total=Sum("amount"), 486 | # trivial subquery example testing existence of 487 | # a user for the order 488 | non_cte_subquery=Exists( 489 | User.objects.filter(pk=OuterRef("user_id")) 490 | ), 491 | ), 492 | ) 493 | regions = with_cte( 494 | sub_totals, 495 | select=Region.objects.annotate( 496 | child_regions_total=Subquery( 497 | sub_totals.queryset() 498 | .filter(region_parent=OuterRef("name")) 499 | .values("total"), 500 | ), 501 | ) 502 | .order_by("name") 503 | ) 504 | print(regions.query) 505 | 506 | data = [(r.name, r.child_regions_total) for r in regions] 507 | self.assertEqual(data, [ 508 | ("bernard's star", None), 509 | ('deimos', None), 510 | ('earth', 6), 511 | ('mars', None), 512 | ('mercury', None), 513 | ('moon', None), 514 | ('phobos', None), 515 | ('proxima centauri', 33), 516 | ('proxima centauri b', None), 517 | ('sun', 368), 518 | ('venus', None) 519 | ]) 520 | 521 | def test_explain(self): 522 | """ 523 | Verifies that using .explain() prepends the EXPLAIN clause in the 524 | correct position 525 | """ 526 | 527 | totals = CTE( 528 | Order.objects 529 | .filter(region__parent="sun") 530 | .values("region_id") 531 | .annotate(total=Sum("amount")), 532 | name="totals", 533 | ) 534 | region_count = CTE( 535 | Region.objects 536 | .filter(parent="sun") 537 | .values("parent_id") 538 | .annotate(num=Count("name")), 539 | name="region_count", 540 | ) 541 | orders = with_cte( 542 | totals, 543 | region_count, 544 | select=region_count.join( 545 | totals.join(Order, region=totals.col.region_id), 546 | region__parent=region_count.col.parent_id 547 | ) 548 | .annotate(region_total=totals.col.total) 549 | .annotate(region_count=region_count.col.num) 550 | .order_by("amount") 551 | ) 552 | print(orders.query) 553 | 554 | self.assertIsInstance(orders.explain(), str) 555 | 556 | def test_empty_result_set_cte(self): 557 | """ 558 | Verifies that the CTEQueryCompiler can handle empty result sets in the 559 | related CTEs 560 | """ 561 | totals = CTE( 562 | Order.objects 563 | .filter(id__in=[]) 564 | .values("region_id") 565 | .annotate(total=Sum("amount")), 566 | name="totals", 567 | ) 568 | orders = with_cte( 569 | totals, 570 | select=totals.join(Order, region=totals.col.region_id) 571 | .annotate(region_total=totals.col.total) 572 | .order_by("amount") 573 | ) 574 | 575 | self.assertEqual(len(orders), 0) 576 | 577 | def test_left_outer_join_on_empty_result_set_cte(self): 578 | totals = CTE( 579 | Order.objects 580 | .filter(id__in=[]) 581 | .values("region_id") 582 | .annotate(total=Sum("amount")), 583 | name="totals", 584 | ) 585 | orders = with_cte( 586 | totals, 587 | select=totals 588 | .join(Order, region=totals.col.region_id, _join_type=LOUTER) 589 | .annotate(region_total=totals.col.total) 590 | .order_by("amount") 591 | ) 592 | 593 | self.assertEqual(len(orders), 22) 594 | 595 | def test_union_query_with_cte(self): 596 | orders = ( 597 | Order.objects 598 | .filter(region__parent="sun") 599 | .only("region", "amount") 600 | ) 601 | orders_cte = CTE(orders, name="orders_cte") 602 | orders_cte_queryset = orders_cte.queryset() 603 | 604 | earth_orders = orders_cte_queryset.filter(region="earth") 605 | mars_orders = orders_cte_queryset.filter(region="mars") 606 | 607 | earth_mars = earth_orders.union(mars_orders, all=True) 608 | earth_mars_cte = with_cte( 609 | orders_cte, 610 | select=earth_mars 611 | .order_by("region", "amount") 612 | .values_list("region", "amount") 613 | ) 614 | print(earth_mars_cte.query) 615 | 616 | self.assertEqual(list(earth_mars_cte), [ 617 | ('earth', 30), 618 | ('earth', 31), 619 | ('earth', 32), 620 | ('earth', 33), 621 | ('mars', 40), 622 | ('mars', 41), 623 | ('mars', 42), 624 | ]) 625 | 626 | def test_cte_select_pk(self): 627 | orders = Order.objects.filter(region="earth").values("pk") 628 | cte = CTE(orders) 629 | queryset = with_cte( 630 | cte, select=cte.join(orders, pk=cte.col.pk) 631 | ).order_by("pk") 632 | print(queryset.query) 633 | self.assertEqual(list(queryset), [ 634 | {'pk': 9}, 635 | {'pk': 10}, 636 | {'pk': 11}, 637 | {'pk': 12}, 638 | ]) 639 | 640 | def test_django52_resolve_ref_regression(self): 641 | cte = CTE( 642 | Order.objects.annotate( 643 | pnt_id=F("region__parent_id"), 644 | region_name=F("region__name"), 645 | ).values( 646 | # important: more than one query.select field 647 | "region_id", 648 | "amount", 649 | # important: more than one query.annotations field 650 | "pnt_id", 651 | "region_name", 652 | ) 653 | ) 654 | qs = with_cte( 655 | cte, 656 | select=cte.queryset() 657 | .values( 658 | amt=cte.col.amount, 659 | pnt_id=cte.col.pnt_id, 660 | region_name=cte.col.region_name, 661 | ) 662 | .filter(region_id="earth") 663 | .order_by("amount") 664 | ) 665 | print(qs.query) 666 | self.assertEqual(list(qs), [ 667 | {'amt': 30, 'region_name': 'earth', 'pnt_id': 'sun'}, 668 | {'amt': 31, 'region_name': 'earth', 'pnt_id': 'sun'}, 669 | {'amt': 32, 'region_name': 'earth', 'pnt_id': 'sun'}, 670 | {'amt': 33, 'region_name': 'earth', 'pnt_id': 'sun'}, 671 | ]) 672 | 673 | def test_django52_queryset_regression(self): 674 | cte = CTE(Order.objects.values("id", "region_id")) 675 | q = cte.queryset() 676 | q.values("id", "region_id") # Raises an exception before the fix 677 | 678 | def test_django52_ambiguous_column_names(self): 679 | cte = CTE(Order.objects.values("region", "amount", "user_id")) 680 | cte2 = CTE(User.objects.annotate(user_id=F("id")), name="cte2") 681 | qs = with_cte( 682 | cte, 683 | cte2, 684 | select=cte2.join(cte.queryset(), user_id=cte2.col.user_id) 685 | .annotate(user_name=cte2.col.name) 686 | .order_by("region", "amount") 687 | .values_list("region", "amount", "user_name"), 688 | ) 689 | # Executing this query should not raise a 690 | # django.db.utils.OperationalError: ambiguous column name: user_id 691 | self.assertEqual(list(qs), [ 692 | ('earth', 30, "admin"), 693 | ('earth', 31, "admin"), 694 | ('earth', 32, "admin"), 695 | ('earth', 33, "admin"), 696 | ('mars', 40, "admin"), 697 | ('mars', 41, "admin"), 698 | ('mars', 42, "admin"), 699 | ('mercury', 10, "admin"), 700 | ('mercury', 11, "admin"), 701 | ('mercury', 12, "admin"), 702 | ('moon', 1, "admin"), 703 | ('moon', 2, "admin"), 704 | ('moon', 3, "admin"), 705 | ('proxima centauri', 2000, "admin"), 706 | ('proxima centauri b', 10, "admin"), 707 | ('proxima centauri b', 11, "admin"), 708 | ('proxima centauri b', 12, "admin"), 709 | ('sun', 1000, "admin"), 710 | ('venus', 20, "admin"), 711 | ('venus', 21, "admin"), 712 | ('venus', 22, "admin"), 713 | ('venus', 23, "admin"), 714 | ]) 715 | 716 | def test_django52_queryset_aggregates_klass_error(self): 717 | cte = CTE( 718 | Order.objects.annotate(user_name=F("user__name")) 719 | .values("user_name") 720 | .annotate(c=Count("user_name")) 721 | .values("user_name", "c") 722 | ) 723 | qs = with_cte(cte, select=cte) 724 | # Executing the query should not raise TypeError: 'NoneType' object is not subscriptable 725 | self.assertEqual(list(qs), [{"user_name": "admin", "c": 22}]) 726 | 727 | def test_django52_annotate_model_field_name_after_queryset(self): 728 | # Select the `id` field in one CTE 729 | cte = CTE(Order.objects.values("id", "region", "user_id")) 730 | # In the next query, when querying from the CTE we reassign the `id` field 731 | # Previously, this would have thrown an exception 732 | qs = ( 733 | with_cte(cte, select=cte) 734 | .annotate(id=F('user_id')) 735 | .values_list('id', 'region') 736 | .order_by('id', 'region') 737 | .distinct() 738 | ) 739 | self.assertEqual(list(qs), [ 740 | (1, 'earth'), 741 | (1, 'mars'), 742 | (1, 'mercury'), 743 | (1, 'moon'), 744 | (1, 'proxima centauri'), 745 | (1, 'proxima centauri b'), 746 | (1, 'sun'), 747 | (1, 'venus'), 748 | ]) 749 | 750 | @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") 751 | def test_queryset_after_values_list(self): 752 | cte = CTE(Order.objects.values_list("region", "amount").order_by("region", "amount")) 753 | qs = with_cte(cte, select=cte) 754 | self.assertEqual(list(qs), [ 755 | ('earth', 30), 756 | ('earth', 31), 757 | ('earth', 32), 758 | ('earth', 33), 759 | ('mars', 40), 760 | ('mars', 41), 761 | ('mars', 42), 762 | ('mercury', 10), 763 | ('mercury', 11), 764 | ('mercury', 12), 765 | ('moon', 1), 766 | ('moon', 2), 767 | ('moon', 3), 768 | ('proxima centauri', 2000), 769 | ('proxima centauri b', 10), 770 | ('proxima centauri b', 11), 771 | ('proxima centauri b', 12), 772 | ('sun', 1000), 773 | ('venus', 20), 774 | ('venus', 21), 775 | ('venus', 22), 776 | ('venus', 23), 777 | ]) 778 | 779 | @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") 780 | def test_queryset_after_values_list_flat(self): 781 | cte = CTE( 782 | Order.objects.values_list("region", flat=True) 783 | .order_by("region") 784 | .distinct() 785 | ) 786 | qs = with_cte(cte, select=cte) 787 | self.assertEqual(list(qs), [ 788 | 'earth', 789 | 'mars', 790 | 'mercury', 791 | 'moon', 792 | 'proxima centauri', 793 | 'proxima centauri b', 794 | 'sun', 795 | 'venus' 796 | ]) 797 | 798 | @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") 799 | def test_queryset_values_list_order1(self): 800 | cte = CTE( 801 | Order.objects.values("region") 802 | .annotate(c=Count("region")) 803 | .values_list("c", "region") 804 | .order_by("region") 805 | ) 806 | qs = with_cte(cte, select=cte) 807 | # Ensure the column order of queried fields is the specified one: c, region 808 | # Before the fix, the order would have been this one: region, c 809 | self.assertEqual(list(qs), [ 810 | (4, 'earth'), 811 | (3, 'mars'), 812 | (3, 'mercury'), 813 | (3, 'moon'), 814 | (1, 'proxima centauri'), 815 | (3, 'proxima centauri b'), 816 | (1, 'sun'), 817 | (4, 'venus'), 818 | ]) 819 | 820 | @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") 821 | def test_queryset_values_list_order2(self): 822 | cte = CTE( 823 | Order.objects.values("region") 824 | .annotate(r=F("region"), c=Count("region")) 825 | .values_list("c", "r") 826 | .order_by("r") 827 | ) 828 | qs = with_cte(cte, select=cte) 829 | # Ensure the column order of queried fields is the specified one: c, r 830 | # Before the fix, the order would have been this one: r, c 831 | self.assertEqual(list(qs), [ 832 | (4, 'earth'), 833 | (3, 'mars'), 834 | (3, 'mercury'), 835 | (3, 'moon'), 836 | (1, 'proxima centauri'), 837 | (3, 'proxima centauri b'), 838 | (1, 'sun'), 839 | (4, 'venus'), 840 | ]) 841 | 842 | @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") 843 | def test_left_outer_join_invalid_innerjoin(self): 844 | totals = CTE( 845 | Order.objects 846 | .values("region_id") 847 | .annotate(total=Sum("amount")) 848 | ) 849 | # Query all regions but only show the total order amount for regions that are 850 | # grandchildren of "sun". 851 | # This requires the __parent to be joined using LEFT OUTER JOIN too. 852 | # Otherwise, if django did an INNER JOIN for __parent then only regions that 853 | # have a parent would be included in the result set. 854 | qs = with_cte( 855 | totals, 856 | select=totals.join( 857 | Region, 858 | parent__parent="sun", 859 | name=totals.col.region_id, 860 | _join_type=LOUTER 861 | ).values("name", total=totals.col.total) 862 | .order_by("name") 863 | ) 864 | self.assertEqual(list(qs), [ 865 | {'name': "bernard's star", 'total': None}, 866 | {'name': 'deimos', 'total': None}, 867 | {'name': 'earth', 'total': None}, 868 | {'name': 'mars', 'total': None}, 869 | {'name': 'mercury', 'total': None}, 870 | {'name': 'moon', 'total': 6}, 871 | {'name': 'phobos', 'total': None}, 872 | {'name': 'proxima centauri', 'total': None}, 873 | {'name': 'proxima centauri b', 'total': None}, 874 | {'name': 'sun', 'total': None}, 875 | {'name': 'venus', 'total': None} 876 | ]) 877 | --------------------------------------------------------------------------------