├── .github └── workflows │ ├── benchmark.yaml │ ├── release-please.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .tool-versions ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── hooks ├── pre-commit └── pre-push ├── pyproject.toml ├── requirements-dev.in ├── requirements-dev.txt ├── src └── zeal │ ├── __init__.py │ ├── apps.py │ ├── constants.py │ ├── errors.py │ ├── listeners.py │ ├── middleware.py │ ├── patch.py │ ├── signals.py │ └── util.py └── tests ├── __init__.py ├── conftest.py ├── djangoproject ├── __init__.py ├── manage.py ├── settings.py ├── social │ ├── __init__.py │ ├── apps.py │ ├── models.py │ └── views.py └── urls.py ├── factories.py ├── test_listeners.py ├── test_nplusones.py ├── test_patch.py ├── test_performance.py └── test_signals.py /.github/workflows/benchmark.yaml: -------------------------------------------------------------------------------- 1 | name: Benchmark 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | # `workflow_dispatch` allows CodSpeed to trigger backtest 10 | # performance analysis in order to generate initial data. 11 | workflow_dispatch: 12 | 13 | jobs: 14 | benchmark: 15 | runs-on: ubuntu-latest 16 | name: Benchmark 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: actions/setup-python@v5 20 | with: 21 | python-version: 3.12 22 | cache: "pip" 23 | - run: make ci 24 | - name: Run benchmarks 25 | uses: CodSpeedHQ/action@v3 26 | with: 27 | token: ${{ secrets.CODSPEED_TOKEN }} 28 | run: pytest tests/ --codspeed 29 | -------------------------------------------------------------------------------- /.github/workflows/release-please.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | 6 | permissions: 7 | contents: write 8 | pull-requests: write 9 | 10 | name: release-please 11 | 12 | jobs: 13 | release-please: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: googleapis/release-please-action@v4 17 | with: 18 | token: ${{ secrets.RELEASE_PLEASE_TOKEN }} 19 | release-type: python 20 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | publish: 10 | name: Build & publish 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/zealot 15 | permissions: 16 | id-token: write 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: actions/setup-python@v5 20 | with: 21 | python-version: 3.12 22 | - name: Install pypa/build 23 | run: >- 24 | python3 -m 25 | pip install 26 | build 27 | --user 28 | - name: Build 29 | run: python3 -m build 30 | - name: Publish to PyPI 31 | uses: pypa/gh-action-pypi-publish@release/v1 32 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.9", "3.10", "3.11", "3.12"] 14 | django-version: ["4.2", "5.0"] 15 | exclude: 16 | # django 5 requires python >=3.10 17 | - python-version: 3.9 18 | django-version: 5.0 19 | - python-version: 3.9 20 | django-version: 5.1 21 | name: Test (Python ${{ matrix.python-version }}, Django ${{ matrix.django-version }}) 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | cache: 'pip' 30 | - run: make ci 31 | - run: pip install Django~=${{ matrix.django-version }} 32 | - run: make test 33 | 34 | test-django-prerelease: 35 | runs-on: ubuntu-latest 36 | strategy: 37 | matrix: 38 | python-version: ["3.10", "3.11", "3.12"] 39 | name: Test Django prerelease (Python ${{ matrix.python-version }}) 40 | 41 | steps: 42 | - uses: actions/checkout@v4 43 | - name: Set up Python ${{ matrix.python-version }} 44 | uses: actions/setup-python@v5 45 | with: 46 | python-version: ${{ matrix.python-version }} 47 | cache: 'pip' 48 | - run: make ci 49 | - run: pip install --pre django 50 | - run: make test 51 | 52 | typecheck: 53 | runs-on: ubuntu-latest 54 | steps: 55 | - uses: actions/checkout@v4 56 | - uses: actions/setup-python@v5 57 | with: 58 | python-version: 3.12 59 | cache: 'pip' 60 | - run: make ci 61 | - run: make typecheck 62 | 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | .ruff_cache/ 165 | .codspeed/ -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- 1 | python 3.9.19 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [2.0.4](https://github.com/taobojlen/django-zeal/compare/v2.0.3...v2.0.4) (2025-01-26) 4 | 5 | 6 | ### Bug Fixes 7 | 8 | * handle empty apps list ([6aa6113](https://github.com/taobojlen/django-zeal/commit/6aa611367f8fffd59508c56384e7613611ddf39a)) 9 | 10 | 11 | ### Performance Improvements 12 | 13 | * optimize stack fetching ([#50](https://github.com/taobojlen/django-zeal/issues/50)) ([ea88ffd](https://github.com/taobojlen/django-zeal/commit/ea88ffd295cf914e1ca7ceb011578e30ba2c8044)) 14 | 15 | ## [2.0.3](https://github.com/taobojlen/django-zeal/compare/v2.0.2...v2.0.3) (2025-01-23) 16 | 17 | 18 | ### Bug Fixes 19 | 20 | * allow zeal_ignore even when zeal is disabled ([#46](https://github.com/taobojlen/django-zeal/issues/46)) ([a412208](https://github.com/taobojlen/django-zeal/commit/a41220879b7f8f985a5b2088b106f66c8e587418)) 21 | 22 | ## [2.0.2](https://github.com/taobojlen/django-zeal/compare/v2.0.1...v2.0.2) (2024-11-22) 23 | 24 | 25 | ### Bug Fixes 26 | 27 | * **#28:** prevent infinite recursion when custom __eq__ is used. ([#43](https://github.com/taobojlen/django-zeal/issues/43)) ([d157059](https://github.com/taobojlen/django-zeal/commit/d1570593bde02cd5f020fcbfb21350df03e43026)) (thanks @bradleyess!) 28 | 29 | ## [2.0.1](https://github.com/taobojlen/django-zeal/compare/v2.0.0...v2.0.1) (2024-11-13) 30 | 31 | 32 | ### Bug Fixes 33 | 34 | * use correct field name in forward many-to-many fields ([#39](https://github.com/taobojlen/django-zeal/issues/39)) ([3ada66f](https://github.com/taobojlen/django-zeal/commit/3ada66fa8f9b9c79acd9f2c45b35cb4967770e04)) 35 | 36 | ## [2.0.0](https://github.com/taobojlen/django-zeal/compare/v1.4.1...v2.0.0) (2024-11-12) 37 | 38 | 39 | ### Features 40 | 41 | * add app name to error messages ([#34](https://github.com/taobojlen/django-zeal/issues/34)) ([ad6fe5f](https://github.com/taobojlen/django-zeal/commit/ad6fe5f6599de26ee9adf30cd372cd3fcb7cded0)) 42 | * Add Django signal for zeal errors ([#31](https://github.com/taobojlen/django-zeal/issues/31)) ([1190ca7](https://github.com/taobojlen/django-zeal/commit/1190ca73b5e714c3ded3b979e0fb09935928abca)) (thanks @MaxTet1703!) 43 | * validate the allowlist ([#33](https://github.com/taobojlen/django-zeal/issues/33)) ([2aa1b42](https://github.com/taobojlen/django-zeal/commit/2aa1b42c3a441d6b66f52dd2ba70abc6ffe5efee)) 44 | 45 | 46 | ### Bug Fixes 47 | 48 | * handle related field names in allowlist validation ([#36](https://github.com/taobojlen/django-zeal/issues/36)) ([654eed6](https://github.com/taobojlen/django-zeal/commit/654eed692b9b6e0a17d5c5edc4ec74a2ae0783c9)) 49 | * use custom error class for validation errors ([#37](https://github.com/taobojlen/django-zeal/issues/37)) ([035ae35](https://github.com/taobojlen/django-zeal/commit/035ae3574e3bf29e1c24d896d5c3cd4100d1002b)) 50 | 51 | 52 | ### Miscellaneous Chores 53 | 54 | * make breaking change ([5eed8ec](https://github.com/taobojlen/django-zeal/commit/5eed8ec26e89f657e659d37acbf51c4ef8c4bed4)) 55 | 56 | ## [1.4.1](https://github.com/taobojlen/django-zeal/compare/v1.4.0...v1.4.1) (2024-09-22) 57 | 58 | 59 | ### Performance Improvements 60 | 61 | * don't load context in callstack ([#26](https://github.com/taobojlen/django-zeal/issues/26)) ([5ade002](https://github.com/taobojlen/django-zeal/commit/5ade0023be95173506167e5cd50388a8dbb5b5e9)) 62 | 63 | ## [1.4.0](https://github.com/taobojlen/django-zeal/compare/v1.3.0...v1.4.0) (2024-09-03) 64 | 65 | **NOTE**: In versions 1.1.0 - 1.3.0, there was a bug that caused `zeal` to be active 66 | in all code, even outside of a `zeal_context` block. That is fixed in 1.4.0. When updating, 67 | make sure that you have installed zeal correctly as per the README. 68 | 69 | ### Features 70 | 71 | * add async support to middleware ([#23](https://github.com/taobojlen/django-zeal/issues/23)) ([815bc16](https://github.com/taobojlen/django-zeal/commit/815bc1651e98a4519a42dfa088dcac4320350a1c)) 72 | 73 | 74 | ### Bug Fixes 75 | 76 | * only run zeal inside context ([#21](https://github.com/taobojlen/django-zeal/issues/21)) ([6c88fd2](https://github.com/taobojlen/django-zeal/commit/6c88fd247388cf58a3c2291917623b7e8094442b)) 77 | 78 | ## [1.3.0](https://github.com/taobojlen/django-zeal/compare/v1.2.0...v1.3.0) (2024-07-25) 79 | 80 | 81 | ### Features 82 | 83 | * add ZEAL_SHOW_ALL_CALLERS to aid in debugging ([#17](https://github.com/taobojlen/django-zeal/issues/17)) ([7fdaf36](https://github.com/taobojlen/django-zeal/commit/7fdaf36db50fed6dee0b0544205e71035c977541)) 84 | 85 | ## [1.2.0](https://github.com/taobojlen/django-zeal/compare/v1.1.0...v1.2.0) (2024-07-22) 86 | 87 | 88 | ### Features 89 | 90 | * use warnings instead of logging ([#15](https://github.com/taobojlen/django-zeal/issues/15)) ([df2c841](https://github.com/taobojlen/django-zeal/commit/df2c841b21fae664c14356d00a7a2f6ecbb7fd61)) 91 | 92 | ## [1.1.0](https://github.com/taobojlen/django-zeal/compare/v1.0.0...v1.1.0) (2024-07-20) 93 | 94 | 95 | ### Features 96 | 97 | * allow ignoring specific models/fields in zeal_ignore ([#13](https://github.com/taobojlen/django-zeal/issues/13)) ([e51413b](https://github.com/taobojlen/django-zeal/commit/e51413ba5fe4d9a3c34409863e9888d873ff84fa)) 98 | 99 | ## [1.0.0](https://github.com/taobojlen/zealot/compare/v0.2.3...v1.0.0) (2024-07-20) 100 | 101 | 102 | ### ⚠ BREAKING CHANGES 103 | 104 | This project has been renamed to `zeal`. To migrate, replace `zealot` with `zeal` in your 105 | project's requirements. In your Django settings, replace `ZEALOT_ALLOWLIST`, `ZEALOT_RAISE`, etc. 106 | with `ZEAL_ALLOWLIST`, `ZEAL_RAISE`, and so on. 107 | In your code, replace `from zealot import ...` with `from zeal import ...`. 108 | 109 | 110 | ### Miscellaneous Chores 111 | 112 | * rename to zeal ([cc429a2](https://github.com/taobojlen/zealot/commit/cc429a26bfede770db69429e8a11fc9e98fbb2a9)) 113 | 114 | ## [0.2.3](https://github.com/taobojlen/zeal/compare/v0.2.2...v0.2.3) (2024-07-18) 115 | 116 | 117 | ### Bug Fixes 118 | 119 | * ensure context is reset after leaving ([#8](https://github.com/taobojlen/zeal/issues/8)) ([f45cabb](https://github.com/taobojlen/zeal/commit/f45cabb2abcabce34cd5aed163f7f95c71256e2c)) 120 | 121 | ## [0.2.2](https://github.com/taobojlen/zeal/compare/v0.2.1...v0.2.2) (2024-07-15) 122 | 123 | 124 | ### Bug Fixes 125 | 126 | * don't alert from calls on different lines ([7f7bda7](https://github.com/taobojlen/zeal/commit/7f7bda709e5fff2e953ddac0277d684255732e7c)) 127 | 128 | ## [0.2.1](https://github.com/taobojlen/zeal/compare/v0.2.0...v0.2.1) (2024-07-08) 129 | 130 | 131 | ### Bug Fixes 132 | 133 | * zeal_ignore always takes precedence ([e61d060](https://github.com/taobojlen/zeal/commit/e61d060c74ed32193c2c86f1b7f20929a37402a1)) 134 | 135 | ## [0.2.0](https://github.com/taobojlen/zeal/compare/v0.1.2...v0.2.0) (2024-07-06) 136 | 137 | 138 | ### Features 139 | 140 | * add support for python 3.9 ([#2](https://github.com/taobojlen/zeal/issues/2)) ([44e5f41](https://github.com/taobojlen/zeal/commit/44e5f41fc247e98683a1dd283ae70322a32445d6)) 141 | 142 | ## 0.1.2 - 2024-07-06 143 | 144 | ### Fixed 145 | 146 | - Handle empty querysets 147 | - Handle incorrectly-used `.prefetch_related()` when `.select_related()` should have been used 148 | - Don't raise an exception when using `.values(...).get()` 149 | 150 | ## 0.1.1 - 2024-07-05 151 | 152 | ### Fixed 153 | 154 | - Ignore N+1s from singly-loaded records 155 | 156 | ## 0.1.0 - 2024-05-03 157 | 158 | Initial release. 159 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2016 Joshua Carp 2 | Copyright 2024 Tao Bojlén 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install-hooks: 2 | git config core.hooksPath hooks/ 3 | 4 | install: 5 | $(MAKE) install-hooks 6 | uv pip compile requirements-dev.in -o requirements-dev.txt && uv pip sync requirements-dev.txt 7 | 8 | ci: 9 | pip install -r requirements-dev.txt 10 | 11 | test: 12 | pytest -s --tb=native --random-order -m "not benchmark" $(ARGS) 13 | 14 | benchmark: 15 | pytest -s $(ARGS) --codspeed 16 | 17 | format-check: 18 | ruff format --check && ruff check 19 | 20 | format: 21 | ruff format && ruff check --fix 22 | 23 | typecheck: 24 | pyright . 25 | 26 | build: 27 | python -m build --installer uv 28 | 29 | publish: 30 | twine upload dist/* 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # django-zeal 2 | 3 | Catch N+1 queries in your Django project. 4 | 5 | [![Static Badge](https://img.shields.io/badge/license-MIT-brightgreen)](https://github.com/taobojlen/django-zeal/blob/main/LICENSE) 6 | [![PyPI - Version](https://img.shields.io/pypi/v/django-zeal?color=lightgrey)](https://pypi.org/project/django-zeal/) 7 | 8 | 🔥 Battle-tested at [Cinder](https://www.cinder.co/) 9 | 10 | ## Features 11 | 12 | - Detects N+1s from missing prefetches and from use of `.defer()`/`.only()` 13 | - Friendly error messages like `N+1 detected on social.User.followers at myapp/views.py:25 in get_user` 14 | - Configurable thresholds 15 | - Allow-list 16 | - Well-tested 17 | - No dependencies 18 | 19 | ## Acknowledgements 20 | 21 | This library draws heavily from jmcarp's [nplusone](https://github.com/jmcarp/nplusone/). 22 | It's not a fork, but a lot of the central concepts and initial code came from nplusone. 23 | 24 | ## Installation 25 | 26 | First: 27 | 28 | ``` 29 | pip install django-zeal 30 | ``` 31 | 32 | Then, add zeal to your `INSTALLED_APPS` and `MIDDLEWARE`. 33 | 34 | ```python 35 | if DEBUG: 36 | INSTALLED_APPS.append("zeal") 37 | MIDDLEWARE.append("zeal.middleware.zeal_middleware") 38 | ``` 39 | 40 | This will detect N+1s that happen in web requests. To catch N+1s in more places, 41 | read on! 42 | 43 | > [!WARNING] 44 | > You probably don't want to run zeal in production: 45 | > there is significant overhead to detecting N+1s, and my benchmarks show that it 46 | > can make your code 2.5x slower in some cases. 47 | 48 | ### Celery 49 | 50 | If you use Celery, you can configure this using [signals](https://docs.celeryq.dev/en/stable/userguide/signals.html): 51 | 52 | ```python 53 | from celery.signals import task_prerun, task_postrun 54 | from zeal import setup, teardown 55 | from django.conf import settings 56 | 57 | @task_prerun.connect() 58 | def setup_zeal(*args, **kwargs): 59 | setup() 60 | 61 | @task_postrun.connect() 62 | def teardown_zeal(*args, **kwargs): 63 | teardown() 64 | ``` 65 | 66 | ### Tests 67 | 68 | Django [runs tests with `DEBUG=False`](https://docs.djangoproject.com/en/5.0/topics/testing/overview/#other-test-conditions), 69 | so to run zeal in your tests, you'll first need to ensure it's added to your 70 | `INSTALLED_APPS` and `MIDDLEWARE`. You could do something like: 71 | 72 | ```python 73 | import sys 74 | 75 | TEST = "test" in sys.argv 76 | if DEBUG or TEST: 77 | INSTALLED_APPS.append("zeal") 78 | MIDDLEWARE.append("zeal.middleware.zeal_middleware") 79 | ``` 80 | 81 | This will enable zeal in any tests that go through your middleware. If you want to enable 82 | it in _all_ tests, you need to do a bit more work. 83 | 84 | If you use pytest, use a fixture in your `conftest.py`: 85 | 86 | ```python 87 | import pytest 88 | from zeal import zeal_context 89 | 90 | @pytest.fixture(scope="function", autouse=True) 91 | def use_zeal(): 92 | with zeal_context(): 93 | yield 94 | ``` 95 | 96 | If you use unittest, add custom test cases and inherit from these rather than directly from Django's test cases: 97 | 98 | ```python 99 | # In e.g. `myapp/testing/test_cases.py` 100 | from zeal import setup as zeal_setup, teardown as zeal_teardown 101 | import unittest 102 | from django.test import SimpleTestCase, TestCase, TransactionTestCase 103 | 104 | class ZealTestMixin(unittest.TestCase): 105 | def setUp(self, test): 106 | zeal_setup() 107 | super().setUp() 108 | 109 | def teardown(self) -> None: 110 | zeal_teardown() 111 | return super().teardown(test, err) 112 | 113 | class CustomSimpleTestCase(ZealTestMixin, SimpleTestCase): 114 | pass 115 | 116 | class CustomTestCase(ZealTestMixin, TestCase): 117 | pass 118 | 119 | class CustomTransactionTestCase(ZealTestMixin, TransactionTestCase): 120 | pass 121 | ``` 122 | 123 | ### Generic setup 124 | 125 | If you also want to detect N+1s in other places not covered here, you can use the `setup` and 126 | `teardown` functions, or the `zeal_context` context manager: 127 | 128 | ```python 129 | from zeal import setup, teardown, zeal_context 130 | 131 | 132 | def foo(): 133 | setup() 134 | try: 135 | # your code goes here 136 | finally: 137 | teardown() 138 | 139 | 140 | @zeal_context() 141 | def bar(): 142 | # your code goes here 143 | 144 | 145 | def baz(): 146 | with zeal_context(): 147 | # your code goes here 148 | ``` 149 | 150 | ## Configuration 151 | 152 | By default, any issues detected by zeal will raise a `ZealError`. If you'd 153 | rather log any detected N+1s as warnings, you can set: 154 | 155 | ```python 156 | ZEAL_RAISE = False 157 | ``` 158 | 159 | N+1s will be reported when the same query is executed twice. To configure this 160 | threshold, set the following in your Django settings. 161 | 162 | ```python 163 | ZEAL_NPLUSONE_THRESHOLD = 3 164 | ``` 165 | 166 | To handle false positives, you can temporarily disable zeal in parts of your code 167 | using a context manager: 168 | 169 | ```python 170 | from zeal import zeal_ignore 171 | 172 | with zeal_ignore(): 173 | # code in this block will not log/raise zeal errors 174 | ``` 175 | 176 | If you only want to ignore a specific N+1, you can pass in a list of models/fields to ignore: 177 | 178 | ```python 179 | with zeal_ignore([{"model": "polls.Question", "field": "options"}]): 180 | # code in this block will ignore N+1s on Question.options 181 | ``` 182 | 183 | If you want to listen to N+1 exceptions globally and do something with them, you can listen to the Django signal that zeal emits: 184 | 185 | ```python 186 | from zeal.signals import nplusone_detected 187 | from django.dispatch import receiver 188 | 189 | @receiver(nplusone_detected) 190 | def handle_nplusone(sender, exception): 191 | # do something 192 | ``` 193 | 194 | Finally, if you want to ignore N+1 alerts from a specific model/field globally, you can 195 | add it to your settings: 196 | 197 | ```python 198 | ZEAL_ALLOWLIST = [ 199 | {"model": "polls.Question", "field": "options"}, 200 | 201 | # you can use fnmatch syntax in the model/field, too 202 | {"model": "polls.*", "field": "options"}, 203 | 204 | # if you don't pass in a field, all N+1s arising from the model will be ignored 205 | {"model": "polls.Question"}, 206 | ] 207 | ``` 208 | 209 | 210 | ## Debugging N+1s 211 | 212 | By default, zeal's alerts will tell you the line of your code that executed the same query 213 | multiple times. If you'd like to see the full call stack from each time the query was executed, 214 | you can set: 215 | 216 | ```python 217 | ZEAL_SHOW_ALL_CALLERS = True 218 | ``` 219 | 220 | in your settings. This will give you the full call stack from each time the query was executed. 221 | 222 | ## Comparison to nplusone 223 | 224 | zeal borrows heavily from [nplusone](https://github.com/jmcarp/nplusone), but has some differences: 225 | 226 | - zeal also detects N+1 caused by using `.only()` and `.defer()` 227 | - it lets you configure your own threshold for what constitutes an N+1 228 | - it has slightly more helpful error messages that tell you where the N+1 occurred 229 | - nplusone patches the Django ORM even in production when it's not enabled. zeal does not! 230 | - nplusone appears to be abandoned at this point. 231 | - however, zeal only works with Django, whereas nplusone can also be used with SQLAlchemy. 232 | - zeal does not (yet) detect unused prefetches, but nplusone does. 233 | 234 | ## Contributing 235 | 236 | 1. First, install [uv](https://github.com/astral-sh/uv). 237 | 2. Create a virtual env using `uv venv` and activate it with `source .venv/bin/activate`. 238 | 3. Run `make install` to install dev dependencies. 239 | 4. To run tests, run `make test`. 240 | -------------------------------------------------------------------------------- /hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | make format-check || exit 1 4 | make typecheck || exit 1 5 | 6 | -------------------------------------------------------------------------------- /hooks/pre-push: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | make test || exit 1 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "django-zeal" 3 | version = "2.0.4" 4 | description = "Detect N+1s in your Django app" 5 | readme = "README.md" 6 | license = { file = "LICENSE" } 7 | requires-python = ">=3.9" 8 | 9 | [tool.setuptools] 10 | package-dir = { "" = "src" } 11 | 12 | [tool.setuptools.packages.find] 13 | where = ["src"] 14 | 15 | [tool.ruff] 16 | line-length = 79 17 | 18 | [tool.ruff.lint] 19 | extend-select = [ 20 | "I", # isort 21 | "N", # naming 22 | "B", # bugbear 23 | "FIX", # disallow FIXME/TODO comments 24 | "F", # pyflakes 25 | "T20", # flake8-print 26 | "ERA", # commented-out code 27 | "UP", # pyupgrade 28 | ] 29 | 30 | [tool.pyright] 31 | include = ["src", "tests"] 32 | 33 | [tool.pytest.ini_options] 34 | DJANGO_SETTINGS_MODULE = "djangoproject.settings" 35 | pythonpath = ["src", "tests"] 36 | testpaths = ["tests"] 37 | addopts = "--nomigrations" 38 | markers = ["nozeal: disable the auto-setup of zeal in a test"] 39 | -------------------------------------------------------------------------------- /requirements-dev.in: -------------------------------------------------------------------------------- 1 | Django~=4.2 2 | pytest~=8.2.2 3 | pytest-django~=4.8.0 4 | factory-boy~=3.3.0 5 | ruff~=0.5.0 6 | django-stubs~=5.0 7 | pyright 8 | build 9 | twine 10 | pytest-random-order 11 | pytest-mock 12 | pytest-codspeed 13 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile requirements-dev.in -o requirements-dev.txt 3 | asgiref==3.8.1 4 | # via 5 | # django 6 | # django-stubs 7 | build==1.2.1 8 | # via -r requirements-dev.in 9 | certifi==2024.6.2 10 | # via requests 11 | cffi==1.17.1 12 | # via pytest-codspeed 13 | charset-normalizer==3.3.2 14 | # via requests 15 | django==4.2.13 16 | # via 17 | # -r requirements-dev.in 18 | # django-stubs 19 | # django-stubs-ext 20 | django-stubs==5.0.2 21 | # via -r requirements-dev.in 22 | django-stubs-ext==5.0.2 23 | # via django-stubs 24 | docutils==0.21.2 25 | # via readme-renderer 26 | factory-boy==3.3.0 27 | # via -r requirements-dev.in 28 | faker==26.0.0 29 | # via factory-boy 30 | filelock==3.16.1 31 | # via pytest-codspeed 32 | idna==3.7 33 | # via requests 34 | importlib-metadata==8.5.0 35 | # via twine 36 | iniconfig==2.0.0 37 | # via pytest 38 | jaraco-classes==3.4.0 39 | # via keyring 40 | jaraco-context==5.3.0 41 | # via keyring 42 | jaraco-functools==4.0.1 43 | # via keyring 44 | keyring==25.2.1 45 | # via twine 46 | markdown-it-py==3.0.0 47 | # via rich 48 | mdurl==0.1.2 49 | # via markdown-it-py 50 | more-itertools==10.3.0 51 | # via 52 | # jaraco-classes 53 | # jaraco-functools 54 | nh3==0.2.17 55 | # via readme-renderer 56 | nodeenv==1.9.1 57 | # via pyright 58 | packaging==24.1 59 | # via 60 | # build 61 | # pytest 62 | pkginfo==1.10.0 63 | # via twine 64 | pluggy==1.5.0 65 | # via pytest 66 | pycparser==2.22 67 | # via cffi 68 | pygments==2.18.0 69 | # via 70 | # readme-renderer 71 | # rich 72 | pyproject-hooks==1.1.0 73 | # via build 74 | pyright==1.1.369 75 | # via -r requirements-dev.in 76 | pytest==8.2.2 77 | # via 78 | # -r requirements-dev.in 79 | # pytest-codspeed 80 | # pytest-django 81 | # pytest-mock 82 | # pytest-random-order 83 | pytest-codspeed==3.0.0 84 | # via -r requirements-dev.in 85 | pytest-django==4.8.0 86 | # via -r requirements-dev.in 87 | pytest-mock==3.14.0 88 | # via -r requirements-dev.in 89 | pytest-random-order==1.1.1 90 | # via -r requirements-dev.in 91 | python-dateutil==2.9.0.post0 92 | # via faker 93 | readme-renderer==43.0 94 | # via twine 95 | requests==2.32.3 96 | # via 97 | # requests-toolbelt 98 | # twine 99 | requests-toolbelt==1.0.0 100 | # via twine 101 | rfc3986==2.0.0 102 | # via twine 103 | rich==13.9.4 104 | # via 105 | # pytest-codspeed 106 | # twine 107 | ruff==0.5.0 108 | # via -r requirements-dev.in 109 | setuptools==75.8.0 110 | # via pytest-codspeed 111 | six==1.16.0 112 | # via python-dateutil 113 | sqlparse==0.5.0 114 | # via django 115 | twine==5.1.1 116 | # via -r requirements-dev.in 117 | types-pyyaml==6.0.12.20240311 118 | # via django-stubs 119 | typing-extensions==4.12.2 120 | # via 121 | # django-stubs 122 | # django-stubs-ext 123 | urllib3==2.2.2 124 | # via 125 | # requests 126 | # twine 127 | zipp==3.21.0 128 | # via importlib-metadata 129 | -------------------------------------------------------------------------------- /src/zeal/__init__.py: -------------------------------------------------------------------------------- 1 | from .errors import NPlusOneError, ZealError 2 | from .listeners import setup, teardown, zeal_context, zeal_ignore 3 | 4 | __all__ = [ 5 | "ZealError", 6 | "NPlusOneError", 7 | "setup", 8 | "teardown", 9 | "zeal_context", 10 | "zeal_ignore", 11 | ] 12 | -------------------------------------------------------------------------------- /src/zeal/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | from .patch import patch 4 | 5 | 6 | class ZealConfig(AppConfig): 7 | name = "zeal" 8 | 9 | def ready(self): 10 | from .constants import initialize_app_registry 11 | 12 | initialize_app_registry() 13 | patch() 14 | -------------------------------------------------------------------------------- /src/zeal/constants.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from django.apps import apps 4 | 5 | ALL_APPS = {} 6 | 7 | 8 | def initialize_app_registry(): 9 | if TYPE_CHECKING: 10 | # pyright is unhappy with model._meta.related_objects below, 11 | # so we need to skip this code path in type checking 12 | return 13 | 14 | for model in apps.get_models(): 15 | # Get direct fields 16 | fields = set( 17 | field.name for field in model._meta.get_fields(include_hidden=True) 18 | ) 19 | 20 | # Get reverse relations using related_objects 21 | reverse_fields = set( 22 | rel.get_accessor_name() for rel in model._meta.related_objects 23 | ) 24 | 25 | ALL_APPS[f"{model._meta.app_label}.{model.__name__}"] = ( 26 | fields | reverse_fields 27 | ) 28 | -------------------------------------------------------------------------------- /src/zeal/errors.py: -------------------------------------------------------------------------------- 1 | class ZealError(Exception): 2 | pass 3 | 4 | 5 | class NPlusOneError(ZealError): 6 | pass 7 | 8 | 9 | class ZealConfigError(ZealError): 10 | pass 11 | -------------------------------------------------------------------------------- /src/zeal/listeners.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import warnings 4 | from abc import ABC, abstractmethod 5 | from collections import defaultdict 6 | from contextlib import contextmanager 7 | from contextvars import ContextVar, Token 8 | from dataclasses import dataclass, field 9 | from fnmatch import fnmatch 10 | from typing import Optional, TypedDict 11 | 12 | from django.conf import settings 13 | from django.db import models 14 | 15 | from zeal.util import get_caller, get_stack 16 | 17 | from .constants import ALL_APPS 18 | from .errors import NPlusOneError, ZealConfigError, ZealError 19 | from .signals import nplusone_detected 20 | 21 | 22 | class QuerySource(TypedDict): 23 | model: type[models.Model] 24 | field: str 25 | instance_key: Optional[str] # e.g. `User:123` 26 | 27 | 28 | # tuple of (model, field, caller) 29 | CountsKey = tuple[type[models.Model], str, str] 30 | 31 | 32 | class AllowListEntry(TypedDict): 33 | model: str 34 | field: Optional[str] 35 | 36 | 37 | def _validate_allowlist(allowlist: list[AllowListEntry]): 38 | for entry in allowlist: 39 | fnmatch_chars = "*?[]" 40 | # if this is an fnmatch, don't do anything 41 | if any(char in entry["model"] for char in fnmatch_chars): 42 | continue 43 | if not ALL_APPS: 44 | # zeal has not been initialized yet 45 | continue 46 | if entry["model"] not in ALL_APPS: 47 | raise ZealConfigError( 48 | f"Model '{entry['model']}' not found in installed Django models" 49 | ) 50 | 51 | if not entry["field"]: 52 | continue 53 | 54 | if any(char in entry["field"] for char in fnmatch_chars): 55 | continue 56 | 57 | if entry["field"] not in ALL_APPS[entry["model"]]: 58 | raise ZealConfigError( 59 | f"Field '{entry['field']}' not found on '{entry['model']}'" 60 | ) 61 | 62 | 63 | @dataclass 64 | class NPlusOneContext: 65 | enabled: bool = False 66 | calls: dict[CountsKey, list[list[inspect.FrameInfo]]] = field( 67 | default_factory=lambda: defaultdict(list) 68 | ) 69 | ignored: set[str] = field(default_factory=set) 70 | allowlist: list[AllowListEntry] = field(default_factory=list) 71 | 72 | 73 | _nplusone_context: ContextVar[NPlusOneContext] = ContextVar( 74 | "nplusone", 75 | default=NPlusOneContext(), 76 | ) 77 | 78 | logger = logging.getLogger("zeal") 79 | 80 | 81 | class Listener(ABC): 82 | @abstractmethod 83 | def notify(self, *args, **kwargs): ... 84 | 85 | @property 86 | @abstractmethod 87 | def error_class(self) -> type[ZealError]: ... 88 | 89 | @property 90 | def _allowlist(self) -> list[AllowListEntry]: 91 | if hasattr(settings, "ZEAL_ALLOWLIST"): 92 | settings_allowlist = settings.ZEAL_ALLOWLIST 93 | else: 94 | settings_allowlist = [] 95 | 96 | return [*settings_allowlist, *_nplusone_context.get().allowlist] 97 | 98 | def _alert( 99 | self, 100 | model: type[models.Model], 101 | field: str, 102 | message: str, 103 | calls: list[list[inspect.FrameInfo]], 104 | ): 105 | should_error = ( 106 | settings.ZEAL_RAISE if hasattr(settings, "ZEAL_RAISE") else True 107 | ) 108 | should_include_all_callers = ( 109 | settings.ZEAL_SHOW_ALL_CALLERS 110 | if hasattr(settings, "ZEAL_SHOW_ALL_CALLERS") 111 | else False 112 | ) 113 | is_allowlisted = False 114 | for entry in self._allowlist: 115 | model_match = fnmatch( 116 | f"{model._meta.app_label}.{model.__name__}", entry["model"] 117 | ) 118 | field_match = fnmatch(field, entry.get("field") or "*") 119 | if model_match and field_match: 120 | is_allowlisted = True 121 | break 122 | 123 | if is_allowlisted: 124 | return 125 | 126 | stack = get_stack() 127 | final_caller = get_caller(stack) 128 | if should_include_all_callers: 129 | message = f"{message} with calls:\n" 130 | for i, caller in enumerate(calls): 131 | message += f"CALL {i+1}:\n" 132 | for frame in caller: 133 | message += f" {frame.filename}:{frame.lineno} in {frame.function}\n" 134 | else: 135 | message = f"{message} at {final_caller.filename}:{final_caller.lineno} in {final_caller.function}" 136 | if should_error: 137 | raise self.error_class(message) 138 | else: 139 | warnings.warn_explicit( 140 | message, 141 | UserWarning, 142 | filename=final_caller.filename, 143 | lineno=final_caller.lineno, 144 | ) 145 | 146 | 147 | class NPlusOneListener(Listener): 148 | @property 149 | def error_class(self): 150 | return NPlusOneError 151 | 152 | def notify( 153 | self, 154 | model: type[models.Model], 155 | field: str, 156 | instance_key: Optional[str], 157 | ): 158 | context = _nplusone_context.get() 159 | if not context.enabled: 160 | return 161 | stack = get_stack() 162 | caller = get_caller(stack) 163 | key = (model, field, f"{caller.filename}:{caller.lineno}") 164 | context.calls[key].append(stack) 165 | count = len(context.calls[key]) 166 | if count >= self._threshold and instance_key not in context.ignored: 167 | message = f"N+1 detected on {model._meta.app_label}.{model.__name__}.{field}" 168 | self._alert(model, field, message, context.calls[key]) 169 | _nplusone_context.set(context) 170 | 171 | def ignore(self, instance_key: Optional[str]): 172 | """ 173 | Tells the listener to ignore N+1s arising from this instance. 174 | 175 | This is used when the given instance is singly-loaded, e.g. via `.first()` 176 | or `.get()`. This is to prevent false positives. 177 | """ 178 | context = _nplusone_context.get() 179 | if not instance_key: 180 | return 181 | context.ignored.add(instance_key) 182 | _nplusone_context.set(context) 183 | 184 | @property 185 | def _threshold(self) -> int: 186 | if hasattr(settings, "ZEAL_NPLUSONE_THRESHOLD"): 187 | return settings.ZEAL_NPLUSONE_THRESHOLD 188 | else: 189 | return 2 190 | 191 | def _alert( 192 | self, 193 | model: type[models.Model], 194 | field: str, 195 | message: str, 196 | calls: list[list[inspect.FrameInfo]], 197 | ): 198 | super()._alert(model, field, message, calls) 199 | nplusone_detected.send( 200 | sender=self, 201 | exception=self.error_class(message), 202 | ) 203 | 204 | 205 | n_plus_one_listener = NPlusOneListener() 206 | 207 | 208 | def setup() -> Optional[Token]: 209 | # if we're already in an ignore-context, we don't want to override 210 | # it. 211 | context = _nplusone_context.get() 212 | if hasattr(settings, "ZEAL_ALLOWLIST"): 213 | _validate_allowlist(settings.ZEAL_ALLOWLIST) 214 | return _nplusone_context.set( 215 | NPlusOneContext(enabled=True, allowlist=context.allowlist) 216 | ) 217 | 218 | 219 | def teardown(token: Optional[Token] = None): 220 | if token: 221 | _nplusone_context.reset(token) 222 | else: 223 | _nplusone_context.set(NPlusOneContext()) 224 | 225 | 226 | @contextmanager 227 | def zeal_context(): 228 | token = setup() 229 | try: 230 | yield 231 | finally: 232 | teardown(token) 233 | 234 | 235 | @contextmanager 236 | def zeal_ignore(allowlist: Optional[list[AllowListEntry]] = None): 237 | old_context = _nplusone_context.get() 238 | if allowlist is None: 239 | allowlist = [{"model": "*", "field": "*"}] 240 | elif old_context.enabled: 241 | _validate_allowlist(allowlist) 242 | 243 | old_context = _nplusone_context.get() 244 | new_context = NPlusOneContext( 245 | enabled=old_context.enabled, 246 | calls=old_context.calls.copy(), 247 | ignored=old_context.ignored.copy(), 248 | allowlist=[*old_context.allowlist, *allowlist], 249 | ) 250 | token = _nplusone_context.set(new_context) 251 | try: 252 | yield 253 | finally: 254 | _nplusone_context.reset(token) 255 | -------------------------------------------------------------------------------- /src/zeal/middleware.py: -------------------------------------------------------------------------------- 1 | from asgiref.sync import iscoroutinefunction 2 | from django.utils.decorators import sync_and_async_middleware 3 | 4 | from .listeners import zeal_context 5 | 6 | 7 | @sync_and_async_middleware 8 | def zeal_middleware(get_response): 9 | if iscoroutinefunction(get_response): 10 | 11 | async def async_middleware(request): 12 | with zeal_context(): 13 | response = await get_response(request) 14 | return response 15 | 16 | return async_middleware 17 | 18 | else: 19 | 20 | def middleware(request): 21 | with zeal_context(): 22 | response = get_response(request) 23 | return response 24 | 25 | return middleware 26 | -------------------------------------------------------------------------------- /src/zeal/patch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | import inspect 4 | from typing import Any, Callable, Optional, TypedDict, Union 5 | 6 | from django.db import models 7 | from django.db.models.fields.related_descriptors import ( 8 | ForwardManyToOneDescriptor, 9 | ReverseOneToOneDescriptor, 10 | create_forward_many_to_many_manager, 11 | create_reverse_many_to_one_manager, 12 | ) 13 | from django.db.models.query import QuerySet 14 | from django.db.models.query_utils import DeferredAttribute 15 | 16 | from zeal.util import is_single_query 17 | 18 | from .listeners import QuerySource, n_plus_one_listener 19 | 20 | 21 | class QuerysetContext(TypedDict): 22 | args: Optional[Any] 23 | kwargs: Optional[Any] 24 | 25 | # This is only used for many-to-many relations. It contains the call args 26 | # when `create_forward_many_to_many_manager` is called. 27 | manager_call_args: Optional[dict[str, Any]] 28 | 29 | # used by ReverseManyToOne. a django model instance. 30 | instance: Optional[models.Model] 31 | 32 | 33 | Parser = Callable[[QuerysetContext], QuerySource] 34 | 35 | 36 | def get_instance_key( 37 | instance: Union[models.Model, dict[str, Any]], 38 | ) -> Optional[str]: 39 | if isinstance(instance, models.Model): 40 | return f"{instance.__class__.__name__}:{instance.pk}" 41 | else: 42 | # when calling a queryset with `.values(...).get()`, the instance 43 | # we get here may be a dict. we don't handle that case formally, 44 | # so we return None to ignore that instance in our listeners. 45 | return None 46 | 47 | 48 | def patch_module_function(original, patched): 49 | module = importlib.import_module(original.__module__) 50 | setattr(module, original.__name__, patched) 51 | 52 | 53 | def patch_queryset_fetch_all( 54 | queryset: models.QuerySet, parser: Parser, context: QuerysetContext 55 | ): 56 | fetch_all = queryset._fetch_all 57 | 58 | @functools.wraps(fetch_all) 59 | def wrapper(*args, **kwargs): 60 | if queryset._result_cache is None: 61 | parsed = parser(context) 62 | n_plus_one_listener.notify( 63 | parsed["model"], 64 | parsed["field"], 65 | parsed["instance_key"], 66 | ) 67 | return fetch_all(*args, **kwargs) 68 | 69 | return wrapper 70 | 71 | 72 | def patch_queryset_function( 73 | queryset_func: Callable[..., models.QuerySet], 74 | parser: Parser, 75 | context: Optional[QuerysetContext] = None, 76 | ): 77 | if context is None: 78 | context = { 79 | "args": None, 80 | "kwargs": None, 81 | "manager_call_args": None, 82 | "instance": None, 83 | } 84 | 85 | @functools.wraps(queryset_func) 86 | def wrapper(*args, **kwargs): 87 | queryset = queryset_func(*args, **kwargs) 88 | 89 | # don't patch the same queryset more than once 90 | if ( 91 | hasattr(queryset, "__zeal_patched") and queryset.__zeal_patched # type: ignore 92 | ): 93 | return queryset 94 | 95 | if args and args != context.get("args"): 96 | context["args"] = args 97 | 98 | # When comparing kwargs, we must use id() rather than == because 99 | # __eq__ methods on model instances can trigger infinite recursion. 100 | if kwargs: 101 | existing_kwargs = context.get("kwargs") 102 | if existing_kwargs is None or any( 103 | id(v) != id(existing_kwargs.get(k)) for k, v in kwargs.items() 104 | ): 105 | context["kwargs"] = kwargs 106 | 107 | queryset._clone = patch_queryset_function( # type: ignore 108 | queryset._clone, # type: ignore 109 | parser, 110 | context=context, 111 | ) 112 | queryset._fetch_all = patch_queryset_fetch_all( 113 | queryset, parser, context 114 | ) 115 | queryset.__zeal_patched = True # type: ignore 116 | return queryset 117 | 118 | return wrapper 119 | 120 | 121 | def patch_forward_many_to_one_descriptor(): 122 | """ 123 | This also handles ForwardOneToOneDescriptor, which is 124 | a subclass of ForwardManyToOneDescriptor. 125 | """ 126 | 127 | def parser(context: QuerysetContext) -> QuerySource: 128 | assert "args" in context and context["args"] is not None 129 | descriptor = context["args"][0] 130 | 131 | if "kwargs" in context and context["kwargs"] is not None: 132 | instance = context["kwargs"]["instance"] 133 | instance_key = get_instance_key(instance) 134 | else: 135 | # `get_queryset` can in some cases be called without any 136 | # kwargs. In those cases, we ignore the instance. 137 | instance_key = None 138 | return { 139 | "model": descriptor.field.model, 140 | "field": descriptor.field.name, 141 | "instance_key": instance_key, 142 | } 143 | 144 | ForwardManyToOneDescriptor.get_queryset = patch_queryset_function( 145 | ForwardManyToOneDescriptor.get_queryset, parser=parser 146 | ) 147 | 148 | 149 | def parse_related_parts( 150 | model: type[models.Model], 151 | related_name: Optional[str], 152 | related_model: type[models.Model], 153 | ) -> tuple[type[models.Model], str]: 154 | field_name = related_name or f"{related_model._meta.model_name}_set" 155 | return (model, field_name) 156 | 157 | 158 | def patch_reverse_many_to_one_descriptor(): 159 | def parser(context: QuerysetContext) -> QuerySource: 160 | assert ( 161 | "manager_call_args" in context 162 | and context["manager_call_args"] is not None 163 | and "rel" in context["manager_call_args"] 164 | ) 165 | assert "instance" in context and context["instance"] is not None 166 | rel = context["manager_call_args"]["rel"] 167 | model, field = parse_related_parts( 168 | rel.model, rel.related_name, rel.related_model 169 | ) 170 | return { 171 | "model": model, 172 | "field": field, 173 | "instance_key": get_instance_key(context["instance"]), 174 | } 175 | 176 | def patched_create_reverse_many_to_one_manager(*args, **kwargs): 177 | manager_call_args = inspect.getcallargs( 178 | create_reverse_many_to_one_manager, *args, **kwargs 179 | ) 180 | manager = create_reverse_many_to_one_manager(*args, **kwargs) 181 | 182 | def patch_init_method(func): 183 | @functools.wraps(func) 184 | def wrapper(self, instance): 185 | self.get_queryset = patch_queryset_function( 186 | self.get_queryset, 187 | parser, 188 | context={ 189 | "args": None, 190 | "kwargs": None, 191 | "manager_call_args": manager_call_args, 192 | "instance": instance, 193 | }, 194 | ) 195 | return func(self, instance) 196 | 197 | return wrapper 198 | 199 | manager.__init__ = patch_init_method(manager.__init__) # type: ignore 200 | return manager 201 | 202 | patch_module_function( 203 | create_reverse_many_to_one_manager, 204 | patched_create_reverse_many_to_one_manager, 205 | ) 206 | 207 | 208 | def patch_reverse_one_to_one_descriptor(): 209 | def parser(context: QuerysetContext) -> QuerySource: 210 | assert "args" in context and context["args"] is not None 211 | descriptor = context["args"][0] 212 | field = descriptor.related.field 213 | if "kwargs" in context and context["kwargs"] is not None: 214 | instance = context["kwargs"]["instance"] 215 | instance_key = get_instance_key(instance) 216 | else: 217 | instance_key = None 218 | return { 219 | "model": field.related_model, 220 | "field": field.remote_field.name, 221 | "instance_key": instance_key, 222 | } 223 | 224 | ReverseOneToOneDescriptor.get_queryset = patch_queryset_function( 225 | ReverseOneToOneDescriptor.get_queryset, parser 226 | ) 227 | 228 | 229 | def patch_many_to_many_descriptor(): 230 | def parser(context: QuerysetContext) -> QuerySource: 231 | assert ( 232 | "manager_call_args" in context 233 | and context["manager_call_args"] is not None 234 | and "rel" in context["manager_call_args"] 235 | ) 236 | assert "args" in context and context["args"] is not None 237 | rel = context["manager_call_args"]["rel"] 238 | manager = context["args"][0] 239 | model = manager.instance.__class__ 240 | related_model = manager.target_field.related_model 241 | is_reverse = context["manager_call_args"]["reverse"] 242 | field_name = ( 243 | rel.related_name if is_reverse else manager.prefetch_cache_name 244 | ) 245 | 246 | model, field_name = parse_related_parts( 247 | model, field_name, related_model 248 | ) 249 | return { 250 | "model": model, 251 | "field": field_name, 252 | "instance_key": get_instance_key(manager.instance), 253 | } 254 | 255 | def patched_create_forward_many_to_many_manager(*args, **kwargs): 256 | manager_call_args = inspect.getcallargs( 257 | create_forward_many_to_many_manager, *args, **kwargs 258 | ) 259 | manager = create_forward_many_to_many_manager(*args, **kwargs) 260 | manager.get_queryset = patch_queryset_function( 261 | manager.get_queryset, 262 | parser, 263 | context={ 264 | "args": None, 265 | "kwargs": None, 266 | "manager_call_args": manager_call_args, 267 | "instance": None, 268 | }, 269 | ) 270 | return manager 271 | 272 | patch_module_function( 273 | create_forward_many_to_many_manager, 274 | patched_create_forward_many_to_many_manager, 275 | ) 276 | 277 | 278 | def patch_deferred_attribute(): 279 | def patched_check_parent_chain(func): 280 | @functools.wraps(func) 281 | def wrapper(self, instance, *args, **kwargs): 282 | result = func(self, instance, *args, **kwargs) 283 | if result is None: 284 | n_plus_one_listener.notify( 285 | instance.__class__, self.field.name, str(instance.pk) 286 | ) 287 | return result 288 | 289 | return wrapper 290 | 291 | DeferredAttribute._check_parent_chain = patched_check_parent_chain( # type: ignore 292 | DeferredAttribute._check_parent_chain # type: ignore 293 | ) 294 | 295 | 296 | def patch_global_queryset(): 297 | """ 298 | We patch `_fetch_all` and `.get()` on querysets to let us ignore singly-loaded 299 | instances. We don't want to trigger N+1 errors from such instances because of 300 | the high false positive rate. 301 | """ 302 | 303 | def patch_fetch_all(func): 304 | @functools.wraps(func) 305 | def wrapper(self, *args, **kwargs): 306 | should_ignore = ( 307 | is_single_query(self.query) and self._result_cache is None 308 | ) 309 | ret = func(self, *args, **kwargs) # call the original _fetch_all 310 | if should_ignore and len(self) > 0: 311 | n_plus_one_listener.ignore(get_instance_key(self[0])) 312 | return ret 313 | 314 | return wrapper 315 | 316 | QuerySet._fetch_all = patch_fetch_all(QuerySet._fetch_all) 317 | 318 | def patch_get(func): 319 | @functools.wraps(func) 320 | def wrapper(*args, **kwargs): 321 | ret = func(*args, **kwargs) 322 | n_plus_one_listener.ignore(get_instance_key(ret)) 323 | return ret 324 | 325 | return wrapper 326 | 327 | QuerySet.get = patch_get(QuerySet.get) 328 | 329 | 330 | def patch(): 331 | patch_forward_many_to_one_descriptor() 332 | patch_reverse_many_to_one_descriptor() 333 | patch_reverse_one_to_one_descriptor() 334 | patch_many_to_many_descriptor() 335 | patch_deferred_attribute() 336 | patch_global_queryset() 337 | -------------------------------------------------------------------------------- /src/zeal/signals.py: -------------------------------------------------------------------------------- 1 | from django.dispatch import Signal 2 | 3 | nplusone_detected = Signal() 4 | -------------------------------------------------------------------------------- /src/zeal/util.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from django.db.models.sql import Query 4 | 5 | PATTERNS = [ 6 | "site-packages", 7 | "zeal/listeners.py", 8 | "zeal/patch.py", 9 | "zeal/util.py", 10 | ] 11 | 12 | 13 | def get_stack() -> list[inspect.FrameInfo]: 14 | """ 15 | Returns the current call stack, excluding any code in site-packages or zeal. 16 | """ 17 | return [ 18 | frame 19 | for frame in inspect.stack(context=0)[1:] 20 | if not any(pattern in frame.filename for pattern in PATTERNS) 21 | ] 22 | 23 | 24 | def get_caller(stack: list[inspect.FrameInfo]) -> inspect.FrameInfo: 25 | """ 26 | Returns the filename and line number of the current caller, 27 | excluding any code in site-packages or zeal. 28 | """ 29 | return next(frame for frame in stack) 30 | 31 | 32 | def is_single_query(query: Query): 33 | return ( 34 | query.high_mark is not None and query.high_mark - query.low_mark == 1 35 | ) 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from zeal import zeal_context 3 | 4 | 5 | @pytest.fixture(scope="function", autouse=True) 6 | def use_zeal(request): 7 | if "nozeal" in request.keywords: 8 | yield 9 | else: 10 | with zeal_context(): 11 | yield 12 | -------------------------------------------------------------------------------- /tests/djangoproject/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taobojlen/django-zeal/2aabdaa5dc9c1443c662982b97a1c918fa62c1c2/tests/djangoproject/__init__.py -------------------------------------------------------------------------------- /tests/djangoproject/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | 4 | import os 5 | import sys 6 | 7 | 8 | def main(): 9 | """Run administrative tasks.""" 10 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "djangoproject.settings") 11 | try: 12 | from django.core.management import execute_from_command_line 13 | except ImportError as exc: 14 | raise ImportError( 15 | "Couldn't import Django. Are you sure it's installed and " 16 | "available on your PYTHONPATH environment variable? Did you " 17 | "forget to activate a virtual environment?" 18 | ) from exc 19 | execute_from_command_line(sys.argv) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /tests/djangoproject/settings.py: -------------------------------------------------------------------------------- 1 | SECRET_KEY = 1 2 | DEBUG = True 3 | USE_TZ = True 4 | TIME_ZONE = "UTC" 5 | 6 | INSTALLED_APPS = [ 7 | "django.contrib.admin", 8 | "django.contrib.auth", 9 | "django.contrib.contenttypes", 10 | "django.contrib.sessions", 11 | "django.contrib.messages", 12 | "django.contrib.staticfiles", 13 | "djangoproject.social", 14 | "zeal", 15 | ] 16 | 17 | MIDDLEWARE = ["zeal.middleware.zeal_middleware"] 18 | 19 | ROOT_URLCONF = "djangoproject.urls" 20 | 21 | DATABASES = { 22 | "default": { 23 | "ENGINE": "django.db.backends.sqlite3", 24 | "NAME": ":memory:", 25 | } 26 | } 27 | 28 | DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" 29 | -------------------------------------------------------------------------------- /tests/djangoproject/social/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taobojlen/django-zeal/2aabdaa5dc9c1443c662982b97a1c918fa62c1c2/tests/djangoproject/social/__init__.py -------------------------------------------------------------------------------- /tests/djangoproject/social/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class SocialConfig(AppConfig): 5 | default_auto_field = "django.db.models.BigAutoField" 6 | name = "djangoproject.social" 7 | -------------------------------------------------------------------------------- /tests/djangoproject/social/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class User(models.Model): 5 | username = models.TextField() 6 | # user.followers and user.following are both ManyToManyDescriptor 7 | following = models.ManyToManyField("User", related_name="followers") 8 | 9 | # note that there's no related_name set here, because we want to 10 | # test that case too. 11 | blocked = models.ManyToManyField("user") 12 | 13 | followers: models.Manager["User"] 14 | user_set: models.Manager["User"] 15 | posts: models.Manager["Post"] 16 | profile: "Profile" 17 | 18 | 19 | class Profile(models.Model): 20 | # profile.user is ForwardOneToOne 21 | # user.profile is ReverseOneToOne 22 | user = models.OneToOneField(User, on_delete=models.CASCADE) 23 | display_name = models.TextField() 24 | 25 | 26 | class Post(models.Model): 27 | # post.author is ForwardManyToOne 28 | # user.posts is ReverseManyToOne 29 | author = models.ForeignKey( 30 | User, on_delete=models.CASCADE, related_name="posts" 31 | ) 32 | text = models.TextField() 33 | -------------------------------------------------------------------------------- /tests/djangoproject/social/views.py: -------------------------------------------------------------------------------- 1 | from django.http import HttpRequest, JsonResponse 2 | 3 | from .models import User 4 | 5 | 6 | def single_user_and_profile(request: HttpRequest, id: int): 7 | user = User.objects.get(id=id) 8 | return JsonResponse( 9 | data={ 10 | "username": user.username, 11 | "display_name": user.profile.display_name, 12 | } 13 | ) 14 | 15 | 16 | def all_users_and_profiles(request: HttpRequest): 17 | """ 18 | This view has an N+1. 19 | """ 20 | return JsonResponse( 21 | data={ 22 | "users": [ 23 | { 24 | "username": user.username, 25 | "display_name": user.profile.display_name, 26 | } 27 | for user in User.objects.all() 28 | ] 29 | } 30 | ) 31 | -------------------------------------------------------------------------------- /tests/djangoproject/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | 3 | from .social.views import all_users_and_profiles, single_user_and_profile 4 | 5 | urlpatterns = [ 6 | path("users/", all_users_and_profiles), 7 | path("user//", single_user_and_profile), 8 | ] 9 | -------------------------------------------------------------------------------- /tests/factories.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | 3 | import factory 4 | from djangoproject.social.models import Post, Profile, User 5 | 6 | T = TypeVar("T") 7 | 8 | 9 | class BaseFactory(Generic[T], factory.django.DjangoModelFactory): 10 | @classmethod 11 | def create(cls, **kwargs) -> T: 12 | return super().create(**kwargs) 13 | 14 | 15 | class UserFactory(BaseFactory[User]): 16 | username = factory.Faker("user_name") 17 | 18 | class Meta: # type: ignore 19 | model = User 20 | 21 | 22 | class ProfileFactory(BaseFactory[Profile]): 23 | display_name = factory.Faker("name") 24 | 25 | class Meta: # type: ignore 26 | model = Profile 27 | 28 | 29 | class PostFactory(BaseFactory[Post]): 30 | text = factory.Faker("sentence") 31 | 32 | class Meta: # type: ignore 33 | model = Post 34 | -------------------------------------------------------------------------------- /tests/test_listeners.py: -------------------------------------------------------------------------------- 1 | import re 2 | import warnings 3 | 4 | import pytest 5 | from djangoproject.social.models import Post, User 6 | from zeal import NPlusOneError, zeal_context, zeal_ignore 7 | from zeal.errors import ZealConfigError 8 | from zeal.listeners import _nplusone_context, n_plus_one_listener 9 | 10 | from .factories import PostFactory, UserFactory 11 | 12 | pytestmark = pytest.mark.django_db 13 | 14 | 15 | def test_can_log_errors(settings, caplog): 16 | settings.ZEAL_RAISE = False 17 | 18 | [user_1, user_2] = UserFactory.create_batch(2) 19 | PostFactory.create(author=user_1) 20 | PostFactory.create(author=user_2) 21 | with warnings.catch_warnings(record=True) as w: 22 | warnings.simplefilter("always") 23 | for user in User.objects.all(): 24 | _ = list(user.posts.all()) 25 | assert len(w) == 1 26 | assert issubclass(w[0].category, UserWarning) 27 | assert re.search( 28 | r"N\+1 detected on social\.User\.posts at .*\/test_listeners\.py:24 in test_can_log_errors", 29 | str(w[0].message), 30 | ) 31 | 32 | 33 | def test_can_log_all_traces(settings): 34 | settings.ZEAL_SHOW_ALL_CALLERS = True 35 | settings.ZEAL_RAISE = False 36 | [user_1, user_2] = UserFactory.create_batch(2) 37 | PostFactory.create(author=user_1) 38 | PostFactory.create(author=user_2) 39 | with warnings.catch_warnings(record=True) as w: 40 | warnings.simplefilter("always") 41 | for user in User.objects.all(): 42 | _ = list(user.posts.all()) 43 | assert len(w) == 1 44 | assert issubclass(w[0].category, UserWarning) 45 | expected_lines = [ 46 | "N+1 detected on social.User.posts with calls:", 47 | "CALL 1:", 48 | "tests/test_listeners.py:42 in test_can_log_all_traces", 49 | "CALL 2:", 50 | "tests/test_listeners.py:42 in test_can_log_all_traces", 51 | ] 52 | for line in expected_lines: 53 | assert line in str(w[0].message) 54 | 55 | 56 | def test_errors_include_caller(): 57 | [user_1, user_2] = UserFactory.create_batch(2) 58 | PostFactory.create(author=user_1) 59 | PostFactory.create(author=user_2) 60 | with pytest.raises( 61 | NPlusOneError, 62 | match=r"N\+1 detected on social\.User\.posts at .*\/test_listeners\.py:65 in test_errors_include_caller", 63 | ): 64 | for user in User.objects.all(): 65 | _ = list(user.posts.all()) 66 | 67 | 68 | def test_can_exclude_with_allowlist(settings): 69 | settings.ZEAL_ALLOWLIST = [{"model": "social.User", "field": "posts"}] 70 | 71 | [user_1, user_2] = UserFactory.create_batch(2) 72 | PostFactory.create(author=user_1) 73 | PostFactory.create(author=user_2) 74 | 75 | # this will not raise, allow-listed 76 | for user in User.objects.all(): 77 | _ = list(user.posts.all()) 78 | 79 | with pytest.raises( 80 | NPlusOneError, match=re.escape("N+1 detected on social.Post.author") 81 | ): 82 | for post in Post.objects.all(): 83 | _ = post.author 84 | 85 | 86 | def test_can_use_fnmatch_pattern_in_allowlist_model(settings): 87 | settings.ZEAL_ALLOWLIST = [{"model": "social.U*"}] 88 | 89 | [user_1, user_2] = UserFactory.create_batch(2) 90 | PostFactory.create(author=user_1) 91 | PostFactory.create(author=user_2) 92 | 93 | # this will not raise, allow-listed 94 | for user in User.objects.all(): 95 | _ = list(user.posts.all()) 96 | 97 | with pytest.raises( 98 | NPlusOneError, match=re.escape("N+1 detected on social.Post.author") 99 | ): 100 | for post in Post.objects.all(): 101 | _ = post.author 102 | 103 | 104 | def test_can_use_fnmatch_pattern_in_allowlist_field(settings): 105 | settings.ZEAL_ALLOWLIST = [{"model": "social.User", "field": "p*st*"}] 106 | 107 | [user_1, user_2] = UserFactory.create_batch(2) 108 | PostFactory.create(author=user_1) 109 | PostFactory.create(author=user_2) 110 | 111 | # this will not raise, allow-listed 112 | for user in User.objects.all(): 113 | _ = list(user.posts.all()) 114 | 115 | with pytest.raises( 116 | NPlusOneError, match=re.escape("N+1 detected on social.Post.author") 117 | ): 118 | for post in Post.objects.all(): 119 | _ = post.author 120 | 121 | 122 | def test_ignore_context_takes_precedence(): 123 | """ 124 | If you're within a `zeal_ignore` context, then even if some later code adds 125 | a zeal context, then the ignore context should take precedence. 126 | """ 127 | with zeal_ignore(): 128 | with zeal_context(): 129 | [user_1, user_2] = UserFactory.create_batch(2) 130 | PostFactory.create(author=user_1) 131 | PostFactory.create(author=user_2) 132 | 133 | # this will not raise because we're in the zeal_ignore context 134 | for user in User.objects.all(): 135 | _ = list(user.posts.all()) 136 | 137 | 138 | def test_reverts_to_previous_state_when_leaving_zeal_ignore(): 139 | # we are currently in a zeal context 140 | initial_context = _nplusone_context.get() 141 | with zeal_ignore(): 142 | assert _nplusone_context.get().allowlist == [ 143 | {"model": "*", "field": "*"} 144 | ] 145 | assert _nplusone_context.get() == initial_context 146 | 147 | 148 | def test_resets_state_in_nested_context(): 149 | [user_1, user_2] = UserFactory.create_batch(2) 150 | PostFactory.create(author=user_1) 151 | PostFactory.create(author=user_2) 152 | 153 | # we're already in a zeal_context within each test, so let's set 154 | # some state. 155 | n_plus_one_listener.ignore("Test:1") 156 | n_plus_one_listener.notify(Post, "test_field", "Post:1") 157 | 158 | context = _nplusone_context.get() 159 | assert context.ignored == {"Test:1"} 160 | assert len(context.calls.values()) == 1 161 | caller = list(context.calls.values())[0] 162 | 163 | with zeal_context(): 164 | # new context, fresh state 165 | context = _nplusone_context.get() 166 | assert context.ignored == set() 167 | assert list(context.calls.values()) == [] 168 | 169 | n_plus_one_listener.ignore("NestedTest:1") 170 | n_plus_one_listener.notify(Post, "nested_test_field", "Post:1") 171 | 172 | context = _nplusone_context.get() 173 | assert context.ignored == {"NestedTest:1"} 174 | assert len(list(context.calls.values())) == 1 175 | 176 | # back outside the nested context, we're back to the old state 177 | context = _nplusone_context.get() 178 | assert context.ignored == {"Test:1"} 179 | assert list(context.calls.values()) == [caller] 180 | 181 | 182 | def test_can_ignore_specific_models(): 183 | [user_1, user_2] = UserFactory.create_batch(2) 184 | PostFactory.create(author=user_1) 185 | PostFactory.create(author=user_2) 186 | 187 | with zeal_ignore([{"model": "social.User", "field": "post*"}]): 188 | # this will not raise, allow-listed 189 | for user in User.objects.all(): 190 | _ = list(user.posts.all()) 191 | 192 | with pytest.raises( 193 | NPlusOneError, 194 | match=re.escape("N+1 detected on social.Post.author"), 195 | ): 196 | # this is *not* allowlisted 197 | for post in Post.objects.all(): 198 | _ = post.author 199 | 200 | # if we ignore another field, we still raise 201 | with zeal_ignore([{"model": "social.User", "field": "following"}]): 202 | with pytest.raises( 203 | NPlusOneError, match=re.escape("N+1 detected on social.User.posts") 204 | ): 205 | for user in User.objects.all(): 206 | _ = list(user.posts.all()) 207 | 208 | # outside of the context, we're back to normal 209 | with pytest.raises( 210 | NPlusOneError, match=re.escape("N+1 detected on social.User.posts") 211 | ): 212 | for user in User.objects.all(): 213 | _ = list(user.posts.all()) 214 | 215 | 216 | def test_context_specific_allowlist_merges(): 217 | [user_1, user_2] = UserFactory.create_batch(2) 218 | PostFactory.create(author=user_1) 219 | PostFactory.create(author=user_2) 220 | 221 | with zeal_ignore([{"model": "social.User", "field": "post*"}]): 222 | # this will not raise, allow-listed 223 | for user in User.objects.all(): 224 | _ = list(user.posts.all()) 225 | 226 | with pytest.raises( 227 | NPlusOneError, 228 | match=re.escape("N+1 detected on social.Post.author"), 229 | ): 230 | # this is *not* allowlisted 231 | for post in Post.objects.all(): 232 | _ = post.author 233 | 234 | with zeal_ignore([{"model": "social.Post", "field": "author"}]): 235 | for post in Post.objects.all(): 236 | _ = post.author 237 | 238 | # this is still allowlisted 239 | for user in User.objects.all(): 240 | _ = list(user.posts.all()) 241 | 242 | 243 | # other tests automatically run in a zeal context, so we need to disable 244 | # that here. 245 | @pytest.mark.nozeal 246 | def test_does_not_run_outside_of_context(): 247 | [user_1, user_2] = UserFactory.create_batch(2) 248 | PostFactory.create(author=user_1) 249 | PostFactory.create(author=user_2) 250 | 251 | # this should not raise since we are outside of a zeal context 252 | for user in User.objects.all(): 253 | _ = list(user.posts.all()) 254 | 255 | with zeal_context(), pytest.raises(NPlusOneError): 256 | # this should raise since we are inside a zeal context 257 | for user in User.objects.all(): 258 | _ = list(user.posts.all()) 259 | 260 | 261 | @pytest.mark.nozeal 262 | def test_validates_global_allowlist_model_name(settings): 263 | settings.ZEAL_ALLOWLIST = [{"model": "foo", "field": "*"}] 264 | with pytest.raises( 265 | ZealConfigError, 266 | match=re.escape("Model 'foo' not found in installed Django models"), 267 | ): 268 | with zeal_context(): 269 | pass 270 | 271 | 272 | @pytest.mark.nozeal 273 | def test_validates_global_allowlist_field_name(settings): 274 | settings.ZEAL_ALLOWLIST = [{"model": "social.User", "field": "foo"}] 275 | with pytest.raises( 276 | ZealConfigError, 277 | match=re.escape("Field 'foo' not found on 'social.User'"), 278 | ): 279 | with zeal_context(): 280 | pass 281 | 282 | 283 | @pytest.mark.nozeal 284 | def test_allows_fnmatch_in_global_allowlist(settings): 285 | settings.ZEAL_ALLOWLIST = [{"model": "social.U[sb]er", "field": "p?st"}] 286 | with zeal_context(): 287 | pass 288 | 289 | 290 | def test_validates_local_allowlist_model_name(): 291 | with pytest.raises( 292 | ZealConfigError, 293 | match=re.escape("Model 'foo' not found in installed Django models"), 294 | ): 295 | with zeal_ignore([{"model": "foo", "field": "*"}]): 296 | pass 297 | 298 | 299 | def test_validates_local_allowlist_field_name(): 300 | with pytest.raises( 301 | ZealConfigError, 302 | match=re.escape("Field 'foo' not found on 'social.User'"), 303 | ): 304 | with zeal_ignore([{"model": "social.User", "field": "foo"}]): 305 | pass 306 | 307 | 308 | def test_allows_fnmatch_in_local_allowlist(): 309 | with zeal_ignore([{"model": "social.U[sb]er", "field": "p?st"}]): 310 | pass 311 | 312 | 313 | def test_validates_related_name_field_names(): 314 | # User.following is a M2M field with related_name followers 315 | with zeal_ignore([{"model": "social.User", "field": "followers"}]): 316 | pass 317 | 318 | # User.blocked is a M2M field with an auto-generated related name (user_set) 319 | with zeal_ignore([{"model": "social.User", "field": "user_set"}]): 320 | pass 321 | 322 | 323 | @pytest.mark.nozeal 324 | def test_handles_zeal_ignore_when_disabled(): 325 | with zeal_ignore([{"model": "social.User", "field": "post"}]): 326 | pass 327 | -------------------------------------------------------------------------------- /tests/test_nplusones.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | from django.db import connection 5 | from django.test.utils import CaptureQueriesContext 6 | from djangoproject.social.models import Post, Profile, User 7 | from zeal import NPlusOneError, zeal_context 8 | from zeal.listeners import zeal_ignore 9 | 10 | from .factories import PostFactory, ProfileFactory, UserFactory 11 | 12 | pytestmark = pytest.mark.django_db 13 | 14 | 15 | def test_detects_nplusone_in_forward_many_to_one(): 16 | [user_1, user_2] = UserFactory.create_batch(2) 17 | PostFactory.create(author=user_1) 18 | PostFactory.create(author=user_2) 19 | with pytest.raises( 20 | NPlusOneError, match=re.escape("N+1 detected on social.Post.author") 21 | ): 22 | for post in Post.objects.all(): 23 | _ = post.author.username 24 | 25 | for post in Post.objects.select_related("author").all(): 26 | _ = post.author.username 27 | 28 | 29 | def test_detects_nplusone_in_forward_many_to_one_iterator(): 30 | for _ in range(4): 31 | user = UserFactory.create() 32 | PostFactory.create(author=user) 33 | 34 | with pytest.raises( 35 | NPlusOneError, match=re.escape("N+1 detected on social.Post.author") 36 | ): 37 | for post in Post.objects.all().iterator(chunk_size=2): 38 | _ = post.author.username 39 | 40 | for post in Post.objects.select_related("author").iterator(chunk_size=2): 41 | _ = post.author.username 42 | 43 | 44 | def test_handles_prefetch_instead_of_select_related_in_forward_many_to_one(): 45 | user_1, user_2 = UserFactory.create_batch(2) 46 | PostFactory(author=user_1) 47 | PostFactory(author=user_2) 48 | with CaptureQueriesContext(connection) as ctx: 49 | # this should be a select_related! but we need to handle it even if someone 50 | # has accidentally used the wrong method. 51 | for post in Post.objects.prefetch_related("author").all(): 52 | _ = post.author.username 53 | assert len(ctx.captured_queries) == 2 54 | 55 | 56 | def test_no_false_positive_when_loading_single_object_forward_many_to_one(): 57 | user = UserFactory.create() 58 | post_1, post_2 = PostFactory.create_batch(2, author=user) 59 | 60 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 61 | post_1 = Post.objects.filter(pk=post_1.pk).first() 62 | post_2 = Post.objects.filter(pk=post_2.pk).first() 63 | assert post_1 is not None and post_2 is not None 64 | # queries on `post` should not raise an exception, because `post` was 65 | # singly-loaded 66 | _ = post_1.author 67 | _ = post_2.author 68 | assert len(ctx.captured_queries) == 4 69 | 70 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 71 | # same when using a slice to get a single record 72 | post_1 = Post.objects.filter(pk=post_1.pk).all()[0] 73 | post_2 = Post.objects.filter(pk=post_2.pk).all()[0] 74 | _ = post_1.author 75 | _ = post_2.author 76 | assert len(ctx.captured_queries) == 4 77 | 78 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 79 | # similarly, when using `.get()`, no N+1 error 80 | post_1 = Post.objects.get(pk=post_1.pk) 81 | post_2 = Post.objects.get(pk=post_2.pk) 82 | _ = post_1.author 83 | _ = post_2.author 84 | assert len(ctx.captured_queries) == 4 85 | 86 | 87 | def test_detects_nplusone_in_reverse_many_to_one(): 88 | [user_1, user_2] = UserFactory.create_batch(2) 89 | PostFactory.create(author=user_1) 90 | PostFactory.create(author=user_2) 91 | with pytest.raises( 92 | NPlusOneError, match=re.escape("N+1 detected on social.User.posts") 93 | ): 94 | for user in User.objects.all(): 95 | _ = list(user.posts.all()) 96 | 97 | for user in User.objects.prefetch_related("posts").all(): 98 | _ = list(user.posts.all()) 99 | 100 | 101 | def test_detects_nplusone_in_reverse_many_to_one_iterator(): 102 | for _ in range(4): 103 | user = UserFactory.create() 104 | PostFactory.create(author=user) 105 | with pytest.raises( 106 | NPlusOneError, match=re.escape("N+1 detected on social.User.posts") 107 | ): 108 | for user in User.objects.all().iterator(chunk_size=2): 109 | _ = list(user.posts.all()) 110 | 111 | for user in User.objects.prefetch_related("posts").iterator(chunk_size=2): 112 | _ = list(user.posts.all()) 113 | 114 | 115 | def test_no_false_positive_when_calling_reverse_many_to_one_twice(): 116 | user = UserFactory.create() 117 | PostFactory.create(author=user) 118 | 119 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 120 | queryset = user.posts.all() 121 | list(queryset) # evaluate queryset once 122 | list(queryset) # evalute again (cached) 123 | assert len(ctx.captured_queries) == 1 124 | 125 | 126 | def test_detects_nplusone_in_forward_one_to_one(): 127 | [user_1, user_2] = UserFactory.create_batch(2) 128 | ProfileFactory.create(user=user_1) 129 | ProfileFactory.create(user=user_2) 130 | with pytest.raises( 131 | NPlusOneError, match=re.escape("N+1 detected on social.Profile.user") 132 | ): 133 | for profile in Profile.objects.all(): 134 | _ = profile.user.username 135 | 136 | for profile in Profile.objects.select_related("user").all(): 137 | _ = profile.user.username 138 | 139 | 140 | def test_detects_nplusone_in_forward_one_to_one_iterator(): 141 | for _ in range(4): 142 | user = UserFactory.create() 143 | ProfileFactory.create(user=user) 144 | with pytest.raises( 145 | NPlusOneError, match=re.escape("N+1 detected on social.Profile.user") 146 | ): 147 | for profile in Profile.objects.all().iterator(chunk_size=2): 148 | _ = profile.user.username 149 | 150 | for profile in Profile.objects.select_related("user").iterator( 151 | chunk_size=2 152 | ): 153 | _ = profile.user.username 154 | 155 | 156 | def test_handles_prefetch_instead_of_select_related_in_forward_one_to_one(): 157 | user_1, user_2 = UserFactory.create_batch(2) 158 | ProfileFactory.create(user=user_1) 159 | ProfileFactory.create(user=user_2) 160 | with CaptureQueriesContext(connection) as ctx: 161 | # this should be a select_related! but we need to handle it even if someone 162 | # has accidentally used the wrong method. 163 | for profile in Profile.objects.prefetch_related("user").all(): 164 | _ = profile.user.username 165 | assert len(ctx.captured_queries) == 2 166 | 167 | 168 | def test_no_false_positive_when_loading_single_object_forward_one_to_one(): 169 | user_1, user_2 = UserFactory.create_batch(2) 170 | profile_1 = ProfileFactory.create(user=user_1) 171 | profile_2 = ProfileFactory.create(user=user_2) 172 | 173 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 174 | profile_1 = Profile.objects.filter(pk=profile_1.pk).first() 175 | profile_2 = Profile.objects.filter(pk=profile_2.pk).first() 176 | assert profile_1 is not None and profile_2 is not None 177 | _ = profile_1.user.username 178 | _ = profile_2.user.username 179 | assert len(ctx.captured_queries) == 4 180 | 181 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 182 | profile_1 = Profile.objects.filter(pk=profile_1.pk)[0] 183 | profile_2 = Profile.objects.filter(pk=profile_2.pk)[0] 184 | _ = profile_1.user.username 185 | _ = profile_2.user.username 186 | assert len(ctx.captured_queries) == 4 187 | 188 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 189 | profile_1 = Profile.objects.get(pk=profile_1.pk) 190 | profile_2 = Profile.objects.get(pk=profile_2.pk) 191 | _ = profile_1.user.username 192 | _ = profile_2.user.username 193 | assert len(ctx.captured_queries) == 4 194 | 195 | 196 | def test_detects_nplusone_in_reverse_one_to_one(): 197 | [user_1, user_2] = UserFactory.create_batch(2) 198 | ProfileFactory.create(user=user_1) 199 | ProfileFactory.create(user=user_2) 200 | with pytest.raises( 201 | NPlusOneError, match=re.escape("N+1 detected on social.User.profile") 202 | ): 203 | for user in User.objects.all(): 204 | _ = user.profile.display_name 205 | 206 | for user in User.objects.select_related("profile").all(): 207 | _ = user.profile.display_name 208 | 209 | 210 | def test_detects_nplusone_in_reverse_one_to_one_iterator(): 211 | for _ in range(4): 212 | user = UserFactory.create() 213 | ProfileFactory.create(user=user) 214 | with pytest.raises( 215 | NPlusOneError, match=re.escape("N+1 detected on social.User.profile") 216 | ): 217 | for user in User.objects.all().iterator(chunk_size=2): 218 | _ = user.profile.display_name 219 | 220 | for user in User.objects.select_related("profile").iterator(chunk_size=2): 221 | _ = user.profile.display_name 222 | 223 | 224 | def test_handles_prefetch_instead_of_select_related_in_reverse_one_to_one(): 225 | [user_1, user_2] = UserFactory.create_batch(2) 226 | ProfileFactory.create(user=user_1) 227 | ProfileFactory.create(user=user_2) 228 | 229 | with CaptureQueriesContext(connection) as ctx: 230 | # this should be a select_related! but we need to handle it even if someone 231 | # has accidentally used the wrong method. 232 | for user in User.objects.prefetch_related("profile").all(): 233 | _ = user.profile.display_name 234 | assert len(ctx.captured_queries) == 2 235 | 236 | 237 | def test_no_false_positive_when_loading_single_object_reverse_one_to_one(): 238 | user_1, user_2 = UserFactory.create_batch(2) 239 | ProfileFactory.create(user=user_1) 240 | ProfileFactory.create(user=user_2) 241 | 242 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 243 | user_1 = User.objects.filter(pk=user_1.pk).first() 244 | user_2 = User.objects.filter(pk=user_2.pk).first() 245 | assert user_1 is not None and user_2 is not None 246 | _ = user_1.profile.display_name 247 | _ = user_2.profile.display_name 248 | assert len(ctx.captured_queries) == 4 249 | 250 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 251 | user_1 = User.objects.filter(pk=user_1.pk)[0] 252 | user_2 = User.objects.filter(pk=user_2.pk)[0] 253 | _ = user_1.profile.display_name 254 | _ = user_2.profile.display_name 255 | assert len(ctx.captured_queries) == 4 256 | 257 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 258 | user_1 = User.objects.get(pk=user_1.pk) 259 | user_2 = User.objects.get(pk=user_2.pk) 260 | _ = user_1.profile.display_name 261 | _ = user_2.profile.display_name 262 | assert len(ctx.captured_queries) == 4 263 | 264 | 265 | def test_detects_nplusone_in_forward_many_to_many(): 266 | [user_1, user_2] = UserFactory.create_batch(2) 267 | user_1.following.add(user_2) 268 | user_2.following.add(user_1) 269 | with pytest.raises( 270 | NPlusOneError, match=re.escape("N+1 detected on social.User.following") 271 | ): 272 | for user in User.objects.all(): 273 | _ = list(user.following.all()) 274 | 275 | for user in User.objects.prefetch_related("following").all(): 276 | _ = list(user.following.all()) 277 | 278 | 279 | def test_detects_nplusone_in_forward_many_to_many_iterator(): 280 | influencer = UserFactory.create() 281 | users = UserFactory.create_batch(4) 282 | influencer.followers.set(users) # type: ignore 283 | 284 | with pytest.raises( 285 | NPlusOneError, match=re.escape("N+1 detected on social.User.following") 286 | ): 287 | for user in User.objects.iterator(chunk_size=2): 288 | _ = list(user.following.all()) 289 | 290 | for user in User.objects.prefetch_related("following").iterator( 291 | chunk_size=2 292 | ): 293 | _ = list(user.following.all()) 294 | 295 | 296 | def test_no_false_positive_when_loading_single_object_forward_many_to_many(): 297 | user_1, user_2 = UserFactory.create_batch(2) 298 | user_1.following.add(user_2) 299 | user_2.following.add(user_1) 300 | 301 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 302 | _ = user_1.following.first().username 303 | _ = user_2.following.first().username 304 | assert len(ctx.captured_queries) == 2 305 | 306 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 307 | _ = user_1.following.all()[0].username 308 | _ = user_2.following.all()[0].username 309 | assert len(ctx.captured_queries) == 2 310 | 311 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 312 | _ = user_1.following.get(pk=user_2.pk).username 313 | _ = user_2.following.get(pk=user_1.pk).username 314 | assert len(ctx.captured_queries) == 2 315 | 316 | 317 | def test_detects_nplusone_in_reverse_many_to_many(): 318 | [user_1, user_2] = UserFactory.create_batch(2) 319 | user_1.following.add(user_2) 320 | user_2.following.add(user_1) 321 | with pytest.raises( 322 | NPlusOneError, match=re.escape("N+1 detected on social.User.followers") 323 | ): 324 | for user in User.objects.all(): 325 | _ = list(user.followers.all()) 326 | 327 | for user in User.objects.prefetch_related("followers").all(): 328 | _ = list(user.followers.all()) 329 | 330 | 331 | def test_detects_nplusone_in_reverse_many_to_many_iterator(): 332 | follower = UserFactory.create() 333 | users = UserFactory.create_batch(4) 334 | follower.following.set(users) # type: ignore 335 | with pytest.raises( 336 | NPlusOneError, match=re.escape("N+1 detected on social.User.followers") 337 | ): 338 | for user in User.objects.all().iterator(chunk_size=2): 339 | _ = list(user.followers.all()) 340 | 341 | for user in ( 342 | User.objects.prefetch_related("followers").all().iterator(chunk_size=2) 343 | ): 344 | _ = list(user.followers.all()) 345 | 346 | 347 | def test_no_false_positive_when_loading_single_object_reverse_many_to_many(): 348 | user_1, user_2 = UserFactory.create_batch(2) 349 | user_1.following.add(user_2) 350 | user_2.following.add(user_1) 351 | 352 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 353 | _ = user_1.followers.first().username 354 | _ = user_2.followers.first().username 355 | assert len(ctx.captured_queries) == 2 356 | 357 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 358 | _ = user_1.followers.all()[0].username 359 | _ = user_2.followers.all()[0].username 360 | assert len(ctx.captured_queries) == 2 361 | 362 | with zeal_context(), CaptureQueriesContext(connection) as ctx: 363 | _ = user_1.followers.get(pk=user_2.pk).username 364 | _ = user_2.followers.get(pk=user_1.pk).username 365 | assert len(ctx.captured_queries) == 2 366 | 367 | 368 | def test_detects_nplusone_in_forward_many_to_many_with_no_related_name(): 369 | [user_1, user_2] = UserFactory.create_batch(2) 370 | user_1.blocked.add(user_2) 371 | user_2.blocked.add(user_1) 372 | with pytest.raises( 373 | NPlusOneError, match=re.escape("N+1 detected on social.User.blocked") 374 | ): 375 | for user in User.objects.all(): 376 | _ = list(user.blocked.all()) 377 | 378 | for user in User.objects.prefetch_related("blocked").all(): 379 | _ = list(user.blocked.all()) 380 | 381 | 382 | def test_detects_nplusone_in_reverse_many_to_many_with_no_related_name(): 383 | [user_1, user_2] = UserFactory.create_batch(2) 384 | user_1.blocked.add(user_2) 385 | user_2.blocked.add(user_1) 386 | with pytest.raises( 387 | NPlusOneError, match=re.escape("N+1 detected on social.User.user_set") 388 | ): 389 | for user in User.objects.all(): 390 | _ = list(user.user_set.all()) 391 | 392 | for user in User.objects.prefetch_related("user_set").all(): 393 | _ = list(user.user_set.all()) 394 | 395 | 396 | def test_detects_nplusone_due_to_deferred_fields(): 397 | [user_1, user_2] = UserFactory.create_batch(2) 398 | PostFactory.create(author=user_1) 399 | PostFactory.create(author=user_2) 400 | with pytest.raises( 401 | NPlusOneError, match=re.escape("N+1 detected on social.User.username") 402 | ): 403 | for post in ( 404 | Post.objects.all().select_related("author").only("author__id") 405 | ): 406 | _ = post.author.username 407 | 408 | for post in ( 409 | Post.objects.all().select_related("author").only("author__username") 410 | ): 411 | _ = post.author.username 412 | 413 | 414 | def test_detects_nplusone_due_to_deferred_fields_in_iterator(): 415 | for _ in range(4): 416 | user = UserFactory.create() 417 | PostFactory.create(author=user) 418 | with pytest.raises( 419 | NPlusOneError, match=re.escape("N+1 detected on social.User.username") 420 | ): 421 | for post in ( 422 | Post.objects.all() 423 | .select_related("author") 424 | .only("author__id") 425 | .iterator(chunk_size=2) 426 | ): 427 | _ = post.author.username 428 | 429 | for post in ( 430 | Post.objects.all() 431 | .select_related("author") 432 | .only("author__username") 433 | .iterator(chunk_size=2) 434 | ): 435 | _ = post.author.username 436 | 437 | 438 | def test_handles_prefetch_instead_of_select_related_with_deferred_fields(): 439 | [user_1, user_2] = UserFactory.create_batch(2) 440 | PostFactory.create(author=user_1) 441 | PostFactory.create(author=user_2) 442 | 443 | with CaptureQueriesContext(connection) as ctx: 444 | # this should be a select_related! but we need to handle it even if someone 445 | # has accidentally used the wrong method. 446 | for post in ( 447 | Post.objects.all() 448 | .prefetch_related("author") 449 | .only("author__username") 450 | ): 451 | _ = post.author.username 452 | assert len(ctx.captured_queries) == 2 453 | 454 | 455 | def test_has_configurable_threshold(settings): 456 | settings.ZEAL_NPLUSONE_THRESHOLD = 3 457 | [user_1, user_2] = UserFactory.create_batch(2) 458 | PostFactory.create(author=user_1) 459 | PostFactory.create(author=user_2) 460 | 461 | for post in Post.objects.all(): 462 | _ = post.author.username 463 | 464 | 465 | @zeal_ignore() 466 | def test_does_nothing_if_not_in_middleware(settings, client): 467 | settings.MIDDLEWARE = [] 468 | [user_1, user_2] = UserFactory.create_batch(2) 469 | ProfileFactory.create(user=user_1) 470 | ProfileFactory.create(user=user_2) 471 | 472 | # this does not raise an N+1 error even though the same 473 | # related field gets hit twice 474 | response = client.get(f"/user/{user_1.pk}/") 475 | assert response.status_code == 200 476 | response = client.get(f"/user/{user_2.pk}/") 477 | assert response.status_code == 200 478 | 479 | 480 | def test_works_in_web_requests(client): 481 | [user_1, user_2] = UserFactory.create_batch(2) 482 | ProfileFactory.create(user=user_1) 483 | ProfileFactory.create(user=user_2) 484 | with pytest.raises( 485 | NPlusOneError, match=re.escape(r"N+1 detected on social.User.profile") 486 | ): 487 | response = client.get("/users/") 488 | 489 | # but multiple requests work fine 490 | response = client.get(f"/user/{user_1.pk}/") 491 | assert response.status_code == 200 492 | response = client.get(f"/user/{user_2.pk}/") 493 | assert response.status_code == 200 494 | 495 | 496 | def test_ignores_calls_on_different_lines(): 497 | [user_1, user_2] = UserFactory.create_batch(2) 498 | PostFactory.create(author=user_1) 499 | PostFactory.create(author=user_2) 500 | 501 | # this should *not* raise an exception 502 | _a = list(user_1.posts.all()) 503 | _b = list(user_2.posts.all()) 504 | -------------------------------------------------------------------------------- /tests/test_patch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | from django.db import models 5 | from djangoproject.social.models import User 6 | 7 | from tests.factories import UserFactory 8 | 9 | pytestmark = pytest.mark.django_db 10 | 11 | 12 | def test_handles_calling_queryset_many_times(): 13 | UserFactory.create() 14 | user = User.objects.prefetch_related("posts").all()[0] 15 | for _ in range(sys.getrecursionlimit() + 1): 16 | # this should *not* raise a recursion error 17 | list(user.posts.all()) 18 | 19 | 20 | def test_handles_empty_querysets(): 21 | User.objects.none().first() 22 | 23 | 24 | def test_handles_get_with_values(): 25 | user = UserFactory.create() 26 | User.objects.filter(pk=user.pk).values("username").get() 27 | 28 | 29 | class CustomEqualityModel(models.Model): 30 | """Model that implements custom equality checking using related fields""" 31 | 32 | name: models.CharField = models.CharField(max_length=100) 33 | relation: models.ForeignKey[ 34 | "CustomEqualityModel", "CustomEqualityModel" 35 | ] = models.ForeignKey( 36 | "self", null=True, on_delete=models.CASCADE, related_name="related" 37 | ) 38 | 39 | def __eq__(self, other: object) -> bool: 40 | if not isinstance(other, CustomEqualityModel): 41 | return NotImplemented 42 | # Explicitly access relation to trigger potential recursion 43 | my_rel = self.relation 44 | other_rel = other.relation 45 | return my_rel == other_rel and self.name == other.name 46 | 47 | class Meta: 48 | app_label = "social" 49 | 50 | 51 | def test_handles_custom_equality_with_relations(): 52 | """ 53 | Ensure model equality comparisons don't cause infinite recursion 54 | when __eq__ methods access related fields. This is important because 55 | Django's lazy loading could trigger repeated relation lookups during 56 | equality checks. 57 | """ 58 | # Create test instances 59 | base = CustomEqualityModel.objects.create(name="base") 60 | obj1 = CustomEqualityModel.objects.create(name="test1", relation=base) 61 | obj2 = CustomEqualityModel.objects.create(name="test1", relation=base) 62 | obj3 = CustomEqualityModel.objects.create(name="test2", relation=base) 63 | 64 | assert obj1 == obj1 # Same object 65 | assert obj1 == obj2 # Different objects, same values 66 | assert obj1 != obj3 # Different values 67 | 68 | result = CustomEqualityModel.objects.filter(name="test1").first() 69 | assert result is not None 70 | _ = result.relation 71 | 72 | 73 | def test_handles_nested_relation_equality(): 74 | """ 75 | Ensure deep relation traversal works correctly without infinite recursion. 76 | This is particularly important for models that compare relations in their 77 | equality checks, as each comparison could potentially trigger a chain of 78 | database lookups through the relationship tree. 79 | """ 80 | root = CustomEqualityModel.objects.create(name="root") 81 | middle = CustomEqualityModel.objects.create(name="middle", relation=root) 82 | leaf1 = CustomEqualityModel.objects.create(name="leaf", relation=middle) 83 | leaf2 = CustomEqualityModel.objects.create(name="leaf", relation=middle) 84 | 85 | assert leaf1 == leaf2 86 | assert leaf1.relation == leaf2.relation 87 | 88 | result = CustomEqualityModel.objects.filter(name="leaf").first() 89 | assert result is not None 90 | _ = result.relation.relation 91 | -------------------------------------------------------------------------------- /tests/test_performance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from djangoproject.social.models import Post, Profile, User 3 | from zeal import zeal_context, zeal_ignore 4 | 5 | from .factories import PostFactory, ProfileFactory, UserFactory 6 | 7 | pytestmark = [pytest.mark.nozeal, pytest.mark.django_db] 8 | 9 | 10 | def test_performance(benchmark): 11 | users = UserFactory.create_batch(10) 12 | 13 | # everyone follows everyone 14 | user_following_relations = [] 15 | for user in users: 16 | for followee in users: 17 | if user == followee: 18 | continue 19 | user_following_relations.append( 20 | User.following.through( 21 | from_user_id=user.id, to_user_id=followee.id 22 | ) 23 | ) 24 | User.following.through.objects.bulk_create(user_following_relations) 25 | 26 | # give everyone a profile 27 | for user in users: 28 | ProfileFactory(user=user) 29 | 30 | # everyone has 10 posts 31 | for user in users: 32 | PostFactory.create_batch(10, author=user) 33 | 34 | @benchmark 35 | def _run_benchmark(): 36 | with ( 37 | zeal_context(), 38 | zeal_ignore(), 39 | ): 40 | # Test forward & reverse many-to-one relationships (Post -> User, User -> Posts) 41 | posts = Post.objects.all() 42 | for post in posts: 43 | _ = post.author.username # forward many-to-one 44 | _ = list(post.author.posts.all()) # reverse many-to-one 45 | 46 | # Test forward & reverse one-to-one relationships (Profile -> User, User -> Profile) 47 | profiles = Profile.objects.all() 48 | for profile in profiles: 49 | _ = profile.user.username # forward one-to-one 50 | _ = profile.user.profile.display_name # reverse one-to-one 51 | 52 | # Test forward & reverse many-to-many relationships 53 | users = User.objects.all() 54 | for user in users: 55 | _ = list(user.following.all()) # forward many-to-many 56 | _ = list(user.followers.all()) # reverse many-to-many 57 | _ = list( 58 | user.blocked.all() 59 | ) # many-to-many without related_name 60 | 61 | # Test chained relationships 62 | for follower in user.followers.all(): 63 | _ = follower.profile.display_name 64 | _ = list(follower.posts.all()) 65 | -------------------------------------------------------------------------------- /tests/test_signals.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | import pytest_django 5 | import pytest_mock 6 | from djangoproject.social import models 7 | from zeal import errors 8 | from zeal.listeners import n_plus_one_listener 9 | 10 | from . import factories 11 | 12 | pytestmark = pytest.mark.django_db 13 | 14 | 15 | def test_signal_send_message( 16 | monkeypatch: pytest.MonkeyPatch, 17 | mocker: pytest_mock.MockerFixture, 18 | settings: pytest_django.fixtures.SettingsWrapper, 19 | ): 20 | """Test signal send message after detecting N+1 query.""" 21 | settings.ZEAL_RAISE = False 22 | patched_signal = mocker.patch( 23 | "zeal.listeners.nplusone_detected.send", 24 | ) 25 | user_1, user_2 = factories.UserFactory.create_batch(2) 26 | factories.PostFactory.create(author=user_1) 27 | factories.PostFactory.create(author=user_2) 28 | with warnings.catch_warnings(record=True): 29 | warnings.simplefilter("always") 30 | _ = [post.author.username for post in models.Post.objects.all()] 31 | patched_signal.assert_called_once() 32 | sender = patched_signal.call_args[1]["sender"] 33 | assert sender == n_plus_one_listener 34 | exception = patched_signal.call_args[1]["exception"] 35 | assert isinstance(exception, errors.NPlusOneError) 36 | assert "N+1 detected on social.Post.author" in str(exception) 37 | --------------------------------------------------------------------------------