├── .github └── workflows │ ├── coverage.yml │ ├── github_release.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── changelog.md ├── conf.py ├── contibuting.md ├── customization.md ├── error_response.md ├── faq.md ├── gotchas.md ├── index.md ├── make.bat ├── openapi.md ├── openapi_sample_description.md ├── quickstart.md ├── release.md └── settings.md ├── drf_standardized_errors ├── __init__.py ├── apps.py ├── formatter.py ├── handler.py ├── openapi.py ├── openapi_hooks.py ├── openapi_serializers.py ├── openapi_utils.py ├── openapi_validation_errors.py ├── py.typed ├── settings.py └── types.py ├── pyproject.toml ├── release └── update_changelog.py ├── tests ├── __init__.py ├── conftest.py ├── models.py ├── settings.py ├── test_exception_handler.py ├── test_flatten_errors.py ├── test_openapi.py ├── test_openapi_utils.py ├── test_openapi_validation_errors.py ├── test_settings.py ├── urls.py ├── utils.py └── views.py └── tox.ini /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Coverage 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | run: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Setup Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: 3.8 16 | - name: Generate coverage report 17 | run: | 18 | pip install . 19 | pip install django~=3.2 pytest pytest-django pytest-cov drf-spectacular django-filter 20 | pytest --cov --cov-report=xml 21 | - name: Upload coverage to Codecov 22 | uses: codecov/codecov-action@v4 23 | with: 24 | token: ${{ secrets.CODECOV_TOKEN }} 25 | fail_ci_if_error: true 26 | verbose: true 27 | -------------------------------------------------------------------------------- /.github/workflows/github_release.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/actions/changelog-reader#example-workflow---create-a-release-from-changelog 2 | on: 3 | push: 4 | tags: 5 | - "v*" 6 | 7 | name: Create GitHub Release 8 | 9 | jobs: 10 | build: 11 | name: Parse release notes from changelog and create a GitHub release 12 | runs-on: ubuntu-latest 13 | permissions: 14 | contents: write 15 | steps: 16 | - name: Get version from tag 17 | id: tag_name 18 | run: | 19 | echo ::set-output name=current_version::${GITHUB_REF#refs/tags/v} 20 | shell: bash 21 | - name: Checkout code 22 | uses: actions/checkout@v4 23 | - name: Get Changelog Entry 24 | id: changelog_reader 25 | uses: mindsers/changelog-reader-action@v2 26 | with: 27 | version: ${{ steps.tag_name.outputs.current_version }} 28 | path: docs/changelog.md 29 | - name: Create/update release 30 | uses: ncipollo/release-action@v1 31 | with: 32 | name: Release v${{ steps.changelog_reader.outputs.version }} 33 | body: ${{ steps.changelog_reader.outputs.changes }} 34 | allowUpdates: true 35 | token: ${{ secrets.GITHUB_TOKEN }} 36 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ${{ matrix.platform }} 10 | strategy: 11 | matrix: 12 | platform: [ ubuntu-latest, macos-latest, windows-latest ] 13 | python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12", "3.13" ] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: pip install .[test] 23 | - name: Run tox 24 | run: tox 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | env/ 9 | build/ 10 | develop-eggs/ 11 | dist/ 12 | downloads/ 13 | eggs/ 14 | .eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # Unit test / coverage reports 26 | htmlcov/ 27 | .tox/ 28 | .coverage 29 | .coverage.* 30 | .cache 31 | nosetests.xml 32 | coverage.xml 33 | *.cover 34 | .pytest_cache/ 35 | 36 | # Sphinx documentation 37 | docs/_build/ 38 | 39 | # virtualenv 40 | .venv 41 | venv/ 42 | ENV/ 43 | 44 | # IDE settings 45 | .vscode/ 46 | .idea/ 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^docs/|\.tox|\.git|venv|^dist' 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | 11 | - repo: https://github.com/Zac-HD/shed 12 | rev: 2025.6.1 13 | hooks: 14 | - id: shed 15 | additional_dependencies: [ 'black~=25.1' ] 16 | types_or: [ python, pyi, markdown, rst ] 17 | 18 | - repo: https://github.com/PyCQA/flake8 19 | rev: 7.3.0 20 | hooks: 21 | - id: flake8 22 | args: [ --max-line-length, '120' , --ignore, 'E,W'] 23 | 24 | - repo: https://github.com/pre-commit/mirrors-mypy 25 | rev: v1.16.1 26 | hooks: 27 | - id: mypy 28 | args: [ --ignore-missing-imports, --check-untyped-defs, --disallow-incomplete-defs ] 29 | exclude: '^tests' 30 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | install: 5 | - method: pip 6 | path: . 7 | extra_requirements: 8 | - doc 9 | - openapi 10 | 11 | build: 12 | os: ubuntu-22.04 13 | tools: 14 | python: "3.12" 15 | 16 | sphinx: 17 | configuration: docs/conf.py 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 Ghazi Abbassi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRF Standardized Errors 2 | 3 | Standardize your [DRF](https://www.django-rest-framework.org/) API error responses. 4 | 5 | [![Read the Docs](https://img.shields.io/readthedocs/drf-standardized-errors)](https://drf-standardized-errors.readthedocs.io/en/latest/) 6 | [![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/ghazi-git/drf-standardized-errors/tests.yml?branch=main&label=Tests&logo=GitHub)](https://github.com/ghazi-git/drf-standardized-errors/actions/workflows/tests.yml) 7 | [![codecov](https://codecov.io/gh/ghazi-git/drf-standardized-errors/branch/main/graph/badge.svg?token=JXTTT1KVBR)](https://codecov.io/gh/ghazi-git/drf-standardized-errors) 8 | [![PyPI](https://img.shields.io/pypi/v/drf-standardized-errors)](https://pypi.org/project/drf-standardized-errors/) 9 | [![PyPI - License](https://img.shields.io/pypi/l/drf-standardized-errors)](https://github.com/ghazi-git/drf-standardized-errors/blob/main/LICENSE) 10 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 11 | 12 | By default, the package will convert all API error responses (4xx and 5xx) to the following standardized format: 13 | ```json 14 | { 15 | "type": "validation_error", 16 | "errors": [ 17 | { 18 | "code": "required", 19 | "detail": "This field is required.", 20 | "attr": "name" 21 | }, 22 | { 23 | "code": "max_length", 24 | "detail": "Ensure this value has at most 100 characters.", 25 | "attr": "title" 26 | } 27 | ] 28 | } 29 | ``` 30 | ```json 31 | { 32 | "type": "client_error", 33 | "errors": [ 34 | { 35 | "code": "authentication_failed", 36 | "detail": "Incorrect authentication credentials.", 37 | "attr": null 38 | } 39 | ] 40 | } 41 | ``` 42 | ```json 43 | { 44 | "type": "server_error", 45 | "errors": [ 46 | { 47 | "code": "error", 48 | "detail": "A server error occurred.", 49 | "attr": null 50 | } 51 | ] 52 | } 53 | ``` 54 | 55 | 56 | ## Features 57 | 58 | - Highly customizable: gives you flexibility to define your own standardized error responses and override 59 | specific aspects the exception handling process without having to rewrite everything. 60 | - Supports nested serializers and ListSerializer errors 61 | - Plays nicely with error monitoring tools (like Sentry, ...) 62 | 63 | 64 | ## Requirements 65 | 66 | - python >= 3.8 67 | - Django >= 3.2 68 | - DRF >= 3.12 69 | 70 | 71 | ## Quickstart 72 | 73 | Install with `pip` 74 | ```shell 75 | pip install drf-standardized-errors 76 | ``` 77 | 78 | Add drf-standardized-errors to your installed apps 79 | ```python 80 | INSTALLED_APPS = [ 81 | # other apps 82 | "drf_standardized_errors", 83 | ] 84 | ``` 85 | 86 | Register the exception handler 87 | ```python 88 | REST_FRAMEWORK = { 89 | # other settings 90 | "EXCEPTION_HANDLER": "drf_standardized_errors.handler.exception_handler" 91 | } 92 | ``` 93 | 94 | ### Notes 95 | - This package is a DRF exception handler, so it standardizes errors that reach a DRF API view. That means it cannot 96 | handle errors that happen at the middleware level for example. To handle those as well, you can customize 97 | the necessary [django error views](https://docs.djangoproject.com/en/dev/topics/http/views/#customizing-error-views). 98 | You can find more about that in [this issue](https://github.com/ghazi-git/drf-standardized-errors/issues/44). 99 | 100 | - Standardized error responses when `DEBUG=True` for **unhandled exceptions** are disabled by default. That is 101 | to allow you to get more information out of the traceback. You can enable standardized errors instead with: 102 | ```python 103 | DRF_STANDARDIZED_ERRORS = {"ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": True} 104 | ``` 105 | 106 | ## Integration with DRF spectacular 107 | If you plan to use [drf-spectacular](https://github.com/tfranzel/drf-spectacular) to generate an OpenAPI 3 schema, 108 | install with `pip install drf-standardized-errors[openapi]`. After that, check the [doc page](https://drf-standardized-errors.readthedocs.io/en/latest/openapi.html) 109 | for configuring the integration. 110 | 111 | ## Links 112 | 113 | - Documentation: https://drf-standardized-errors.readthedocs.io/en/latest/ 114 | - Changelog: https://github.com/ghazi-git/drf-standardized-errors/releases 115 | - Code & issues: https://github.com/ghazi-git/drf-standardized-errors 116 | - PyPI: https://pypi.org/project/drf-standardized-errors/ 117 | 118 | 119 | ## License 120 | 121 | This project is [MIT licensed](LICENSE). 122 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | github_checks: 2 | annotations: false 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = ./_build 10 | APP = ../weather_data 11 | 12 | 13 | .PHONY: help livehtml Makefile 14 | 15 | # Put it first so that "make" without argument is like "make help". 16 | help: 17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -c . 18 | 19 | # Build, watch and serve docs with live reload 20 | livehtml: 21 | sphinx-autobuild -b html --host 0.0.0.0 --port 9000 --watch $(APP) -c . $(SOURCEDIR) $(BUILDDIR)/html 22 | 23 | # Catch-all target: route all unknown targets to Sphinx using the new 24 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 25 | %: Makefile 26 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -c . 27 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [UNRELEASED] 8 | 9 | ## [0.15.0] - 2025-06-09 10 | ### Added 11 | - add support for python 3.13 12 | - add support for django 5.2 13 | - add support for DRF 3.16 14 | 15 | ### Changed (backward-incompatible) 16 | - Unhandled exceptions now return a generic error message by default. This avoids unintentionally leaking 17 | sensitive data included in the exception message. To revert to the old behavior or change the default error 18 | message: 19 | - create a custom exception handler class 20 | ```python 21 | from rest_framework.exceptions import APIException 22 | from drf_standardized_errors.handler import ExceptionHandler 23 | 24 | class MyExceptionHandler(ExceptionHandler): 25 | def convert_unhandled_exceptions(self, exc: Exception) -> APIException: 26 | if not isinstance(exc, APIException): 27 | # `return APIException(detail=str(exc))` restores the old behavior 28 | return APIException(detail="New error message") 29 | else: 30 | return exc 31 | ``` 32 | - Then, update the settings to point to your exception handler class 33 | ```python 34 | DRF_STANDARDIZED_ERRORS = { 35 | # ... 36 | "EXCEPTION_HANDLER_CLASS": "path.to.MyExceptionHandler" 37 | } 38 | ``` 39 | - set minimum version of drf-spectacular to 0.27.1 40 | - `drf_standardized_errors.types.ErrorType` is now the following type hint 41 | ```python 42 | from typing import Literal 43 | ErrorType = Literal["validation_error", "client_error", "server_error"] 44 | ``` 45 | `ErrorType` was previously an enum. If you referenced its members in your code, make sure to replace their 46 | use cases with the newly added constants: 47 | ``` 48 | from drf_standardized_errors.types import VALIDATION_ERROR, CLIENT_ERROR, SERVER_ERROR 49 | ErrorType.VALIDATION_ERROR --> VALIDATION_ERROR 50 | ErrorType.CLIENT_ERROR --> CLIENT_ERROR 51 | ErrorType.SERVER_ERROR --> SERVER_ERROR 52 | ``` 53 | 54 | ## [0.14.1] - 2024-08-10 55 | ### Added 56 | - declare support for django 5.1 57 | 58 | ### Fixed 59 | - stop ignoring exceptions with detail as an empty string when returning api errors. 60 | 61 | ## [0.14.0] - 2024-06-19 62 | ### Added 63 | - declare support for DRF 3.15 64 | 65 | ### Fixed 66 | - enforce support of only drf-spectacular 0.27 and newer in pyproject.toml 67 | - ensure examples from `@extend_schema_serializer` are not ignored when adding error response examples 68 | - show default error response examples only when the corresponding status code is allowed 69 | - add `"null"` to the error code enum of `non_field_errors` validation errors 70 | 71 | ## [0.13.0] - 2024-02-28 72 | ### Changed 73 | - If you're using drf-spectacular 0.27.0 or newer, update `ENUM_NAME_OVERRIDES` entries to reference `choices` 74 | rather than `values`. The list of overrides specific to this package should become like this: 75 | ```python 76 | SPECTACULAR_SETTINGS = { 77 | # other settings 78 | "ENUM_NAME_OVERRIDES": { 79 | "ValidationErrorEnum": "drf_standardized_errors.openapi_serializers.ValidationErrorEnum.choices", 80 | "ClientErrorEnum": "drf_standardized_errors.openapi_serializers.ClientErrorEnum.choices", 81 | "ServerErrorEnum": "drf_standardized_errors.openapi_serializers.ServerErrorEnum.choices", 82 | "ErrorCode401Enum": "drf_standardized_errors.openapi_serializers.ErrorCode401Enum.choices", 83 | "ErrorCode403Enum": "drf_standardized_errors.openapi_serializers.ErrorCode403Enum.choices", 84 | "ErrorCode404Enum": "drf_standardized_errors.openapi_serializers.ErrorCode404Enum.choices", 85 | "ErrorCode405Enum": "drf_standardized_errors.openapi_serializers.ErrorCode405Enum.choices", 86 | "ErrorCode406Enum": "drf_standardized_errors.openapi_serializers.ErrorCode406Enum.choices", 87 | "ErrorCode415Enum": "drf_standardized_errors.openapi_serializers.ErrorCode415Enum.choices", 88 | "ErrorCode429Enum": "drf_standardized_errors.openapi_serializers.ErrorCode429Enum.choices", 89 | "ErrorCode500Enum": "drf_standardized_errors.openapi_serializers.ErrorCode500Enum.choices", 90 | # other overrides 91 | }, 92 | } 93 | ``` 94 | 95 | ### Added 96 | - add compatibility with drf-spectacular 0.27.x 97 | - add support for django 5.0 98 | 99 | ### Fixed 100 | - Ensure accurate traceback inclusion in 500 error emails sent to ADMINS by capturing the original exception information using `self.exc`. This fixes the issue where tracebacks were previously showing as None for `django version >= 4.1`. 101 | - Handle error responses with +1000 errors 102 | 103 | ## [0.12.6] - 2023-10-25 104 | ### Added 105 | - declare support for type checking 106 | - add support for django 4.2 107 | - add support for python 3.12 108 | 109 | ### Fixed 110 | - Avoid calling `AutoSchema.get_request_serializer` when inspecting a get operation for possible error responses. 111 | 112 | ## [0.12.5] - 2023-01-14 113 | ### Added 114 | - allow adding extra validation errors on an operation-basis using the new `@extend_validation_errors` decorator. 115 | You can find [more information about that in the documentation](openapi.md#customize-error-codes-on-an-operation-basis). 116 | 117 | ### Fixed 118 | - use `model._default_manager` instead of `model.objects`. 119 | - Don't generate error responses for OpenAPI callbacks. 120 | - Make `_should_add_http403_error_response` check if permission is `IsAuthenticated` and 121 | `AllowAny` via `type` instead of `isinstance` 122 | - Don't collect error codes from nested `read_only` fields 123 | 124 | ## [0.12.4] - 2022-12-11 125 | ### Fixed 126 | - account for specifying the request serializer as a basic type (like `OpenApiTypes.STR`) or as a 127 | `PolymorphicProxySerializer` using `@extend_schema(request=...)` when determining error codes for validation errors. 128 | 129 | ## [0.12.3] - 2022-11-13 130 | ### Added 131 | - add support for python 3.11 132 | 133 | ## [0.12.2] - 2022-09-25 134 | ### Added 135 | - When a custom validator class defines a `code` attribute, add it to the list of error codes of raised by 136 | the corresponding field. 137 | - add support for DRF 3.14 138 | 139 | ## [0.12.1] - 2022-09-03 140 | ### Fixed 141 | - generate the mapping for discriminator fields properly instead of showing a "null" value in the generated schema (#12). 142 | 143 | ## [0.12.0] - 2022-08-27 144 | ### Added 145 | - add support for automatically generating error responses schema with [drf-spectacular](https://github.com/tfranzel/drf-spectacular). 146 | Check out the [corresponding documentation page](https://drf-standardized-errors.readthedocs.io/en/latest/openapi.html) 147 | to know more about the integration with drf-spectacular. 148 | - add support for django 4.1 149 | 150 | ## [0.11.0] - 2022-06-24 151 | ### Changed (Backward-incompatible) 152 | - Removed all imports from `drf_standardized_errors.__init__.py`. This avoids facing the `AppRegistryNotReady` error 153 | in certain situations (fixes #7). This change **only affects where functions/classes are imported from**, there are 154 | **no changes to how they work**. To upgrade to this version, you need to: 155 | - Update the `"EXCEPTION_HANDLER"` setting in `REST_FRAMEWORK` to `"drf_standardized_errors.handler.exception_handler"`. 156 | - If you imported the exception handler directly, make sure the import looks like this 157 | `from drf_standardized_errors.handler import exception_handler`. 158 | - If you imported the exception handler class, make sure the import looks like this 159 | `from drf_standardized_errors.handler import ExceptionHandler`. 160 | - If you imported the exception formatter class, make sure the import looks like this 161 | `from drf_standardized_errors.formatter import ExceptionFormatter`. 162 | 163 | ## [0.10.2] - 2022-05-08 164 | ### Fixed 165 | - disable tag creation by the "create GitHub release" action since it is already created by tbump 166 | 167 | ## [0.10.1] - 2022-05-08 168 | ### Fixed 169 | - add write permission to create release action, so it can push release notes to GitHub 170 | - fix license badge link so it works on PyPI 171 | 172 | ## [0.10.0] - 2022-05-08 173 | ### Added 174 | 175 | - Build the documentation automatically on every commit to the main branch. The docs are 176 | [hosted on readthedocs](https://drf-standardized-errors.readthedocs.io/en/latest/). 177 | - Add package metadata 178 | - add a GitHub workflow to create a GitHub release when a new tag is pushed 179 | - add a GitHub workflow to run tests on every push and pull request 180 | - add test coverage 181 | 182 | ## [0.9.0] - 2022-05-07 183 | ### Added 184 | 185 | - Common error response format for DRF-based APIs 186 | - Easily customize the error response format. 187 | - Handle error responses for list serializers and nested serializers. 188 | - Add documentation 189 | - Add tests 190 | - Automate release steps 191 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | import os 13 | import sys 14 | 15 | from django.conf import settings 16 | 17 | settings.configure() 18 | 19 | sys.path.insert(0, os.path.abspath("..")) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "DRF Standardized Errors" 24 | copyright = "2022, Ghazi Abbassi" 25 | author = "Ghazi Abbassi" 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.autodoc", 35 | "sphinx_rtd_theme", 36 | "myst_parser", 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | # templates_path = ["_templates"] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = "sphinx_rtd_theme" 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | # html_static_path = ["_static"] 58 | 59 | myst_heading_anchors = 3 60 | -------------------------------------------------------------------------------- /docs/contibuting.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | ## Contributing 4 | 5 | - Fork this repository. 6 | - Clone the forked repository locally. 7 | - Create a virtual environment, activate it and install development dependencies using `pip install '.[dev]'` 8 | - Install pre-commit hooks with `pre-commit install` 9 | - Create a branch, make changes and commit them. 10 | - Push the changes to your forked repository. 11 | - Create a pull request against **this** repository. 12 | 13 | ## Run tests locally 14 | 15 | You can run tests using tox. 16 | 17 | ```shell 18 | pip install .[test] 19 | tox 20 | ``` 21 | 22 | By default, tests will run against all supported environments. However, if a supported python version is not available 23 | on your machine, you should see an `InterpreterNotFound` error. You can use pyenv to install the needed python versions. 24 | 25 | ## Documentation 26 | 27 | The documentation is built using Sphinx and is written using markdown thanks to MyST Parser. In many cases, knowing 28 | markdown is enough to update the docs, but if that's not the case, please check the 29 | [MyST syntax guide](https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html) or 30 | [this cheatsheet](https://jupyterbook.org/reference/cheatsheet.html). 31 | 32 | For building the documentation locally, first you'll need to install the documentation dependencies with 33 | ```shell 34 | pip install .[doc] 35 | cd docs 36 | make livehtml 37 | ``` 38 | The last command will start a server that makes the docs available at `http://localhost:9000` and rebuilds it 39 | on any change. That is done thanks to [sphinx-autobuild](https://github.com/executablebooks/sphinx-autobuild). 40 | -------------------------------------------------------------------------------- /docs/customization.md: -------------------------------------------------------------------------------- 1 | # Customization 2 | 3 | The idea behind this package is to standardize error responses and make it easier to customize. To accomplish that, 4 | the exception handler was rewritten as a class, so it's easy to subclass and make small customizations. First, we'll 5 | go through a brief description of the flow for generating error response, then we'll check some customizations that 6 | you might want to make. 7 | 8 | ## Exception handling flow 9 | 10 | You're encouraged to read the source code since it's not that much but here's a quick overview. 11 | 12 | - The flow starts with converting known exceptions like `django.core.exceptions.PermissionDenied` and 13 | `django.http.Http404` to [DRF exceptions](https://www.django-rest-framework.org/api-guide/exceptions/#api-reference). 14 | - Any unhandled exception is then converted to an instance of `rest_framework.exceptions.APIException`. 15 | - Afterwards, the exception data is extracted and formatted, and the error response is generated with 16 | the correct headers. 17 | - Finally, if the exception is a server error (status code is 5xx) then it is logged and the signal 18 | `got_request_exception` is sent out. This helps the 19 | [django test client](https://github.com/django/django/blob/1b3c0d3b54d4ff5f75af57d3130180b1d22468e9/django/test/client.py#L712) 20 | or an [error monitoring tool like Sentry](https://github.com/getsentry/sentry-python/blob/d880f47add3876d5cedefb4178a1dcd4d85b5d1b/sentry_sdk/integrations/django/__init__.py#L138) 21 | capture exception details. 22 | 23 | 24 | ## Sample customizations 25 | 26 | ### Handle a non-DRF exception 27 | 28 | This can be done the same way as what [DRF recommends](https://www.django-rest-framework.org/api-guide/exceptions/#apiexception): 29 | - Create a new exception class by inheriting from `APIException` and setting the `default_detail` and `default_code` 30 | attributes. 31 | - Also, set the `status_code` attribute, but keep in mind that the status code is used to determine the error type. 32 | A 4xx status code results in a `client_error` and a 5xx results in a `server_error`. 33 | - In your view, you can now raise the new exception, and it will be handled appropriately. 34 | 35 | Also, you can customize the exception handler instead of raising the new exception in your code: 36 | - Assuming the example from DRF docs for a `ServiceUnavailable` exception 37 | ```python 38 | from rest_framework.exceptions import APIException 39 | 40 | class ServiceUnavailable(APIException): 41 | status_code = 503 42 | default_detail = 'Service temporarily unavailable, try again later.' 43 | default_code = 'service_unavailable' 44 | ``` 45 | - You need to subclass `drf_standardized_errors.handler.ExceptionHandler` and override `convert_known_exceptions` 46 | ``` 47 | import requests 48 | from drf_standardized_errors.handler import ExceptionHandler 49 | 50 | class MyExceptionHandler(ExceptionHandler): 51 | def convert_known_exceptions(self, exc: Exception) -> Exception: 52 | if isinstance(exc, requests.Timeout): 53 | return ServiceUnavailable() 54 | else: 55 | return super().convert_known_exceptions(exc) 56 | ``` 57 | Then, update the setting to point to your exception handler class 58 | ```python 59 | DRF_STANDARDIZED_ERRORS = {"EXCEPTION_HANDLER_CLASS": "path.to.MyExceptionHandler"} 60 | ``` 61 | 62 | ### Change the format of the error response 63 | 64 | Let's say you don't need to return multiple errors, and you don't like some key names in the error response: specifically, 65 | you want to change `detail` to `message` and `attr` to `field_name`. 66 | 67 | You'll need to subclass `ExceptionFormatter` and override `format_error_response`. 68 | ```python 69 | from drf_standardized_errors.formatter import ExceptionFormatter 70 | from drf_standardized_errors.types import ErrorResponse 71 | 72 | class MyExceptionFormatter(ExceptionFormatter): 73 | def format_error_response(self, error_response: ErrorResponse): 74 | error = error_response.errors[0] 75 | return { 76 | "type": error_response.type, 77 | "code": error.code, 78 | "message": error.detail, 79 | "field_name": error.attr 80 | } 81 | ``` 82 | Then, update the corresponding setting 83 | ```python 84 | DRF_STANDARDIZED_ERRORS = {"EXCEPTION_FORMATTER_CLASS": "path.to.MyExceptionFormatter"} 85 | ``` 86 | -------------------------------------------------------------------------------- /docs/error_response.md: -------------------------------------------------------------------------------- 1 | # Error Response Format 2 | 3 | The default error response format looks like this 4 | ```json 5 | { 6 | "type": "validation_error", 7 | "errors": [ 8 | { 9 | "code": "required", 10 | "detail": "This field is required.", 11 | "attr": "name" 12 | } 13 | ] 14 | } 15 | ``` 16 | 17 | - `type`: can be `validation_error`, `client_error` or `server_error` 18 | - `code`: short string describing the error. Can be used by API consumers to customize their behavior. 19 | - `detail`: User-friendly text describing the error. 20 | - `attr`: set only when the error type is a `validation_error` and maps to the serializer field name or `settings.NON_FIELD_ERRORS_KEY`. 21 | 22 | ## Error Types 23 | 24 | ### Validation Errors 25 | 26 | These are ones caused by raising a `rest_framework.exceptions.ValidationError`. They are the only error type 27 | that can have more than 1 error. The list of corresponding error codes depends on the serializer and its fields. 28 | 29 | ### Client Errors 30 | 31 | Covers all 4xx errors other than validation errors. Here's a reference to all possible error codes and corresponding exceptions: 32 | 33 | | Error Code | Status Code | DRF Exception | 34 | | ---------------------- | ----------- | -------------------------------------------------------------------------------------------------------- | 35 | | parse_error | 400 | [ParseError](https://www.django-rest-framework.org/api-guide/exceptions/#parseerror) | 36 | | authentication_failed | 401 | [AuthenticationFailed](https://www.django-rest-framework.org/api-guide/exceptions/#authenticationfailed) | 37 | | not_authenticated | 401 | [NotAuthenticated](https://www.django-rest-framework.org/api-guide/exceptions/#notauthenticated) | 38 | | permission_denied | 403 | [PermissionDenied](https://www.django-rest-framework.org/api-guide/exceptions/#permissiondenied) | 39 | | not_found | 404 | [NotFound](https://www.django-rest-framework.org/api-guide/exceptions/#notfound) | 40 | | method_not_allowed | 405 | [MethodNotAllowed](https://www.django-rest-framework.org/api-guide/exceptions/#methodnotallowed) | 41 | | not_acceptable | 406 | [NotAcceptable](https://www.django-rest-framework.org/api-guide/exceptions/#notacceptable) | 42 | | unsupported_media_type | 415 | [UnsupportedMediaType](https://www.django-rest-framework.org/api-guide/exceptions/#unsupportedmediatype) | 43 | | throttled | 429 | [Throttled](https://www.django-rest-framework.org/api-guide/exceptions/#throttled) | 44 | 45 | ### Server Errors 46 | 47 | These are ones caused by raising a `rest_framework.exceptions.APIException` or by unhandled exceptions. 48 | The corresponding error code is `error` and the status code is 500. 49 | 50 | ## Multiple Errors Support 51 | 52 | Out of the box, only validation errors would result in an error response with multiple errors. The errors 53 | can be for the same field or for different fields. So, a sample DRF error dict like this: 54 | ``` 55 | { 56 | "phone": [ 57 | ErrorDetail("The phone number entered is not valid.", code="invalid_phone_number") 58 | ], 59 | "password": [ 60 | ErrorDetail("This password is too short.", code="password_too_short"), 61 | ErrorDetail("The password is too similar to the username.", code="password_too_similar"), 62 | ], 63 | } 64 | ``` 65 | would be converted to: 66 | ```json 67 | { 68 | "type": "validation_error", 69 | "errors": [ 70 | { 71 | "code": "invalid_phone_number", 72 | "detail": "The phone number entered is not valid.", 73 | "attr": "phone" 74 | }, 75 | { 76 | "code": "password_too_short", 77 | "detail": "This password is too short.", 78 | "attr": "password" 79 | }, 80 | { 81 | "code": "password_too_similar", 82 | "detail": "The password is too similar to the username.", 83 | "attr": "password" 84 | } 85 | ] 86 | } 87 | ``` 88 | 89 | ## Nested Serializers Support 90 | 91 | Taking this example 92 | ``` 93 | { 94 | "shipping_address": { 95 | "non_field_errors": [ErrorDetail("We do not support shipping to the provided address.", code="unsupported")] 96 | } 97 | } 98 | ``` 99 | It will be converted to: 100 | ```json 101 | { 102 | "code": "unsupported", 103 | "detail": "We do not support shipping to the provided address.", 104 | "attr": "shipping_address.non_field_errors" 105 | } 106 | ``` 107 | Note how the `attr` is the combined value of parent serializer field name and nested one. They are separated by 108 | a `.` by default, but that can be changed through the setting `NESTED_FIELD_SEPARATOR`. 109 | 110 | ## List Serializers Support 111 | This example 112 | ``` 113 | { 114 | "recipients": [ 115 | {"name": [ErrorDetail("This field is required.", code="required")]}, 116 | {"email": [ErrorDetail("Enter a valid email address.", code="invalid")]}, 117 | ] 118 | } 119 | ``` 120 | would be converted to 121 | ```json 122 | { 123 | "type": "validation_error", 124 | "errors": [ 125 | { 126 | "code": "required", 127 | "detail": "This field is required.", 128 | "attr": "recipients.0.name" 129 | }, 130 | { 131 | "code": "invalid", 132 | "detail": "Enter a valid email address.", 133 | "attr": "recipients.1.email" 134 | } 135 | ] 136 | } 137 | ``` 138 | Note that distinguishing between errors in different objects in the nested list serializer is done using 139 | 0-based indexing. 140 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # FAQs 2 | 3 | ## Standardized errors are not shown in local development 4 | 5 | By default, standardized error responses when `DEBUG=True` for unhandled exceptions are disabled. 6 | That is to allow you to get more information out of the traceback. You can enable standardized errors 7 | instead with: 8 | 9 | ```python 10 | DRF_STANDARDIZED_ERRORS = {"ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": True} 11 | ``` 12 | 13 | 14 | ## Some exceptions are not converted to the standardized format 15 | 16 | This package is a DRF exception handler, so it only standardizes errors that reach a DRF API view. 17 | That means it cannot handle errors that happen at the middleware level for example. To handle those 18 | as well, you can customize the necessary [django error views](https://docs.djangoproject.com/en/dev/topics/http/views/#customizing-error-views). 19 | You can find more about that in [this issue](https://github.com/ghazi-git/drf-standardized-errors/issues/44). 20 | 21 | 22 | ## I want to let exceptions propagate up the middleware stack 23 | 24 | This might be needed when code written in middleware adds custom logic based on raised exceptions 25 | (either by you or by a third party package). In that case, it is possible to allow the exception 26 | to pass through the DRF exception handler and later convert it to the corresponding error response 27 | in django error views. You can check [this issue](https://github.com/ghazi-git/drf-standardized-errors/issues/91#issuecomment-2397956441) for sample code. 28 | 29 | 30 | ## How can I add extra details about the exception in the error response 31 | 32 | This can be done using a custom exception along with a custom exception formatter. You can find sample 33 | code in [this issue](https://github.com/ghazi-git/drf-standardized-errors/issues/95#issuecomment-2661633736). 34 | Note that this does not work with `ValidationError`s or its subclasses raised in a serializer. That's 35 | because DRF creates new `ValidationError` instances when they are raised. See 36 | [here](https://github.com/encode/django-rest-framework/blob/f30c0e2eedda410a7e6a0d1b351377a9084361b4/rest_framework/serializers.py#L221-L231) 37 | and [here](https://github.com/encode/django-rest-framework/blob/f30c0e2eedda410a7e6a0d1b351377a9084361b4/rest_framework/serializers.py#L443-L448). 38 | 39 | 40 | ## How to integrate this package with djangorestframework-camel-case 41 | 42 | You can check this [issue](https://github.com/ghazi-git/drf-standardized-errors/issues/59#issuecomment-1889826918) 43 | for a possible solution. Still, `djangorestframework-camel-case` is built to work specifically with 44 | the default exception handler from DRF. It assumes that field names are the keys in the returned dict. 45 | So, that does not work well with this package. 46 | 47 | 48 | ## How can I change the default error message for unhandled exceptions 49 | 50 | You need to create a custom exception handler class 51 | ```python 52 | from rest_framework.exceptions import APIException 53 | from drf_standardized_errors.handler import ExceptionHandler 54 | 55 | class MyExceptionHandler(ExceptionHandler): 56 | def convert_unhandled_exceptions(self, exc: Exception) -> APIException: 57 | if not isinstance(exc, APIException): 58 | return APIException(detail="New error message") 59 | else: 60 | return exc 61 | ``` 62 | Then, update the settings to point to your exception handler class 63 | ```python 64 | DRF_STANDARDIZED_ERRORS = { 65 | # ... 66 | "EXCEPTION_HANDLER_CLASS": "path.to.MyExceptionHandler" 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /docs/gotchas.md: -------------------------------------------------------------------------------- 1 | # Gotchas 2 | 3 | ## Writing tests 4 | 5 | ### TL;DR 6 | 7 | If: 8 | - you've customized the exception handler and, 9 | - the view raises an exception that causes a 5xx status code and, 10 | - you're writing a test that ensures that the view will return the proper response when the exception is raised 11 | 12 | Then, make sure to pass `raise_request_exception=False`, otherwise the test will keep failing. `raise_request_exception=False` 13 | [allows returning a 500 response instead of raising an exception](https://docs.djangoproject.com/en/stable/topics/testing/tools/#exceptions). 14 | 15 | 16 | ### The long version 17 | 18 | I faced this while writing a test for this package, so I wanted to share it in case someone else stumbles upon it. 19 | I was testing a custom exception formatter to make sure it's used when set in settings and that the error response 20 | format matches my expectation. So, here's the test 21 | 22 | ```python 23 | # views.py 24 | from rest_framework.views import APIView 25 | 26 | class ErrorView(APIView): 27 | def get(self, request, *args, **kwargs): 28 | raise Exception("Internal server error.") 29 | ``` 30 | 31 | ``` 32 | # urls.py 33 | from django.urls import path 34 | 35 | from .views import ErrorView 36 | 37 | urlpatterns = [ 38 | path("error/", ErrorView.as_view()), 39 | ] 40 | ``` 41 | 42 | ```python 43 | # tests.py 44 | import pytest 45 | from rest_framework.test import APIClient 46 | 47 | from drf_standardized_errors.formatter import ExceptionFormatter 48 | from drf_standardized_errors.types import ErrorResponse 49 | 50 | 51 | @pytest.fixture 52 | def api_client(): 53 | return APIClient() 54 | 55 | 56 | def test_custom_exception_formatter_class(settings, api_client): 57 | settings.DRF_STANDARDIZED_ERRORS = { 58 | "EXCEPTION_FORMATTER_CLASS": "tests.CustomExceptionFormatter" 59 | } 60 | response = api_client.get("/error/") 61 | assert response.status_code == 500 62 | assert response.data["type"] == "server_error" 63 | assert response.data["code"] == "error" 64 | assert response.data["message"] == "Internal server error." 65 | assert response.data["field_name"] is None 66 | 67 | 68 | class CustomExceptionFormatter(ExceptionFormatter): 69 | def format_error_response(self, error_response: ErrorResponse): 70 | """return one error at a time and change error response key names""" 71 | error = error_response.errors[0] 72 | return { 73 | "type": error_response.type, 74 | "code": error.code, 75 | "message": error.detail, 76 | "field_name": error.attr, 77 | } 78 | ``` 79 | 80 | This test kept failing and showing a traceback including `raise Exception("Internal server error.")`. 81 | To me, it seemed like the exception handler is not doing its job. 82 | 83 | Running the test in debug mode, I was able to see that the response returned by the view is indeed what I expected, yet, 84 | the test is still failing. 85 | 86 | Looking again at the test traceback and after reading the [relevant code in django test client](https://github.com/django/django/blob/0b31e024873681e187b574fe1c4afe5e48aeeecf/django/test/client.py#L803-L810), 87 | that's when I realized what's going on: the test client defines a receiver for the signal `got_request_exception` 88 | and if that signal is sent, it concludes that an issue happened and raises the exception. 89 | In my test, I was raising an `Exception("Internal server error.")` that is considered a server error so, 90 | the signal is sent out by the exception handler and django fails the test since it receives the signal. 91 | 92 | As for why is the signal sent out by the exception handler in the first place, that's because error monitoring tools 93 | (like Sentry) rely on it to collect exception information and make it available through their UI. Also, and as found 94 | during the debugging of this issue, django test client needs it to determine if the view in question has raised an 95 | exception or not and notify the developer. 96 | 97 | The nice thing is that [django test client allows retrieving the response without raising the exception](https://docs.djangoproject.com/en/stable/topics/testing/tools/#exceptions). 98 | That's possible by passing `raise_request_exception=False` when instantiating the test client. 99 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # DRF Standardized Errors 2 | 3 | Standardize your [DRF](https://www.django-rest-framework.org/) API error responses. 4 | 5 | ```{toctree} 6 | :caption: Table of Contents 7 | quickstart.md 8 | settings.md 9 | error_response.md 10 | customization.md 11 | faq.md 12 | gotchas.md 13 | openapi.md 14 | openapi_sample_description.md 15 | contibuting.md 16 | release.md 17 | changelog.md 18 | ``` 19 | 20 | ## Credits 21 | 22 | This package was inspired by [DRF Exceptions Hog](https://github.com/PostHog/drf-exceptions-hog) but, with an emphasis 23 | on the ability to customize the error response format to your liking while keeping the implementation of the exception 24 | handler as close as possible to the original DRF implementation. 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | 8 | if "%SPHINXBUILD%" == "" ( 9 | set SPHINXBUILD=sphinx-build 10 | ) 11 | set SOURCEDIR=. 12 | set BUILDDIR=_build 13 | set APP=..\drf_standardized_errors 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.Install sphinx-autobuild for live serving. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | if "%1" == "" goto help 30 | if "%1" == "livehtml" goto livehtml 31 | 32 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | goto end 34 | 35 | :livehtml 36 | sphinx-autobuild -b html --open-browser --port 9000 --watch %APP% -c . %SOURCEDIR% %BUILDDIR%/html 37 | GOTO :EOF 38 | 39 | :help 40 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 41 | :end 42 | 43 | popd -------------------------------------------------------------------------------- /docs/openapi_sample_description.md: -------------------------------------------------------------------------------- 1 | # Sample API Description 2 | 3 | Check the [Tips and Tricks](openapi.md#hide-error-responses-that-show-in-every-operation) to understand how to use this sample API description. 4 | 5 | ## Overview 6 | Here you will probably give an overview of the API and add a description for it. 7 | 8 | ## Authentication 9 | Since the API requires authentication, you might want to add a section to describe the authentication flow as well. 10 | 11 | ## Errors 12 | Now this is the important section in this example. In this section, you can list the error responses that appear 13 | in every operation with some explanation. It can go sth like: 14 | 15 | ### 401 Unauthorized 16 | These errors are returned with the status code 401 whenever the authentication fails or a request is made to an 17 | endpoint without providing the authentication information as part of the request. Here are the 2 possible errors 18 | that can be returned. 19 | ```json 20 | { 21 | "type": "client_error", 22 | "errors": [ 23 | { 24 | "code": "authentication_failed", 25 | "detail": "Incorrect authentication credentials.", 26 | "attr": null 27 | } 28 | ] 29 | } 30 | ``` 31 | ```json 32 | { 33 | "type": "client_error", 34 | "errors": [ 35 | { 36 | "code": "not_authenticated", 37 | "detail": "Authentication credentials were not provided.", 38 | "attr": null 39 | } 40 | ] 41 | } 42 | ``` 43 | 44 | ### 405 Method Not Allowed 45 | This is returned when an endpoint is called with an unexpected http method. For example, if updating a user requires 46 | a POST request and a PATCH is issued instead, this error is returned. Here's how it looks like: 47 | 48 | ```json 49 | { 50 | "type": "client_error", 51 | "errors": [ 52 | { 53 | "code": "method_not_allowed", 54 | "detail": "Method “patch” not allowed.", 55 | "attr": null 56 | } 57 | ] 58 | } 59 | ``` 60 | 61 | ### 406 Not Acceptable 62 | This is returned if the `Accept` header is submitted and contains a value other than `application/json`. Here's how the response would look: 63 | 64 | ```json 65 | { 66 | "type": "client_error", 67 | "errors": [ 68 | { 69 | "code": "not_acceptable", 70 | "detail": "Could not satisfy the request Accept header.", 71 | "attr": null 72 | } 73 | ] 74 | } 75 | ``` 76 | 77 | ### 415 Unsupported Media Type 78 | This is returned when the request content type is not json. Here's how the response would look: 79 | 80 | ```json 81 | { 82 | "type": "client_error", 83 | "errors": [ 84 | { 85 | "code": "not_acceptable", 86 | "detail": "Unsupported media type “application/xml” in request.", 87 | "attr": null 88 | } 89 | ] 90 | } 91 | ``` 92 | 93 | ### 500 Internal Server Error 94 | This is returned when the API server encounters an unexpected error. Here's how the response would look: 95 | 96 | ```json 97 | { 98 | "type": "server_error", 99 | "errors": [ 100 | { 101 | "code": "error", 102 | "detail": "A server error occurred.", 103 | "attr": null 104 | } 105 | ] 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | Install with `pip` 4 | ```shell 5 | pip install drf-standardized-errors 6 | ``` 7 | 8 | Add drf-standardized-errors to your installed apps 9 | ```python 10 | INSTALLED_APPS = [ 11 | # other apps 12 | "drf_standardized_errors", 13 | ] 14 | ``` 15 | 16 | Set the exception handler for all API views 17 | ```python 18 | REST_FRAMEWORK = { 19 | # other settings 20 | "EXCEPTION_HANDLER": "drf_standardized_errors.handler.exception_handler" 21 | } 22 | ``` 23 | 24 | or on a view basis (especially if you're introducing this to a versioned API) 25 | ```python 26 | from drf_standardized_errors.handler import exception_handler 27 | from rest_framework.views import APIView 28 | 29 | class MyAPIView(APIView): 30 | def get_exception_handler(self): 31 | return exception_handler 32 | ``` 33 | 34 | Now, your API error responses for 4xx and 5xx errors, will look like this 35 | ```json 36 | { 37 | "type": "validation_error", 38 | "errors": [ 39 | { 40 | "code": "required", 41 | "detail": "This field is required.", 42 | "attr": "name" 43 | }, 44 | { 45 | "code": "max_length", 46 | "detail": "Ensure this value has at most 100 characters.", 47 | "attr": "title" 48 | } 49 | ] 50 | } 51 | ``` 52 | or 53 | ```json 54 | { 55 | "type": "server_error", 56 | "errors": [ 57 | { 58 | "code": "error", 59 | "detail": "A server error occurred.", 60 | "attr": null 61 | } 62 | ] 63 | } 64 | ``` 65 | 66 | ## Important Notes 67 | 68 | - Standardized error responses when `DEBUG=True` for **unhandled exceptions** are disabled by default. That is 69 | to allow you to get more information out of the traceback. You can enable standardized errors instead with: 70 | ```python 71 | DRF_STANDARDIZED_ERRORS = {"ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": True} 72 | ``` 73 | 74 | - Cases where you explicitly return a response with a 4xx or 5xx status code in your `APIView` do not go through 75 | the exception handler and thus, will not have the standardized error format. So, we recommend that you raise an 76 | exception with 77 | `raise APIException("Service temporarily unavailable.", code="service_unavailable")` 78 | instead of `return Response(data, status=500)`. That way, error response formatting is handled automatically for you. 79 | But, keep in mind that exceptions that result in 5xx response are reported to error monitoring tools (like Sentry) 80 | if you're using one. 81 | 82 | ## Integration with DRF spectacular 83 | If you plan to use [drf-spectacular](https://github.com/tfranzel/drf-spectacular) to generate an OpenAPI 3 schema, 84 | install with `pip install drf-standardized-errors[openapi]`. After that, check the doc page for configuring the 85 | integration. 86 | -------------------------------------------------------------------------------- /docs/release.md: -------------------------------------------------------------------------------- 1 | # Release 2 | 3 | ## Automated Release 4 | The release is automated using [tbump](https://github.com/dmerejkowsky/tbump). So, to create a new release: 5 | - run `pip install .[release]` (preferably inside a virtualenv) 6 | - run `tbump ` (e.g. `tbump 0.6.0`). 7 | 8 | That's it, the package should then be available on [PyPI](https://pypi.org/project/drf-standardized-errors/). 9 | 10 | As part of the previous step, a GitHub release is created since a new tag is pushed. That's automated 11 | [using GitHub actions](https://github.com/ghazi-git/drf-standardized-errors/actions/workflows/github_release.yml). 12 | 13 | For the documentation, it is built automatically on every commit to the main branch. It can be found 14 | [here](https://drf-standardized-errors.readthedocs.io/en/latest/). 15 | 16 | ## Manual Release Steps 17 | 18 | This is kept as docs in case the release flow needs to change or someone new is trying to understand what's going on 19 | to make some change or improve it. 20 | 21 | - `pip install .[release]` 22 | - update the changelog: 23 | - replace unreleased with the current version and date 24 | - create a new unreleased section at the top. 25 | - update the version in `drf_standardized_errors.__init__.__version__` 26 | - commit changes 27 | - create a tag with the new version as `v{version}` 28 | - push changes 29 | - publish release to pypi: the command `flit publish` takes care of that. 30 | - build the docs 31 | - create a new release on GitHub. 32 | -------------------------------------------------------------------------------- /docs/settings.md: -------------------------------------------------------------------------------- 1 | # Settings 2 | 3 | Here are all available settings with their defaults, you can override them in your project settings 4 | ```python 5 | DRF_STANDARDIZED_ERRORS = { 6 | # class responsible for handling the exceptions. Can be subclassed to change 7 | # which exceptions are handled by default, to update which exceptions are 8 | # reported to error monitoring tools (like Sentry), ... 9 | "EXCEPTION_HANDLER_CLASS": "drf_standardized_errors.handler.ExceptionHandler", 10 | # class responsible for generating error response output. Can be subclassed 11 | # to change the format of the error response. 12 | "EXCEPTION_FORMATTER_CLASS": "drf_standardized_errors.formatter.ExceptionFormatter", 13 | # enable the standardized errors when DEBUG=True for unhandled exceptions. 14 | # By default, this is set to False so you're able to view the traceback in 15 | # the terminal and get more information about the exception. 16 | "ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": False, 17 | # When a validation error is raised in a nested serializer, the 'attr' key 18 | # of the error response will look like: 19 | # {field}{NESTED_FIELD_SEPARATOR}{nested_field} 20 | # for example: 'shipping_address.zipcode' 21 | "NESTED_FIELD_SEPARATOR": ".", 22 | 23 | # The below settings are for OpenAPI 3 schema generation 24 | 25 | # ONLY the responses that correspond to these status codes will appear 26 | # in the API schema. 27 | "ALLOWED_ERROR_STATUS_CODES": [ 28 | "400", 29 | "401", 30 | "403", 31 | "404", 32 | "405", 33 | "406", 34 | "415", 35 | "429", 36 | "500", 37 | ], 38 | 39 | # A mapping used to override the default serializers used to describe 40 | # the error response. The key is the status code and the value is a string 41 | # that represents the path to the serializer class that describes the 42 | # error response. 43 | "ERROR_SCHEMAS": None, 44 | 45 | # When there is a validation error in list serializers, the "attr" returned 46 | # will be sth like "0.email", "1.email", "2.email", ... So, to describe 47 | # the error codes linked to the same field in a list serializer, the field 48 | # will appear in the schema with the name "INDEX.email" 49 | "LIST_INDEX_IN_API_SCHEMA": "INDEX", 50 | 51 | # When there is a validation error in a DictField with the name "extra_data", 52 | # the "attr" returned will be sth like "extra_data.", "extra_data.", 53 | # "extra_data.", ... Since the keys of a DictField are not predetermined, 54 | # this setting is used as a common name to be used in the API schema. So, the 55 | # corresponding "attr" value for the previous example will be "extra_data.KEY" 56 | "DICT_KEY_IN_API_SCHEMA": "KEY", 57 | 58 | # should be unique to error components since it is used to identify error 59 | # components generated dynamically to exclude them from being processed by 60 | # the postprocessing hook. This avoids raising warnings for "code" and "attr" 61 | # which can have the same choices across multiple serializers. 62 | "ERROR_COMPONENT_NAME_SUFFIX": "ErrorComponent", 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /drf_standardized_errors/__init__.py: -------------------------------------------------------------------------------- 1 | """Standardize your API error responses.""" 2 | 3 | __version__ = "0.15.0" 4 | -------------------------------------------------------------------------------- /drf_standardized_errors/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class StandardizedErrorsConfig(AppConfig): 5 | name = "drf_standardized_errors" 6 | verbose_name = "drf-standardized-errors" 7 | -------------------------------------------------------------------------------- /drf_standardized_errors/formatter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Any, List, Optional, Union 3 | 4 | from rest_framework import exceptions 5 | from rest_framework.status import is_client_error 6 | 7 | from .settings import package_settings 8 | from .types import ( 9 | CLIENT_ERROR, 10 | SERVER_ERROR, 11 | VALIDATION_ERROR, 12 | Error, 13 | ErrorResponse, 14 | ErrorType, 15 | ExceptionHandlerContext, 16 | ) 17 | 18 | 19 | class ExceptionFormatter: 20 | def __init__( 21 | self, 22 | exc: exceptions.APIException, 23 | context: ExceptionHandlerContext, 24 | original_exc: Exception, 25 | ): 26 | self.exc = exc 27 | self.context = context 28 | self.original_exc = original_exc 29 | 30 | def run(self) -> Any: 31 | """ 32 | Entrypoint for formatting the error response. 33 | 34 | The default error response format is as follows: 35 | - type: validation_error, client_error or server_error 36 | - errors: list of errors where each one has: 37 | - code: short string describing the error. Can be used by API consumers 38 | to customize their behavior. 39 | - detail: User-friendly text describing the error. 40 | - attr: set only when the error type is a validation error and maps 41 | to the serializer field name or NON_FIELD_ERRORS_KEY. 42 | 43 | Only validation errors can have multiple errors. Other error types have only 44 | one error. 45 | """ 46 | error_type = self.get_error_type() 47 | errors = self.get_errors() 48 | error_response = self.get_error_response(error_type, errors) 49 | return self.format_error_response(error_response) 50 | 51 | def get_error_type(self) -> ErrorType: 52 | if isinstance(self.exc, exceptions.ValidationError): 53 | return VALIDATION_ERROR 54 | elif is_client_error(self.exc.status_code): 55 | return CLIENT_ERROR 56 | else: 57 | return SERVER_ERROR 58 | 59 | def get_errors(self) -> List[Error]: 60 | """ 61 | Account for validation errors in nested serializers by returning a list 62 | of errors instead of a nested dict 63 | """ 64 | return flatten_errors(self.exc.detail) 65 | 66 | def get_error_response( 67 | self, error_type: ErrorType, errors: List[Error] 68 | ) -> ErrorResponse: 69 | return ErrorResponse(error_type, errors) 70 | 71 | def format_error_response(self, error_response: ErrorResponse) -> Any: 72 | return asdict(error_response) 73 | 74 | 75 | def flatten_errors( 76 | detail: Union[list, dict, exceptions.ErrorDetail], 77 | attr: Optional[str] = None, 78 | index: Optional[int] = None, 79 | ) -> List[Error]: 80 | """ 81 | convert this: 82 | { 83 | "password": [ 84 | ErrorDetail("This password is too short.", code="password_too_short"), 85 | ErrorDetail("The password is too similar to the username.", code="password_too_similar"), 86 | ], 87 | "linked_accounts" [ 88 | {}, 89 | {"email": [ErrorDetail("Enter a valid email address.", code="invalid")]}, 90 | ] 91 | } 92 | to: 93 | { 94 | "type": "validation_error", 95 | "errors": [ 96 | { 97 | "code": "password_too_short", 98 | "detail": "This password is too short.", 99 | "attr": "password" 100 | }, 101 | { 102 | "code": "password_too_similar", 103 | "detail": "The password is too similar to the username.", 104 | "attr": "password" 105 | }, 106 | { 107 | "code": "invalid", 108 | "detail": "Enter a valid email address.", 109 | "attr": "linked_accounts.1.email" 110 | } 111 | ] 112 | } 113 | """ 114 | 115 | # preserve the order of the previous implementation with a fifo queue 116 | fifo = [(detail, attr, index)] 117 | errors = [] 118 | while fifo: 119 | detail, attr, index = fifo.pop(0) 120 | if not detail and detail != "": 121 | continue 122 | elif isinstance(detail, list): 123 | for item in detail: 124 | if not isinstance(item, exceptions.ErrorDetail): 125 | index = 0 if index is None else index + 1 126 | if attr: 127 | new_attr = ( 128 | f"{attr}{package_settings.NESTED_FIELD_SEPARATOR}{index}" 129 | ) 130 | else: 131 | new_attr = str(index) 132 | fifo.append((item, new_attr, index)) 133 | else: 134 | fifo.append((item, attr, index)) 135 | 136 | elif isinstance(detail, dict): 137 | for key, value in detail.items(): 138 | if attr: 139 | key = f"{attr}{package_settings.NESTED_FIELD_SEPARATOR}{key}" 140 | fifo.append((value, key, None)) 141 | 142 | else: 143 | errors.append(Error(detail.code, str(detail), attr)) 144 | 145 | return errors 146 | -------------------------------------------------------------------------------- /drf_standardized_errors/handler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | 4 | import django 5 | from django.conf import settings 6 | from django.core import signals 7 | from django.core.exceptions import PermissionDenied 8 | from django.http import Http404 9 | from django.utils.log import log_response 10 | from rest_framework import exceptions 11 | from rest_framework.request import Request 12 | from rest_framework.response import Response 13 | from rest_framework.status import is_server_error 14 | from rest_framework.views import set_rollback 15 | 16 | from .formatter import ExceptionFormatter 17 | from .settings import package_settings 18 | from .types import ExceptionHandlerContext 19 | 20 | 21 | def exception_handler( 22 | exc: Exception, context: ExceptionHandlerContext 23 | ) -> Optional[Response]: 24 | exception_handler_class = package_settings.EXCEPTION_HANDLER_CLASS 25 | msg = "`EXCEPTION_HANDLER_CLASS` should be a subclass of ExceptionHandler." 26 | assert issubclass(exception_handler_class, ExceptionHandler), msg 27 | return exception_handler_class(exc, context).run() 28 | 29 | 30 | class ExceptionHandler: 31 | def __init__(self, exc: Exception, context: ExceptionHandlerContext): 32 | self.exc = exc 33 | self.context = context 34 | 35 | def run(self) -> Optional[Response]: 36 | """entrypoint for handling an exception""" 37 | exc = self.convert_known_exceptions(self.exc) 38 | if self.should_not_handle(exc): 39 | return None 40 | 41 | exc = self.convert_unhandled_exceptions(exc) 42 | data = self.format_exception(exc) 43 | self.set_rollback() 44 | response = self.get_response(exc, data) 45 | self.report_exception(exc, response) 46 | return response 47 | 48 | def convert_known_exceptions(self, exc: Exception) -> Exception: 49 | """ 50 | By default, Django's built-in `Http404` and `PermissionDenied` are converted 51 | to their DRF equivalent. 52 | """ 53 | if isinstance(exc, Http404): 54 | return exceptions.NotFound() 55 | elif isinstance(exc, PermissionDenied): 56 | return exceptions.PermissionDenied() 57 | else: 58 | return exc 59 | 60 | def should_not_handle(self, exc: Exception) -> bool: 61 | """ 62 | By default, don't handle non-DRF errors in DEBUG mode. That's because 63 | handling the exception means the developer will not see the exception 64 | traceback. 65 | """ 66 | return ( 67 | getattr(settings, "DEBUG", False) 68 | and not package_settings.ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS 69 | and not isinstance(exc, exceptions.APIException) 70 | ) 71 | 72 | def convert_unhandled_exceptions(self, exc: Exception) -> exceptions.APIException: 73 | """ 74 | Any non-DRF unhandled exception is converted to an APIException which 75 | has a 500 status code. 76 | """ 77 | if not isinstance(exc, exceptions.APIException): 78 | # return a generic error message to avoid potentially leaking sensitive 79 | # data and match DRF/django behavior (same generic error message returned 80 | # by django.views.defaults.server_error) 81 | return exceptions.APIException(detail="Server Error (500)") 82 | else: 83 | return exc 84 | 85 | def format_exception(self, exc: exceptions.APIException) -> dict: 86 | exception_formatter_class = package_settings.EXCEPTION_FORMATTER_CLASS 87 | msg = "`EXCEPTION_FORMATTER_CLASS` should be a subclass of ExceptionFormatter." 88 | assert issubclass(exception_formatter_class, ExceptionFormatter), msg 89 | return exception_formatter_class(exc, self.context, self.exc).run() 90 | 91 | def set_rollback(self) -> None: 92 | set_rollback() 93 | 94 | def get_response(self, exc: exceptions.APIException, data: dict) -> Response: 95 | headers = self.get_headers(exc) 96 | return Response(data, status=exc.status_code, headers=headers) 97 | 98 | def get_headers(self, exc: exceptions.APIException) -> dict: 99 | headers = {} 100 | if getattr(exc, "auth_header", None): 101 | headers["WWW-Authenticate"] = exc.auth_header 102 | if getattr(exc, "wait", None): 103 | headers["Retry-After"] = "%d" % exc.wait 104 | return headers 105 | 106 | def report_exception( 107 | self, exc: exceptions.APIException, response: Response 108 | ) -> None: 109 | """ 110 | Normally, when an exception is unhandled (non-DRF exception), DRF delegates 111 | handling it to Django. Django, then, takes care of returning the appropriate 112 | response. That is done in: django.core.handlers.exception.convert_exception_to_response 113 | 114 | However, this package handles all exceptions. So, to stay in line with Django's 115 | default behavior, the got_request_exception signal is sent and the response is 116 | also logged. Sending the signal should allow error monitoring tools (like Sentry) 117 | to work as usual (error is captured and sent to their servers). 118 | """ 119 | if is_server_error(exc.status_code): 120 | try: 121 | drf_request: Request = self.context["request"] 122 | request = drf_request._request 123 | except AttributeError: 124 | request = None 125 | signals.got_request_exception.send(sender=None, request=request) 126 | if django.VERSION < (4, 1): 127 | log_response( 128 | "%s: %s", 129 | exc.detail, 130 | getattr(request, "path", ""), 131 | response=response, 132 | request=request, 133 | exc_info=sys.exc_info(), 134 | ) 135 | else: 136 | log_response( 137 | "%s: %s", 138 | exc.detail, 139 | getattr(request, "path", ""), 140 | response=response, 141 | request=request, 142 | exception=self.exc, 143 | ) 144 | -------------------------------------------------------------------------------- /drf_standardized_errors/openapi.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections import defaultdict 3 | from typing import Dict, List, Set, Type, Union 4 | 5 | from drf_spectacular.drainage import warn 6 | from drf_spectacular.extensions import OpenApiFilterExtension 7 | from drf_spectacular.openapi import AutoSchema as BaseAutoSchema 8 | from drf_spectacular.utils import ( 9 | Direction, 10 | OpenApiExample, 11 | PolymorphicProxySerializer, 12 | _SchemaType, 13 | ) 14 | from inflection import camelize 15 | from rest_framework import serializers 16 | from rest_framework.negotiation import DefaultContentNegotiation 17 | from rest_framework.pagination import CursorPagination, PageNumberPagination 18 | from rest_framework.parsers import FileUploadParser, JSONParser, MultiPartParser 19 | from rest_framework.permissions import AllowAny, IsAuthenticated 20 | from rest_framework.versioning import ( 21 | AcceptHeaderVersioning, 22 | HostNameVersioning, 23 | NamespaceVersioning, 24 | QueryParameterVersioning, 25 | URLPathVersioning, 26 | ) 27 | 28 | from .handler import exception_handler as standardized_errors_handler 29 | from .openapi_serializers import ( 30 | ClientErrorEnum, 31 | ErrorResponse401Serializer, 32 | ErrorResponse403Serializer, 33 | ErrorResponse404Serializer, 34 | ErrorResponse405Serializer, 35 | ErrorResponse406Serializer, 36 | ErrorResponse415Serializer, 37 | ErrorResponse429Serializer, 38 | ErrorResponse500Serializer, 39 | ParseErrorResponseSerializer, 40 | ValidationErrorEnum, 41 | ) 42 | from .openapi_utils import ( 43 | InputDataField, 44 | get_django_filter_backends, 45 | get_error_examples, 46 | get_filter_forms, 47 | get_flat_serializer_fields, 48 | get_form_fields_with_error_codes, 49 | get_serializer_fields_with_error_codes, 50 | get_validation_error_serializer, 51 | ) 52 | from .openapi_validation_errors import get_validation_errors 53 | from .settings import package_settings 54 | 55 | S = Union[Type[serializers.Serializer], serializers.Serializer] 56 | 57 | 58 | class AutoSchema(BaseAutoSchema): 59 | def _get_response_bodies(self, direction: Direction = "response") -> _SchemaType: 60 | responses = super()._get_response_bodies(direction=direction) 61 | if direction == "response": 62 | error_responses = {} 63 | 64 | status_codes = self._get_allowed_error_status_codes() 65 | for status_code in status_codes: 66 | if self._should_add_error_response(responses, status_code): 67 | serializer = self._get_error_response_serializer(status_code) 68 | if not serializer: 69 | warn( 70 | f"drf-standardized-errors: The status code '{status_code}' " 71 | "is one of the allowed error status codes in the setting " 72 | "'ALLOWED_ERROR_STATUS_CODES'. However, a corresponding " 73 | "error response serializer could not be determined. Make " 74 | "sure to add one to the 'ERROR_SCHEMAS' setting: this " 75 | "setting is a dict where the key is the status code and " 76 | "the value is the serializer." 77 | ) 78 | continue 79 | error_responses[status_code] = self._get_response_for_code( 80 | serializer, status_code 81 | ) 82 | 83 | return {**error_responses, **responses} 84 | else: 85 | # for callbacks (direction=request), we should not add the error responses 86 | return responses 87 | 88 | def _get_allowed_error_status_codes(self) -> List[str]: 89 | allowed_status_codes = package_settings.ALLOWED_ERROR_STATUS_CODES or [] 90 | return [str(status_code) for status_code in allowed_status_codes] 91 | 92 | def _should_add_error_response(self, responses: dict, status_code: str) -> bool: 93 | if ( 94 | self.view.get_exception_handler() is not standardized_errors_handler 95 | or status_code in responses 96 | ): 97 | # this means that the exception handler has been overridden for this view 98 | # or the error response has already been added via extend_schema, so we 99 | # should not override that 100 | return False 101 | 102 | if status_code == "400": 103 | return ( 104 | self._should_add_parse_error_response() 105 | or self._should_add_validation_error_response() 106 | ) 107 | elif status_code == "401": 108 | return self._should_add_http401_error_response() 109 | elif status_code == "403": 110 | return self._should_add_http403_error_response() 111 | elif status_code == "404": 112 | return self._should_add_http404_error_response() 113 | elif status_code == "405": 114 | return self._should_add_http405_error_response() 115 | elif status_code == "406": 116 | return self._should_add_http406_error_response() 117 | elif status_code == "415": 118 | return self._should_add_http415_error_response() 119 | elif status_code == "429": 120 | return self._should_add_http429_error_response() 121 | elif status_code == "500": 122 | return self._should_add_http500_error_response() 123 | else: 124 | # user might add extra status codes and their serializers, so we 125 | # should always add corresponding error responses 126 | return True 127 | 128 | def _should_add_parse_error_response(self) -> bool: 129 | parsers = self.view.get_parsers() 130 | parsers_that_raise_parse_errors = ( 131 | JSONParser, 132 | MultiPartParser, 133 | FileUploadParser, 134 | ) 135 | return any( 136 | isinstance(parser, parsers_that_raise_parse_errors) for parser in parsers 137 | ) 138 | 139 | def _should_add_validation_error_response(self) -> bool: 140 | """ 141 | add a validation error response when unsafe methods have a request body 142 | or when a list view implements filtering with django-filters. 143 | """ 144 | 145 | has_request_body = False 146 | if self.method in ("PUT", "PATCH", "POST"): 147 | request_serializer = self.get_request_serializer() 148 | has_request_body = isinstance(request_serializer, serializers.Field) or ( 149 | inspect.isclass(request_serializer) 150 | and issubclass(request_serializer, serializers.Field) 151 | ) 152 | 153 | filter_backends = get_django_filter_backends(self.get_filter_backends()) 154 | filter_extensions = [ 155 | OpenApiFilterExtension.get_match(backend) for backend in filter_backends 156 | ] 157 | has_filters = any( 158 | filter_extension.get_schema_operation_parameters(self) 159 | for filter_extension in filter_extensions 160 | if filter_extension 161 | ) 162 | has_extra_validation_errors = bool(self._get_extra_validation_errors()) 163 | return has_request_body or has_filters or has_extra_validation_errors 164 | 165 | def _should_add_http401_error_response(self) -> bool: 166 | # empty dicts are appended to auth methods if AllowAny or 167 | # IsAuthenticatedOrReadOnly are in permission classes, so 168 | # we need to account for that. 169 | auth_methods = [auth_method for auth_method in self.get_auth() if auth_method] 170 | return bool(auth_methods) 171 | 172 | def _should_add_http403_error_response(self) -> bool: 173 | permissions = self.view.get_permissions() 174 | is_allow_any = len(permissions) == 1 and type(permissions[0]) == AllowAny 175 | # if the only permission class is IsAuthenticated and there are auth classes 176 | # in the view, then the error raised is a 401 not a 403 (check implementation 177 | # of rest_framework.views.APIView.permission_denied) 178 | is_authenticated = ( 179 | len(permissions) == 1 180 | and type(permissions[0]) == IsAuthenticated 181 | and self.view.get_authenticators() 182 | ) 183 | return bool(permissions) and not is_allow_any and not is_authenticated 184 | 185 | def _should_add_http404_error_response(self) -> bool: 186 | paginator = self._get_paginator() 187 | paginator_can_raise_404 = isinstance( 188 | paginator, (PageNumberPagination, CursorPagination) 189 | ) 190 | versioning_scheme_can_raise_404 = self.view.versioning_class and issubclass( 191 | self.view.versioning_class, 192 | ( 193 | URLPathVersioning, 194 | NamespaceVersioning, 195 | HostNameVersioning, 196 | QueryParameterVersioning, 197 | ), 198 | ) 199 | has_path_parameters = bool( 200 | [ 201 | parameter 202 | for parameter in self._get_parameters() 203 | if parameter["in"] == "path" 204 | ] 205 | ) 206 | return ( 207 | paginator_can_raise_404 208 | or versioning_scheme_can_raise_404 209 | or has_path_parameters 210 | ) 211 | 212 | def _should_add_http405_error_response(self) -> bool: 213 | # API consumers can at all ties use the wrong method against any endpoint 214 | return True 215 | 216 | def _should_add_http406_error_response(self) -> bool: 217 | content_negotiator = self.view.get_content_negotiator() 218 | return isinstance(content_negotiator, DefaultContentNegotiation) or ( 219 | self.view.versioning_class 220 | and issubclass(self.view.versioning_class, AcceptHeaderVersioning) 221 | ) 222 | 223 | def _should_add_http415_error_response(self) -> bool: 224 | """ 225 | This is raised whenever the default content negotiator is unable to 226 | determine a parser. So, if the view does not have a parser that 227 | handles everything (media type "*/*"), then this error can be raised. 228 | """ 229 | content_negotiator = self.view.get_content_negotiator() 230 | parsers_that_handle_everything = [ 231 | parser for parser in self.view.get_parsers() if parser.media_type == "*/*" 232 | ] 233 | return ( 234 | isinstance(content_negotiator, DefaultContentNegotiation) 235 | and not parsers_that_handle_everything 236 | ) 237 | 238 | def _should_add_http429_error_response(self) -> bool: 239 | return bool(self.view.get_throttles()) 240 | 241 | def _should_add_http500_error_response(self) -> bool: 242 | # bugs are inevitable 243 | return True 244 | 245 | def _get_error_response_serializer(self, status_code: str) -> S: 246 | error_schemas = package_settings.ERROR_SCHEMAS or {} 247 | error_schemas = { 248 | str(status_code): schema for status_code, schema in error_schemas.items() 249 | } 250 | if serializer := error_schemas.get(status_code): 251 | return serializer 252 | 253 | # the user did not provide a serializer for the status code so we will 254 | # fall back to the default error serializers 255 | if status_code == "400": 256 | return self._get_http400_serializer() 257 | else: 258 | error_serializers = { 259 | "401": ErrorResponse401Serializer, 260 | "403": ErrorResponse403Serializer, 261 | "404": ErrorResponse404Serializer, 262 | "405": ErrorResponse405Serializer, 263 | "406": ErrorResponse406Serializer, 264 | "415": ErrorResponse415Serializer, 265 | "429": ErrorResponse429Serializer, 266 | "500": ErrorResponse500Serializer, 267 | } 268 | return error_serializers.get(status_code) 269 | 270 | def _get_http400_serializer(self) -> S: 271 | # using the operation id (which is unique) to generate a unique 272 | # component name 273 | operation_id = self.get_operation_id() 274 | component_name = f"{camelize(operation_id)}ErrorResponse400" 275 | 276 | http400_serializers = {} 277 | if self._should_add_validation_error_response(): 278 | serializer = self._get_serializer_for_validation_error_response() 279 | http400_serializers[ValidationErrorEnum.VALIDATION_ERROR.value] = serializer # type: ignore[attr-defined] 280 | if self._should_add_parse_error_response(): 281 | serializer = ParseErrorResponseSerializer 282 | http400_serializers[ClientErrorEnum.CLIENT_ERROR.value] = serializer # type: ignore[attr-defined] 283 | 284 | return PolymorphicProxySerializer( 285 | component_name=component_name, 286 | serializers=http400_serializers, 287 | resource_type_field_name="type", 288 | ) 289 | 290 | def _get_serializer_for_validation_error_response(self) -> S: 291 | fields_with_error_codes = self._determine_fields_with_error_codes() 292 | error_codes_by_field = self._get_validation_error_codes_by_field( 293 | fields_with_error_codes 294 | ) 295 | 296 | operation_id = self.get_operation_id() 297 | return get_validation_error_serializer(operation_id, error_codes_by_field) 298 | 299 | def _determine_fields_with_error_codes(self) -> "List[InputDataField]": 300 | if self.method in ("PUT", "PATCH", "POST"): 301 | serializer = self.get_request_serializer() 302 | fields = get_flat_serializer_fields(serializer) 303 | return get_serializer_fields_with_error_codes(fields) 304 | else: 305 | filter_backends = get_django_filter_backends(self.get_filter_backends()) 306 | filter_forms = get_filter_forms(self.view, filter_backends) 307 | fields_with_error_codes = [] 308 | for form in filter_forms: 309 | fields = get_form_fields_with_error_codes(form) 310 | fields_with_error_codes.extend(fields) 311 | return fields_with_error_codes 312 | 313 | def _get_validation_error_codes_by_field( 314 | self, data_fields: "List[InputDataField]" 315 | ) -> Dict[str, Set[str]]: 316 | # When there are multiple fields with the same name in the list of data_fields, 317 | # their error codes are combined. This can happen when using a PolymorphicProxySerializer 318 | error_codes_by_field = defaultdict(set) 319 | for field in data_fields: 320 | error_codes_by_field[field.name].update(field.error_codes) 321 | 322 | # add error codes set by the @extend_validation_errors decorator 323 | extra_errors = self._get_extra_validation_errors() 324 | for field_name, error_codes in extra_errors.items(): 325 | error_codes_by_field[field_name].update(error_codes) 326 | 327 | return error_codes_by_field 328 | 329 | def _get_extra_validation_errors(self) -> Dict[str, Set[str]]: 330 | extra_codes_by_field = {} 331 | validation_errors = get_validation_errors(self.view) 332 | for field_name, field_errors in validation_errors.items(): 333 | # Get the first encountered error when looping in reverse order. 334 | # That will return errors defined in child views over ones 335 | # defined in the parent class. 336 | for err in reversed(field_errors): 337 | if err.is_in_scope(self): 338 | extra_codes_by_field[field_name] = err.error_codes 339 | break 340 | 341 | return extra_codes_by_field 342 | 343 | def _get_examples( 344 | self, serializer, direction, media_type, status_code=None, extras=None 345 | ): 346 | if direction == "response": 347 | all_examples = (extras or []) + self._get_error_response_examples() 348 | else: 349 | all_examples = extras 350 | return super()._get_examples( 351 | serializer, 352 | direction, 353 | media_type, 354 | status_code=status_code, 355 | extras=all_examples, 356 | ) 357 | 358 | def _get_error_response_examples(self) -> List[OpenApiExample]: 359 | status_codes = set(self._get_allowed_error_status_codes()) 360 | examples = get_error_examples() 361 | return [ 362 | example 363 | for example in examples 364 | if status_codes.intersection(example.status_codes) 365 | ] 366 | -------------------------------------------------------------------------------- /drf_standardized_errors/openapi_hooks.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | # since it's a copy of drf-spectacular postprocessing hook 3 | import re 4 | from collections import defaultdict 5 | 6 | from drf_spectacular.hooks import postprocess_schema_enum_id_removal 7 | from drf_spectacular.plumbing import ( 8 | ResolvedComponent, 9 | list_hash, 10 | load_enum_name_overrides, 11 | safe_ref, 12 | warn, 13 | ) 14 | from drf_spectacular.settings import spectacular_settings 15 | from inflection import camelize 16 | 17 | from .settings import package_settings 18 | 19 | 20 | def postprocess_schema_enums(result, generator, **kwargs): 21 | """ 22 | This a copy of the postprocessing hook for enums provided by drf-spectacular 23 | with only one change in `iter_prop_containers`. The change allows excluding 24 | components that have a certain suffix from enum component auto-generation. 25 | This excludes certain validation error components from postprocessing. 26 | The excluded enum components are for dynamically created error serializers 27 | where "attr" and "code" fields might have the same choices across multiple 28 | serializers. 29 | 30 | simple replacement of Enum/Choices that globally share the same name and have 31 | the same choices. Aids client generation to not generate a separate enum for 32 | every occurrence. only takes effect when replacement is guaranteed to be correct. 33 | """ 34 | 35 | def iter_prop_containers(schema, component_name=None): 36 | if not component_name: 37 | for component_name, schema in schema.items(): 38 | if spectacular_settings.COMPONENT_SPLIT_PATCH: 39 | component_name = re.sub("^Patched(.+)", r"\1", component_name) 40 | if spectacular_settings.COMPONENT_SPLIT_REQUEST: 41 | component_name = re.sub("(.+)Request$", r"\1", component_name) 42 | yield from iter_prop_containers(schema, component_name) 43 | elif isinstance(schema, list): 44 | for item in schema: 45 | yield from iter_prop_containers(item, component_name) 46 | elif isinstance(schema, dict): 47 | # This is the only change made (suffix check added to condition on L50): 48 | # exclude error components from postprocessing. That's because the 49 | # components are for dynamically created error serializers where 50 | # "attr" and "code" fields might have the same choices across 51 | # multiple serializers. 52 | suffix = package_settings.ERROR_COMPONENT_NAME_SUFFIX 53 | if schema.get("properties") and not component_name.endswith(suffix): 54 | yield component_name, schema["properties"] 55 | yield from iter_prop_containers(schema.get("oneOf", []), component_name) 56 | yield from iter_prop_containers(schema.get("allOf", []), component_name) 57 | yield from iter_prop_containers(schema.get("anyOf", []), component_name) 58 | 59 | def create_enum_component(name, schema): 60 | component = ResolvedComponent( 61 | name=name, type=ResolvedComponent.SCHEMA, schema=schema, object=name 62 | ) 63 | generator.registry.register_on_missing(component) 64 | return component 65 | 66 | def extract_hash(schema): 67 | if "x-spec-enum-id" in schema: 68 | # try to use the injected enum hash first as it generated from (name, value) tuples, 69 | # which prevents collisions on choice sets only differing in labels not values. 70 | return schema["x-spec-enum-id"] 71 | else: 72 | # fall back to actual list hashing when we encounter enums not generated by us. 73 | # remove blank/null entry for hashing. will be reconstructed in the last step 74 | return list_hash([(i, i) for i in schema["enum"] if i not in ("", None)]) 75 | 76 | schemas = result.get("components", {}).get("schemas", {}) 77 | 78 | overrides = load_enum_name_overrides() 79 | 80 | prop_hash_mapping = defaultdict(set) 81 | hash_name_mapping = defaultdict(set) 82 | # collect all enums, their names and choice sets 83 | for component_name, props in iter_prop_containers(schemas): 84 | for prop_name, prop_schema in props.items(): 85 | if prop_schema.get("type") == "array": 86 | prop_schema = prop_schema.get("items", {}) 87 | if "enum" not in prop_schema: 88 | continue 89 | 90 | prop_enum_cleaned_hash = extract_hash(prop_schema) 91 | prop_hash_mapping[prop_name].add(prop_enum_cleaned_hash) 92 | hash_name_mapping[prop_enum_cleaned_hash].add((component_name, prop_name)) 93 | 94 | # get the suffix to be used for enums from settings 95 | enum_suffix = spectacular_settings.ENUM_SUFFIX 96 | 97 | # traverse all enum properties and generate a name for the choice set. naming collisions 98 | # are resolved and a warning is emitted. giving a choice set multiple names is technically 99 | # correct but potentially unwanted. also emit a warning there to make the user aware. 100 | enum_name_mapping = {} 101 | for prop_name, prop_hash_set in prop_hash_mapping.items(): 102 | for prop_hash in prop_hash_set: 103 | if prop_hash in overrides: 104 | enum_name = overrides[prop_hash] 105 | elif len(prop_hash_set) == 1: 106 | # prop_name has been used exclusively for one choice set (best case) 107 | enum_name = f"{camelize(prop_name)}{enum_suffix}" 108 | elif len(hash_name_mapping[prop_hash]) == 1: 109 | # prop_name has multiple choice sets, but each one limited to one component only 110 | component_name, _ = next(iter(hash_name_mapping[prop_hash])) 111 | enum_name = ( 112 | f"{camelize(component_name)}{camelize(prop_name)}{enum_suffix}" 113 | ) 114 | else: 115 | enum_name = ( 116 | f"{camelize(prop_name)}{prop_hash[:3].capitalize()}{enum_suffix}" 117 | ) 118 | warn( 119 | f"enum naming encountered a non-optimally resolvable collision for fields " 120 | f'named "{prop_name}". The same name has been used for multiple choice sets ' 121 | f'in multiple components. The collision was resolved with "{enum_name}". ' 122 | f"add an entry to ENUM_NAME_OVERRIDES to fix the naming." 123 | ) 124 | if enum_name_mapping.get(prop_hash, enum_name) != enum_name: 125 | warn( 126 | f"encountered multiple names for the same choice set ({enum_name}). This " 127 | f"may be unwanted even though the generated schema is technically correct. " 128 | f"Add an entry to ENUM_NAME_OVERRIDES to fix the naming." 129 | ) 130 | del enum_name_mapping[prop_hash] 131 | else: 132 | enum_name_mapping[prop_hash] = enum_name 133 | enum_name_mapping[(prop_hash, prop_name)] = enum_name 134 | 135 | # replace all enum occurrences with a enum schema component. cut out the 136 | # enum, replace it with a reference and add a corresponding component. 137 | for _, props in iter_prop_containers(schemas): 138 | for prop_name, prop_schema in props.items(): 139 | is_array = prop_schema.get("type") == "array" 140 | if is_array: 141 | prop_schema = prop_schema.get("items", {}) 142 | 143 | if "enum" not in prop_schema: 144 | continue 145 | 146 | prop_enum_original_list = prop_schema["enum"] 147 | prop_schema["enum"] = [ 148 | i for i in prop_schema["enum"] if i not in ["", None] 149 | ] 150 | prop_hash = extract_hash(prop_schema) 151 | # when choice sets are reused under multiple names, the generated name cannot be 152 | # resolved from the hash alone. fall back to prop_name and hash for resolution. 153 | enum_name = ( 154 | enum_name_mapping.get(prop_hash) 155 | or enum_name_mapping[prop_hash, prop_name] 156 | ) 157 | 158 | # split property into remaining property and enum component parts 159 | enum_schema = { 160 | k: v for k, v in prop_schema.items() if k in ["type", "enum"] 161 | } 162 | prop_schema = { 163 | k: v 164 | for k, v in prop_schema.items() 165 | if k not in ["type", "enum", "x-spec-enum-id"] 166 | } 167 | 168 | # separate actual description from name-value tuples 169 | if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION: 170 | if prop_schema.get("description", "").startswith("*"): 171 | enum_schema["description"] = prop_schema.pop("description") 172 | elif "\n\n*" in prop_schema.get("description", ""): 173 | _, _, post = prop_schema["description"].partition("\n\n*") 174 | enum_schema["description"] = "*" + post 175 | 176 | components = [create_enum_component(enum_name, schema=enum_schema)] 177 | if spectacular_settings.ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE: 178 | if "" in prop_enum_original_list: 179 | components.append( 180 | create_enum_component( 181 | f"Blank{enum_suffix}", schema={"enum": [""]} 182 | ) 183 | ) 184 | if None in prop_enum_original_list: 185 | if spectacular_settings.OAS_VERSION.startswith("3.1"): 186 | components.append( 187 | create_enum_component( 188 | f"Null{enum_suffix}", schema={"type": "null"} 189 | ) 190 | ) 191 | else: 192 | components.append( 193 | create_enum_component( 194 | f"Null{enum_suffix}", schema={"enum": [None]} 195 | ) 196 | ) 197 | 198 | # undo OAS 3.1 type list NULL construction as we cover this in a separate component already 199 | if spectacular_settings.OAS_VERSION.startswith("3.1") and isinstance( 200 | enum_schema["type"], list 201 | ): 202 | enum_schema["type"] = [t for t in enum_schema["type"] if t != "null"][0] 203 | 204 | if len(components) == 1: 205 | prop_schema.update(components[0].ref) 206 | else: 207 | prop_schema.update({"oneOf": [c.ref for c in components]}) 208 | 209 | if is_array: 210 | props[prop_name]["items"] = safe_ref(prop_schema) 211 | else: 212 | props[prop_name] = safe_ref(prop_schema) 213 | 214 | # sort again with additional components 215 | result["components"] = generator.registry.build( 216 | spectacular_settings.APPEND_COMPONENTS 217 | ) 218 | 219 | # remove remaining ids that were not part of this hook (operation parameters mainly) 220 | postprocess_schema_enum_id_removal(result, generator) 221 | 222 | return result 223 | -------------------------------------------------------------------------------- /drf_standardized_errors/openapi_serializers.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from rest_framework import serializers 3 | 4 | 5 | class ValidationErrorEnum(models.TextChoices): 6 | VALIDATION_ERROR = "validation_error" 7 | 8 | 9 | class ClientErrorEnum(models.TextChoices): 10 | CLIENT_ERROR = "client_error" 11 | 12 | 13 | class ServerErrorEnum(models.TextChoices): 14 | SERVER_ERROR = "server_error" 15 | 16 | 17 | class ParseErrorCodeEnum(models.TextChoices): 18 | PARSE_ERROR = "parse_error" 19 | 20 | 21 | class ErrorCode401Enum(models.TextChoices): 22 | AUTHENTICATION_FAILED = "authentication_failed" 23 | NOT_AUTHENTICATED = "not_authenticated" 24 | 25 | 26 | class ErrorCode403Enum(models.TextChoices): 27 | PERMISSION_DENIED = "permission_denied" 28 | 29 | 30 | class ErrorCode404Enum(models.TextChoices): 31 | NOT_FOUND = "not_found" 32 | 33 | 34 | class ErrorCode405Enum(models.TextChoices): 35 | METHOD_NOT_ALLOWED = "method_not_allowed" 36 | 37 | 38 | class ErrorCode406Enum(models.TextChoices): 39 | NOT_ACCEPTABLE = "not_acceptable" 40 | 41 | 42 | class ErrorCode415Enum(models.TextChoices): 43 | UNSUPPORTED_MEDIA_TYPE = "unsupported_media_type" 44 | 45 | 46 | class ErrorCode429Enum(models.TextChoices): 47 | THROTTLED = "throttled" 48 | 49 | 50 | class ErrorCode500Enum(models.TextChoices): 51 | ERROR = "error" 52 | 53 | 54 | class ValidationErrorSerializer(serializers.Serializer): 55 | code = serializers.CharField() 56 | detail = serializers.CharField() 57 | attr = serializers.CharField() 58 | 59 | 60 | class ValidationErrorResponseSerializer(serializers.Serializer): 61 | type = serializers.ChoiceField(choices=ValidationErrorEnum.choices) 62 | errors = ValidationErrorSerializer(many=True) 63 | 64 | 65 | class ParseErrorSerializer(serializers.Serializer): 66 | code = serializers.ChoiceField(choices=ParseErrorCodeEnum.choices) 67 | detail = serializers.CharField() 68 | attr = serializers.CharField(allow_null=True) 69 | 70 | 71 | class ParseErrorResponseSerializer(serializers.Serializer): 72 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 73 | errors = ParseErrorSerializer(many=True) 74 | 75 | 76 | class Error401Serializer(serializers.Serializer): 77 | code = serializers.ChoiceField(choices=ErrorCode401Enum.choices) 78 | detail = serializers.CharField() 79 | attr = serializers.CharField(allow_null=True) 80 | 81 | 82 | class ErrorResponse401Serializer(serializers.Serializer): 83 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 84 | errors = Error401Serializer(many=True) 85 | 86 | 87 | class Error403Serializer(serializers.Serializer): 88 | code = serializers.ChoiceField(choices=ErrorCode403Enum.choices) 89 | detail = serializers.CharField() 90 | attr = serializers.CharField(allow_null=True) 91 | 92 | 93 | class ErrorResponse403Serializer(serializers.Serializer): 94 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 95 | errors = Error403Serializer(many=True) 96 | 97 | 98 | class Error404Serializer(serializers.Serializer): 99 | code = serializers.ChoiceField(choices=ErrorCode404Enum.choices) 100 | detail = serializers.CharField() 101 | attr = serializers.CharField(allow_null=True) 102 | 103 | 104 | class ErrorResponse404Serializer(serializers.Serializer): 105 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 106 | errors = Error404Serializer(many=True) 107 | 108 | 109 | class Error405Serializer(serializers.Serializer): 110 | code = serializers.ChoiceField(choices=ErrorCode405Enum.choices) 111 | detail = serializers.CharField() 112 | attr = serializers.CharField(allow_null=True) 113 | 114 | 115 | class ErrorResponse405Serializer(serializers.Serializer): 116 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 117 | errors = Error405Serializer(many=True) 118 | 119 | 120 | class Error406Serializer(serializers.Serializer): 121 | code = serializers.ChoiceField(choices=ErrorCode406Enum.choices) 122 | detail = serializers.CharField() 123 | attr = serializers.CharField(allow_null=True) 124 | 125 | 126 | class ErrorResponse406Serializer(serializers.Serializer): 127 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 128 | errors = Error406Serializer(many=True) 129 | 130 | 131 | class Error415Serializer(serializers.Serializer): 132 | code = serializers.ChoiceField(choices=ErrorCode415Enum.choices) 133 | detail = serializers.CharField() 134 | attr = serializers.CharField(allow_null=True) 135 | 136 | 137 | class ErrorResponse415Serializer(serializers.Serializer): 138 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 139 | errors = Error415Serializer(many=True) 140 | 141 | 142 | class Error429Serializer(serializers.Serializer): 143 | code = serializers.ChoiceField(choices=ErrorCode429Enum.choices) 144 | detail = serializers.CharField() 145 | attr = serializers.CharField(allow_null=True) 146 | 147 | 148 | class ErrorResponse429Serializer(serializers.Serializer): 149 | type = serializers.ChoiceField(choices=ClientErrorEnum.choices) 150 | errors = Error429Serializer(many=True) 151 | 152 | 153 | class Error500Serializer(serializers.Serializer): 154 | code = serializers.ChoiceField(choices=ErrorCode500Enum.choices) 155 | detail = serializers.CharField() 156 | attr = serializers.CharField(allow_null=True) 157 | 158 | 159 | class ErrorResponse500Serializer(serializers.Serializer): 160 | type = serializers.ChoiceField(choices=ServerErrorEnum.choices) 161 | errors = Error500Serializer(many=True) 162 | -------------------------------------------------------------------------------- /drf_standardized_errors/openapi_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field as dataclass_field 2 | from typing import Any, Dict, List, Optional, Set, Type, Union 3 | 4 | from django import forms 5 | from django.core.validators import ( 6 | DecimalValidator, 7 | validate_image_file_extension, 8 | validate_integer, 9 | validate_ipv4_address, 10 | validate_ipv6_address, 11 | validate_ipv46_address, 12 | ) 13 | from drf_spectacular.plumbing import ( 14 | force_instance, 15 | get_view_model, 16 | is_basic_serializer, 17 | is_list_serializer, 18 | is_serializer, 19 | ) 20 | from drf_spectacular.utils import OpenApiExample, PolymorphicProxySerializer 21 | from inflection import camelize 22 | from rest_framework import exceptions, serializers 23 | from rest_framework.settings import api_settings as drf_settings 24 | from rest_framework.status import is_client_error 25 | from rest_framework.validators import ( 26 | BaseUniqueForValidator, 27 | UniqueTogetherValidator, 28 | UniqueValidator, 29 | ) 30 | from rest_framework.views import APIView 31 | 32 | from .openapi_serializers import ValidationErrorEnum 33 | from .settings import package_settings 34 | 35 | 36 | def get_flat_serializer_fields( 37 | field: Union[serializers.Field, List[serializers.Field]], 38 | prefix: Optional[str] = None, 39 | ) -> "List[InputDataField]": 40 | """ 41 | return a flat list of serializer fields. The fields list will later be used 42 | to identify error codes that can be raised by each field. So, it contains 43 | at least one field representing "non field errors" and accounts properly 44 | for composite fields by returning 2 fields: one for the errors linked to 45 | the parent field and another one for errors linked to the child field. 46 | """ 47 | if not field or getattr(field, "read_only", False): 48 | return [] 49 | 50 | field = force_instance(field) 51 | if is_list_serializer(field): 52 | prefix = get_prefix(prefix, field.field_name) 53 | non_field_errors_name = get_prefix(prefix, drf_settings.NON_FIELD_ERRORS_KEY) 54 | f = InputDataField(non_field_errors_name, field) 55 | prefix = get_prefix(prefix, package_settings.LIST_INDEX_IN_API_SCHEMA) 56 | return [f] + get_flat_serializer_fields(field.child, prefix) 57 | elif isinstance(field, PolymorphicProxySerializer): 58 | if isinstance(field.serializers, dict): 59 | return get_flat_serializer_fields(list(field.serializers.values()), prefix) 60 | else: 61 | return get_flat_serializer_fields(field.serializers, prefix) 62 | elif is_serializer(field): 63 | prefix = get_prefix(prefix, field.field_name) 64 | non_field_errors_name = get_prefix(prefix, drf_settings.NON_FIELD_ERRORS_KEY) 65 | f = InputDataField(non_field_errors_name, field) 66 | return [f] + get_flat_serializer_fields(list(field.fields.values()), prefix) 67 | elif isinstance(field, (list, tuple)): 68 | first, *remaining = field 69 | return get_flat_serializer_fields(first, prefix) + get_flat_serializer_fields( 70 | remaining, prefix 71 | ) 72 | elif hasattr(field, "child"): 73 | # composite field (List or Dict fields) 74 | prefix = get_prefix(prefix, field.field_name) 75 | f = InputDataField(prefix, field) 76 | if isinstance(field, serializers.ListField): 77 | child_prefix = get_prefix(prefix, package_settings.LIST_INDEX_IN_API_SCHEMA) 78 | else: 79 | child_prefix = get_prefix(prefix, package_settings.DICT_KEY_IN_API_SCHEMA) 80 | return [f] + get_flat_serializer_fields(field.child, child_prefix) 81 | else: 82 | name = get_prefix(prefix, field.field_name) 83 | return [InputDataField(name, field)] 84 | 85 | 86 | def get_prefix(prefix: Optional[str], name: str) -> str: 87 | if prefix and name: 88 | return f"{prefix}{package_settings.NESTED_FIELD_SEPARATOR}{name}" 89 | elif prefix: 90 | return prefix 91 | else: 92 | return name 93 | 94 | 95 | def get_serializer_fields_with_error_codes( 96 | serializer_fields: "List[InputDataField]", 97 | ) -> "List[InputDataField]": 98 | fields_with_error_codes = [] 99 | for sfield in serializer_fields: 100 | if error_codes := get_serializer_field_error_codes(sfield.field, sfield.name): 101 | sfield.error_codes = error_codes 102 | fields_with_error_codes.append(sfield) 103 | 104 | # add error codes that correspond to unique together and unique for date validators 105 | sfields_with_unique_together_validators = [ 106 | sfield 107 | for sfield in fields_with_error_codes 108 | if is_basic_serializer(sfield.field) 109 | and has_validator(sfield.field, UniqueTogetherValidator) 110 | ] 111 | add_unique_together_error_codes( 112 | sfields_with_unique_together_validators, fields_with_error_codes 113 | ) 114 | 115 | sfields_with_unique_for_validators = [ 116 | sfield 117 | for sfield in fields_with_error_codes 118 | if is_basic_serializer(sfield.field) 119 | and has_validator(sfield.field, BaseUniqueForValidator) 120 | ] 121 | add_unique_for_error_codes( 122 | sfields_with_unique_for_validators, fields_with_error_codes 123 | ) 124 | 125 | return fields_with_error_codes 126 | 127 | 128 | def get_serializer_field_error_codes(field: serializers.Field, attr: str) -> Set[str]: 129 | if field.read_only or isinstance(field, serializers.HiddenField): 130 | return set() 131 | 132 | error_codes = set() 133 | if field.required: 134 | error_codes.add("required") 135 | if not field.allow_null: 136 | error_codes.add("null") 137 | if ( 138 | hasattr(field, "allow_blank") 139 | and not field.allow_blank 140 | and not isinstance(field, serializers.ChoiceField) 141 | ): 142 | error_codes.add("blank") 143 | if getattr(field, "max_digits", None) is not None: 144 | error_codes.add("max_digits") 145 | if getattr(field, "decimal_places", None) is not None: 146 | error_codes.add("max_decimal_places") 147 | if getattr(field, "max_whole_digits", None) is not None: 148 | error_codes.add("max_whole_digits") 149 | if isinstance(field, serializers.DateTimeField): 150 | field_timezone = getattr(field, "timezone", field.default_timezone()) 151 | if field_timezone is not None: 152 | error_codes.update(["overflow", "make_aware"]) 153 | if (hasattr(field, "allow_empty") and not field.allow_empty) or ( 154 | hasattr(field, "allow_empty_file") and not field.allow_empty_file 155 | ): 156 | error_codes.add("empty") 157 | if isinstance(field, serializers.FileField) and field.max_length is not None: 158 | error_codes.add("max_length") 159 | if isinstance(field, serializers.IPAddressField) and field.protocol in ( 160 | "both", 161 | "ipv6", 162 | ): 163 | error_codes.add("invalid") 164 | 165 | # identify error codes based on DRF and django built-in validators 166 | error_codes.update(get_error_codes_from_validators(field)) 167 | if has_validator(field, UniqueValidator): 168 | error_codes.add("unique") 169 | 170 | error_codes_with_specific_conditions = [ 171 | "required", 172 | "null", 173 | "blank", 174 | "max_length", 175 | "min_length", 176 | "max_value", 177 | "min_value", 178 | "max_digits", 179 | "max_decimal_places", 180 | "max_whole_digits", 181 | "overflow", 182 | "make_aware", 183 | "empty", 184 | # for slug field, "invalid_unicode" is added to error_messages but it is 185 | # not set as the validator code. "invalid" is the code used instead. 186 | "invalid_unicode", 187 | ] 188 | fields_where_invalid_is_enforced_by_validators = ( 189 | serializers.EmailField, 190 | serializers.RegexField, 191 | serializers.SlugField, 192 | serializers.URLField, 193 | serializers.IPAddressField, 194 | ) 195 | if isinstance(field, fields_where_invalid_is_enforced_by_validators): 196 | # the "invalid" error code is enforced by a validator and is also added 197 | # to error messages, so it should not be added automatically to error codes 198 | error_codes_with_specific_conditions.append("invalid") 199 | 200 | remaining_error_codes = set(field.error_messages).difference( 201 | error_codes_with_specific_conditions 202 | ) 203 | error_codes.update(remaining_error_codes) 204 | 205 | # for top-level (as opposed to nested) serializer non_field_errors, 206 | # "required" and "null" errors are not raised 207 | if attr == drf_settings.NON_FIELD_ERRORS_KEY: 208 | error_codes = set(error_codes).difference(["required"]) 209 | 210 | # for ManyRelatedFields, add the error codes from the child_relation 211 | # to the parent error codes. That's because DRF raises child_relation 212 | # errors as if raised by the parent (which is a different behavior 213 | # from ListSerializer and ListField). For example, ManyRelatedField 214 | # would return the errors like this: 215 | # {'zones': [ErrorDetail(string='Invalid pk "0" - object does not exist.', code='does_not_exist')]} 216 | # while ListField returns them like this: 217 | # {'zones': {0: [ErrorDetail(string='A valid integer is required.', code='invalid')]}} 218 | if isinstance(field, serializers.ManyRelatedField): 219 | # required and null are added depending on the ManyRelatedField definition 220 | child_error_codes = set(field.child_relation.error_messages).difference( 221 | ["required", "null"] 222 | ) 223 | error_codes.update(child_error_codes) 224 | 225 | return error_codes 226 | 227 | 228 | def add_unique_together_error_codes( 229 | sfields_with_unique_together_validators: "List[InputDataField]", 230 | sfields_with_error_codes: "List[InputDataField]", 231 | ) -> None: 232 | for sfield in sfields_with_unique_together_validators: 233 | sfield.error_codes.add("unique") 234 | unique_together_validators = [ 235 | validator 236 | for validator in sfield.field.validators 237 | if isinstance(validator, UniqueTogetherValidator) 238 | ] 239 | # fields involved in a unique together constraint have an implied 240 | # "required" state, so we're adding the "required" error code to them 241 | implicitly_required_fields = set() 242 | for validator in unique_together_validators: 243 | implicitly_required_fields.update(validator.fields) 244 | for field in implicitly_required_fields: 245 | add_error_code(sfield.name, field, "required", sfields_with_error_codes) 246 | 247 | 248 | def add_unique_for_error_codes( 249 | sfields_with_unique_for_validators: "List[InputDataField]", 250 | sfields_with_error_codes: "List[InputDataField]", 251 | ) -> None: 252 | for sfield in sfields_with_unique_for_validators: 253 | unique_for_validators = [ 254 | validator 255 | for validator in sfield.field.validators 256 | if isinstance(validator, BaseUniqueForValidator) 257 | ] 258 | for v in unique_for_validators: 259 | add_error_code( 260 | sfield.name, v.date_field, "required", sfields_with_error_codes 261 | ) 262 | add_error_code(sfield.name, v.field, "required", sfields_with_error_codes) 263 | add_error_code(sfield.name, v.field, "unique", sfields_with_error_codes) 264 | 265 | 266 | def add_error_code( 267 | attr: str, field_name: str, error_code: str, sfields: "List[InputDataField]" 268 | ) -> None: 269 | """ 270 | To add the error code to the right serializer field, we need to 271 | determine the full field name taking into account nested serializers. 272 | attr ends with drf_settings.NON_FIELD_ERRORS_KEY, so we remove that 273 | and replace it with the field_name. 274 | """ 275 | parts = attr.split(package_settings.NESTED_FIELD_SEPARATOR) 276 | parts[-1] = field_name 277 | full_field_name = package_settings.NESTED_FIELD_SEPARATOR.join(parts) 278 | 279 | for sfield in sfields: 280 | if sfield.name == full_field_name: 281 | sfield.error_codes.add(error_code) 282 | break 283 | 284 | 285 | def get_filter_forms(view: APIView, filter_backends: list) -> List[forms.Form]: 286 | filter_forms = [] 287 | for backend in filter_backends: 288 | model = get_view_model(view) 289 | if not model: 290 | continue 291 | filterset = backend.get_filterset( 292 | view.request, model._default_manager.none(), view 293 | ) 294 | if filterset: 295 | filter_forms.append(filterset.form) 296 | return filter_forms 297 | 298 | 299 | def get_form_fields_with_error_codes(form: forms.Form) -> "List[InputDataField]": 300 | data_fields = [] 301 | for field_name, field in form.fields.items(): 302 | error_codes = set() 303 | fields = get_form_fields(field) 304 | for f in fields: 305 | error_codes.update(get_form_field_error_codes(f)) 306 | if error_codes: 307 | data_fields.append(InputDataField(field_name, field, error_codes)) 308 | return data_fields 309 | 310 | 311 | def get_form_fields(field: Union[forms.Field, List[forms.Field]]) -> List[forms.Field]: 312 | if not field: 313 | return [] 314 | 315 | if isinstance(field, (list, tuple)): 316 | first, *rest = field 317 | return get_form_fields(first) + get_form_fields(rest) 318 | elif isinstance(field, (forms.ComboField, forms.MultiValueField)): 319 | return [field] + get_form_fields(field.fields) 320 | else: 321 | return [field] 322 | 323 | 324 | def get_form_field_error_codes(field: forms.Field) -> Set[str]: 325 | if field.disabled: 326 | return set() 327 | 328 | error_codes = set() 329 | if field.required: 330 | error_codes.add("required") 331 | if isinstance(field, forms.FileField) and field.max_length is not None: 332 | error_codes.add("max_length") 333 | if isinstance(field, forms.FileField) and not field.allow_empty_file: 334 | error_codes.add("empty") 335 | if isinstance(field, forms.GenericIPAddressField): 336 | # because to_python calls clean_ipv6_address which can raise an error 337 | # with this code 338 | error_codes.add("invalid") 339 | 340 | # add the error codes of built-in django validators 341 | error_codes.update(get_error_codes_from_validators(field)) 342 | 343 | # add the error codes defined in error_messages after excluding the ones 344 | # that are conditionally raised 345 | error_codes_with_specific_conditions = ["required", "max_length", "empty"] 346 | remaining_error_codes = set(field.error_messages).difference( 347 | error_codes_with_specific_conditions 348 | ) 349 | error_codes.update(remaining_error_codes) 350 | 351 | # the "missing" error code is defined but never used by FileField 352 | # the "incomplete" error code is not used when raising the related 353 | # ValidationError in forms.MultiValueField 354 | return error_codes.difference(["missing", "incomplete"]) 355 | 356 | 357 | def has_validator( 358 | field: Union[serializers.Field, forms.Field], validator: Type 359 | ) -> bool: 360 | return any(isinstance(v, validator) for v in field.validators) 361 | 362 | 363 | def get_error_codes_from_validators( 364 | field: Union[serializers.Field, forms.Field], 365 | ) -> Set[str]: 366 | error_codes = set() 367 | 368 | for validator in field.validators: 369 | if code := getattr(validator, "code", None): 370 | error_codes.add(code) 371 | 372 | if validators := [v for v in field.validators if isinstance(v, DecimalValidator)]: 373 | validator = validators[0] 374 | if validator.max_digits is not None: 375 | error_codes.add("max_digits") 376 | if validator.decimal_places is not None: 377 | error_codes.add("max_decimal_places") 378 | if validator.decimal_places is not None and validator.max_digits is not None: 379 | error_codes.add("max_whole_digits") 380 | 381 | if ( 382 | validate_ipv4_address in field.validators 383 | or validate_ipv6_address in field.validators 384 | or validate_ipv46_address in field.validators 385 | or validate_integer in field.validators 386 | ): 387 | error_codes.add("invalid") 388 | 389 | if validate_image_file_extension in field.validators: 390 | error_codes.add("invalid_extension") 391 | 392 | return error_codes 393 | 394 | 395 | def get_validation_error_serializer( 396 | operation_id: str, error_codes_by_field: Dict[str, Set[str]] 397 | ) -> Type[serializers.Serializer]: 398 | validation_error_component_name = f"{camelize(operation_id)}ValidationError" 399 | errors_component_name = f"{camelize(operation_id)}Error" 400 | 401 | sub_serializers = { 402 | field_name: get_error_serializer(operation_id, field_name, error_codes) 403 | for field_name, error_codes in error_codes_by_field.items() 404 | } 405 | 406 | class ValidationErrorSerializer(serializers.Serializer): 407 | type = serializers.ChoiceField(choices=ValidationErrorEnum.choices) 408 | errors = PolymorphicProxySerializer( 409 | component_name=errors_component_name, 410 | resource_type_field_name="attr", 411 | serializers=sub_serializers, 412 | many=True, 413 | ) 414 | 415 | class Meta: 416 | ref_name = validation_error_component_name 417 | 418 | return ValidationErrorSerializer 419 | 420 | 421 | def get_error_serializer( 422 | operation_id: str, attr: Optional[str], error_codes: Set[str] 423 | ) -> Type[serializers.Serializer]: 424 | attr_kwargs: Dict[str, Any] = {"choices": [(attr, attr)]} 425 | if not attr: 426 | attr_kwargs["allow_null"] = True 427 | error_code_choices = sorted(zip(error_codes, error_codes)) 428 | 429 | camelcase_operation_id = camelize(operation_id) 430 | attr_with_underscores = (attr or "").replace( 431 | package_settings.NESTED_FIELD_SEPARATOR, "_" 432 | ) 433 | camelcase_attr = camelize(attr_with_underscores) 434 | suffix = package_settings.ERROR_COMPONENT_NAME_SUFFIX 435 | component_name = f"{camelcase_operation_id}{camelcase_attr}{suffix}" 436 | 437 | class ErrorSerializer(serializers.Serializer): 438 | attr = serializers.ChoiceField(**attr_kwargs) 439 | code = serializers.ChoiceField(choices=error_code_choices) 440 | detail = serializers.CharField() 441 | 442 | class Meta: 443 | ref_name = component_name 444 | 445 | return ErrorSerializer 446 | 447 | 448 | @dataclass 449 | class InputDataField: 450 | name: str 451 | field: Union[serializers.Field, forms.Field] 452 | error_codes: Set[str] = dataclass_field(default_factory=set) 453 | 454 | 455 | def get_django_filter_backends(backends: list) -> list: 456 | """determine django filter backends that raise validation errors""" 457 | try: 458 | from django_filters.rest_framework import DjangoFilterBackend 459 | except ImportError: 460 | return [] 461 | 462 | filter_backends = [filter_backend() for filter_backend in backends] 463 | return [ 464 | backend 465 | for backend in filter_backends 466 | if isinstance(backend, DjangoFilterBackend) and backend.raise_exception 467 | ] 468 | 469 | 470 | def get_error_examples() -> List[OpenApiExample]: 471 | """ 472 | error examples for media type "application/json". The main reason for 473 | adding them is that they will show `"attr": null` instead of the 474 | auto-generated `"attr": "string"` 475 | """ 476 | errors = [ 477 | exceptions.AuthenticationFailed(), 478 | exceptions.NotAuthenticated(), 479 | exceptions.PermissionDenied(), 480 | exceptions.NotFound(), 481 | exceptions.MethodNotAllowed("get"), 482 | exceptions.NotAcceptable(), 483 | exceptions.UnsupportedMediaType("application/json"), 484 | exceptions.Throttled(), 485 | exceptions.APIException(), 486 | ] 487 | return [get_example_from_exception(error) for error in errors] 488 | 489 | 490 | def get_example_from_exception(exc: exceptions.APIException) -> OpenApiExample: 491 | if is_client_error(exc.status_code): 492 | type_ = "client_error" 493 | else: 494 | type_ = "server_error" 495 | return OpenApiExample( 496 | exc.__class__.__name__, 497 | value={ 498 | "type": type_, 499 | "errors": [{"code": exc.get_codes(), "detail": exc.detail, "attr": None}], 500 | }, 501 | response_only=True, 502 | status_codes=[str(exc.status_code)], 503 | ) 504 | -------------------------------------------------------------------------------- /drf_standardized_errors/openapi_validation_errors.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | from collections import defaultdict 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union 6 | 7 | from drf_spectacular.drainage import error, warn 8 | from drf_spectacular.openapi import AutoSchema 9 | from rest_framework.views import APIView 10 | from rest_framework.viewsets import ViewSetMixin 11 | 12 | from .types import SetValidationErrorsKwargs 13 | 14 | V = TypeVar("V", bound=Union[Type[APIView], Callable[..., Any]]) 15 | 16 | 17 | def extend_validation_errors( 18 | error_codes: List[str], 19 | field_name: Optional[str] = None, 20 | actions: Optional[List[str]] = None, 21 | methods: Optional[List[str]] = None, 22 | versions: Optional[List[str]] = None, 23 | ) -> Callable[[V], V]: 24 | """ 25 | A view/viewset decorator for adding extra error codes to validation errors. 26 | This decorator does not override error codes already collected by 27 | drf-standardized-errors. 28 | 29 | :param error_codes: list of error codes to add. 30 | :param field_name: name of serializer or form field to which the error codes 31 | will be added. It can be set to ``"non_field_errors"`` when the error codes 32 | correspond to validation inside ``Serializer.validate`` or ``"__all__"`` when 33 | they correspond to validation inside ``Form.clean``. It can also be left 34 | as ``None`` when the validation is not linked to any serializer or form 35 | (for example, raising ``serializers.ValidationError`` inside the view 36 | or viewset directly). 37 | :param actions: can be set when decorating a viewset. Limits the added error 38 | codes to the specified actions. Defaults to adding the error codes to all 39 | actions. 40 | :param methods: Limits the added error codes to the specified methods (get, 41 | post, ...). Defaults to adding the error codes regardless of the method. 42 | :param versions: Limits the added error codes to the specified versions. 43 | Defaults to adding the error codes regardless of the version. 44 | """ 45 | if methods: 46 | methods = [method.lower() for method in methods] 47 | 48 | def wrapper(view): # type: ignore 49 | # special case for @api_view. Decorate the WrappedAPIView class 50 | if callable(view) and hasattr(view, "cls"): 51 | extend_validation_errors( 52 | error_codes, field_name, actions, methods, versions 53 | )(view.cls) 54 | return view 55 | 56 | if not inspect.isclass(view) or ( 57 | inspect.isclass(view) and not issubclass(view, APIView) 58 | ): 59 | error( 60 | "`@extend_validation_errors` can only be applied to APIViews or " 61 | "ViewSets or function-based views already decorated with @api_view. " 62 | f"{view.__name__} is none of these." 63 | ) 64 | return view 65 | 66 | if not error_codes: 67 | error( 68 | "No error codes are passed to the `@extend_validation_errors` " 69 | f"decorator that is applied to {view.__name__}." 70 | ) 71 | return view 72 | 73 | kwargs: SetValidationErrorsKwargs = { 74 | "error_codes": error_codes, 75 | "field_name": field_name, 76 | "actions": actions, 77 | "methods": methods, 78 | "versions": versions, 79 | } 80 | if actions and issubclass(view, ViewSetMixin): 81 | # validate the actions provided are indeed defined on the viewset class 82 | possible_actions = get_action_names(view) 83 | unknown_actions = set(actions).difference(possible_actions) 84 | if unknown_actions: 85 | is_or_are = "is" if len(unknown_actions) == 1 else "are" 86 | warn( 87 | f"'{', '.join(unknown_actions)}' {is_or_are} not in the list of " 88 | f"actions defined on the viewset {view.__name__}. The actions " 89 | "specified will be ignored." 90 | ) 91 | kwargs["actions"] = None 92 | elif actions: 93 | warn( 94 | "The 'actions' argument of 'extend_validation_errors' should only be " 95 | f"set when decorating viewsets. '{view.__name__}' is not a viewset. " 96 | "The actions specified will be ignored." 97 | ) 98 | kwargs["actions"] = None 99 | 100 | if methods: 101 | # validate that the methods are in the list of allowed methods 102 | allowed_methods = get_allowed_http_methods(view) 103 | unknown_methods = set(methods).difference(allowed_methods) 104 | if unknown_methods: 105 | is_or_are = "is" if len(unknown_methods) == 1 else "are" 106 | warn( 107 | f"'{', '.join(unknown_methods)}' {is_or_are} not in the list of " 108 | f"allowed http methods of {view.__name__}. The methods specified " 109 | "will be ignored." 110 | ) 111 | kwargs["methods"] = None 112 | 113 | # now that all checks are done, let's set the extra validation error 114 | # on the view to later add them to the schema 115 | set_validation_errors(view, **kwargs) 116 | 117 | return view 118 | 119 | return wrapper 120 | 121 | 122 | def get_action_names(viewset: Type[ViewSetMixin]) -> List[str]: 123 | # based on drf_spectacular.drainage.get_view_method_names 124 | builtin_action_names = ["list"] + list(viewset.schema.method_mapping.values()) 125 | return [ 126 | item 127 | for item in dir(viewset) 128 | if callable(getattr(viewset, item)) 129 | and (item in builtin_action_names or is_custom_action(viewset, item)) 130 | ] 131 | 132 | 133 | def is_custom_action(viewset: Type[ViewSetMixin], method_name: str) -> bool: 134 | # i.e. defined using the @action decorator 135 | return hasattr(getattr(viewset, method_name), "mapping") 136 | 137 | 138 | def get_allowed_http_methods(view: Type[APIView]) -> List[str]: 139 | if issubclass(view, ViewSetMixin): 140 | return view.http_method_names 141 | else: 142 | # based on drf_spectacular.drainage.get_view_method_names 143 | return [ 144 | item 145 | for item in dir(view) 146 | if callable(getattr(view, item)) and item in view.http_method_names 147 | ] 148 | 149 | 150 | def set_validation_errors( 151 | view: Type[APIView], 152 | error_codes: List[str], 153 | field_name: Optional[str], 154 | actions: Optional[List[str]], 155 | methods: Optional[List[str]], 156 | versions: Optional[List[str]], 157 | ) -> None: 158 | if hasattr(view, "_standardized_errors"): 159 | if "_standardized_errors" not in vars(view): 160 | # that means it is defined on a parent class, so we first create 161 | # a copy of it to avoid the validation error showing for the parent 162 | # class as well 163 | view._standardized_errors = copy.deepcopy(view._standardized_errors) 164 | else: 165 | view._standardized_errors = defaultdict(list) 166 | 167 | errors = generate_standardized_errors( 168 | error_codes, field_name, actions, methods, versions 169 | ) 170 | 171 | # errors are stored in a list to preserve order. When determining the error 172 | # codes for each field for a specific operation, we will traverse this list 173 | # in reverse order and pick the first encountered error that is in scope 174 | # of the operation in question. The reason we do this in reverse order is 175 | # to account the ability to override error codes in a child view. 176 | view._standardized_errors[field_name].extend(errors) 177 | 178 | 179 | def generate_standardized_errors( 180 | error_codes: List[str], 181 | field_name: Optional[str], 182 | actions: Optional[List[str]], 183 | methods: Optional[List[str]], 184 | versions: Optional[List[str]], 185 | ) -> "List[StandardizedError]": 186 | actions = actions or [None] # type: ignore 187 | methods = methods or [None] # type: ignore 188 | versions = versions or [None] # type: ignore 189 | 190 | return [ 191 | StandardizedError(set(error_codes), field_name, action, method, version) 192 | for action in actions 193 | for method in methods 194 | for version in versions 195 | ] 196 | 197 | 198 | def get_validation_errors(view: APIView) -> "Dict[str, List[StandardizedError]]": 199 | return getattr(view, "_standardized_errors", {}) 200 | 201 | 202 | @dataclass 203 | class StandardizedError: 204 | error_codes: Set[str] 205 | field_name: Optional[str] = None 206 | action: Optional[str] = None 207 | method: Optional[str] = None 208 | version: Optional[str] = None 209 | 210 | def is_in_scope(self, schema: AutoSchema) -> bool: 211 | """Determine if the error is in scope of the current operation""" 212 | view = schema.view 213 | api_version, _ = view.determine_version(view.request, **view.kwargs) 214 | version_in_scope = self.version is None or self.version == api_version 215 | method_in_scope = self.method is None or self.method == schema.method.lower() 216 | 217 | if isinstance(view, ViewSetMixin): 218 | action_in_scope = self.action is None or self.action == view.action 219 | return action_in_scope and method_in_scope and version_in_scope 220 | else: 221 | return method_in_scope and version_in_scope 222 | -------------------------------------------------------------------------------- /drf_standardized_errors/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghazi-git/drf-standardized-errors/2245a078ae845bc15437d7847e7ee754e8fc8a17/drf_standardized_errors/py.typed -------------------------------------------------------------------------------- /drf_standardized_errors/settings.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Set, Tuple 2 | 3 | from django.conf import settings 4 | from django.core.signals import setting_changed 5 | from django.dispatch import receiver 6 | from rest_framework.settings import import_from_string, perform_import 7 | 8 | 9 | class PackageSettings: 10 | """ 11 | Copy of DRF APISettings class with support for importing settings that 12 | are dicts with value as a string representing the path to the class 13 | to be imported. 14 | """ 15 | 16 | setting_name = "DRF_STANDARDIZED_ERRORS" 17 | 18 | def __init__( 19 | self, 20 | defaults: Optional[Dict[str, Any]] = None, 21 | import_strings: Optional[Tuple[str, ...]] = None, 22 | ): 23 | self.defaults = defaults or DEFAULTS 24 | self.import_strings = import_strings or IMPORT_STRINGS 25 | self._cached_attrs: Set[str] = set() 26 | 27 | @property 28 | def user_settings(self) -> Dict[str, Any]: 29 | if not hasattr(self, "_user_settings"): 30 | self._user_settings = getattr(settings, self.setting_name, {}) 31 | return self._user_settings 32 | 33 | def __getattr__(self, attr: str) -> Any: 34 | if attr not in self.defaults: 35 | raise AttributeError(f"Invalid API setting: '{attr}'") 36 | 37 | try: 38 | # Check if present in user settings 39 | val = self.user_settings[attr] 40 | except KeyError: 41 | # Fall back to defaults 42 | val = self.defaults[attr] 43 | 44 | # Coerce import strings into classes 45 | if attr in self.import_strings: 46 | if isinstance(val, dict): 47 | val = { 48 | status_code: import_from_string(error_schema, attr) 49 | for status_code, error_schema in val.items() 50 | } 51 | else: 52 | val = perform_import(val, attr) 53 | 54 | # Cache the result 55 | self._cached_attrs.add(attr) 56 | setattr(self, attr, val) 57 | return val 58 | 59 | def reload(self) -> None: 60 | for attr in self._cached_attrs: 61 | delattr(self, attr) 62 | self._cached_attrs.clear() 63 | if hasattr(self, "_user_settings"): 64 | delattr(self, "_user_settings") 65 | 66 | 67 | DEFAULTS: Dict[str, Any] = { 68 | "EXCEPTION_HANDLER_CLASS": "drf_standardized_errors.handler.ExceptionHandler", 69 | "EXCEPTION_FORMATTER_CLASS": "drf_standardized_errors.formatter.ExceptionFormatter", 70 | "ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": False, 71 | "NESTED_FIELD_SEPARATOR": ".", 72 | "ALLOWED_ERROR_STATUS_CODES": [ 73 | "400", 74 | "401", 75 | "403", 76 | "404", 77 | "405", 78 | "406", 79 | "415", 80 | "429", 81 | "500", 82 | ], 83 | "ERROR_SCHEMAS": None, 84 | "LIST_INDEX_IN_API_SCHEMA": "INDEX", 85 | "DICT_KEY_IN_API_SCHEMA": "KEY", 86 | "ERROR_COMPONENT_NAME_SUFFIX": "ErrorComponent", 87 | } 88 | 89 | IMPORT_STRINGS = ( 90 | "EXCEPTION_FORMATTER_CLASS", 91 | "EXCEPTION_HANDLER_CLASS", 92 | "ERROR_SCHEMAS", 93 | ) 94 | 95 | package_settings = PackageSettings(DEFAULTS, IMPORT_STRINGS) 96 | 97 | 98 | @receiver(setting_changed) 99 | def reload_package_settings(*args: Any, **kwargs: Any) -> None: 100 | setting = kwargs["setting"] 101 | if setting == package_settings.setting_name: 102 | package_settings.reload() 103 | -------------------------------------------------------------------------------- /drf_standardized_errors/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Final, List, Literal, Optional, TypedDict 3 | 4 | from rest_framework.request import Request 5 | from rest_framework.views import APIView 6 | 7 | 8 | class ExceptionHandlerContext(TypedDict): 9 | view: APIView 10 | args: tuple 11 | kwargs: dict 12 | request: Optional[Request] 13 | 14 | 15 | VALIDATION_ERROR: Final = "validation_error" 16 | CLIENT_ERROR: Final = "client_error" 17 | SERVER_ERROR: Final = "server_error" 18 | ErrorType = Literal["validation_error", "client_error", "server_error"] 19 | 20 | 21 | @dataclass 22 | class Error: 23 | code: str 24 | detail: str 25 | attr: Optional[str] 26 | 27 | 28 | @dataclass 29 | class ErrorResponse: 30 | type: ErrorType 31 | errors: List[Error] 32 | 33 | 34 | class SetValidationErrorsKwargs(TypedDict): 35 | error_codes: List[str] 36 | field_name: Optional[str] 37 | actions: Optional[List[str]] 38 | methods: Optional[List[str]] 39 | versions: Optional[List[str]] 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "drf-standardized-errors" 7 | keywords = [ 8 | "standardized errors", 9 | "errors formatter", 10 | "django rest framework", 11 | "exception handler", 12 | ] 13 | authors = [{ name = "Ghazi Abbassi" }] 14 | license = { file = "LICENSE" } 15 | readme = "README.md" 16 | classifiers = [ 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.8", 21 | "Programming Language :: Python :: 3.9", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Programming Language :: Python :: 3.13", 26 | ] 27 | dynamic = ["version", "description"] 28 | requires-python = ">=3.8" 29 | dependencies = [ 30 | "django >=3.2", 31 | "djangorestframework >=3.12", 32 | ] 33 | 34 | [project.urls] 35 | Homepage = "https://github.com/ghazi-git/drf-standardized-errors" 36 | Documentation = "https://drf-standardized-errors.readthedocs.io/en/latest/" 37 | Code = "https://github.com/ghazi-git/drf-standardized-errors" 38 | Issues = "https://github.com/ghazi-git/drf-standardized-errors/issues" 39 | Changelog = "https://github.com/ghazi-git/drf-standardized-errors/releases" 40 | 41 | [tool.flit.module] 42 | name = "drf_standardized_errors" 43 | 44 | [project.optional-dependencies] 45 | dev = ["pre-commit"] 46 | doc = [ 47 | "sphinx!=5.2.0.post0", 48 | "sphinx-autobuild", 49 | "sphinx-rtd-theme>=1.1.0", 50 | "myst-parser", 51 | ] 52 | test = [ 53 | "tox", 54 | "tox-gh-actions", 55 | ] 56 | release = [ 57 | "flit", 58 | "keyring", 59 | "tbump", 60 | ] 61 | openapi = [ 62 | "drf-spectacular>=0.27.1", 63 | "inflection", 64 | ] 65 | 66 | [tool.tbump] 67 | 68 | [tool.tbump.version] 69 | current = "0.15.0" 70 | regex = ''' 71 | (?P\d+) 72 | \. 73 | (?P\d+) 74 | \. 75 | (?P\d+) 76 | ''' 77 | 78 | [tool.tbump.git] 79 | message_template = "Bump to {new_version}" 80 | tag_template = "v{new_version}" 81 | 82 | [[tool.tbump.file]] 83 | src = "drf_standardized_errors/__init__.py" 84 | search = '__version__ = "{current_version}"' 85 | 86 | [[tool.tbump.before_commit]] 87 | name = "Update the changelog" 88 | cmd = "python release/update_changelog.py --new-version {new_version}" 89 | 90 | [[tool.tbump.after_push]] 91 | name = "Publish to PyPI" 92 | cmd = "flit publish" 93 | -------------------------------------------------------------------------------- /release/update_changelog.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from argparse import ArgumentParser 3 | from datetime import date 4 | 5 | CHANGELOG_RELATIVE_PATH = "docs/changelog.md" 6 | CURRENT_DIR = pathlib.Path(__file__).parent.resolve() 7 | CHANGELOG_ABS_PATH = CURRENT_DIR.parent / CHANGELOG_RELATIVE_PATH 8 | UNRELEASED_LINE = "## [UNRELEASED]" 9 | NEW_VERSION_LINE_TEMPLATE = "## [{version}] - {date}\n" 10 | 11 | 12 | def main(version): 13 | lines = get_changelog_file_content() 14 | unreleased_line_number = find_unreleased_line_number(lines) 15 | lines = add_new_version_to_changelog(lines, unreleased_line_number, version) 16 | update_changelog_file(lines) 17 | 18 | 19 | def get_changelog_file_content(): 20 | with open(CHANGELOG_ABS_PATH, encoding="utf-8") as f: 21 | return f.readlines() 22 | 23 | 24 | def find_unreleased_line_number(lines): 25 | unreleased_line_number = None 26 | for i, line in enumerate(lines): 27 | if line.strip().lower() == UNRELEASED_LINE.lower(): 28 | unreleased_line_number = i 29 | break 30 | 31 | if unreleased_line_number is None: 32 | # Abort the release process as we're not able to update the changelog. 33 | # If, for some reason, there is no need to update the changelog, comment 34 | # the step to update the changelog in pyproject.toml 35 | raise ChangelogUpdateError( 36 | f"Unable to find a line with the text '{UNRELEASED_LINE}' when " 37 | f"trying to update the changelog at '{CHANGELOG_RELATIVE_PATH}'." 38 | ) 39 | return unreleased_line_number 40 | 41 | 42 | def add_new_version_to_changelog(lines, unreleased_line_number, version): 43 | new_version_line = NEW_VERSION_LINE_TEMPLATE.format( 44 | version=version, date=date.today() 45 | ) 46 | return ( 47 | lines[: unreleased_line_number + 1] 48 | + ["\n", new_version_line] 49 | + lines[unreleased_line_number + 1 :] 50 | ) 51 | 52 | 53 | def update_changelog_file(lines): 54 | with open(CHANGELOG_ABS_PATH, "w", encoding="utf-8") as f: 55 | f.writelines(lines) 56 | 57 | 58 | class ChangelogUpdateError(Exception): 59 | """Unable to update the changelog with the new version""" 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = ArgumentParser( 64 | description=( 65 | "Replace UNRELEASED with the new version and current date and add " 66 | "a new unreleased section to the changelog." 67 | ) 68 | ) 69 | parser.add_argument( 70 | "-n", 71 | "--new-version", 72 | dest="version", 73 | help="The new version to be released", 74 | metavar="NEW_VERSION", 75 | ) 76 | args = parser.parse_args() 77 | 78 | main(args.version) 79 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghazi-git/drf-standardized-errors/2245a078ae845bc15437d7847e7ee754e8fc8a17/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rest_framework.test import APIClient, APIRequestFactory 3 | 4 | from .views import ErrorView 5 | 6 | 7 | @pytest.fixture 8 | def api_client(): 9 | return APIClient(raise_request_exception=False) 10 | 11 | 12 | @pytest.fixture 13 | def api_request(): 14 | factory = APIRequestFactory() 15 | return factory.get("/error/") 16 | 17 | 18 | @pytest.fixture 19 | def exception_context(api_request): 20 | return {"view": ErrorView(), "args": (), "kwargs": {}, "request": api_request} 21 | 22 | 23 | @pytest.fixture 24 | def exc(): 25 | return Exception("Internal server error.") 26 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Post(models.Model): 5 | title = models.CharField(max_length=200, unique_for_date="published_at") 6 | body = models.TextField() 7 | published_at = models.DateField() 8 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | SECRET_KEY = "some_secret_key" 2 | 3 | USE_TZ = True 4 | 5 | INSTALLED_APPS = ( 6 | "django.contrib.auth", 7 | "django.contrib.contenttypes", 8 | "rest_framework", 9 | "drf_standardized_errors", 10 | "tests", 11 | ) 12 | 13 | DATABASES = {"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}} 14 | 15 | MIDDLEWARE_CLASSES = ("django.middleware.common.CommonMiddleware",) 16 | 17 | PASSWORD_HASHERS = ("django.contrib.auth.hashers.MD5PasswordHasher",) 18 | 19 | ROOT_URLCONF = "tests.urls" 20 | 21 | REST_FRAMEWORK = { 22 | "EXCEPTION_HANDLER": "drf_standardized_errors.handler.exception_handler", 23 | "DEFAULT_AUTHENTICATION_CLASSES": [], 24 | "DEFAULT_PERMISSION_CLASSES": [], 25 | "TEST_REQUEST_DEFAULT_FORMAT": "json", 26 | "DEFAULT_SCHEMA_CLASS": "drf_standardized_errors.openapi.AutoSchema", 27 | } 28 | 29 | SPECTACULAR_SETTINGS = { 30 | "TITLE": "API", 31 | "DESCRIPTION": "Amazing API", 32 | "VERSION": "1.0.0", 33 | "SERVE_INCLUDE_SCHEMA": False, 34 | "ENUM_NAME_OVERRIDES": { 35 | "ValidationErrorEnum": "drf_standardized_errors.openapi_serializers.ValidationErrorEnum.choices", 36 | "ClientErrorEnum": "drf_standardized_errors.openapi_serializers.ClientErrorEnum.choices", 37 | "ServerErrorEnum": "drf_standardized_errors.openapi_serializers.ServerErrorEnum.choices", 38 | "ErrorCode401Enum": "drf_standardized_errors.openapi_serializers.ErrorCode401Enum.choices", 39 | "ErrorCode403Enum": "drf_standardized_errors.openapi_serializers.ErrorCode403Enum.choices", 40 | "ErrorCode404Enum": "drf_standardized_errors.openapi_serializers.ErrorCode404Enum.choices", 41 | "ErrorCode405Enum": "drf_standardized_errors.openapi_serializers.ErrorCode405Enum.choices", 42 | "ErrorCode406Enum": "drf_standardized_errors.openapi_serializers.ErrorCode406Enum.choices", 43 | "ErrorCode415Enum": "drf_standardized_errors.openapi_serializers.ErrorCode415Enum.choices", 44 | "ErrorCode429Enum": "drf_standardized_errors.openapi_serializers.ErrorCode429Enum.choices", 45 | "ErrorCode500Enum": "drf_standardized_errors.openapi_serializers.ErrorCode500Enum.choices", 46 | }, 47 | "POSTPROCESSING_HOOKS": [ 48 | "drf_standardized_errors.openapi_hooks.postprocess_schema_enums" 49 | ], 50 | } 51 | -------------------------------------------------------------------------------- /tests/test_exception_handler.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | from django.core.exceptions import PermissionDenied as DjangoPermissionDenied 5 | from django.core.signals import got_request_exception 6 | from django.http import Http404 7 | from rest_framework.exceptions import ( 8 | APIException, 9 | ErrorDetail, 10 | PermissionDenied, 11 | ValidationError, 12 | ) 13 | 14 | from drf_standardized_errors.handler import exception_handler 15 | 16 | 17 | @pytest.fixture 18 | def validation_error(): 19 | return ValidationError( 20 | {"name": [ErrorDetail("This field is required.", code="required")]} 21 | ) 22 | 23 | 24 | def test_validation_error(validation_error, exception_context): 25 | response = exception_handler(validation_error, exception_context) 26 | assert response.status_code == 400 27 | assert response.data["type"] == "validation_error" 28 | assert len(response.data["errors"]) == 1 29 | error = response.data["errors"][0] 30 | assert error["code"] == "required" 31 | assert error["detail"] == "This field is required." 32 | assert error["attr"] == "name" 33 | 34 | 35 | @pytest.fixture 36 | def permission_denied_error(): 37 | return PermissionDenied() 38 | 39 | 40 | def test_permission_denied_error(permission_denied_error, exception_context): 41 | response = exception_handler(permission_denied_error, exception_context) 42 | assert response.status_code == 403 43 | assert response.data["type"] == "client_error" 44 | assert len(response.data["errors"]) == 1 45 | error = response.data["errors"][0] 46 | assert error["code"] == "permission_denied" 47 | assert error["detail"] == "You do not have permission to perform this action." 48 | assert error["attr"] is None 49 | 50 | 51 | @pytest.fixture 52 | def server_error(): 53 | return APIException() 54 | 55 | 56 | def test_server_error(server_error, exception_context): 57 | response = exception_handler(server_error, exception_context) 58 | assert response.status_code == 500 59 | assert response.data["type"] == "server_error" 60 | assert len(response.data["errors"]) == 1 61 | error = response.data["errors"][0] 62 | assert error["code"] == "error" 63 | assert error["detail"] == "A server error occurred." 64 | assert error["attr"] is None 65 | 66 | 67 | @pytest.fixture 68 | def service_unavailable_error(): 69 | return ServiceUnavailable() 70 | 71 | 72 | class ServiceUnavailable(APIException): 73 | status_code = 503 74 | default_detail = "Service temporarily unavailable, try again later." 75 | default_code = "service_unavailable" 76 | 77 | 78 | def test_custom_exception(service_unavailable_error, exception_context): 79 | response = exception_handler(service_unavailable_error, exception_context) 80 | assert response.status_code == 503 81 | assert response.data["type"] == "server_error" 82 | assert len(response.data["errors"]) == 1 83 | error = response.data["errors"][0] 84 | assert error["code"] == "service_unavailable" 85 | assert error["detail"] == "Service temporarily unavailable, try again later." 86 | assert error["attr"] is None 87 | 88 | 89 | @pytest.fixture 90 | def unhandled_exception(): 91 | return Exception() 92 | 93 | 94 | def test_unhandled_exception(settings, unhandled_exception, exception_context): 95 | settings.DEBUG = True 96 | response = exception_handler(unhandled_exception, exception_context) 97 | assert response is None 98 | 99 | 100 | def test_got_request_exception_signal_sent(server_error, exception_context): 101 | # register a signal handler 102 | mock = MagicMock() 103 | got_request_exception.connect(mock) 104 | 105 | exception_handler(server_error, exception_context) 106 | assert mock.called 107 | 108 | 109 | def test_got_request_exception_signal_not_sent(validation_error, exception_context): 110 | # register a signal handler 111 | mock = MagicMock() 112 | got_request_exception.connect(mock) 113 | 114 | exception_handler(validation_error, exception_context) 115 | assert mock.called is False 116 | 117 | 118 | @pytest.fixture 119 | def django_permission_denied(): 120 | return DjangoPermissionDenied() 121 | 122 | 123 | def test_django_permission_denied_conversion( 124 | django_permission_denied, exception_context 125 | ): 126 | response = exception_handler(django_permission_denied, exception_context) 127 | assert response.status_code == 403 128 | assert response.data["type"] == "client_error" 129 | assert len(response.data["errors"]) == 1 130 | error = response.data["errors"][0] 131 | assert error["code"] == "permission_denied" 132 | 133 | 134 | @pytest.fixture 135 | def http404_error(): 136 | return Http404() 137 | 138 | 139 | def test_django_http404_conversion(http404_error, exception_context): 140 | response = exception_handler(http404_error, exception_context) 141 | assert response.status_code == 404 142 | assert response.data["type"] == "client_error" 143 | assert len(response.data["errors"]) == 1 144 | error = response.data["errors"][0] 145 | assert error["code"] == "not_found" 146 | 147 | 148 | def test_auth_header_is_set(api_client): 149 | response = api_client.get("/auth-error/") 150 | assert response.status_code == 401 151 | auth_header = response.headers.get("WWW-Authenticate") 152 | assert auth_header == 'Basic realm="api"' 153 | 154 | 155 | def test_retry_after_header_is_set(api_client): 156 | response = api_client.get("/rate-limit-error/") 157 | assert response.status_code == 429 158 | retry_after_header = response.headers.get("Retry-After") 159 | assert retry_after_header == "600" 160 | -------------------------------------------------------------------------------- /tests/test_flatten_errors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rest_framework.exceptions import ErrorDetail 3 | from rest_framework.test import APIClient 4 | 5 | from drf_standardized_errors.formatter import flatten_errors 6 | 7 | 8 | @pytest.fixture 9 | def required_name_error(): 10 | return {"name": [ErrorDetail("This field is required.", code="required")]} 11 | 12 | 13 | def test_one_error(required_name_error): 14 | errors = flatten_errors(required_name_error) 15 | assert len(errors) == 1 16 | assert errors[0].code == "required" 17 | assert errors[0].detail == "This field is required." 18 | assert errors[0].attr == "name" 19 | 20 | 21 | @pytest.fixture 22 | def multiple_errors(): 23 | return { 24 | "phone": [ 25 | ErrorDetail( 26 | "The phone number entered is not valid.", code="invalid_phone_number" 27 | ) 28 | ], 29 | "password": [ 30 | ErrorDetail("This password is too short.", code="password_too_short"), 31 | ErrorDetail( 32 | "The password is too similar to the username.", 33 | code="password_too_similar", 34 | ), 35 | ], 36 | } 37 | 38 | 39 | def test_multiple_errors(multiple_errors): 40 | errors = flatten_errors(multiple_errors) 41 | assert len(errors) == 3 42 | assert errors[0].code == "invalid_phone_number" 43 | assert errors[0].attr == "phone" 44 | assert errors[1].code == "password_too_short" 45 | assert errors[1].attr == "password" 46 | assert errors[2].code == "password_too_similar" 47 | assert errors[2].attr == "password" 48 | 49 | 50 | @pytest.fixture 51 | def nested_error(): 52 | return { 53 | "shipping_address": { 54 | "non_field_errors": [ 55 | ErrorDetail( 56 | "We do not support shipping to the provided address.", 57 | code="unsupported", 58 | ) 59 | ] 60 | } 61 | } 62 | 63 | 64 | def test_nested_error(nested_error): 65 | errors = flatten_errors(nested_error) 66 | assert len(errors) == 1 67 | assert errors[0].code == "unsupported" 68 | assert errors[0].detail == "We do not support shipping to the provided address." 69 | assert errors[0].attr == "shipping_address.non_field_errors" 70 | 71 | 72 | @pytest.fixture 73 | def list_serializer_errors(): 74 | return [ 75 | { 76 | "name": [ErrorDetail("This field is required.", code="required")], 77 | "email": [ErrorDetail("Enter a valid email address.", code="invalid")], 78 | }, 79 | {"email": [ErrorDetail("Enter a valid email address.", code="invalid")]}, 80 | ] 81 | 82 | 83 | def test_list_serializer_errors(list_serializer_errors): 84 | errors = flatten_errors(list_serializer_errors) 85 | assert len(errors) == 3 86 | assert errors[0].code == "required" 87 | assert errors[0].attr == "0.name" 88 | assert errors[1].code == "invalid" 89 | assert errors[1].attr == "0.email" 90 | assert errors[2].code == "invalid" 91 | assert errors[2].attr == "1.email" 92 | 93 | 94 | @pytest.fixture 95 | def nested_list_serializer_error(): 96 | return { 97 | "recipients": [ 98 | {}, 99 | {"email": [ErrorDetail("Enter a valid email address.", code="invalid")]}, 100 | ] 101 | } 102 | 103 | 104 | def test_nested_list_serializer_error(nested_list_serializer_error): 105 | errors = flatten_errors(nested_list_serializer_error) 106 | assert len(errors) == 1 107 | assert errors[0].code == "invalid" 108 | assert errors[0].detail == "Enter a valid email address." 109 | assert errors[0].attr == "recipients.1.email" 110 | 111 | 112 | def test_does_not_raise_recursion_error(): 113 | client = APIClient() 114 | try: 115 | client.get("/recursion-error/") 116 | except RecursionError: 117 | pytest.fail( 118 | "Failed due to a recursion error. Use an iterative approach rather than " 119 | "a recursive one to avoid reaching the maximum recursion depth in python." 120 | ) 121 | 122 | 123 | def test_exception_with_detail_empty(): 124 | detail = {"some_field": [ErrorDetail("", code="invalid")]} 125 | errors = flatten_errors(detail) 126 | assert len(errors) == 1 127 | assert errors[0].attr == "some_field" 128 | assert errors[0].detail == "" 129 | -------------------------------------------------------------------------------- /tests/test_openapi_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest import mock 3 | 4 | import django 5 | import pytest 6 | from django import forms 7 | from django.contrib.auth.models import User 8 | from django.contrib.contenttypes.models import ContentType 9 | from django.core.validators import FileExtensionValidator 10 | from django_filters import CharFilter 11 | from django_filters.rest_framework import DjangoFilterBackend, FilterSet 12 | from drf_spectacular.plumbing import build_mock_request 13 | from rest_framework import serializers 14 | from rest_framework.generics import ListAPIView 15 | from rest_framework.schemas.openapi import SchemaGenerator 16 | 17 | from drf_standardized_errors.openapi_utils import ( 18 | InputDataField, 19 | get_django_filter_backends, 20 | get_error_serializer, 21 | get_filter_forms, 22 | get_flat_serializer_fields, 23 | get_form_fields_with_error_codes, 24 | get_serializer_fields_with_error_codes, 25 | ) 26 | 27 | from .models import Post 28 | 29 | 30 | class NestedSerializer(serializers.Serializer): 31 | nested_field1 = serializers.DictField(child=serializers.CharField()) 32 | nested_field2 = serializers.IntegerField() 33 | 34 | 35 | class CustomSerializer(serializers.Serializer): 36 | field1 = serializers.CharField() 37 | field2 = serializers.ListField(child=serializers.IntegerField()) 38 | field3 = NestedSerializer() 39 | field4 = NestedSerializer(many=True) 40 | 41 | 42 | class CustomSerializerWithNestedReadOnly(serializers.Serializer): 43 | field1 = serializers.CharField() 44 | field2 = serializers.ListField(child=serializers.IntegerField()) 45 | field3 = NestedSerializer(read_only=True) 46 | field4 = NestedSerializer(read_only=True, many=True) 47 | 48 | 49 | def test_get_flat_serializer_fields(): 50 | fields = get_flat_serializer_fields(CustomSerializer(many=True)) 51 | expected_fields = { 52 | "non_field_errors", 53 | "INDEX.non_field_errors", 54 | "INDEX.field1", 55 | "INDEX.field2", 56 | "INDEX.field2.INDEX", 57 | "INDEX.field3.non_field_errors", 58 | "INDEX.field3.nested_field1", 59 | "INDEX.field3.nested_field1.KEY", 60 | "INDEX.field3.nested_field2", 61 | "INDEX.field4.non_field_errors", 62 | "INDEX.field4.INDEX.non_field_errors", 63 | "INDEX.field4.INDEX.nested_field1", 64 | "INDEX.field4.INDEX.nested_field1.KEY", 65 | "INDEX.field4.INDEX.nested_field2", 66 | } 67 | assert {field.name for field in fields} == expected_fields 68 | 69 | 70 | def test_get_flat_serializer_fields_with_nested_read_only(): 71 | """Check case when NestedSerializer is read-only with non read-only fields.""" 72 | fields = get_flat_serializer_fields(CustomSerializerWithNestedReadOnly(many=True)) 73 | expected_fields = { 74 | "non_field_errors", 75 | "INDEX.non_field_errors", 76 | "INDEX.field1", 77 | "INDEX.field2", 78 | "INDEX.field2.INDEX", 79 | } 80 | assert {field.name for field in fields} == expected_fields 81 | 82 | 83 | @pytest.fixture 84 | def char_field(): 85 | return InputDataField( 86 | name="name", field=serializers.CharField(min_length=1, max_length=200) 87 | ) 88 | 89 | 90 | def test_char_field_error_codes(char_field): 91 | (field,) = get_serializer_fields_with_error_codes([char_field]) 92 | assert field.error_codes == { 93 | "null", 94 | "required", 95 | "invalid", 96 | "blank", 97 | "max_length", 98 | "min_length", 99 | "surrogate_characters_not_allowed", 100 | "null_characters_not_allowed", 101 | } 102 | 103 | 104 | @pytest.fixture 105 | def slug_field(): 106 | return InputDataField( 107 | name="title", 108 | field=serializers.SlugField( 109 | allow_null=True, allow_blank=True, allow_unicode=True 110 | ), 111 | ) 112 | 113 | 114 | def test_slug_field_error_codes(slug_field): 115 | (field,) = get_serializer_fields_with_error_codes([slug_field]) 116 | assert field.error_codes == { 117 | "required", 118 | "invalid", 119 | "surrogate_characters_not_allowed", 120 | "null_characters_not_allowed", 121 | } 122 | 123 | 124 | @pytest.fixture 125 | def ip_field(): 126 | return InputDataField(name="ip", field=serializers.IPAddressField(required=False)) 127 | 128 | 129 | def test_ip_field_error_codes(ip_field): 130 | (field,) = get_serializer_fields_with_error_codes([ip_field]) 131 | assert field.error_codes == { 132 | "null", 133 | "blank", 134 | "invalid", 135 | "surrogate_characters_not_allowed", 136 | "null_characters_not_allowed", 137 | } 138 | 139 | 140 | @pytest.fixture 141 | def integer_field(): 142 | return InputDataField( 143 | name="age", 144 | field=serializers.IntegerField(required=False, min_value=1, max_value=120), 145 | ) 146 | 147 | 148 | def test_integer_field_error_codes(integer_field): 149 | (field,) = get_serializer_fields_with_error_codes([integer_field]) 150 | assert field.error_codes == { 151 | "null", 152 | "invalid", 153 | "max_value", 154 | "min_value", 155 | "max_string_length", 156 | } 157 | 158 | 159 | @pytest.fixture 160 | def decimal_field(): 161 | return InputDataField( 162 | name="rate", 163 | field=serializers.DecimalField( 164 | max_digits=3, decimal_places=2, required=False, allow_null=True 165 | ), 166 | ) 167 | 168 | 169 | def test_decimal_field_error_codes(decimal_field): 170 | (field,) = get_serializer_fields_with_error_codes([decimal_field]) 171 | assert field.error_codes == { 172 | "invalid", 173 | "max_digits", 174 | "max_decimal_places", 175 | "max_whole_digits", 176 | "max_string_length", 177 | } 178 | 179 | 180 | @pytest.fixture 181 | def datetime_field(): 182 | return InputDataField( 183 | name="dt", field=serializers.DateTimeField(required=False, allow_null=True) 184 | ) 185 | 186 | 187 | def test_datetime_field_error_codes(datetime_field): 188 | (field,) = get_serializer_fields_with_error_codes([datetime_field]) 189 | assert field.error_codes == {"invalid", "date", "make_aware", "overflow"} 190 | 191 | 192 | def test_naive_datetime_field_error_codes(settings, datetime_field): 193 | settings.USE_TZ = False 194 | 195 | (field,) = get_serializer_fields_with_error_codes([datetime_field]) 196 | assert field.error_codes == {"invalid", "date"} 197 | 198 | 199 | @pytest.fixture 200 | def date_field(): 201 | return InputDataField( 202 | name="date", field=serializers.DateField(required=False, allow_null=True) 203 | ) 204 | 205 | 206 | def test_date_field_error_codes(date_field): 207 | (field,) = get_serializer_fields_with_error_codes([date_field]) 208 | assert field.error_codes == {"invalid", "datetime"} 209 | 210 | 211 | @pytest.fixture 212 | def multiple_choice_field(): 213 | return InputDataField( 214 | name="colors", 215 | field=serializers.MultipleChoiceField( 216 | required=False, 217 | allow_null=True, 218 | allow_empty=False, 219 | allow_blank=False, 220 | choices=[("blue", "Blue"), ("red", "Red")], 221 | ), 222 | ) 223 | 224 | 225 | def test_multiple_choice_field_error_codes(multiple_choice_field): 226 | (field,) = get_serializer_fields_with_error_codes([multiple_choice_field]) 227 | assert field.error_codes == {"invalid_choice", "not_a_list", "empty"} 228 | 229 | 230 | @pytest.fixture 231 | def image_field(): 232 | return InputDataField( 233 | name="image", 234 | field=serializers.ImageField( 235 | required=False, 236 | allow_null=True, 237 | max_length=100, 238 | validators=[FileExtensionValidator(allowed_extensions=["png"])], 239 | ), 240 | ) 241 | 242 | 243 | def test_image_field_error_codes(image_field): 244 | (field,) = get_serializer_fields_with_error_codes([image_field]) 245 | assert field.error_codes == { 246 | "invalid", 247 | "no_name", 248 | "empty", 249 | "max_length", 250 | "invalid_image", 251 | "invalid_extension", 252 | } 253 | 254 | 255 | @pytest.fixture 256 | def list_field(): 257 | return InputDataField( 258 | name="items", field=serializers.ListField(required=False, allow_null=True) 259 | ) 260 | 261 | 262 | def test_list_field_error_codes(list_field): 263 | (field,) = get_serializer_fields_with_error_codes([list_field]) 264 | assert field.error_codes == {"not_a_list"} 265 | 266 | 267 | @pytest.fixture 268 | def m2m_field(): 269 | return InputDataField( 270 | name="items", 271 | field=serializers.PrimaryKeyRelatedField( 272 | many=True, required=False, allow_empty=False, queryset=User.objects.all() 273 | ), 274 | ) 275 | 276 | 277 | def test_m2m_field_error_codes(m2m_field): 278 | (field,) = get_serializer_fields_with_error_codes([m2m_field]) 279 | assert field.error_codes == { 280 | "null", 281 | "not_a_list", 282 | "empty", 283 | "incorrect_type", 284 | "does_not_exist", 285 | } 286 | 287 | 288 | class UserSerializer(serializers.Serializer): 289 | name = serializers.CharField() 290 | 291 | 292 | @pytest.fixture 293 | def serializer(): 294 | return InputDataField( 295 | name="non_field_errors", field=UserSerializer(required=True, allow_null=False) 296 | ) 297 | 298 | 299 | def test_top_level_non_field_errors_error_codes(serializer): 300 | """required and null should NOT be listed as error codes""" 301 | (field,) = get_serializer_fields_with_error_codes([serializer]) 302 | assert field.error_codes == {"invalid", "null"} 303 | 304 | 305 | @pytest.fixture 306 | def read_only_field(): 307 | return InputDataField(name="id", field=serializers.IntegerField(read_only=True)) 308 | 309 | 310 | def test_read_only_field_error_codes(read_only_field): 311 | """required and null should NOT be listed as error codes""" 312 | fields = get_serializer_fields_with_error_codes([read_only_field]) 313 | assert not fields 314 | 315 | 316 | class UniqueUserSerializer(serializers.ModelSerializer): 317 | class Meta: 318 | model = User 319 | fields = ["username"] 320 | 321 | 322 | @pytest.fixture 323 | def unique_field(): 324 | s = UniqueUserSerializer() 325 | field = list(s.fields.values())[0] 326 | return InputDataField(name="username", field=field) 327 | 328 | 329 | def test_unique_field_error_codes(unique_field): 330 | (field,) = get_serializer_fields_with_error_codes([unique_field]) 331 | assert field.error_codes == { 332 | "null", 333 | "required", 334 | "invalid", 335 | "blank", 336 | "max_length", 337 | "surrogate_characters_not_allowed", 338 | "null_characters_not_allowed", 339 | "unique", 340 | } 341 | 342 | 343 | class ContentTypeSerializer(serializers.ModelSerializer): 344 | """ 345 | The field redefinition is intentional to set fields as not required 346 | and test that the required error code is added to fields because 347 | of the unique together constraint 348 | """ 349 | 350 | app_label = serializers.CharField(max_length=100, required=False) 351 | model = serializers.CharField( 352 | label="Python model class name", max_length=100, required=False 353 | ) 354 | 355 | class Meta: 356 | model = ContentType 357 | fields = ["app_label", "model"] 358 | 359 | 360 | @pytest.fixture 361 | def unique_together(): 362 | return get_flat_serializer_fields(ContentTypeSerializer()) 363 | 364 | 365 | def test_unique_together_error_codes(unique_together): 366 | non_field_errors, app_label, model = get_serializer_fields_with_error_codes( 367 | unique_together 368 | ) 369 | 370 | assert "unique" in non_field_errors.error_codes 371 | assert "required" in app_label.error_codes 372 | assert "required" in model.error_codes 373 | 374 | 375 | class PostSerializer(serializers.ModelSerializer): 376 | """ 377 | Intentional required=False to test that the 'required' error code is added 378 | despite that since the fields involved in a unique for date constraint 379 | are enforced as required by the unique for date validator. 380 | """ 381 | 382 | title = serializers.CharField(max_length=200, required=False) 383 | published_at = serializers.DateField(required=False) 384 | 385 | class Meta: 386 | model = Post 387 | fields = ["title", "published_at"] 388 | 389 | 390 | @pytest.fixture 391 | def unique_for_date(): 392 | return get_flat_serializer_fields(PostSerializer()) 393 | 394 | 395 | def test_unique_for_date_error_codes(unique_for_date): 396 | _, title, published_at = get_serializer_fields_with_error_codes(unique_for_date) 397 | 398 | assert "unique" in title.error_codes 399 | assert "required" in title.error_codes 400 | assert "required" in published_at.error_codes 401 | 402 | 403 | class OddNumberField(serializers.IntegerField): 404 | default_error_messages = {"even_number": "Please provide an odd number"} 405 | 406 | def to_internal_value(self, data): 407 | data = super().to_internal_value(data) 408 | if data % 2 == 0: 409 | self.fail("even_number") 410 | 411 | return data 412 | 413 | 414 | @pytest.fixture 415 | def custom_serializer_field(): 416 | return InputDataField(name="odd", field=OddNumberField()) 417 | 418 | 419 | def test_custom_serializer_field_error_codes(custom_serializer_field): 420 | (field,) = get_serializer_fields_with_error_codes([custom_serializer_field]) 421 | 422 | assert "even_number" in field.error_codes 423 | 424 | 425 | class DiagnosisValidator: 426 | message = "Unknown diagnosis code." 427 | code = "unknown_diagnosis" 428 | 429 | def __init__(self, known_diagnosis_codes=("G00", "G01", "G02")): 430 | self.known_diagnosis_codes = known_diagnosis_codes 431 | 432 | def __call__(self, diagnosis_code): 433 | if diagnosis_code not in self.known_diagnosis_codes: 434 | raise serializers.ValidationError(self.message, code=self.code) 435 | 436 | 437 | @pytest.fixture 438 | def field_with_custom_validator(): 439 | return InputDataField( 440 | name="diagnosis_code", 441 | field=serializers.CharField(validators=[DiagnosisValidator()]), 442 | ) 443 | 444 | 445 | def test_field_with_custom_validator(field_with_custom_validator): 446 | (field,) = get_serializer_fields_with_error_codes([field_with_custom_validator]) 447 | assert "unknown_diagnosis" in field.error_codes 448 | 449 | 450 | def test_django_filter_not_installed(monkeypatch): 451 | with mock.patch.dict(sys.modules, {"django_filters.rest_framework": None}): 452 | backends = get_django_filter_backends([DjangoFilterBackend]) 453 | assert not backends 454 | 455 | 456 | class AdminSerializer(serializers.ModelSerializer): 457 | class Meta: 458 | model = User 459 | fields = ["id", "username"] 460 | 461 | 462 | class UserFilterSet(FilterSet): 463 | username = CharFilter() 464 | 465 | 466 | class FilterView(ListAPIView): 467 | queryset = User.objects.filter(is_superuser=True) 468 | serializer_class = AdminSerializer 469 | filter_backends = [DjangoFilterBackend] 470 | filterset_class = UserFilterSet 471 | 472 | 473 | @pytest.fixture 474 | def filter_view(): 475 | generator = SchemaGenerator() 476 | view = generator.create_view(FilterView.as_view(), "get") 477 | view.request = build_mock_request("get", "filter/", view, None) 478 | return view 479 | 480 | 481 | def test_get_filter_forms(filter_view): 482 | (form,) = get_filter_forms(filter_view, [DjangoFilterBackend()]) 483 | assert "username" in form.fields 484 | 485 | 486 | @pytest.fixture 487 | def filter_view_no_model(): 488 | generator = SchemaGenerator() 489 | view = generator.create_view(FilterView.as_view(queryset=None), "get") 490 | view.request = build_mock_request("get", "filter/", view, None) 491 | return view 492 | 493 | 494 | def test_no_filter_forms_returned(filter_view_no_model): 495 | filter_forms = get_filter_forms(filter_view_no_model, [DjangoFilterBackend()]) 496 | assert not filter_forms 497 | 498 | 499 | class CharForm(forms.Form): 500 | char = forms.CharField(max_length=100, min_length=2) 501 | slug = forms.SlugField(required=False) 502 | regex = forms.RegexField(r"^go") 503 | uuid = forms.UUIDField() 504 | ip = forms.GenericIPAddressField(required=False) 505 | 506 | 507 | def test_char_fields_with_error_codes(): 508 | (char, slug, regex, uuid, ip) = get_form_fields_with_error_codes(CharForm()) 509 | 510 | assert char.error_codes == { 511 | "required", 512 | "null_characters_not_allowed", 513 | "min_length", 514 | "max_length", 515 | } 516 | assert slug.error_codes == {"invalid", "null_characters_not_allowed"} 517 | assert regex.error_codes == {"invalid", "required", "null_characters_not_allowed"} 518 | assert uuid.error_codes == {"invalid", "required", "null_characters_not_allowed"} 519 | if django.VERSION >= (4, 2): 520 | assert ip.error_codes == { 521 | "invalid", 522 | "null_characters_not_allowed", 523 | "max_length", 524 | } 525 | else: 526 | assert ip.error_codes == {"invalid", "null_characters_not_allowed"} 527 | 528 | 529 | class NumberForm(forms.Form): 530 | integer = forms.IntegerField(max_value=100, min_value=2) 531 | dec1 = forms.DecimalField(required=False, max_digits=4, decimal_places=2) 532 | dec2 = forms.DecimalField(required=False, decimal_places=2) 533 | dec3 = forms.DecimalField(required=False, max_digits=4) 534 | dec4 = forms.DecimalField(required=False) 535 | 536 | 537 | def test_number_fields_with_error_codes(): 538 | (integer, dec1, dec2, dec3, dec4) = get_form_fields_with_error_codes(NumberForm()) 539 | 540 | assert integer.error_codes == {"required", "max_value", "min_value", "invalid"} 541 | assert dec1.error_codes == { 542 | "invalid", 543 | "max_digits", 544 | "max_decimal_places", 545 | "max_whole_digits", 546 | } 547 | assert dec2.error_codes == {"invalid", "max_decimal_places"} 548 | assert dec3.error_codes == {"invalid", "max_digits"} 549 | assert dec4.error_codes == {"invalid"} 550 | 551 | 552 | class TemporalForm(forms.Form): 553 | date = forms.DateField() 554 | datetime = forms.DateTimeField(required=False) 555 | duration = forms.DurationField(required=False) 556 | 557 | 558 | def test_temporal_fields_with_error_codes(): 559 | (date, datetime, duration) = get_form_fields_with_error_codes(TemporalForm()) 560 | 561 | assert date.error_codes == {"required", "invalid"} 562 | assert datetime.error_codes == {"invalid"} 563 | assert duration.error_codes == {"invalid", "overflow"} 564 | 565 | 566 | class ImageForm(forms.Form): 567 | image = forms.ImageField(required=False, max_length=100) 568 | 569 | 570 | def test_image_fields_with_error_codes(): 571 | (image,) = get_form_fields_with_error_codes(ImageForm()) 572 | 573 | assert image.error_codes == { 574 | "invalid_image", 575 | "invalid_extension", 576 | "invalid", 577 | "empty", 578 | "max_length", 579 | "contradiction", 580 | } 581 | 582 | 583 | class ChoiceForm(forms.Form): 584 | COLORS = [[("red", "Red"), ("blue", "Blue")]] 585 | choice = forms.ChoiceField(choices=COLORS) 586 | multiple_choice = forms.MultipleChoiceField(required=False, choices=COLORS) 587 | 588 | 589 | def test_choice_fields_with_error_codes(): 590 | (choice, multiple_choice) = get_form_fields_with_error_codes(ChoiceForm()) 591 | 592 | assert choice.error_codes == {"required", "invalid_choice"} 593 | assert multiple_choice.error_codes == {"invalid_choice", "invalid_list"} 594 | 595 | 596 | class MultiValueForm(forms.Form): 597 | split = forms.SplitDateTimeField() 598 | if django.VERSION >= (5, 0): 599 | disabled = forms.URLField(disabled=True, assume_scheme="https") 600 | else: 601 | disabled = forms.URLField(disabled=True) 602 | 603 | 604 | def test_multi_value_fields_with_error_codes(): 605 | (split,) = get_form_fields_with_error_codes(MultiValueForm()) 606 | 607 | assert split.error_codes == {"required", "invalid", "invalid_date", "invalid_time"} 608 | 609 | 610 | def test_error_component_name_suffix(): 611 | serializer = get_error_serializer("users_create", "first_name", {"required"}) 612 | assert serializer.Meta.ref_name.endswith("ErrorComponent") 613 | 614 | 615 | def test_updated_error_component_name_suffix(settings): 616 | settings.DRF_STANDARDIZED_ERRORS = {"ERROR_COMPONENT_NAME_SUFFIX": "FaultComponent"} 617 | 618 | serializer = get_error_serializer("users_create", "first_name", {"required"}) 619 | assert serializer.Meta.ref_name.endswith("FaultComponent") 620 | 621 | 622 | def test_list_index_in_api_schema(): 623 | fields = get_flat_serializer_fields(UserSerializer(many=True)) 624 | expected_fields = {"non_field_errors", "INDEX.non_field_errors", "INDEX.name"} 625 | assert {field.name for field in fields} == expected_fields 626 | 627 | 628 | def test_updated_list_index_in_api_schema(settings): 629 | settings.DRF_STANDARDIZED_ERRORS = {"LIST_INDEX_IN_API_SCHEMA": "IDX"} 630 | 631 | fields = get_flat_serializer_fields(UserSerializer(many=True)) 632 | expected_fields = {"non_field_errors", "IDX.non_field_errors", "IDX.name"} 633 | assert {field.name for field in fields} == expected_fields 634 | 635 | 636 | class DictSerializer(serializers.Serializer): 637 | d = serializers.DictField(child=serializers.IntegerField()) 638 | 639 | 640 | def test_dict_index_in_api_schema(): 641 | fields = get_flat_serializer_fields(DictSerializer()) 642 | expected_fields = {"non_field_errors", "d", "d.KEY"} 643 | assert {field.name for field in fields} == expected_fields 644 | 645 | 646 | def test_updated_dict_index_in_api_schema(settings): 647 | settings.DRF_STANDARDIZED_ERRORS = {"DICT_KEY_IN_API_SCHEMA": "DICT_KEY"} 648 | 649 | fields = get_flat_serializer_fields(DictSerializer()) 650 | expected_fields = {"non_field_errors", "d", "d.DICT_KEY"} 651 | assert {field.name for field in fields} == expected_fields 652 | -------------------------------------------------------------------------------- /tests/test_openapi_validation_errors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from django.contrib.auth.models import Group, User 3 | from django.views.generic import UpdateView 4 | from drf_spectacular.utils import extend_schema 5 | from rest_framework import serializers 6 | from rest_framework.decorators import action, api_view 7 | from rest_framework.generics import DestroyAPIView, UpdateAPIView 8 | from rest_framework.response import Response 9 | from rest_framework.versioning import URLPathVersioning 10 | from rest_framework.viewsets import ModelViewSet 11 | 12 | from drf_standardized_errors.openapi_validation_errors import extend_validation_errors 13 | 14 | from .utils import generate_versioned_view_schema, generate_view_schema, get_error_codes 15 | 16 | 17 | class UserSerializer(serializers.ModelSerializer): 18 | class Meta: 19 | fields = ["first_name"] 20 | model = User 21 | 22 | 23 | @pytest.fixture 24 | def viewset_with_extra_errors(): 25 | @extend_validation_errors( 26 | ["extra_error"], field_name="first_name", actions=["create"] 27 | ) 28 | class ValidationViewSet(ModelViewSet): 29 | serializer_class = UserSerializer 30 | queryset = User.objects.all() 31 | 32 | return ValidationViewSet 33 | 34 | 35 | def test_extra_validation_errors_to_viewset(viewset_with_extra_errors): 36 | """simple test for using @extend_validation_errors with ViewSets""" 37 | 38 | route = "validate/" 39 | view = viewset_with_extra_errors.as_view({"post": "create"}) 40 | schema = generate_view_schema(route, view) 41 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 42 | assert "extra_error" in error_codes 43 | 44 | 45 | @pytest.fixture 46 | def view_with_extra_errors(): 47 | @extend_validation_errors(["extra_error"], field_name="first_name", methods=["put"]) 48 | class ValidationView(UpdateAPIView): 49 | serializer_class = UserSerializer 50 | queryset = User.objects.all() 51 | 52 | return ValidationView 53 | 54 | 55 | def test_extra_validation_errors_to_view(view_with_extra_errors): 56 | """simple test for using @extend_validation_errors with APIViews""" 57 | 58 | route = "validate/" 59 | view = view_with_extra_errors.as_view() 60 | schema = generate_view_schema(route, view) 61 | error_codes = get_error_codes(schema, "ValidateUpdateFirstNameErrorComponent") 62 | assert "extra_error" in error_codes 63 | error_codes = get_error_codes( 64 | schema, "ValidatePartialUpdateFirstNameErrorComponent" 65 | ) 66 | assert "extra_error" not in error_codes 67 | 68 | 69 | @pytest.fixture 70 | def function_based_view_with_extra_errors(): 71 | @extend_validation_errors( 72 | ["extra_error"], field_name="first_name", methods=["post"] 73 | ) 74 | @extend_schema(request=UserSerializer, responses={201: None}) 75 | @api_view(http_method_names=["post"]) 76 | def validate(request): 77 | serializer = UserSerializer(data=request.data) 78 | serializer.is_valid(raise_exception=True) 79 | return Response(status=201) 80 | 81 | return validate 82 | 83 | 84 | def test_extra_validation_errors_to_function_based_api_view( 85 | function_based_view_with_extra_errors, 86 | ): 87 | route = "validate/" 88 | view = function_based_view_with_extra_errors 89 | schema = generate_view_schema(route, view) 90 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 91 | assert "extra_error" in error_codes 92 | 93 | 94 | @pytest.fixture 95 | def validation_viewset(): 96 | class ValidationViewSet(ModelViewSet): 97 | serializer_class = UserSerializer 98 | queryset = User.objects.all() 99 | 100 | return ValidationViewSet 101 | 102 | 103 | def test_methods_case_sensitivity(validation_viewset): 104 | """make sure it doesn't matter if we pass 'post' or 'POST' or 'PosT'""" 105 | extend_validation_errors( 106 | ["another_code"], field_name="first_name", methods=["PosT"] 107 | )(validation_viewset) 108 | 109 | route = "validate/" 110 | view = validation_viewset.as_view({"post": "create"}) 111 | schema = generate_view_schema(route, view) 112 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 113 | assert "another_code" in error_codes 114 | 115 | 116 | @pytest.fixture 117 | def function_based_view(): 118 | def get_users(request): 119 | serializer = UserSerializer(instance=User.objects.all()) 120 | return Response(serializer.data) 121 | 122 | return get_users 123 | 124 | 125 | def test_decorating_non_api_view_functions(function_based_view, capsys): 126 | extend_validation_errors(["new_code"])(function_based_view) 127 | stderr = capsys.readouterr().err 128 | assert "`@extend_validation_errors` can only be applied to APIViews" in stderr 129 | 130 | 131 | @pytest.fixture 132 | def django_class_based_view(): 133 | class UserView(UpdateView): 134 | model = User 135 | fields = ["first_name"] 136 | 137 | return UserView 138 | 139 | 140 | def test_decorating_non_api_view_classes(django_class_based_view, capsys): 141 | extend_validation_errors(["new_code"])(django_class_based_view) 142 | stderr = capsys.readouterr().err 143 | assert "`@extend_validation_errors` can only be applied to APIViews" in stderr 144 | 145 | 146 | def test_not_passing_error_codes(validation_viewset, capsys): 147 | extend_validation_errors([])(validation_viewset) 148 | stderr = capsys.readouterr().err 149 | assert "No error codes are passed to the `@extend_validation_errors`" in stderr 150 | 151 | 152 | def test_passing_field_name_as_none(validation_viewset): 153 | extend_validation_errors(["some_code"], methods=["post"])(validation_viewset) 154 | 155 | route = "validate/" 156 | view = validation_viewset.as_view({"post": "create"}) 157 | schema = generate_view_schema(route, view) 158 | error_codes = get_error_codes(schema, "ValidateCreateErrorComponent") 159 | assert "some_code" in error_codes 160 | 161 | 162 | def test_passing_incorrect_action(validation_viewset, capsys): 163 | extend_validation_errors(["some_code"], actions=["no_action"])(validation_viewset) 164 | stderr = capsys.readouterr().err 165 | assert "not in the list of actions defined on the viewset" in stderr 166 | 167 | 168 | @pytest.fixture 169 | def function_based_api_view(): 170 | @extend_schema(request=UserSerializer, responses={201: None}) 171 | @api_view(http_method_names=["post"]) 172 | def validate(request): 173 | serializer = UserSerializer(data=request.data) 174 | serializer.is_valid(raise_exception=True) 175 | return Response(status=201) 176 | 177 | return validate 178 | 179 | 180 | def test_passing_action_for_api_view(function_based_api_view, capsys): 181 | extend_validation_errors(["some_error"], actions=["some_action"])( 182 | function_based_api_view 183 | ) 184 | 185 | stderr = capsys.readouterr().err 186 | warning_msg = ( 187 | "The 'actions' argument of 'extend_validation_errors' should " 188 | "only be set when decorating viewsets." 189 | ) 190 | assert warning_msg in stderr 191 | 192 | 193 | @pytest.fixture 194 | def validation_view(): 195 | class ValidationView(UpdateAPIView): 196 | serializer_class = UserSerializer 197 | queryset = User.objects.all() 198 | 199 | return ValidationView 200 | 201 | 202 | def test_passing_incorrect_method(validation_view, capsys): 203 | extend_validation_errors(["some_code"], methods=["get"])(validation_view) 204 | stderr = capsys.readouterr().err 205 | assert "not in the list of allowed http methods" in stderr 206 | 207 | 208 | def test_passing_multiple_actions(validation_viewset): 209 | extend_validation_errors( 210 | ["some_error"], field_name="first_name", actions=["create", "partial_update"] 211 | )(validation_viewset) 212 | 213 | route = "validate/" 214 | view = validation_viewset.as_view( 215 | {"post": "create", "put": "update", "patch": "partial_update"} 216 | ) 217 | schema = generate_view_schema(route, view) 218 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 219 | assert "some_error" in error_codes 220 | error_codes = get_error_codes( 221 | schema, "ValidatePartialUpdateFirstNameErrorComponent" 222 | ) 223 | assert "some_error" in error_codes 224 | error_codes = get_error_codes(schema, "ValidateUpdateFirstNameErrorComponent") 225 | assert "some_error" not in error_codes 226 | 227 | 228 | def test_passing_actions_as_none(validation_viewset): 229 | extend_validation_errors(["some_error"], field_name="first_name")( 230 | validation_viewset 231 | ) 232 | 233 | route = "validate/" 234 | view = validation_viewset.as_view( 235 | {"post": "create", "put": "update", "patch": "partial_update"} 236 | ) 237 | schema = generate_view_schema(route, view) 238 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 239 | assert "some_error" in error_codes 240 | error_codes = get_error_codes( 241 | schema, "ValidatePartialUpdateFirstNameErrorComponent" 242 | ) 243 | assert "some_error" in error_codes 244 | error_codes = get_error_codes(schema, "ValidateUpdateFirstNameErrorComponent") 245 | assert "some_error" in error_codes 246 | 247 | 248 | def test_passing_multiple_methods(validation_viewset): 249 | extend_validation_errors( 250 | ["some_error"], field_name="first_name", methods=["post", "put"] 251 | )(validation_viewset) 252 | 253 | route = "validate/" 254 | view = validation_viewset.as_view( 255 | {"post": "create", "put": "update", "patch": "partial_update"} 256 | ) 257 | schema = generate_view_schema(route, view) 258 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 259 | assert "some_error" in error_codes 260 | error_codes = get_error_codes(schema, "ValidateUpdateFirstNameErrorComponent") 261 | assert "some_error" in error_codes 262 | error_codes = get_error_codes( 263 | schema, "ValidatePartialUpdateFirstNameErrorComponent" 264 | ) 265 | assert "some_error" not in error_codes 266 | 267 | 268 | def test_passing_methods_as_none(validation_viewset): 269 | extend_validation_errors(["some_error"], field_name="first_name")( 270 | validation_viewset 271 | ) 272 | 273 | route = "validate/" 274 | view = validation_viewset.as_view( 275 | {"post": "create", "put": "update", "patch": "partial_update"} 276 | ) 277 | schema = generate_view_schema(route, view) 278 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 279 | assert "some_error" in error_codes 280 | error_codes = get_error_codes(schema, "ValidateUpdateFirstNameErrorComponent") 281 | assert "some_error" in error_codes 282 | error_codes = get_error_codes( 283 | schema, "ValidatePartialUpdateFirstNameErrorComponent" 284 | ) 285 | assert "some_error" in error_codes 286 | 287 | 288 | @pytest.fixture 289 | def versioned_view(): 290 | class ValidationView(UpdateAPIView): 291 | serializer_class = UserSerializer 292 | queryset = User.objects.all() 293 | versioning_class = URLPathVersioning 294 | 295 | return ValidationView 296 | 297 | 298 | def test_passing_multiple_versions(versioned_view): 299 | extend_validation_errors( 300 | ["some_error"], field_name="first_name", versions=["v1", "v2"] 301 | )(versioned_view) 302 | 303 | view = versioned_view.as_view() 304 | 305 | versioned_schema = generate_versioned_view_schema(view, "v1") 306 | error_codes = get_error_codes( 307 | versioned_schema, "V1ValidateUpdateFirstNameErrorComponent" 308 | ) 309 | assert "some_error" in error_codes 310 | 311 | versioned_schema = generate_versioned_view_schema(view, "v2") 312 | error_codes = get_error_codes( 313 | versioned_schema, "V2ValidateUpdateFirstNameErrorComponent" 314 | ) 315 | assert "some_error" in error_codes 316 | 317 | versioned_schema = generate_versioned_view_schema(view, "v3") 318 | error_codes = get_error_codes( 319 | versioned_schema, "V3ValidateUpdateFirstNameErrorComponent" 320 | ) 321 | assert "some_error" not in error_codes 322 | 323 | 324 | def test_passing_versions_as_none(versioned_view): 325 | extend_validation_errors(["some_error"], field_name="first_name")(versioned_view) 326 | 327 | view = versioned_view.as_view() 328 | 329 | versioned_schema = generate_versioned_view_schema(view, "v1") 330 | error_codes = get_error_codes( 331 | versioned_schema, "V1ValidateUpdateFirstNameErrorComponent" 332 | ) 333 | assert "some_error" in error_codes 334 | 335 | versioned_schema = generate_versioned_view_schema(view, "v2") 336 | error_codes = get_error_codes( 337 | versioned_schema, "V2ValidateUpdateFirstNameErrorComponent" 338 | ) 339 | assert "some_error" in error_codes 340 | 341 | versioned_schema = generate_versioned_view_schema(view, "v3") 342 | error_codes = get_error_codes( 343 | versioned_schema, "V3ValidateUpdateFirstNameErrorComponent" 344 | ) 345 | assert "some_error" in error_codes 346 | 347 | 348 | def test_applying_decorator_multiple_times(validation_view): 349 | """all error codes should be added to corresponding fields""" 350 | extend_first_name_errors = extend_validation_errors( 351 | ["short_name"], field_name="first_name" 352 | ) 353 | extend_non_field_errors = extend_validation_errors( 354 | ["some_error"], field_name="non_field_errors" 355 | ) 356 | extend_non_field_errors(extend_first_name_errors(validation_view)) 357 | 358 | route = "validate/" 359 | view = validation_view.as_view() 360 | schema = generate_view_schema(route, view) 361 | error_codes = get_error_codes(schema, "ValidateUpdateFirstNameErrorComponent") 362 | assert "short_name" in error_codes 363 | 364 | error_codes = get_error_codes(schema, "ValidateUpdateNonFieldErrorsErrorComponent") 365 | assert "some_error" in error_codes 366 | 367 | 368 | def test_applying_decorator_multiple_times_same_field(validation_viewset): 369 | """only second_error should appear in the resulting schema""" 370 | add_first_error = extend_validation_errors(["first_error"], field_name="first_name") 371 | add_second_error = extend_validation_errors( 372 | ["second_error"], field_name="first_name" 373 | ) 374 | add_second_error(add_first_error(validation_viewset)) 375 | 376 | route = "validate/" 377 | view = validation_viewset.as_view({"post": "create"}) 378 | schema = generate_view_schema(route, view) 379 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 380 | assert "second_error" in error_codes 381 | 382 | 383 | @pytest.fixture 384 | def child_viewset(): 385 | @extend_validation_errors(["parent_error"], field_name="first_name") 386 | class ParentViewSet(ModelViewSet): 387 | serializer_class = UserSerializer 388 | queryset = User.objects.all() 389 | 390 | class ChildViewSet(ParentViewSet): 391 | pass 392 | 393 | return ChildViewSet 394 | 395 | 396 | def test_inherited_validation_errors(child_viewset): 397 | """ 398 | errors defined on a parent are found on the child and parent errors are 399 | not affected 400 | """ 401 | extend_validation_errors(["child_error"], field_name="non_field_errors")( 402 | child_viewset 403 | ) 404 | 405 | route = "validate/" 406 | view = child_viewset.as_view({"post": "create"}) 407 | schema = generate_view_schema(route, view) 408 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 409 | assert "parent_error" in error_codes 410 | 411 | error_codes = get_error_codes(schema, "ValidateCreateNonFieldErrorsErrorComponent") 412 | assert "child_error" in error_codes 413 | 414 | 415 | def test_overriding_inherited_validation_errors(child_viewset): 416 | extend_validation_errors(["child_error"], field_name="first_name")(child_viewset) 417 | 418 | route = "validate/" 419 | view = child_viewset.as_view({"post": "create"}) 420 | schema = generate_view_schema(route, view) 421 | error_codes = get_error_codes(schema, "ValidateCreateFirstNameErrorComponent") 422 | assert "child_error" in error_codes 423 | assert "parent_error" not in error_codes 424 | 425 | 426 | @pytest.fixture 427 | def delete_view(): 428 | class ValidationView(DestroyAPIView): 429 | serializer_class = UserSerializer 430 | queryset = User.objects.all() 431 | 432 | return ValidationView 433 | 434 | 435 | def test_extra_validation_errors_for_unexpected_method(delete_view): 436 | """ 437 | Test that it is possible to add validation errors even for delete even though 438 | validation errors are auto-generated only for post,put,patch or get on a list action 439 | """ 440 | extend_validation_errors( 441 | ["some_error"], field_name="first_name", methods=["delete"] 442 | )(delete_view) 443 | 444 | route = "validate/" 445 | view = delete_view.as_view() 446 | schema = generate_view_schema(route, view) 447 | error_codes = get_error_codes(schema, "ValidateDestroyFirstNameErrorComponent") 448 | assert "some_error" in error_codes 449 | 450 | 451 | @pytest.fixture 452 | def viewset_with_custom_action(): 453 | class CustomActionViewSet(ModelViewSet): 454 | serializer_class = UserSerializer 455 | queryset = User.objects.all() 456 | 457 | @action(methods=["get"], detail=False) 458 | def fetch_superusers(self, request, *args, **kwargs): 459 | serializer = UserSerializer(instance=User.objects.filter(is_superuser=True)) 460 | return Response(serializer.data) 461 | 462 | return CustomActionViewSet 463 | 464 | 465 | def test_extra_validation_errors_for_unexpected_action(viewset_with_custom_action): 466 | """ 467 | Test that it is possible to add validation errors even for get on custom action 468 | even though validation errors are auto-generated only for post,put,patch or get 469 | on a list action 470 | """ 471 | extend_validation_errors(["some_error"], field_name="first_name", methods=["get"])( 472 | viewset_with_custom_action 473 | ) 474 | 475 | route = "superusers/" 476 | view = viewset_with_custom_action.as_view({"get": "fetch_superusers"}) 477 | schema = generate_view_schema(route, view) 478 | error_codes = get_error_codes(schema, "SuperusersRetrieveFirstNameErrorComponent") 479 | assert "some_error" in error_codes 480 | 481 | 482 | @pytest.fixture 483 | def viewset_with_nested_serializer(): 484 | class GroupSerializer(serializers.ModelSerializer): 485 | class Meta: 486 | fields = ["name"] 487 | model = Group 488 | 489 | class UserSerializer(serializers.ModelSerializer): 490 | groups = GroupSerializer(many=True) 491 | 492 | class Meta: 493 | fields = ["first_name", "groups"] 494 | model = User 495 | 496 | class NestedViewSet(ModelViewSet): 497 | serializer_class = UserSerializer 498 | queryset = User.objects.all() 499 | 500 | return NestedViewSet 501 | 502 | 503 | def test_extra_validation_errors_for_nested_list_serializer_field( 504 | viewset_with_nested_serializer, 505 | ): 506 | extend_validation_errors(["some_error"], field_name="groups.INDEX.name")( 507 | viewset_with_nested_serializer 508 | ) 509 | 510 | route = "validate/" 511 | view = viewset_with_nested_serializer.as_view({"post": "create"}) 512 | schema = generate_view_schema(route, view) 513 | error_codes = get_error_codes(schema, "ValidateCreateGroupsINDEXNameErrorComponent") 514 | assert "some_error" in error_codes 515 | -------------------------------------------------------------------------------- /tests/test_settings.py: -------------------------------------------------------------------------------- 1 | from django.db import IntegrityError 2 | from rest_framework.exceptions import APIException 3 | 4 | from drf_standardized_errors.formatter import ExceptionFormatter 5 | from drf_standardized_errors.handler import ExceptionHandler, exception_handler 6 | from drf_standardized_errors.types import ErrorResponse 7 | 8 | 9 | def test_custom_exception_handler_class(settings, api_client): 10 | settings.DRF_STANDARDIZED_ERRORS = { 11 | "EXCEPTION_HANDLER_CLASS": "tests.test_settings.CustomExceptionHandler" 12 | } 13 | response = api_client.post("/integrity-error/") 14 | assert response.status_code == 409 15 | assert response.data["type"] == "client_error" 16 | assert len(response.data["errors"]) == 1 17 | error = response.data["errors"][0] 18 | assert error["code"] == "conflict" 19 | assert error["detail"] == "Concurrent update prevented." 20 | assert error["attr"] is None 21 | 22 | 23 | class CustomExceptionHandler(ExceptionHandler): 24 | def convert_known_exceptions(self, exc: Exception) -> Exception: 25 | if isinstance(exc, IntegrityError): 26 | return ConcurrentUpdateError(str(exc)) 27 | else: 28 | return super().convert_known_exceptions(exc) 29 | 30 | 31 | class ConcurrentUpdateError(APIException): 32 | status_code = 409 33 | default_code = "conflict" 34 | 35 | 36 | def test_custom_exception_formatter_class(settings, api_client): 37 | settings.DRF_STANDARDIZED_ERRORS = { 38 | "EXCEPTION_FORMATTER_CLASS": "tests.test_settings.CustomExceptionFormatter" 39 | } 40 | response = api_client.get("/error/") 41 | assert response.status_code == 500 42 | assert response.data["type"] == "server_error" 43 | assert response.data["code"] == "error" 44 | assert response.data["message"] == "Server Error (500)" 45 | assert response.data["field_name"] is None 46 | 47 | 48 | class CustomExceptionFormatter(ExceptionFormatter): 49 | def format_error_response(self, error_response: ErrorResponse): 50 | """return one error at a time and change error response key names""" 51 | error = error_response.errors[0] 52 | return { 53 | "type": error_response.type, 54 | "code": error.code, 55 | "message": error.detail, 56 | "field_name": error.attr, 57 | } 58 | 59 | 60 | def test_enable_in_debug_for_unhandled_exception_is_false( 61 | settings, exc, exception_context 62 | ): 63 | settings.DEBUG = True 64 | settings.DRF_STANDARDIZED_ERRORS = { 65 | "ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": False 66 | } 67 | response = exception_handler(exc, exception_context) 68 | assert response is None 69 | 70 | 71 | def test_enable_in_debug_for_unhandled_exception_is_true( 72 | settings, exc, exception_context 73 | ): 74 | settings.DEBUG = True 75 | settings.DRF_STANDARDIZED_ERRORS = { 76 | "ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": True 77 | } 78 | response = exception_handler(exc, exception_context) 79 | assert response is not None 80 | assert response.status_code == 500 81 | assert response.data["type"] == "server_error" 82 | assert len(response.data["errors"]) == 1 83 | error = response.data["errors"][0] 84 | assert error["code"] == "error" 85 | assert error["detail"] == "Server Error (500)" 86 | assert error["attr"] is None 87 | 88 | 89 | def test_nested_field_separator(settings, api_client): 90 | settings.DRF_STANDARDIZED_ERRORS = {"NESTED_FIELD_SEPARATOR": "__"} 91 | address = { 92 | "street_address": "123 Main street", 93 | "city": "Floral Park", 94 | "state": "NY", 95 | "zipcode": "11001", 96 | } 97 | response = api_client.post("/order-error/", data={"shipping_address": address}) 98 | assert response.status_code == 400 99 | assert response.data["type"] == "validation_error" 100 | error = response.data["errors"][0] 101 | assert error["code"] == "unsupported" 102 | assert error["attr"] == "shipping_address__state" 103 | -------------------------------------------------------------------------------- /tests/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | from drf_spectacular.views import SpectacularAPIView 3 | from rest_framework.permissions import IsAdminUser 4 | 5 | from .views import ( 6 | AuthErrorView, 7 | ErrorView, 8 | IntegrityErrorView, 9 | OrderErrorView, 10 | RateLimitErrorView, 11 | RecursionView, 12 | ) 13 | 14 | urlpatterns = [ 15 | path("integrity-error/", IntegrityErrorView.as_view()), 16 | path("error/", ErrorView.as_view()), 17 | path("order-error/", OrderErrorView.as_view()), 18 | path("auth-error/", AuthErrorView.as_view()), 19 | path("rate-limit-error/", RateLimitErrorView.as_view()), 20 | path("recursion-error/", RecursionView.as_view()), 21 | path("schema/", SpectacularAPIView.as_view(), name="api-schema"), 22 | path( 23 | "protected-schema/", 24 | SpectacularAPIView.as_view(permission_classes=[IsAdminUser]), 25 | ), 26 | ] 27 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from django.urls import path, re_path 2 | from drf_spectacular.generators import SchemaGenerator 3 | from drf_spectacular.validation import validate_schema 4 | 5 | 6 | def generate_view_schema(route, view): 7 | patterns = [path(route, view)] 8 | 9 | generator = SchemaGenerator(patterns=patterns) 10 | schema = generator.get_schema(request=None, public=True) 11 | validate_schema(schema) 12 | return schema 13 | 14 | 15 | def generate_versioned_view_schema(view, version): 16 | patterns = [re_path(r"^(?P(v1|v2|v3))/validate/", view)] 17 | generator = SchemaGenerator(patterns=patterns, api_version=version) 18 | schema = generator.get_schema(request=None, public=True) 19 | validate_schema(schema) 20 | return schema 21 | 22 | 23 | def get_responses(schema: dict, route: str, method="get"): 24 | return schema["paths"][f"/{route}"][method]["responses"] 25 | 26 | 27 | def get_error_codes(api_schema, schema_name): 28 | return api_schema["components"]["schemas"][schema_name]["properties"]["code"][ 29 | "enum" 30 | ] 31 | -------------------------------------------------------------------------------- /tests/views.py: -------------------------------------------------------------------------------- 1 | from django.db import IntegrityError 2 | from rest_framework import serializers 3 | from rest_framework.authentication import BasicAuthentication 4 | from rest_framework.generics import GenericAPIView 5 | from rest_framework.permissions import IsAuthenticated 6 | from rest_framework.response import Response 7 | from rest_framework.throttling import BaseThrottle 8 | from rest_framework.views import APIView 9 | 10 | 11 | class IntegrityErrorView(APIView): 12 | def post(self, request, *args, **kwargs): 13 | raise IntegrityError("Concurrent update prevented.") 14 | 15 | 16 | class ErrorView(APIView): 17 | def get(self, request, *args, **kwargs): 18 | raise Exception("Internal server error.") 19 | 20 | 21 | class ShippingAddressSerializer(serializers.Serializer): 22 | street_address = serializers.CharField() 23 | city = serializers.CharField() 24 | state = serializers.CharField() 25 | zipcode = serializers.CharField() 26 | 27 | def validate_state(self, value): 28 | if value != "CA": 29 | raise serializers.ValidationError( 30 | "We do not support shipping to the provided address.", 31 | code="unsupported", 32 | ) 33 | return value 34 | 35 | 36 | class OrderSerializer(serializers.Serializer): 37 | shipping_address = ShippingAddressSerializer() 38 | 39 | 40 | class OrderErrorView(GenericAPIView): 41 | serializer_class = OrderSerializer 42 | 43 | def post(self, request, *args, **kwargs): 44 | serializer = self.get_serializer(data=request.data) 45 | serializer.is_valid(raise_exception=True) 46 | return Response(status=204) 47 | 48 | 49 | class AuthErrorView(APIView): 50 | authentication_classes = [BasicAuthentication] 51 | permission_classes = [IsAuthenticated] 52 | 53 | def get(self, request, *args, **kwargs): 54 | return Response(status=204) 55 | 56 | 57 | class CustomThrottle(BaseThrottle): 58 | def allow_request(self, request, view): 59 | return False 60 | 61 | def wait(self): 62 | return 600 63 | 64 | 65 | class RateLimitErrorView(APIView): 66 | throttle_classes = [CustomThrottle] 67 | 68 | def get(self, request, *args, **kwargs): 69 | return Response(status=204) 70 | 71 | 72 | class RecursionView(APIView): 73 | def get(self, request, *args, **kwargs): 74 | errors = [{"field": ["Some Error"]} for _ in range(1, 1000)] 75 | raise serializers.ValidationError(errors) 76 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # tox (https://tox.readthedocs.io/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | isolated_build = True 8 | envlist = 9 | py{38,39,310,311}-dj32-drf312 10 | py{38,39,310,311,312}-dj{32,40,41}-drf{313,314,315} 11 | py{38,39,310,311,312,313}-dj42-drf{314,315} 12 | py{310,311,312,313}-dj50-drf{314,315} 13 | py{310,311,312,313}-dj{51,52}-drf315 14 | py{310,311,312,313}-dj{42,50,51,52}-drf316 15 | lint 16 | docs 17 | 18 | [gh-actions] 19 | python = 20 | 3.8: py38 21 | 3.9: py39 22 | 3.10: py310 23 | 3.11: py311 24 | 3.12: py312, lint, docs 25 | 3.13: py313 26 | 27 | [testenv] 28 | deps = 29 | pytest 30 | pytest-django 31 | drf-spectacular>=0.27.1 32 | django-filter 33 | dj32: Django>=3.2,<4.0 34 | dj40: Django>=4.0,<4.1 35 | dj41: Django>=4.1,<4.2 36 | dj42: Django>=4.2,<5.0 37 | dj50: Django>=5.0,<5.1 38 | dj51: Django>=5.1,<5.2 39 | dj52: Django>=5.2,<6.0 40 | drf312: djangorestframework>=3.12,<3.13 41 | drf313: djangorestframework>=3.13,<3.14 42 | drf314: djangorestframework>=3.14,<3.15 43 | drf315: djangorestframework>=3.15,<3.16 44 | drf316: djangorestframework>=3.16,<3.17 45 | commands = 46 | pytest 47 | 48 | [testenv:lint] 49 | skip_install = true 50 | deps = pre-commit 51 | commands = pre-commit run --all-files --show-diff-on-failure 52 | 53 | [testenv:docs] 54 | extras = doc 55 | commands = sphinx-build -d "{toxworkdir}/docs_doctree" docs "{toxworkdir}/docs_out" --color -W -bhtml {posargs} 56 | python -c 'import pathlib; print("documentation available under file://\{0\}".format(pathlib.Path(r"{toxworkdir}") / "docs_out" / "index.html"))' 57 | 58 | [pytest] 59 | DJANGO_SETTINGS_MODULE = tests.settings 60 | testpaths = tests 61 | pythonpath = . drf_standardized_errors 62 | 63 | [coverage:run] 64 | branch = True 65 | source = drf_standardized_errors 66 | 67 | [coverage:report] 68 | omit = 69 | # the hook code is a copy from drf-spectacular with one change 70 | # to exclude error components from being processed by the hook 71 | drf_standardized_errors/openapi_hooks.py 72 | exclude_lines = 73 | # Have to re-enable the standard pragma 74 | pragma: no cover 75 | 76 | # Don't complain about missing debug-only code: 77 | def __repr__ 78 | if self\.debug 79 | 80 | # Don't complain if tests don't hit defensive assertion code: 81 | raise AssertionError 82 | raise NotImplementedError 83 | 84 | # Don't complain if non-runnable code isn't run: 85 | if 0: 86 | if __name__ == .__main__.: 87 | 88 | # Don't complain about abstract methods, they aren't run: 89 | @(abc\.)?abstractmethod 90 | ignore_errors = True 91 | --------------------------------------------------------------------------------