├── .circleci └── config.yml ├── .github └── workflows │ ├── publish.yaml │ └── tests.yml ├── .gitignore ├── .sonarcloud.properties ├── CHANGES.md ├── LICENSE.txt ├── README.md ├── example_app ├── app.py └── example.py ├── flask_pydantic ├── __init__.py ├── converters.py ├── core.py ├── exceptions.py └── version.py ├── pyproject.toml ├── requirements ├── build.txt └── test.txt └── tests ├── __init__.py ├── conftest.py ├── func ├── __init__.py └── test_app.py ├── pydantic_v1 ├── __init__.py ├── conftest.py ├── func │ ├── __init__.py │ └── test_app.py └── unit │ ├── __init__.py │ └── test_core.py ├── unit ├── __init__.py └── test_core.py └── util.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | # specify the version you desire here 10 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers` 11 | - image: circleci/python:3.7 12 | 13 | working_directory: ~/repo 14 | 15 | steps: 16 | - checkout 17 | 18 | # Download and cache dependencies 19 | - restore_cache: 20 | keys: 21 | - v1-dependencies-{{ checksum "requirements/base.pip" }}-{{ checksum "requirements/test.pip" }} 22 | # fallback to using the latest cache if no exact match is found 23 | - v1-dependencies- 24 | 25 | - run: 26 | name: install dependencies 27 | command: | 28 | python3 -m venv venv 29 | . venv/bin/activate 30 | pip install -r requirements/test.pip 31 | 32 | - save_cache: 33 | paths: 34 | - ./venv 35 | key: v1-dependencies-{{ checksum "requirements/base.pip" }}-{{ checksum "requirements/test.pip" }} 36 | 37 | # run tests! 38 | # this example uses Django's built-in test-runner 39 | # other common Python testing frameworks include pytest and nose 40 | # https://pytest.org 41 | # https://nose.readthedocs.io 42 | - run: 43 | name: run tests 44 | command: | 45 | . venv/bin/activate 46 | python -m pytest 47 | 48 | - store_artifacts: 49 | path: htmlcov 50 | destination: test-reports 51 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | outputs: 10 | hash: ${{ steps.hash.outputs.hash }} 11 | steps: 12 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 13 | - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 14 | with: 15 | python-version: '3.x' 16 | cache: pip 17 | cache-dependency-path: requirements*/*.txt 18 | - run: pip install -r requirements/build.txt 19 | # Use the commit date instead of the current date during the build. 20 | - run: echo "SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct)" >> $GITHUB_ENV 21 | - run: python -m build 22 | # Generate hashes used for provenance. 23 | - name: generate hash 24 | id: hash 25 | run: cd dist && echo "hash=$(sha256sum * | base64 -w0)" >> $GITHUB_OUTPUT 26 | - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 27 | with: 28 | path: ./dist 29 | provenance: 30 | needs: [build] 31 | permissions: 32 | actions: read 33 | id-token: write 34 | contents: write 35 | # Can't pin with hash due to how this workflow works. 36 | uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v2.0.0 37 | with: 38 | base64-subjects: ${{ needs.build.outputs.hash }} 39 | create-release: 40 | # Upload the sdist, wheels, and provenance to a GitHub release. They remain 41 | # available as build artifacts for a while as well. 42 | needs: [provenance] 43 | runs-on: ubuntu-latest 44 | permissions: 45 | contents: write 46 | steps: 47 | - uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # v4.2.1 48 | - name: create release 49 | run: > 50 | gh release create --draft --repo ${{ github.repository }} 51 | ${{ github.ref_name }} 52 | *.intoto.jsonl/* artifact/* 53 | env: 54 | GH_TOKEN: ${{ github.token }} 55 | publish-pypi: 56 | needs: [provenance] 57 | # Wait for approval before attempting to upload to PyPI. This allows reviewing the 58 | # files in the draft release. 59 | environment: 60 | name: publish 61 | url: https://pypi.org/project/Flask-Pydantic/${{ github.ref_name }} 62 | runs-on: ubuntu-latest 63 | permissions: 64 | id-token: write 65 | steps: 66 | - uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # v4.2.1 67 | - uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 68 | with: 69 | packages-dir: artifact/ 70 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] 16 | os: [ubuntu-latest, macOS-latest] 17 | # Python 3.7 is not supported on Apple ARM64, 18 | # or the latest Ubuntu 2404 19 | exclude: 20 | - python-version: "3.7" 21 | os: macos-latest 22 | - python-version: "3.7" 23 | os: ubuntu-latest 24 | include: # Python 3.7 is tested with a x86 macOS version 25 | - python-version: "3.7" 26 | os: macos-13 27 | 28 | 29 | steps: 30 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | cache: pip 36 | cache-dependency-path: requirements/*.txt 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -e . 41 | pip install -r requirements/test.txt 42 | - name: Run tests and check code format 43 | run: | 44 | python3 -m pytest --ruff --ruff-format 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VS Code settings 132 | .vscode 133 | 134 | docker-compose.yaml 135 | .dockerignore 136 | .idea/ -------------------------------------------------------------------------------- /.sonarcloud.properties: -------------------------------------------------------------------------------- 1 | sonar.tests=tests 2 | sonar.sources=flask_pydantic 3 | sonar.exclusions=example_app/**/* 4 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | ## Version 0.14.0-dev 2 | 3 | Unreleased 4 | 5 | ## Version 0.13.1 6 | 7 | Released 2025-04-23 8 | 9 | Features 10 | 11 | - Better support for iterables in query parameters. (thanks to @oda02) 12 | 13 | ## Version 0.13.0 14 | 15 | Released 2025-04-02 16 | 17 | Features 18 | 19 | - Support for Pydantic v1 namespace. (thanks to @Merinorus) 20 | 21 | ## Version 0.12.0 22 | 23 | Released 2024-01-08 24 | 25 | Features 26 | 27 | - Support Pydantic 2. Drop support for Pydantic 1. (thanks to @jkseppan) 28 | 29 | ## Version 0.11.0 30 | 31 | Released 2022-09-25 32 | 33 | Features 34 | 35 | - Allow raising `flask_pydantic.ValidationError` by setting 36 | `FLASK_PYDANTIC_VALIDATION_ERROR_RAISE=True` 37 | 38 | ## Version 0.10.0 39 | 40 | Released 2022-07-31 41 | 42 | Features 43 | 44 | - Add validation for form data 45 | - Handle extra headers returned by route functions 46 | 47 | Internal 48 | 49 | - Cleanup pipelines, drop python 3.6 tests, test on MacOS images 50 | 51 | ## Version 0.9.0 52 | 53 | Released 2021-10-28 54 | 55 | Features 56 | 57 | - Support for passing parameters to [`flask.Request.get_json`](https://tedboy.github.io/flask/generated/generated/flask.Request.get_json.html) function via 58 | `validate`'s `get_json_params` parameter 59 | 60 | Internal 61 | 62 | - Add tests for Python 3.10 to pipeline 63 | 64 | ## Version 0.8.0 65 | 66 | Released 2021-05-09 67 | 68 | Features 69 | 70 | - Return `400` response when model's `__root__` validation fails 71 | 72 | ## Version 0.7.2 73 | 74 | Released 2021-04-26 75 | 76 | Bugfixes 77 | 78 | - ignore return-type annotations 79 | 80 | ## Version 0.7.1 81 | 82 | Released 2021-04-08 83 | 84 | Bugfixes 85 | 86 | - recognize mime types with character encoding standard 87 | 88 | ## Version 0.7.0 89 | 90 | Released 2021-04-05 91 | 92 | Features 93 | 94 | - add support for URL path parameters parsing and validation 95 | 96 | ## Version 0.6.3 97 | 98 | Released 2021-03-26 99 | 100 | - do pin specific versions of required packages 101 | 102 | ## Version 0.6.2 103 | 104 | Released 2021-03-09 105 | 106 | Bugfixes 107 | 108 | - fix type annotations of decorated method 109 | 110 | ## Version 0.6.1 111 | 112 | Released 2021-02-18 113 | 114 | Bugfixes 115 | 116 | - parsing of query parameters in older versions of python 3.6 117 | 118 | ## Version 0.6.0 119 | 120 | Released 2021-01-31 121 | 122 | Features 123 | 124 | - improve README, example app 125 | - add support for pydantic's [custom root types](https://pydantic-docs.helpmanual.io/usage/models/#custom-root-types) 126 | 127 | ## Version 0.5.0 128 | 129 | Released 2021-01-17 130 | 131 | Features 132 | 133 | - add `Flask` classifier 134 | 135 | ## Version 0.4.0 136 | 137 | Released 2020-09-10 138 | 139 | Features 140 | 141 | - add support for [alias feature](https://pydantic-docs.helpmanual.io/usage/model_config/#alias-generator) in response models 142 | 143 | ## Version 0.3.0 144 | 145 | Released 2020-09-08 146 | 147 | Features 148 | 149 | - add possibility to specify models using keyword arguments 150 | 151 | ## Version 0.2.0 152 | 153 | Released 2020-08-07 154 | 155 | Features 156 | 157 | - add support for python version `3.6` 158 | 159 | ## Version 0.1.0 160 | 161 | Released 2020-08-02 162 | 163 | Features 164 | 165 | - add proper parsing and validation of array query parameters 166 | 167 | ## Version 0.0.7 168 | 169 | Released 2020-07-20 170 | 171 | - add possibility to configure response status code after 172 | `ValidationError` using flask app config value 173 | `FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE` 174 | 175 | ## Version 0.0.6 176 | 177 | Released 2020-06-11 178 | 179 | Features 180 | 181 | - return 182 | `415 - Unsupported media type` response for requests to endpoints with specified body model with other content type than 183 | `application/json`. 184 | 185 | ## Version 0.0.5 186 | 187 | Released 2020-01-15 188 | 189 | Bugfixes 190 | 191 | - do not try to access query or body requests parameters unless model is provided~~ 192 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jiri Bauer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flask-Pydantic 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/Flask-Pydantic?color=g)](https://pypi.org/project/Flask-Pydantic/) 4 | [![License](https://img.shields.io/badge/license-MIT-purple)](https://github.com/bauerji/flask_pydantic/blob/master/LICENSE) 5 | 6 | Flask extension for integration of the awesome [pydantic package](https://github.com/samuelcolvin/pydantic) with [Flask](https://palletsprojects.com/p/flask/). 7 | 8 | ## Pallets Community Ecosystem 9 | 10 | > [!IMPORTANT]\ 11 | > This project is part of the Pallets Community Ecosystem. Pallets is the open 12 | > source organization that maintains Flask; Pallets-Eco enables community 13 | > maintenance of Flask extensions. If you are interested in helping maintain 14 | > this project, please reach out on [the Pallets Discord server][discord]. 15 | > 16 | > [discord]: https://discord.gg/pallets 17 | 18 | ## Installation 19 | 20 | `python3 -m pip install Flask-Pydantic` 21 | 22 | ## Basics 23 | ### URL query and body parameters 24 | 25 | `validate` decorator validates query, body and form-data request parameters and makes them accessible two ways: 26 | 27 | 1. [Using `validate` arguments, via flask's `request` variable](#basic-example) 28 | 29 | | **parameter type** | **`request` attribute name** | 30 | |:------------------:|:----------------------------:| 31 | | query | `query_params` | 32 | | body | `body_params` | 33 | | form | `form_params` | 34 | 35 | 2. [Using the decorated function argument parameters type hints](#using-the-decorated-function-kwargs) 36 | 37 | ### URL path parameter 38 | 39 | If you use annotated path URL path parameters as follows 40 | ```python 41 | 42 | @app.route("/users/", methods=["GET"]) 43 | @validate() 44 | def get_user(user_id: str): 45 | pass 46 | ``` 47 | flask_pydantic will parse and validate `user_id` variable in the same manner as for body and query parameters. 48 | 49 | --- 50 | 51 | ### Additional `validate` arguments 52 | 53 | - Success response status code can be modified via `on_success_status` parameter of `validate` decorator. 54 | - `response_many` parameter set to `True` enables serialization of multiple models (route function should therefore return iterable of models). 55 | - `request_body_many` parameter set to `False` analogically enables serialization of multiple models inside of the root level of request body. If the request body doesn't contain an array of objects `400` response is returned, 56 | - `get_json_params` - parameters to be passed to [`flask.Request.get_json`](https://tedboy.github.io/flask/generated/generated/flask.Request.get_json.html) function 57 | - If validation fails, `400` response is returned with failure explanation. 58 | 59 | For more details see in-code docstring or example app. 60 | 61 | ## Usage 62 | 63 | ### Example 1: Query parameters only 64 | 65 | Simply use `validate` decorator on route function. 66 | 67 | :exclamation: Be aware that `@app.route` decorator must precede `@validate` (i. e. `@validate` must be closer to the function declaration). 68 | 69 | ```python 70 | from typing import Optional 71 | from flask import Flask, request 72 | from pydantic import BaseModel 73 | 74 | from flask_pydantic import validate 75 | 76 | app = Flask("flask_pydantic_app") 77 | 78 | class QueryModel(BaseModel): 79 | age: int 80 | 81 | class ResponseModel(BaseModel): 82 | id: int 83 | age: int 84 | name: str 85 | nickname: Optional[str] = None 86 | 87 | # Example 1: query parameters only 88 | @app.route("/", methods=["GET"]) 89 | @validate() 90 | def get(query: QueryModel): 91 | age = query.age 92 | return ResponseModel( 93 | age=age, 94 | id=0, name="abc", nickname="123" 95 | ) 96 | ``` 97 | 98 | 99 | See the full example app here 100 | 101 | 102 | 103 | - `age` query parameter is a required `int` 104 | - `curl --location --request GET 'http://127.0.0.1:5000/'` 105 | - if none is provided the response contains: 106 | ```json 107 | { 108 | "validation_error": { 109 | "query_params": [ 110 | { 111 | "loc": ["age"], 112 | "msg": "field required", 113 | "type": "value_error.missing" 114 | } 115 | ] 116 | } 117 | } 118 | ``` 119 | - for incompatible type (e. g. string `/?age=not_a_number`) 120 | - `curl --location --request GET 'http://127.0.0.1:5000/?age=abc'` 121 | ```json 122 | { 123 | "validation_error": { 124 | "query_params": [ 125 | { 126 | "loc": ["age"], 127 | "msg": "value is not a valid integer", 128 | "type": "type_error.integer" 129 | } 130 | ] 131 | } 132 | } 133 | ``` 134 | - likewise for body parameters 135 | - example call with valid parameters: 136 | `curl --location --request GET 'http://127.0.0.1:5000/?age=20'` 137 | 138 | -> `{"id": 0, "age": 20, "name": "abc", "nickname": "123"}` 139 | 140 | 141 | ### Example 2: URL path parameter 142 | 143 | ```python 144 | @app.route("/character//", methods=["GET"]) 145 | @validate() 146 | def get_character(character_id: int): 147 | characters = [ 148 | ResponseModel(id=1, age=95, name="Geralt", nickname="White Wolf"), 149 | ResponseModel(id=2, age=45, name="Triss Merigold", nickname="sorceress"), 150 | ResponseModel(id=3, age=42, name="Julian Alfred Pankratz", nickname="Jaskier"), 151 | ResponseModel(id=4, age=101, name="Yennefer", nickname="Yenn"), 152 | ] 153 | try: 154 | return characters[character_id] 155 | except IndexError: 156 | return {"error": "Not found"}, 400 157 | ``` 158 | 159 | 160 | ### Example 3: Request body only 161 | 162 | ```python 163 | class RequestBodyModel(BaseModel): 164 | name: str 165 | nickname: Optional[str] = None 166 | 167 | # Example2: request body only 168 | @app.route("/", methods=["POST"]) 169 | @validate() 170 | def post(body: RequestBodyModel): 171 | name = body.name 172 | nickname = body.nickname 173 | return ResponseModel( 174 | name=name, nickname=nickname,id=0, age=1000 175 | ) 176 | ``` 177 | 178 | 179 | See the full example app here 180 | 181 | 182 | ### Example 4: BOTH query paramaters and request body 183 | 184 | ```python 185 | # Example 3: both query paramters and request body 186 | @app.route("/both", methods=["POST"]) 187 | @validate() 188 | def get_and_post(body: RequestBodyModel, query: QueryModel): 189 | name = body.name # From request body 190 | nickname = body.nickname # From request body 191 | age = query.age # from query parameters 192 | return ResponseModel( 193 | age=age, name=name, nickname=nickname, 194 | id=0 195 | ) 196 | ``` 197 | 198 | 199 | See the full example app here 200 | 201 | 202 | 203 | ### Example 5: Request form-data only 204 | 205 | ```python 206 | class RequestFormDataModel(BaseModel): 207 | name: str 208 | nickname: Optional[str] = None 209 | 210 | # Example2: request body only 211 | @app.route("/", methods=["POST"]) 212 | @validate() 213 | def post(form: RequestFormDataModel): 214 | name = form.name 215 | nickname = form.nickname 216 | return ResponseModel( 217 | name=name, nickname=nickname,id=0, age=1000 218 | ) 219 | ``` 220 | 221 | 222 | See the full example app here 223 | 224 | 225 | ### Modify response status code 226 | 227 | The default success status code is `200`. It can be modified in two ways 228 | 229 | - in return statement 230 | 231 | ```python 232 | # necessary imports, app and models definition 233 | ... 234 | 235 | @app.route("/", methods=["POST"]) 236 | @validate(body=BodyModel, query=QueryModel) 237 | def post(): 238 | return ResponseModel( 239 | id=id_, 240 | age=request.query_params.age, 241 | name=request.body_params.name, 242 | nickname=request.body_params.nickname, 243 | ), 201 244 | ``` 245 | 246 | - in `validate` decorator 247 | 248 | ```python 249 | @app.route("/", methods=["POST"]) 250 | @validate(body=BodyModel, query=QueryModel, on_success_status=201) 251 | def post(): 252 | ... 253 | ``` 254 | 255 | Status code in case of validation error can be modified using `FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE` flask configuration variable. 256 | 257 | ### Using the decorated function `kwargs` 258 | 259 | Instead of passing `body` and `query` to `validate`, it is possible to directly 260 | defined them by using type hinting in the decorated function. 261 | 262 | ```python 263 | # necessary imports, app and models definition 264 | ... 265 | 266 | @app.route("/", methods=["POST"]) 267 | @validate() 268 | def post(body: BodyModel, query: QueryModel): 269 | return ResponseModel( 270 | id=id_, 271 | age=query.age, 272 | name=body.name, 273 | nickname=body.nickname, 274 | ) 275 | ``` 276 | 277 | This way, the parsed data will be directly available in `body` and `query`. 278 | Furthermore, your IDE will be able to correctly type them. 279 | 280 | ### Model aliases 281 | 282 | Pydantic's [alias feature](https://pydantic-docs.helpmanual.io/usage/model_config/#alias-generator) is natively supported for query and body models. 283 | To use aliases in response modify response model 284 | ```python 285 | def modify_key(text: str) -> str: 286 | # do whatever you want with model keys 287 | return text 288 | 289 | 290 | class MyModel(BaseModel): 291 | ... 292 | model_config = ConfigDict( 293 | alias_generator=modify_key, 294 | populate_by_name=True 295 | ) 296 | 297 | ``` 298 | 299 | and set `response_by_alias=True` in `validate` decorator 300 | 301 | ```python 302 | @app.route(...) 303 | @validate(response_by_alias=True) 304 | def my_route(): 305 | ... 306 | return MyModel(...) 307 | ``` 308 | 309 | ### Example app 310 | 311 | For more complete examples see [example application](https://github.com/bauerji/flask_pydantic/tree/master/example_app). 312 | 313 | ### Configuration 314 | 315 | The behaviour can be configured using flask's application config 316 | `FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE` - response status code after validation error (defaults to `400`) 317 | 318 | Additionally, you can set `FLASK_PYDANTIC_VALIDATION_ERROR_RAISE` to `True` to cause 319 | `flask_pydantic.ValidationError` to be raised with either `body_params`, 320 | `form_params`, `path_params`, or `query_params` set as a list of error 321 | dictionaries. You can use `flask.Flask.register_error_handler` to catch that 322 | exception and fully customize the output response for a validation error. 323 | 324 | ## Contributing 325 | 326 | Feature requests and pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 327 | 328 | - clone repository 329 | ```bash 330 | git clone https://github.com/pallets-eco/flask_pydantic.git 331 | cd flask_pydantic 332 | ``` 333 | - create virtual environment and activate it 334 | ```bash 335 | python3 -m venv venv 336 | source venv/bin/activate 337 | ``` 338 | - install development requirements 339 | ```bash 340 | python3 -m pip install -r requirements/test.txt 341 | ``` 342 | - checkout new branch and make your desired changes (don't forget to update tests) 343 | ```bash 344 | git checkout -b 345 | ``` 346 | - make sure your code style is compliant with [Ruff](https://github.com/astral-sh/ruff). Your can check these errors and automatically correct some of them with `ruff check --select I --fix . ` 347 | - run tests and check code format 348 | ```bash 349 | python3 -m pytest --ruff --ruff-format 350 | ``` 351 | - push your changes and create a pull request to master branch 352 | 353 | ## TODOs: 354 | 355 | - header request parameters 356 | - cookie request parameters 357 | -------------------------------------------------------------------------------- /example_app/app.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from flask import Flask, jsonify, request 5 | from flask_pydantic import validate 6 | from pydantic import BaseModel 7 | 8 | app = Flask("flask_pydantic_app") 9 | 10 | 11 | @dataclass 12 | class Config: 13 | FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE: int = 422 14 | 15 | 16 | app.config.from_object(Config) 17 | 18 | 19 | class QueryModel(BaseModel): 20 | age: int 21 | 22 | 23 | class IndexParam(BaseModel): 24 | index: int 25 | 26 | 27 | class BodyModel(BaseModel): 28 | name: str 29 | nickname: Optional[str] = None 30 | 31 | 32 | class FormModel(BaseModel): 33 | name: str 34 | nickname: Optional[str] = None 35 | 36 | 37 | class ResponseModel(BaseModel): 38 | id: int 39 | age: int 40 | name: str 41 | nickname: Optional[str] = None 42 | 43 | 44 | @app.route("/", methods=["POST"]) 45 | @validate(body=BodyModel, query=QueryModel) 46 | def post(): 47 | """ 48 | Basic example with both query and body parameters, response object serialization. 49 | """ 50 | # save model to DB 51 | id_ = 2 52 | 53 | return ResponseModel( 54 | id=id_, 55 | age=request.query_params.age, 56 | name=request.body_params.name, 57 | nickname=request.body_params.nickname, 58 | ) 59 | 60 | 61 | @app.route("/form", methods=["POST"]) 62 | @validate(form=FormModel, query=QueryModel) 63 | def form_post(): 64 | """ 65 | Basic example with both query and form-data parameters, response object serialization. 66 | """ 67 | # save model to DB 68 | id_ = 2 69 | 70 | return ResponseModel( 71 | id=id_, 72 | age=request.query_params.age, 73 | name=request.form_params.name, 74 | nickname=request.form_params.nickname, 75 | ) 76 | 77 | 78 | @app.route("/kwargs", methods=["POST"]) 79 | @validate() 80 | def post_kwargs(body: BodyModel, query: QueryModel): 81 | """ 82 | Basic example with both query and body parameters, response object serialization. 83 | This time using the decorated function kwargs `body` and `query` type hinting 84 | """ 85 | # save model to DB 86 | id_ = 3 87 | 88 | return ResponseModel(id=id_, age=query.age, name=body.name, nickname=body.nickname) 89 | 90 | 91 | @app.route("/form/kwargs", methods=["POST"]) 92 | @validate() 93 | def form_post_kwargs(form: FormModel, query: QueryModel): 94 | """ 95 | Basic example with both query and form-data parameters, response object serialization. 96 | This time using the decorated function kwargs `form` and `query` type hinting 97 | """ 98 | # save model to DB 99 | id_ = 3 100 | 101 | return ResponseModel(id=id_, age=query.age, name=form.name, nickname=form.nickname) 102 | 103 | 104 | @app.route("/many", methods=["GET"]) 105 | @validate(response_many=True) 106 | def get_many(): 107 | """ 108 | This route returns response containing many serialized objects. 109 | """ 110 | return [ 111 | ResponseModel(id=1, age=95, name="Geralt", nickname="White Wolf"), 112 | ResponseModel(id=2, age=45, name="Triss Merigold", nickname="sorceress"), 113 | ResponseModel(id=3, age=42, name="Julian Alfred Pankratz", nickname="Jaskier"), 114 | ResponseModel(id=4, age=101, name="Yennefer", nickname="Yenn"), 115 | ] 116 | 117 | 118 | @app.route("/select", methods=["POST"]) 119 | @validate(request_body_many=True, query=IndexParam, body=BodyModel) 120 | def select_from_array(): 121 | """ 122 | This route takes array of objects in request body and returns the object on index 123 | (index is a url query parameter) 124 | """ 125 | try: 126 | return BodyModel(**request.body_params[request.query_params.index].dict()) 127 | except IndexError: 128 | return jsonify({"reason": "index out of bound"}), 400 129 | -------------------------------------------------------------------------------- /example_app/example.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from flask import Flask 4 | from flask_pydantic import validate 5 | from pydantic import BaseModel 6 | 7 | app = Flask("flask_pydantic_app") 8 | 9 | 10 | class RequestBodyModel(BaseModel): 11 | name: str 12 | nickname: Optional[str] = None 13 | 14 | 15 | class QueryModel(BaseModel): 16 | age: int 17 | 18 | 19 | class FormModel(BaseModel): 20 | name: str 21 | nickname: Optional[str] = None 22 | 23 | 24 | @app.route("/", methods=["GET"]) 25 | @validate() 26 | def get(query: QueryModel): 27 | age = query.age 28 | return ResponseModel(age=age, id=0, name="abc", nickname="123") 29 | 30 | 31 | """ 32 | curl --location --request GET 'http://127.0.0.1:5000/' 33 | curl --location --request GET 'http://127.0.0.1:5000/?ageeee=5' 34 | curl --location --request GET 'http://127.0.0.1:5000/?age=abc' 35 | 36 | curl --location --request GET 'http://127.0.0.1:5000/?age=5' 37 | """ 38 | 39 | 40 | class ResponseModel(BaseModel): 41 | id: int 42 | age: int 43 | name: str 44 | nickname: Optional[str] = None 45 | 46 | 47 | @app.route("/character//", methods=["GET"]) 48 | @validate() 49 | def get_character(character_id: int): 50 | characters = [ 51 | ResponseModel(id=1, age=95, name="Geralt", nickname="White Wolf"), 52 | ResponseModel(id=2, age=45, name="Triss Merigold", nickname="sorceress"), 53 | ResponseModel(id=3, age=42, name="Julian Alfred Pankratz", nickname="Jaskier"), 54 | ResponseModel(id=4, age=101, name="Yennefer", nickname="Yenn"), 55 | ] 56 | try: 57 | return characters[character_id] 58 | except IndexError: 59 | return {"error": "Not found"}, 400 60 | 61 | 62 | """ 63 | curl http://127.0.0.1:5000/character/2/ \ 64 | --header 'Content-Type: application/json' 65 | """ 66 | 67 | 68 | @app.route("/", methods=["POST"]) 69 | @validate() 70 | def post(body: RequestBodyModel): 71 | name = body.name 72 | nickname = body.nickname 73 | return ResponseModel(name=name, nickname=nickname, id=0, age=1000) 74 | 75 | 76 | """ 77 | curl --location --request POST 'http://127.0.0.1:5000/' 78 | 79 | curl --location --request POST 'http://127.0.0.1:5000/' \ 80 | --header 'Content-Type: application/json' \ 81 | --data-raw '{' 82 | 83 | curl --location --request POST 'http://127.0.0.1:5000/' \ 84 | --header 'Content-Type: application/json' \ 85 | --data-raw '{"nameee":123}' 86 | 87 | curl --location --request POST 'http://127.0.0.1:5000/' \ 88 | --header 'Content-Type: application/json' \ 89 | --data-raw '{"name":123}' 90 | """ 91 | 92 | 93 | @app.route("/form", methods=["POST"]) 94 | @validate() 95 | def form_post(form: FormModel): 96 | name = form.name 97 | nickname = form.nickname 98 | return ResponseModel(name=name, nickname=nickname, id=0, age=1000) 99 | 100 | 101 | """ 102 | curl --location --request POST 'http://127.0.0.1:5000/form' 103 | 104 | curl --location --request POST 'http://127.0.0.1:5000/form' \ 105 | -F name=123\ 106 | 107 | curl --location --request POST 'http://127.0.0.1:5000/form' \ 108 | -F name=some-name 109 | """ 110 | 111 | 112 | @app.route("/both", methods=["POST"]) 113 | @validate() 114 | def get_and_post(body: RequestBodyModel, query: QueryModel): 115 | name = body.name # From request body 116 | nickname = body.nickname # From request body 117 | age = query.age # from query parameters 118 | return ResponseModel(age=age, name=name, nickname=nickname, id=0) 119 | 120 | 121 | """ 122 | curl --location --request POST 'http://127.0.0.1:5000/both' \ 123 | --header 'Content-Type: application/json' \ 124 | --data-raw '{"name":123}' 125 | 126 | curl --location --request POST 'http://127.0.0.1:5000/both?age=40' \ 127 | --header 'Content-Type: application/json' \ 128 | --data-raw '{"name":123}' 129 | """ 130 | 131 | 132 | if __name__ == "__main__": 133 | app.run() 134 | -------------------------------------------------------------------------------- /flask_pydantic/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import validate # noqa: F401 2 | from .exceptions import ValidationError # noqa: F401 3 | from .version import __version__ # noqa: F401 4 | -------------------------------------------------------------------------------- /flask_pydantic/converters.py: -------------------------------------------------------------------------------- 1 | import types 2 | from collections import deque 3 | from typing import Deque, FrozenSet, List, Sequence, Set, Tuple, Type, Union 4 | 5 | try: 6 | from typing import get_args, get_origin 7 | except ImportError: 8 | from typing_extensions import get_args, get_origin 9 | 10 | from pydantic import BaseModel 11 | from pydantic.v1 import BaseModel as V1BaseModel 12 | from werkzeug.datastructures import ImmutableMultiDict 13 | 14 | V1OrV2BaseModel = Union[BaseModel, V1BaseModel] 15 | UnionType = getattr(types, "UnionType", Union) 16 | 17 | sequence_types = { 18 | Sequence, 19 | List, 20 | list, 21 | Tuple, 22 | tuple, 23 | Set, 24 | set, 25 | FrozenSet, 26 | frozenset, 27 | Deque, 28 | deque, 29 | } 30 | 31 | 32 | def _is_sequence(type_: Type) -> bool: 33 | origin = get_origin(type_) or type_ 34 | if origin is Union or origin is UnionType: 35 | return any(_is_sequence(t) for t in get_args(type_)) 36 | 37 | return origin in sequence_types and origin not in (str, bytes) 38 | 39 | 40 | def convert_query_params( 41 | query_params: ImmutableMultiDict, model: Type[V1OrV2BaseModel] 42 | ) -> dict: 43 | """ 44 | group query parameters into lists if model defines them 45 | 46 | :param query_params: flasks request.args 47 | :param model: query parameter's model 48 | :return: resulting parameters 49 | """ 50 | if issubclass(model, BaseModel): 51 | return { 52 | **query_params.to_dict(), 53 | **{ 54 | key: value 55 | for key, value in query_params.to_dict(flat=False).items() 56 | if key in model.model_fields 57 | and _is_sequence(model.model_fields[key].annotation) 58 | }, 59 | } 60 | else: 61 | return { 62 | **query_params.to_dict(), 63 | **{ 64 | key: value 65 | for key, value in query_params.to_dict(flat=False).items() 66 | if key in model.__fields__ and model.__fields__[key].is_complex() 67 | }, 68 | } 69 | -------------------------------------------------------------------------------- /flask_pydantic/core.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union 3 | 4 | from flask import Response, current_app, jsonify, make_response, request 5 | from pydantic import BaseModel, RootModel, TypeAdapter, ValidationError 6 | from pydantic.v1 import BaseModel as V1BaseModel 7 | from pydantic.v1.error_wrappers import ValidationError as V1ValidationError 8 | from pydantic.v1.tools import parse_obj_as 9 | 10 | from .converters import convert_query_params 11 | from .exceptions import ( 12 | InvalidIterableOfModelsException, 13 | JsonBodyParsingError, 14 | ManyModelValidationError, 15 | ) 16 | from .exceptions import ValidationError as FailedValidation 17 | 18 | try: 19 | from flask_restful import original_flask_make_response as make_response 20 | except ImportError: 21 | pass 22 | 23 | V1OrV2BaseModel = Union[BaseModel, V1BaseModel] 24 | 25 | 26 | def _model_dump_json(model: V1OrV2BaseModel, **kwargs): 27 | """Adapter to dump a model to json, whether it's a Pydantic V1 or V2 model.""" 28 | if isinstance(model, BaseModel): 29 | return model.model_dump_json(**kwargs) 30 | else: 31 | return model.json(**kwargs) 32 | 33 | 34 | def make_json_response( 35 | content: Union[V1OrV2BaseModel, Iterable[V1OrV2BaseModel]], 36 | status_code: int, 37 | by_alias: bool, 38 | exclude_none: bool = False, 39 | many: bool = False, 40 | ) -> Response: 41 | """serializes model, creates JSON response with given status code""" 42 | if many: 43 | js = f"[{', '.join([_model_dump_json(model, exclude_none=exclude_none, by_alias=by_alias) for model in content])}]" 44 | else: 45 | js = _model_dump_json(content, exclude_none=exclude_none, by_alias=by_alias) 46 | response = make_response(js, status_code) 47 | response.mimetype = "application/json" 48 | return response 49 | 50 | 51 | def unsupported_media_type_response(request_cont_type: str) -> Response: 52 | body = { 53 | "detail": f"Unsupported media type '{request_cont_type}' in request. " 54 | "'application/json' is required." 55 | } 56 | return make_response(jsonify(body), 415) 57 | 58 | 59 | def is_iterable_of_models(content: Any) -> bool: 60 | try: 61 | return all(isinstance(obj, (BaseModel, V1BaseModel)) for obj in content) 62 | except TypeError: 63 | return False 64 | 65 | 66 | def validate_many_models( 67 | model: Type[V1OrV2BaseModel], content: Any 68 | ) -> List[V1OrV2BaseModel]: 69 | try: 70 | return [model(**fields) for fields in content] 71 | except TypeError as te: 72 | # iteration through `content` fails 73 | err = [ 74 | { 75 | "loc": ["root"], 76 | "msg": "is not an array of objects", 77 | "type": "type_error.array", 78 | } 79 | ] 80 | 81 | raise ManyModelValidationError(err) from te 82 | except (ValidationError, V1ValidationError) as ve: 83 | raise ManyModelValidationError(ve.errors()) from ve 84 | 85 | 86 | def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: 87 | errors = [] 88 | validated = {} 89 | for name, type_ in func.__annotations__.items(): 90 | if name in {"query", "body", "form", "return"}: 91 | continue 92 | try: 93 | if not isinstance(type_, V1BaseModel): 94 | adapter = TypeAdapter(type_) 95 | validated[name] = adapter.validate_python(kwargs.get(name)) 96 | else: 97 | value = parse_obj_as(type_, kwargs.get(name)) 98 | validated[name] = value 99 | except (ValidationError, V1ValidationError) as e: 100 | err = e.errors()[0] 101 | err["loc"] = [name] 102 | errors.append(err) 103 | kwargs = {**kwargs, **validated} 104 | return kwargs, errors 105 | 106 | 107 | def get_body_dict(**params): 108 | data = request.get_json(**params) 109 | if data is None and params.get("silent"): 110 | return {} 111 | return data 112 | 113 | 114 | def validate( 115 | body: Optional[Type[V1OrV2BaseModel]] = None, 116 | query: Optional[Type[V1OrV2BaseModel]] = None, 117 | on_success_status: int = 200, 118 | exclude_none: bool = False, 119 | response_many: bool = False, 120 | request_body_many: bool = False, 121 | response_by_alias: bool = False, 122 | get_json_params: Optional[dict] = None, 123 | form: Optional[Type[V1OrV2BaseModel]] = None, 124 | ): 125 | """ 126 | Decorator for route methods which will validate query, body and form parameters 127 | as well as serialize the response (if it derives from pydantic's BaseModel 128 | class). 129 | 130 | Request parameters are accessible via flask's `request` variable: 131 | - request.query_params 132 | - request.body_params 133 | - request.form_params 134 | 135 | Or directly as `kwargs`, if you define them in the decorated function. 136 | 137 | `exclude_none` whether to remove None fields from response 138 | `response_many` whether content of response consists of many objects 139 | (e. g. List[BaseModel]). Resulting response will be an array of serialized 140 | models. 141 | `request_body_many` whether response body contains array of given model 142 | (request.body_params then contains list of models i. e. List[BaseModel]) 143 | `response_by_alias` whether Pydantic's alias is used 144 | `get_json_params` - parameters to be passed to Request.get_json() function 145 | 146 | example:: 147 | 148 | from flask import request 149 | from flask_pydantic import validate 150 | from pydantic import BaseModel 151 | 152 | class Query(BaseModel): 153 | query: str 154 | 155 | class Body(BaseModel): 156 | color: str 157 | 158 | class Form(BaseModel): 159 | name: str 160 | 161 | class MyModel(BaseModel): 162 | id: int 163 | color: str 164 | description: str 165 | 166 | ... 167 | 168 | @app.route("/") 169 | @validate(query=Query, body=Body, form=Form) 170 | def test_route(): 171 | query = request.query_params.query 172 | color = request.body_params.query 173 | 174 | return MyModel(...) 175 | 176 | @app.route("/kwargs") 177 | @validate() 178 | def test_route_kwargs(query:Query, body:Body, form:Form): 179 | 180 | return MyModel(...) 181 | 182 | -> that will render JSON response with serialized MyModel instance 183 | """ 184 | 185 | def decorate(func: Callable) -> Callable: 186 | @wraps(func) 187 | def wrapper(*args, **kwargs): 188 | q, b, f, err = None, None, None, {} 189 | kwargs, path_err = validate_path_params(func, kwargs) 190 | if path_err: 191 | err["path_params"] = path_err 192 | query_in_kwargs = func.__annotations__.get("query") 193 | query_model = query_in_kwargs or query 194 | if query_model: 195 | query_params = convert_query_params(request.args, query_model) 196 | try: 197 | q = query_model(**query_params) 198 | except (ValidationError, V1ValidationError) as ve: 199 | err["query_params"] = ve.errors() 200 | body_in_kwargs = func.__annotations__.get("body") 201 | body_model = body_in_kwargs or body 202 | if body_model: 203 | body_params = get_body_dict(**(get_json_params or {})) 204 | if ( 205 | issubclass(body_model, V1BaseModel) 206 | and "__root__" in body_model.__fields__ 207 | ): 208 | try: 209 | b = body_model(__root__=body_params).__root__ 210 | except (ValidationError, V1ValidationError) as ve: 211 | err["body_params"] = ve.errors() 212 | elif issubclass(body_model, RootModel): 213 | try: 214 | b = body_model(body_params) 215 | except (ValidationError, V1ValidationError) as ve: 216 | err["body_params"] = ve.errors() 217 | elif request_body_many: 218 | try: 219 | b = validate_many_models(body_model, body_params) 220 | except ManyModelValidationError as e: 221 | err["body_params"] = e.errors() 222 | else: 223 | try: 224 | b = body_model(**body_params) 225 | except TypeError as te: 226 | content_type = request.headers.get("Content-Type", "").lower() 227 | media_type = content_type.split(";")[0] 228 | if media_type != "application/json": 229 | return unsupported_media_type_response(content_type) 230 | else: 231 | raise JsonBodyParsingError() from te 232 | except (ValidationError, V1ValidationError) as ve: 233 | err["body_params"] = ve.errors() 234 | form_in_kwargs = func.__annotations__.get("form") 235 | form_model = form_in_kwargs or form 236 | if form_model: 237 | form_params = request.form 238 | if ( 239 | isinstance(form, V1BaseModel) 240 | and "__root__" in form_model.__fields__ 241 | ): 242 | try: 243 | f = form_model(form_params) 244 | except (ValidationError, V1ValidationError) as ve: 245 | err["form_params"] = ve.errors() 246 | elif issubclass(form_model, RootModel): 247 | try: 248 | f = form_model(form_params) 249 | except (ValidationError, V1ValidationError) as ve: 250 | err["form_params"] = ve.errors() 251 | else: 252 | try: 253 | f = form_model(**form_params) 254 | except TypeError as te: 255 | content_type = request.headers.get("Content-Type", "").lower() 256 | media_type = content_type.split(";")[0] 257 | if media_type != "multipart/form-data": 258 | return unsupported_media_type_response(content_type) 259 | else: 260 | raise JsonBodyParsingError from te 261 | except (ValidationError, V1ValidationError) as ve: 262 | err["form_params"] = ve.errors() 263 | request.query_params = q 264 | request.body_params = b 265 | request.form_params = f 266 | if query_in_kwargs: 267 | kwargs["query"] = q 268 | if body_in_kwargs: 269 | kwargs["body"] = b 270 | if form_in_kwargs: 271 | kwargs["form"] = f 272 | 273 | if err: 274 | if current_app.config.get( 275 | "FLASK_PYDANTIC_VALIDATION_ERROR_RAISE", False 276 | ): 277 | raise FailedValidation(**err) 278 | else: 279 | status_code = current_app.config.get( 280 | "FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400 281 | ) 282 | return make_response( 283 | jsonify({"validation_error": err}), status_code 284 | ) 285 | res = func(*args, **kwargs) 286 | 287 | if response_many: 288 | if is_iterable_of_models(res): 289 | return make_json_response( 290 | res, 291 | on_success_status, 292 | by_alias=response_by_alias, 293 | exclude_none=exclude_none, 294 | many=True, 295 | ) 296 | else: 297 | raise InvalidIterableOfModelsException(res) 298 | 299 | if isinstance(res, (BaseModel, V1BaseModel)): 300 | return make_json_response( 301 | res, 302 | on_success_status, 303 | exclude_none=exclude_none, 304 | by_alias=response_by_alias, 305 | ) 306 | 307 | if ( 308 | isinstance(res, tuple) 309 | and len(res) in [2, 3] 310 | and isinstance(res[0], (BaseModel, V1BaseModel)) 311 | ): 312 | headers = None 313 | status = on_success_status 314 | if isinstance(res[1], (dict, tuple, list)): 315 | headers = res[1] 316 | elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)): 317 | status = res[1] 318 | headers = res[2] 319 | else: 320 | status = res[1] 321 | 322 | ret = make_json_response( 323 | res[0], 324 | status, 325 | exclude_none=exclude_none, 326 | by_alias=response_by_alias, 327 | ) 328 | if headers: 329 | ret.headers.update(headers) 330 | return ret 331 | 332 | return res 333 | 334 | return wrapper 335 | 336 | return decorate 337 | -------------------------------------------------------------------------------- /flask_pydantic/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | 4 | class BaseFlaskPydanticException(Exception): 5 | """Base exc class for all exception from this library""" 6 | 7 | pass 8 | 9 | 10 | class InvalidIterableOfModelsException(BaseFlaskPydanticException): 11 | """This exception is raised if there is a failure during serialization of 12 | response object with `response_many=True`""" 13 | 14 | pass 15 | 16 | 17 | class JsonBodyParsingError(BaseFlaskPydanticException): 18 | """Exception for error occurring during parsing of request body""" 19 | 20 | pass 21 | 22 | 23 | class ManyModelValidationError(BaseFlaskPydanticException): 24 | """This exception is raised if there is a failure during validation of many 25 | models in an iterable""" 26 | 27 | def __init__(self, errors: List[dict], *args): 28 | self._errors = errors 29 | super().__init__(*args) 30 | 31 | def errors(self): 32 | return self._errors 33 | 34 | 35 | class ValidationError(BaseFlaskPydanticException): 36 | """This exception is raised if there is a failure during validation if the 37 | user has configured an exception to be raised instead of a response""" 38 | 39 | def __init__( 40 | self, 41 | body_params: Optional[List[dict]] = None, 42 | form_params: Optional[List[dict]] = None, 43 | path_params: Optional[List[dict]] = None, 44 | query_params: Optional[List[dict]] = None, 45 | ): 46 | super().__init__() 47 | self.body_params = body_params 48 | self.form_params = form_params 49 | self.path_params = path_params 50 | self.query_params = query_params 51 | -------------------------------------------------------------------------------- /flask_pydantic/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.13.1" 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "Flask-Pydantic" 3 | version = "0.13.1" 4 | description = "Flask extension for integration with Pydantic library." 5 | readme = "README.md" 6 | license = { file = "LICENSE.txt" } 7 | authors = [{ name = "Jiri Bauer" }] 8 | maintainers = [{ name = "Pallets", email = "contact@palletsprojects.com" }] 9 | classifiers = [ 10 | "Environment :: Web Environment", 11 | "Framework :: Flask", 12 | "Intended Audience :: Developers", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: OS Independent", 15 | "Programming Language :: Python", 16 | "Topic :: Internet :: WWW/HTTP :: Dynamic Content", 17 | "Topic :: Software Development :: Libraries :: Python Modules", 18 | ] 19 | requires-python = ">=3.7" 20 | dependencies = [ 21 | "Flask", 22 | "pydantic>=2.0", 23 | "typing_extensions>=4.1.1; python_version < '3.8'" 24 | ] 25 | 26 | [project.urls] 27 | Donate = "https://palletsprojects.com/donate" 28 | Source = "https://github.com/pallets-eco/flask-pydantic" 29 | Chat = "https://discord.gg/pallets" 30 | 31 | [build-system] 32 | requires = ["flit_core<4"] 33 | build-backend = "flit_core.buildapi" 34 | 35 | [tool.flit.module] 36 | name = "flask_pydantic" 37 | 38 | [tool.pytest] 39 | testpaths = "tests" 40 | addopts = "-vv --ruff --ruff-format --cov --cov-config=pyproject.toml -s" 41 | 42 | [tool.ruff] 43 | src = ["flask_pydantic"] 44 | lint.select = [ 45 | "B", # flake8-bugbear 46 | "E", # pycodestyle error 47 | "F", # pyflakes 48 | "I", # isort 49 | "UP", # pyupgrade 50 | "W", # pycodestyle warning 51 | ] 52 | lint.ignore = ["E501"] 53 | 54 | [tool.coverage.run] 55 | branch = true 56 | omit = [ 57 | "example/*" 58 | ] 59 | include = [ 60 | "flask_pydantic/*" 61 | ] 62 | 63 | [tool.coverage.report] 64 | show_missing = true 65 | skip_covered = true 66 | -------------------------------------------------------------------------------- /requirements/build.txt: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-flask 3 | pytest-coverage 4 | pytest-mock 5 | pytest-ruff 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pallets-eco/flask-pydantic/ccdd2a2c9816012c440491af735fbaaa54f57c6b/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Type 2 | 3 | import pytest 4 | from flask import Flask, request 5 | from flask_pydantic import validate 6 | from pydantic import BaseModel 7 | 8 | 9 | @pytest.fixture 10 | def posts() -> List[dict]: 11 | return [ 12 | {"title": "title 1", "text": "random text", "views": 1}, 13 | {"title": "2", "text": "another text", "views": 2}, 14 | {"title": "3", "text": "longer text than usual", "views": 4}, 15 | {"title": "title 13", "text": "nothing", "views": 5}, 16 | ] 17 | 18 | 19 | @pytest.fixture 20 | def query_model() -> Type[BaseModel]: 21 | class Query(BaseModel): 22 | limit: int = 2 23 | min_views: Optional[int] = None 24 | 25 | return Query 26 | 27 | 28 | @pytest.fixture 29 | def body_model() -> Type[BaseModel]: 30 | class Body(BaseModel): 31 | search_term: str 32 | exclude: Optional[str] = None 33 | 34 | return Body 35 | 36 | 37 | @pytest.fixture 38 | def form_model() -> Type[BaseModel]: 39 | class Form(BaseModel): 40 | search_term: str 41 | exclude: Optional[str] = None 42 | 43 | return Form 44 | 45 | 46 | @pytest.fixture 47 | def post_model() -> Type[BaseModel]: 48 | class Post(BaseModel): 49 | title: str 50 | text: str 51 | views: int 52 | 53 | return Post 54 | 55 | 56 | @pytest.fixture 57 | def response_model(post_model: BaseModel) -> Type[BaseModel]: 58 | class Response(BaseModel): 59 | results: List[post_model] 60 | count: int 61 | 62 | return Response 63 | 64 | 65 | @pytest.fixture 66 | def request_ctx(app): 67 | with app.test_request_context() as ctx: 68 | yield ctx 69 | 70 | 71 | def is_excluded(post: dict, exclude: Optional[str] = None) -> bool: 72 | if exclude is None: 73 | return False 74 | return exclude in post["title"] or exclude in post["text"] 75 | 76 | 77 | def pass_search( 78 | post: dict, 79 | search_term: str, 80 | exclude: Optional[str] = None, 81 | min_views: Optional[int] = None, 82 | ) -> bool: 83 | return ( 84 | (search_term in post["title"] or search_term in post["text"]) 85 | and not is_excluded(post, exclude) 86 | and (min_views is None or post["views"] >= min_views) 87 | ) 88 | 89 | 90 | @pytest.fixture 91 | def app(posts, response_model, query_model, body_model, post_model, form_model): 92 | app = Flask("test_app") 93 | app.config["DEBUG"] = True 94 | app.config["TESTING"] = True 95 | 96 | @app.route("/search", methods=["POST"]) 97 | @validate(query=query_model, body=body_model) 98 | def post(): 99 | query_params = request.query_params 100 | body = request.body_params 101 | results = [ 102 | post_model(**p) 103 | for p in posts 104 | if pass_search(p, body.search_term, body.exclude, query_params.min_views) 105 | ] 106 | return response_model(results=results[: query_params.limit], count=len(results)) 107 | 108 | @app.route("/search/kwargs", methods=["POST"]) 109 | @validate() 110 | def post_kwargs(query: query_model, body: body_model): 111 | results = [ 112 | post_model(**p) 113 | for p in posts 114 | if pass_search(p, body.search_term, body.exclude, query.min_views) 115 | ] 116 | return response_model(results=results[: query.limit], count=len(results)) 117 | 118 | @app.route("/search/form/kwargs", methods=["POST"]) 119 | @validate() 120 | def post_kwargs_form(query: query_model, form: form_model): 121 | results = [ 122 | post_model(**p) 123 | for p in posts 124 | if pass_search(p, form.search_term, form.exclude, query.min_views) 125 | ] 126 | return response_model(results=results[: query.limit], count=len(results)) 127 | 128 | return app 129 | -------------------------------------------------------------------------------- /tests/func/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pallets-eco/flask-pydantic/ccdd2a2c9816012c440491af735fbaaa54f57c6b/tests/func/__init__.py -------------------------------------------------------------------------------- /tests/func/test_app.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Optional 3 | 4 | import pytest 5 | from flask import jsonify, request 6 | from flask_pydantic import ValidationError, validate 7 | from pydantic import BaseModel, ConfigDict, RootModel 8 | 9 | from ..util import assert_matches 10 | 11 | 12 | class ArrayModel(BaseModel): 13 | arr1: List[str] 14 | arr2: Optional[List[int]] = None 15 | 16 | 17 | @pytest.fixture 18 | def app_with_array_route(app): 19 | @app.route("/arr", methods=["GET"]) 20 | @validate(query=ArrayModel, exclude_none=True) 21 | def pass_array(): 22 | return ArrayModel( 23 | arr1=request.query_params.arr1, arr2=request.query_params.arr2 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def app_with_optional_body(app): 29 | class Body(BaseModel): 30 | param: str 31 | 32 | @app.route("/no_params", methods=["POST"]) 33 | @validate() 34 | def no_params(body: Body): 35 | return body 36 | 37 | @app.route("/silent", methods=["POST"]) 38 | @validate(get_json_params={"silent": True}) 39 | def silent(body: Body): 40 | return body 41 | 42 | 43 | @pytest.fixture 44 | def app_raise_on_validation_error(app): 45 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 46 | 47 | def validation_error(error: ValidationError): 48 | return ( 49 | jsonify( 50 | { 51 | "title": "validation error", 52 | "body": error.body_params, 53 | } 54 | ), 55 | 422, 56 | ) 57 | 58 | app.register_error_handler(ValidationError, validation_error) 59 | 60 | class Body(BaseModel): 61 | param: str 62 | 63 | @app.route("/silent", methods=["POST"]) 64 | @validate(get_json_params={"silent": True}) 65 | def silent(body: Body): 66 | return body 67 | 68 | 69 | @pytest.fixture 70 | def app_with_int_path_param_route(app): 71 | class IdObj(BaseModel): 72 | id: int 73 | 74 | @app.route("/path_param//", methods=["GET"]) 75 | @validate() 76 | def int_path_param(obj_id: int): 77 | return IdObj(id=obj_id) 78 | 79 | 80 | @pytest.fixture 81 | def app_with_untyped_path_param_route(app): 82 | class IdObj(BaseModel): 83 | id: str 84 | 85 | @app.route("/path_param//", methods=["GET"]) 86 | @validate() 87 | def int_path_param(obj_id): 88 | return IdObj(id=obj_id) 89 | 90 | 91 | @pytest.fixture 92 | def app_with_custom_root_type(app): 93 | class Person(BaseModel): 94 | name: str 95 | age: Optional[int] = None 96 | 97 | class PersonBulk(RootModel): 98 | root: List[Person] 99 | 100 | def __len__(self): 101 | return len(self.root) 102 | 103 | @app.route("/root_type", methods=["POST"]) 104 | @validate() 105 | def root_type(body: PersonBulk): 106 | return {"number": len(body)} 107 | 108 | 109 | @pytest.fixture 110 | def app_with_custom_headers(app): 111 | @app.route("/custom_headers", methods=["GET"]) 112 | @validate() 113 | def custom_headers(): 114 | return {"test": 1}, {"CUSTOM_HEADER": "UNIQUE"} 115 | 116 | 117 | @pytest.fixture 118 | def app_with_custom_headers_status(app): 119 | @app.route("/custom_headers_status", methods=["GET"]) 120 | @validate() 121 | def custom_headers(): 122 | return {"test": 1}, 201, {"CUSTOM_HEADER": "UNIQUE"} 123 | 124 | 125 | @pytest.fixture 126 | def app_with_camel_route(app): 127 | def to_camel(x: str) -> str: 128 | first, *rest = x.split("_") 129 | return "".join([first] + [x.capitalize() for x in rest]) 130 | 131 | class RequestModel(BaseModel): 132 | x: int 133 | y: int 134 | 135 | class ResultModel(BaseModel): 136 | model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) 137 | 138 | result_of_addition: int 139 | result_of_multiplication: int 140 | 141 | @app.route("/compute", methods=["GET"]) 142 | @validate(response_by_alias=True) 143 | def compute(query: RequestModel): 144 | return ResultModel( 145 | result_of_addition=query.x + query.y, 146 | result_of_multiplication=query.x * query.y, 147 | ) 148 | 149 | 150 | test_cases = [ 151 | pytest.param( 152 | "?limit=limit", 153 | {"search_term": "text"}, 154 | 400, 155 | { 156 | "validation_error": { 157 | "query_params": [ 158 | { 159 | "input": "limit", 160 | "loc": ["limit"], 161 | "msg": "Input should be a valid integer, unable to parse string as an integer", 162 | "type": "int_parsing", 163 | "url": re.compile( 164 | r"https://errors\.pydantic\.dev/.*/v/int_parsing" 165 | ), 166 | } 167 | ], 168 | } 169 | }, 170 | id="invalid limit", 171 | ), 172 | pytest.param( 173 | "?limit=2", 174 | {}, 175 | 400, 176 | { 177 | "validation_error": { 178 | "body_params": [ 179 | { 180 | "input": {}, 181 | "loc": ["search_term"], 182 | "msg": "Field required", 183 | "type": "missing", 184 | "url": re.compile( 185 | r"https://errors\.pydantic\.dev/.*/v/missing" 186 | ), 187 | } 188 | ], 189 | } 190 | }, 191 | id="missing required body parameter", 192 | ), 193 | pytest.param( 194 | "?limit=1&min_views=2", 195 | {"search_term": "text"}, 196 | 200, 197 | {"count": 2, "results": [{"title": "2", "text": "another text", "views": 2}]}, 198 | id="valid parameters", 199 | ), 200 | pytest.param( 201 | "", 202 | {"search_term": "text"}, 203 | 200, 204 | { 205 | "count": 3, 206 | "results": [ 207 | {"title": "title 1", "text": "random text", "views": 1}, 208 | {"title": "2", "text": "another text", "views": 2}, 209 | ], 210 | }, 211 | id="valid params, no query", 212 | ), 213 | ] 214 | 215 | form_test_cases = [ 216 | pytest.param( 217 | "?limit=2", 218 | {}, 219 | 400, 220 | { 221 | "validation_error": { 222 | "form_params": [ 223 | { 224 | "input": {}, 225 | "loc": ["search_term"], 226 | "msg": "Field required", 227 | "type": "missing", 228 | "url": re.compile( 229 | r"https://errors\.pydantic\.dev/.*/v/missing" 230 | ), 231 | } 232 | ] 233 | } 234 | }, 235 | id="missing required form parameter", 236 | ), 237 | pytest.param( 238 | "?limit=1&min_views=2", 239 | {"search_term": "text"}, 240 | 200, 241 | {"count": 2, "results": [{"title": "2", "text": "another text", "views": 2}]}, 242 | id="valid parameters", 243 | ), 244 | pytest.param( 245 | "", 246 | {"search_term": "text"}, 247 | 200, 248 | { 249 | "count": 3, 250 | "results": [ 251 | {"title": "title 1", "text": "random text", "views": 1}, 252 | {"title": "2", "text": "another text", "views": 2}, 253 | ], 254 | }, 255 | id="valid params, no query", 256 | ), 257 | ] 258 | 259 | 260 | class TestSimple: 261 | @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) 262 | def test_post(self, client, query, body, expected_status, expected_response): 263 | response = client.post(f"/search{query}", json=body) 264 | assert_matches(expected_response, response.json) 265 | assert response.status_code == expected_status 266 | 267 | @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) 268 | def test_post_kwargs(self, client, query, body, expected_status, expected_response): 269 | response = client.post(f"/search/kwargs{query}", json=body) 270 | assert_matches(expected_response, response.json) 271 | assert response.status_code == expected_status 272 | 273 | @pytest.mark.parametrize( 274 | "query,form,expected_status,expected_response", form_test_cases 275 | ) 276 | def test_post_kwargs_form( 277 | self, client, query, form, expected_status, expected_response 278 | ): 279 | response = client.post( 280 | f"/search/form/kwargs{query}", 281 | data=form, 282 | ) 283 | assert_matches(expected_response, response.json) 284 | assert response.status_code == expected_status 285 | 286 | def test_error_status_code(self, app, mocker, client): 287 | mocker.patch.dict( 288 | app.config, {"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE": 422} 289 | ) 290 | response = client.post("/search?limit=2", json={}) 291 | assert response.status_code == 422 292 | 293 | 294 | @pytest.mark.usefixtures("app_with_custom_root_type") 295 | def test_custom_root_types(client): 296 | response = client.post( 297 | "/root_type", 298 | json=[{"name": "Joshua Bardwell", "age": 46}, {"name": "Andrew Cambden"}], 299 | ) 300 | assert response.json == {"number": 2} 301 | 302 | 303 | @pytest.mark.usefixtures("app_with_custom_headers") 304 | def test_custom_headers(client): 305 | response = client.get("/custom_headers") 306 | assert response.json == {"test": 1} 307 | assert response.status_code == 200 308 | assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" 309 | 310 | 311 | @pytest.mark.usefixtures("app_with_custom_headers_status") 312 | def test_custom_headers_status(client): 313 | response = client.get("/custom_headers_status") 314 | assert response.json == {"test": 1} 315 | assert response.status_code == 201 316 | assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" 317 | 318 | 319 | @pytest.mark.usefixtures("app_with_array_route") 320 | class TestArrayQueryParam: 321 | def test_no_param_raises(self, client): 322 | response = client.get("/arr") 323 | assert_matches( 324 | { 325 | "validation_error": { 326 | "query_params": [ 327 | { 328 | "input": {}, 329 | "loc": ["arr1"], 330 | "msg": "Field required", 331 | "type": "missing", 332 | "url": re.compile( 333 | r"https://errors\.pydantic\.dev/.*/v/missing" 334 | ), 335 | } 336 | ] 337 | } 338 | }, 339 | response.json, 340 | ) 341 | 342 | def test_correctly_returns_first_arr(self, client): 343 | response = client.get("/arr?arr1=first&arr1=second") 344 | assert response.json == {"arr1": ["first", "second"]} 345 | 346 | def test_correctly_returns_first_arr_one_element(self, client): 347 | response = client.get("/arr?arr1=first") 348 | assert response.json == {"arr1": ["first"]} 349 | 350 | def test_correctly_returns_both_arrays(self, client): 351 | response = client.get("/arr?arr1=first&arr1=second&arr2=1&arr2=10") 352 | assert response.json == {"arr1": ["first", "second"], "arr2": [1, 10]} 353 | 354 | 355 | aliases_test_cases = [ 356 | pytest.param(1, 2, {"resultOfMultiplication": 2, "resultOfAddition": 3}), 357 | pytest.param(10, 20, {"resultOfMultiplication": 200, "resultOfAddition": 30}), 358 | pytest.param(999, 0, {"resultOfMultiplication": 0, "resultOfAddition": 999}), 359 | ] 360 | 361 | 362 | @pytest.mark.usefixtures("app_with_camel_route") 363 | @pytest.mark.parametrize("x,y,expected_result", aliases_test_cases) 364 | def test_aliases(x, y, expected_result, client): 365 | response = client.get(f"/compute?x={x}&y={y}") 366 | assert_matches(expected_result, response.json) 367 | 368 | 369 | @pytest.mark.usefixtures("app_with_int_path_param_route") 370 | class TestPathIntParameter: 371 | def test_correct_param_passes(self, client): 372 | id_ = 12 373 | expected_response = {"id": id_} 374 | response = client.get(f"/path_param/{id_}/") 375 | assert_matches(expected_response, response.json) 376 | 377 | def test_string_parameter(self, client): 378 | expected_response = { 379 | "validation_error": { 380 | "path_params": [ 381 | { 382 | "input": "not_an_int", 383 | "loc": ["obj_id"], 384 | "msg": "Input should be a valid integer, unable to parse string as an integer", 385 | "type": "int_parsing", 386 | "url": re.compile( 387 | r"https://errors\.pydantic\.dev/.*/v/int_parsing" 388 | ), 389 | } 390 | ] 391 | } 392 | } 393 | response = client.get("/path_param/not_an_int/") 394 | 395 | assert_matches(expected_response, response.json) 396 | assert response.status_code == 400 397 | 398 | 399 | @pytest.mark.usefixtures("app_with_untyped_path_param_route") 400 | class TestPathUnannotatedParameter: 401 | def test_int_str_param_passes(self, client): 402 | id_ = 12 403 | expected_response = {"id": str(id_)} 404 | response = client.get(f"/path_param/{id_}/") 405 | 406 | assert_matches(expected_response, response.json) 407 | 408 | def test_str_param_passes(self, client): 409 | id_ = "twelve" 410 | expected_response = {"id": id_} 411 | response = client.get(f"/path_param/{id_}/") 412 | 413 | assert_matches(expected_response, response.json) 414 | 415 | 416 | @pytest.mark.usefixtures("app_with_optional_body") 417 | class TestGetJsonParams: 418 | def test_empty_body_fails(self, client): 419 | response = client.post( 420 | "/no_params", headers={"Content-Type": "application/json"} 421 | ) 422 | 423 | assert response.status_code == 400 424 | assert ( 425 | "failed to decode json object: expecting value: line 1 column 1 (char 0)" 426 | in response.text.lower() 427 | ) 428 | 429 | def test_silent(self, client): 430 | response = client.post("/silent", headers={"Content-Type": "application/json"}) 431 | 432 | assert_matches( 433 | { 434 | "validation_error": { 435 | "body_params": [ 436 | { 437 | "input": {}, 438 | "loc": ["param"], 439 | "msg": "Field required", 440 | "type": "missing", 441 | "url": re.compile( 442 | r"https://errors\.pydantic\.dev/.*/v/missing" 443 | ), 444 | } 445 | ] 446 | } 447 | }, 448 | response.json, 449 | ) 450 | assert response.status_code == 400 451 | 452 | 453 | @pytest.mark.usefixtures("app_raise_on_validation_error") 454 | class TestCustomResponse: 455 | def test_silent(self, client): 456 | response = client.post("/silent", headers={"Content-Type": "application/json"}) 457 | 458 | assert response.json["title"] == "validation error" 459 | assert_matches( 460 | [ 461 | { 462 | "input": {}, 463 | "loc": ["param"], 464 | "msg": "Field required", 465 | "type": "missing", 466 | "url": re.compile(r"https://errors\.pydantic\.dev/.*/v/missing"), 467 | } 468 | ], 469 | response.json["body"], 470 | ) 471 | assert response.status_code == 422 472 | -------------------------------------------------------------------------------- /tests/pydantic_v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pallets-eco/flask-pydantic/ccdd2a2c9816012c440491af735fbaaa54f57c6b/tests/pydantic_v1/__init__.py -------------------------------------------------------------------------------- /tests/pydantic_v1/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Specific confest.py file for testing behavior with Pydantic V1. 3 | 4 | The fixtures below override the confest.py's fixtures for this module only. 5 | """ 6 | 7 | from typing import List, Optional, Type 8 | 9 | import pytest 10 | from pydantic.v1 import BaseModel 11 | 12 | 13 | @pytest.fixture 14 | def query_model() -> Type[BaseModel]: 15 | class Query(BaseModel): 16 | limit: int = 2 17 | min_views: Optional[int] = None 18 | 19 | return Query 20 | 21 | 22 | @pytest.fixture 23 | def body_model() -> Type[BaseModel]: 24 | class Body(BaseModel): 25 | search_term: str 26 | exclude: Optional[str] = None 27 | 28 | return Body 29 | 30 | 31 | @pytest.fixture 32 | def form_model() -> Type[BaseModel]: 33 | class Form(BaseModel): 34 | search_term: str 35 | exclude: Optional[str] = None 36 | 37 | return Form 38 | 39 | 40 | @pytest.fixture 41 | def post_model() -> Type[BaseModel]: 42 | class Post(BaseModel): 43 | title: str 44 | text: str 45 | views: int 46 | 47 | return Post 48 | 49 | 50 | @pytest.fixture 51 | def response_model(post_model: BaseModel) -> Type[BaseModel]: 52 | class Response(BaseModel): 53 | results: List[post_model] 54 | count: int 55 | 56 | return Response 57 | -------------------------------------------------------------------------------- /tests/pydantic_v1/func/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pallets-eco/flask-pydantic/ccdd2a2c9816012c440491af735fbaaa54f57c6b/tests/pydantic_v1/func/__init__.py -------------------------------------------------------------------------------- /tests/pydantic_v1/func/test_app.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Optional 3 | 4 | import pytest 5 | from flask import jsonify, request 6 | from flask_pydantic import ValidationError, validate 7 | from pydantic.v1 import BaseModel 8 | 9 | from ...util import assert_matches 10 | 11 | 12 | class ArrayModel(BaseModel): 13 | arr1: List[str] 14 | arr2: Optional[List[int]] = None 15 | 16 | 17 | @pytest.fixture 18 | def app_with_array_route(app): 19 | @app.route("/arr", methods=["GET"]) 20 | @validate(query=ArrayModel, exclude_none=True) 21 | def pass_array(): 22 | return ArrayModel( 23 | arr1=request.query_params.arr1, arr2=request.query_params.arr2 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def app_with_optional_body(app): 29 | class Body(BaseModel): 30 | param: str 31 | 32 | @app.route("/no_params", methods=["POST"]) 33 | @validate() 34 | def no_params(body: Body): 35 | return body 36 | 37 | @app.route("/silent", methods=["POST"]) 38 | @validate(get_json_params={"silent": True}) 39 | def silent(body: Body): 40 | return body 41 | 42 | 43 | @pytest.fixture 44 | def app_raise_on_validation_error(app): 45 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 46 | 47 | def validation_error(error: ValidationError): 48 | return ( 49 | jsonify( 50 | { 51 | "title": "validation error", 52 | "body": error.body_params, 53 | } 54 | ), 55 | 422, 56 | ) 57 | 58 | app.register_error_handler(ValidationError, validation_error) 59 | 60 | class Body(BaseModel): 61 | param: str 62 | 63 | @app.route("/silent", methods=["POST"]) 64 | @validate(get_json_params={"silent": True}) 65 | def silent(body: Body): 66 | return body 67 | 68 | 69 | @pytest.fixture 70 | def app_with_int_path_param_route(app): 71 | class IdObj(BaseModel): 72 | id: int 73 | 74 | @app.route("/path_param//", methods=["GET"]) 75 | @validate() 76 | def int_path_param(obj_id: int): 77 | return IdObj(id=obj_id) 78 | 79 | 80 | @pytest.fixture 81 | def app_with_untyped_path_param_route(app): 82 | class IdObj(BaseModel): 83 | id: str 84 | 85 | @app.route("/path_param//", methods=["GET"]) 86 | @validate() 87 | def int_path_param(obj_id): 88 | return IdObj(id=obj_id) 89 | 90 | 91 | @pytest.fixture 92 | def app_with_custom_root_type(app): 93 | class Person(BaseModel): 94 | name: str 95 | age: Optional[int] = None 96 | 97 | class PersonBulk(BaseModel): 98 | __root__: List[Person] 99 | 100 | def __len__(self): 101 | return len(self.root) 102 | 103 | @app.route("/root_type", methods=["POST"]) 104 | @validate() 105 | def root_type(body: PersonBulk): 106 | return {"number": len(body)} 107 | 108 | 109 | @pytest.fixture 110 | def app_with_custom_headers(app): 111 | @app.route("/custom_headers", methods=["GET"]) 112 | @validate() 113 | def custom_headers(): 114 | return {"test": 1}, {"CUSTOM_HEADER": "UNIQUE"} 115 | 116 | 117 | @pytest.fixture 118 | def app_with_custom_headers_status(app): 119 | @app.route("/custom_headers_status", methods=["GET"]) 120 | @validate() 121 | def custom_headers(): 122 | return {"test": 1}, 201, {"CUSTOM_HEADER": "UNIQUE"} 123 | 124 | 125 | @pytest.fixture 126 | def app_with_camel_route(app): 127 | def to_camel(x: str) -> str: 128 | first, *rest = x.split("_") 129 | return "".join([first] + [x.capitalize() for x in rest]) 130 | 131 | class RequestModel(BaseModel): 132 | x: int 133 | y: int 134 | 135 | class ResultModel(BaseModel): 136 | result_of_addition: int 137 | result_of_multiplication: int 138 | 139 | class Config: 140 | alias_generator = to_camel 141 | allow_population_by_field_name = True 142 | 143 | @app.route("/compute", methods=["GET"]) 144 | @validate(response_by_alias=True) 145 | def compute(query: RequestModel): 146 | return ResultModel( 147 | result_of_addition=query.x + query.y, 148 | result_of_multiplication=query.x * query.y, 149 | ) 150 | 151 | 152 | test_cases = [ 153 | pytest.param( 154 | "?limit=limit", 155 | {"search_term": "text"}, 156 | 400, 157 | { 158 | "validation_error": { 159 | "query_params": [ 160 | { 161 | "loc": ["limit"], 162 | "msg": "value is not a valid integer", 163 | "type": "type_error.integer", 164 | } 165 | ] 166 | } 167 | }, 168 | id="invalid limit", 169 | ), 170 | pytest.param( 171 | "?limit=2", 172 | {}, 173 | 400, 174 | { 175 | "validation_error": { 176 | "body_params": [ 177 | { 178 | "loc": ["search_term"], 179 | "msg": "field required", 180 | "type": "value_error.missing", 181 | } 182 | ] 183 | } 184 | }, 185 | id="missing required body parameter", 186 | ), 187 | pytest.param( 188 | "?limit=1&min_views=2", 189 | {"search_term": "text"}, 190 | 200, 191 | {"count": 2, "results": [{"title": "2", "text": "another text", "views": 2}]}, 192 | id="valid parameters", 193 | ), 194 | pytest.param( 195 | "", 196 | {"search_term": "text"}, 197 | 200, 198 | { 199 | "count": 3, 200 | "results": [ 201 | {"title": "title 1", "text": "random text", "views": 1}, 202 | {"title": "2", "text": "another text", "views": 2}, 203 | ], 204 | }, 205 | id="valid params, no query", 206 | ), 207 | ] 208 | 209 | form_test_cases = [ 210 | pytest.param( 211 | "?limit=2", 212 | {}, 213 | 400, 214 | { 215 | "validation_error": { 216 | "form_params": [ 217 | { 218 | "loc": ["search_term"], 219 | "msg": "field required", 220 | "type": "value_error.missing", 221 | } 222 | ] 223 | } 224 | }, 225 | id="missing required form parameter", 226 | ), 227 | pytest.param( 228 | "?limit=1&min_views=2", 229 | {"search_term": "text"}, 230 | 200, 231 | {"count": 2, "results": [{"title": "2", "text": "another text", "views": 2}]}, 232 | id="valid parameters", 233 | ), 234 | pytest.param( 235 | "", 236 | {"search_term": "text"}, 237 | 200, 238 | { 239 | "count": 3, 240 | "results": [ 241 | {"title": "title 1", "text": "random text", "views": 1}, 242 | {"title": "2", "text": "another text", "views": 2}, 243 | ], 244 | }, 245 | id="valid params, no query", 246 | ), 247 | ] 248 | 249 | 250 | class TestSimple: 251 | @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) 252 | def test_post(self, client, query, body, expected_status, expected_response): 253 | response = client.post(f"/search{query}", json=body) 254 | assert_matches(expected_response, response.json) 255 | assert response.status_code == expected_status 256 | 257 | @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) 258 | def test_post_kwargs(self, client, query, body, expected_status, expected_response): 259 | response = client.post(f"/search/kwargs{query}", json=body) 260 | assert_matches(expected_response, response.json) 261 | assert response.status_code == expected_status 262 | 263 | @pytest.mark.parametrize( 264 | "query,form,expected_status,expected_response", form_test_cases 265 | ) 266 | def test_post_kwargs_form( 267 | self, client, query, form, expected_status, expected_response 268 | ): 269 | response = client.post( 270 | f"/search/form/kwargs{query}", 271 | data=form, 272 | ) 273 | assert_matches(expected_response, response.json) 274 | assert response.status_code == expected_status 275 | 276 | def test_error_status_code(self, app, mocker, client): 277 | mocker.patch.dict( 278 | app.config, {"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE": 422} 279 | ) 280 | response = client.post("/search?limit=2", json={}) 281 | assert response.status_code == 422 282 | 283 | 284 | @pytest.mark.usefixtures("app_with_custom_root_type") 285 | def test_custom_root_types(client): 286 | response = client.post( 287 | "/root_type", 288 | json=[{"name": "Joshua Bardwell", "age": 46}, {"name": "Andrew Cambden"}], 289 | ) 290 | assert response.json == {"number": 2} 291 | 292 | 293 | @pytest.mark.usefixtures("app_with_custom_headers") 294 | def test_custom_headers(client): 295 | response = client.get("/custom_headers") 296 | assert response.json == {"test": 1} 297 | assert response.status_code == 200 298 | assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" 299 | 300 | 301 | @pytest.mark.usefixtures("app_with_custom_headers_status") 302 | def test_custom_headers_status(client): 303 | response = client.get("/custom_headers_status") 304 | assert response.json == {"test": 1} 305 | assert response.status_code == 201 306 | assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" 307 | 308 | 309 | @pytest.mark.usefixtures("app_with_array_route") 310 | class TestArrayQueryParam: 311 | def test_no_param_raises(self, client): 312 | response = client.get("/arr") 313 | assert_matches( 314 | { 315 | "validation_error": { 316 | "query_params": [ 317 | { 318 | "loc": ["arr1"], 319 | "msg": "field required", 320 | "type": "value_error.missing", 321 | } 322 | ] 323 | } 324 | }, 325 | response.json, 326 | ) 327 | 328 | def test_correctly_returns_first_arr(self, client): 329 | response = client.get("/arr?arr1=first&arr1=second") 330 | assert response.json == {"arr1": ["first", "second"]} 331 | 332 | def test_correctly_returns_first_arr_one_element(self, client): 333 | response = client.get("/arr?arr1=first") 334 | assert response.json == {"arr1": ["first"]} 335 | 336 | def test_correctly_returns_both_arrays(self, client): 337 | response = client.get("/arr?arr1=first&arr1=second&arr2=1&arr2=10") 338 | assert response.json == {"arr1": ["first", "second"], "arr2": [1, 10]} 339 | 340 | 341 | aliases_test_cases = [ 342 | pytest.param(1, 2, {"resultOfMultiplication": 2, "resultOfAddition": 3}), 343 | pytest.param(10, 20, {"resultOfMultiplication": 200, "resultOfAddition": 30}), 344 | pytest.param(999, 0, {"resultOfMultiplication": 0, "resultOfAddition": 999}), 345 | ] 346 | 347 | 348 | @pytest.mark.usefixtures("app_with_camel_route") 349 | @pytest.mark.parametrize("x,y,expected_result", aliases_test_cases) 350 | def test_aliases(x, y, expected_result, client): 351 | response = client.get(f"/compute?x={x}&y={y}") 352 | assert_matches(expected_result, response.json) 353 | 354 | 355 | @pytest.mark.usefixtures("app_with_int_path_param_route") 356 | class TestPathIntParameter: 357 | def test_correct_param_passes(self, client): 358 | id_ = 12 359 | expected_response = {"id": id_} 360 | response = client.get(f"/path_param/{id_}/") 361 | assert_matches(expected_response, response.json) 362 | 363 | def test_string_parameter(self, client): 364 | expected_response = { 365 | "validation_error": { 366 | "path_params": [ 367 | { 368 | "input": "not_an_int", 369 | "loc": ["obj_id"], 370 | "msg": "Input should be a valid integer, unable to parse string as an integer", 371 | "type": "int_parsing", 372 | "url": re.compile( 373 | r"https://errors\.pydantic\.dev/.*/v/int_parsing" 374 | ), 375 | } 376 | ] 377 | } 378 | } 379 | response = client.get("/path_param/not_an_int/") 380 | 381 | assert_matches(expected_response, response.json) 382 | assert response.status_code == 400 383 | 384 | 385 | @pytest.mark.usefixtures("app_with_untyped_path_param_route") 386 | class TestPathUnannotatedParameter: 387 | def test_int_str_param_passes(self, client): 388 | id_ = 12 389 | expected_response = {"id": str(id_)} 390 | response = client.get(f"/path_param/{id_}/") 391 | 392 | assert_matches(expected_response, response.json) 393 | 394 | def test_str_param_passes(self, client): 395 | id_ = "twelve" 396 | expected_response = {"id": id_} 397 | response = client.get(f"/path_param/{id_}/") 398 | 399 | assert_matches(expected_response, response.json) 400 | 401 | 402 | @pytest.mark.usefixtures("app_with_optional_body") 403 | class TestGetJsonParams: 404 | def test_empty_body_fails(self, client): 405 | response = client.post( 406 | "/no_params", headers={"Content-Type": "application/json"} 407 | ) 408 | 409 | assert response.status_code == 400 410 | assert ( 411 | "failed to decode json object: expecting value: line 1 column 1 (char 0)" 412 | in response.text.lower() 413 | ) 414 | 415 | def test_silent(self, client): 416 | response = client.post("/silent", headers={"Content-Type": "application/json"}) 417 | 418 | assert_matches( 419 | { 420 | "validation_error": { 421 | "body_params": [ 422 | { 423 | "loc": ["param"], 424 | "msg": "field required", 425 | "type": "value_error.missing", 426 | } 427 | ] 428 | } 429 | }, 430 | response.json, 431 | ) 432 | assert response.status_code == 400 433 | 434 | 435 | @pytest.mark.usefixtures("app_raise_on_validation_error") 436 | class TestCustomResponse: 437 | def test_silent(self, client): 438 | response = client.post("/silent", headers={"Content-Type": "application/json"}) 439 | 440 | assert response.json["title"] == "validation error" 441 | assert_matches( 442 | [ 443 | { 444 | "loc": ["param"], 445 | "msg": "field required", 446 | "type": "value_error.missing", 447 | } 448 | ], 449 | response.json["body"], 450 | ) 451 | assert response.status_code == 422 452 | -------------------------------------------------------------------------------- /tests/pydantic_v1/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pallets-eco/flask-pydantic/ccdd2a2c9816012c440491af735fbaaa54f57c6b/tests/pydantic_v1/unit/__init__.py -------------------------------------------------------------------------------- /tests/pydantic_v1/unit/test_core.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, NamedTuple, Optional, Type, Union 2 | 3 | import pytest 4 | from flask import jsonify 5 | from flask_pydantic import ValidationError, validate 6 | from flask_pydantic.core import convert_query_params, is_iterable_of_models 7 | from flask_pydantic.exceptions import ( 8 | InvalidIterableOfModelsException, 9 | JsonBodyParsingError, 10 | ) 11 | from pydantic.v1 import BaseModel 12 | from werkzeug.datastructures import ImmutableMultiDict 13 | 14 | from ...util import assert_matches 15 | 16 | 17 | class ValidateParams(NamedTuple): 18 | body_model: Optional[Type[BaseModel]] = None 19 | query_model: Optional[Type[BaseModel]] = None 20 | form_model: Optional[Type[BaseModel]] = None 21 | response_model: Type[BaseModel] = None 22 | on_success_status: int = 200 23 | request_query: ImmutableMultiDict = ImmutableMultiDict({}) 24 | request_body: Union[dict, List[dict]] = {} 25 | request_form: ImmutableMultiDict = ImmutableMultiDict({}) 26 | expected_response_body: Optional[dict] = None 27 | expected_status_code: int = 200 28 | exclude_none: bool = False 29 | response_many: bool = False 30 | request_body_many: bool = False 31 | 32 | 33 | class ResponseModel(BaseModel): 34 | q1: int 35 | q2: str 36 | b1: float 37 | b2: Optional[str] = None 38 | 39 | 40 | class QueryModel(BaseModel): 41 | q1: int 42 | q2: str = "default" 43 | 44 | 45 | class RequestBodyModel(BaseModel): 46 | b1: float 47 | b2: Optional[str] = None 48 | 49 | 50 | class FormModel(BaseModel): 51 | f1: int 52 | f2: str = None 53 | 54 | 55 | class RequestBodyModelRoot(BaseModel): 56 | __root__: Union[str, RequestBodyModel] 57 | 58 | 59 | validate_test_cases = [ 60 | pytest.param( 61 | ValidateParams( 62 | request_body={"b1": 1.4}, 63 | request_query=ImmutableMultiDict({"q1": 1}), 64 | request_form=ImmutableMultiDict({"f1": 1}), 65 | form_model=FormModel, 66 | expected_response_body={"q1": 1, "q2": "default", "b1": 1.4, "b2": None}, 67 | response_model=ResponseModel, 68 | query_model=QueryModel, 69 | body_model=RequestBodyModel, 70 | ), 71 | id="simple valid example with default values", 72 | ), 73 | pytest.param( 74 | ValidateParams( 75 | request_body={"b1": 1.4}, 76 | request_query=ImmutableMultiDict({"q1": 1}), 77 | request_form=ImmutableMultiDict({"f1": 1}), 78 | form_model=FormModel, 79 | expected_response_body={"q1": 1, "q2": "default", "b1": 1.4}, 80 | response_model=ResponseModel, 81 | query_model=QueryModel, 82 | body_model=RequestBodyModel, 83 | exclude_none=True, 84 | ), 85 | id="simple valid example with default values, exclude none", 86 | ), 87 | pytest.param( 88 | ValidateParams( 89 | query_model=QueryModel, 90 | expected_response_body={ 91 | "validation_error": { 92 | "query_params": [ 93 | { 94 | "loc": ["q1"], 95 | "msg": "field required", 96 | "type": "value_error.missing", 97 | } 98 | ] 99 | } 100 | }, 101 | expected_status_code=400, 102 | ), 103 | id="invalid query param", 104 | ), 105 | pytest.param( 106 | ValidateParams( 107 | body_model=RequestBodyModel, 108 | expected_response_body={ 109 | "validation_error": { 110 | "body_params": [ 111 | { 112 | "loc": ["root"], 113 | "msg": "is not an array of objects", 114 | "type": "type_error.array", 115 | } 116 | ] 117 | } 118 | }, 119 | request_body={"b1": 3.14, "b2": "str"}, 120 | expected_status_code=400, 121 | request_body_many=True, 122 | ), 123 | id="`request_body_many=True` but in request body is a single object", 124 | ), 125 | pytest.param( 126 | ValidateParams( 127 | expected_response_body={ 128 | "validation_error": { 129 | "body_params": [ 130 | { 131 | "loc": ["b1"], 132 | "msg": "field required", 133 | "type": "value_error.missing", 134 | } 135 | ] 136 | } 137 | }, 138 | body_model=RequestBodyModel, 139 | expected_status_code=400, 140 | ), 141 | id="invalid body param", 142 | ), 143 | pytest.param( 144 | ValidateParams( 145 | expected_response_body={ 146 | "validation_error": { 147 | "body_params": [ 148 | { 149 | "loc": ["b1"], 150 | "msg": "field required", 151 | "type": "value_error.missing", 152 | } 153 | ] 154 | } 155 | }, 156 | body_model=RequestBodyModel, 157 | expected_status_code=400, 158 | request_body=[{}], 159 | request_body_many=True, 160 | ), 161 | id="invalid body param in many-object request body", 162 | ), 163 | pytest.param( 164 | ValidateParams( 165 | form_model=FormModel, 166 | expected_response_body={ 167 | "validation_error": { 168 | "form_params": [ 169 | { 170 | "loc": ["f1"], 171 | "msg": "field required", 172 | "type": "value_error.missing", 173 | } 174 | ] 175 | } 176 | }, 177 | expected_status_code=400, 178 | ), 179 | id="invalid form param", 180 | ), 181 | ] 182 | 183 | 184 | class TestValidate: 185 | @pytest.mark.parametrize("parameters", validate_test_cases) 186 | def test_validate(self, mocker, request_ctx, parameters: ValidateParams): 187 | mock_request = mocker.patch.object(request_ctx, "request") 188 | mock_request.args = parameters.request_query 189 | mock_request.get_json = lambda: parameters.request_body 190 | mock_request.form = parameters.request_form 191 | 192 | def f(): 193 | body = {} 194 | query = {} 195 | if mock_request.form_params: 196 | body = mock_request.form_params.dict() 197 | if mock_request.body_params: 198 | body = mock_request.body_params.dict() 199 | if mock_request.query_params: 200 | query = mock_request.query_params.dict() 201 | return parameters.response_model(**body, **query) 202 | 203 | response = validate( 204 | query=parameters.query_model, 205 | body=parameters.body_model, 206 | on_success_status=parameters.on_success_status, 207 | exclude_none=parameters.exclude_none, 208 | response_many=parameters.response_many, 209 | request_body_many=parameters.request_body_many, 210 | form=parameters.form_model, 211 | )(f)() 212 | 213 | assert response.status_code == parameters.expected_status_code 214 | assert_matches(parameters.expected_response_body, response.json) 215 | if 200 <= response.status_code < 300: 216 | assert ( 217 | mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) 218 | == parameters.request_body 219 | ) 220 | assert ( 221 | mock_request.query_params.dict(exclude_none=True, exclude_defaults=True) 222 | == parameters.request_query.to_dict() 223 | ) 224 | 225 | @pytest.mark.parametrize("parameters", validate_test_cases) 226 | def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams): 227 | mock_request = mocker.patch.object(request_ctx, "request") 228 | mock_request.args = parameters.request_query 229 | mock_request.get_json = lambda: parameters.request_body 230 | mock_request.form = parameters.request_form 231 | 232 | def f( 233 | body: parameters.body_model, 234 | query: parameters.query_model, 235 | form: parameters.form_model, 236 | ): 237 | return parameters.response_model( 238 | **body.dict(), **query.dict(), **form.dict() 239 | ) 240 | 241 | response = validate( 242 | on_success_status=parameters.on_success_status, 243 | exclude_none=parameters.exclude_none, 244 | response_many=parameters.response_many, 245 | request_body_many=parameters.request_body_many, 246 | )(f)() 247 | 248 | assert_matches(parameters.expected_response_body, response.json) 249 | assert response.status_code == parameters.expected_status_code 250 | if 200 <= response.status_code < 300: 251 | assert ( 252 | mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) 253 | == parameters.request_body 254 | ) 255 | assert ( 256 | mock_request.query_params.dict(exclude_none=True, exclude_defaults=True) 257 | == parameters.request_query.to_dict() 258 | ) 259 | 260 | @pytest.mark.usefixtures("request_ctx") 261 | def test_response_with_status(self): 262 | expected_status_code = 201 263 | expected_response_body = dict(q1=1, q2="2", b1=3.14, b2="b2") 264 | 265 | def f(): 266 | return ResponseModel(q1=1, q2="2", b1=3.14, b2="b2"), expected_status_code 267 | 268 | response = validate()(f)() 269 | assert response.status_code == expected_status_code 270 | assert_matches(expected_response_body, response.json) 271 | 272 | @pytest.mark.usefixtures("request_ctx") 273 | def test_response_already_response(self): 274 | expected_response_body = {"a": 1, "b": 2} 275 | 276 | def f(): 277 | return jsonify(expected_response_body) 278 | 279 | response = validate()(f)() 280 | assert_matches(expected_response_body, response.json) 281 | 282 | @pytest.mark.usefixtures("request_ctx") 283 | def test_response_many_response_objs(self): 284 | response_content = [ 285 | ResponseModel(q1=1, q2="2", b1=3.14, b2="b2"), 286 | ResponseModel(q1=2, q2="3", b1=3.14), 287 | ResponseModel(q1=3, q2="4", b1=6.9, b2="b4"), 288 | ] 289 | expected_response_body = [ 290 | {"q1": 1, "q2": "2", "b1": 3.14, "b2": "b2"}, 291 | {"q1": 2, "q2": "3", "b1": 3.14}, 292 | {"q1": 3, "q2": "4", "b1": 6.9, "b2": "b4"}, 293 | ] 294 | 295 | def f(): 296 | return response_content 297 | 298 | response = validate(exclude_none=True, response_many=True)(f)() 299 | assert_matches(expected_response_body, response.json) 300 | 301 | @pytest.mark.usefixtures("request_ctx") 302 | def test_invalid_many_raises(self): 303 | def f(): 304 | return ResponseModel(q1=1, q2="2", b1=3.14, b2="b2") 305 | 306 | with pytest.raises(InvalidIterableOfModelsException): 307 | validate(response_many=True)(f)() 308 | 309 | def test_valid_array_object_request_body(self, mocker, request_ctx): 310 | mock_request = mocker.patch.object(request_ctx, "request") 311 | mock_request.args = ImmutableMultiDict({"q1": 1}) 312 | mock_request.get_json = lambda: [ 313 | {"b1": 1.0, "b2": "str1"}, 314 | {"b1": 2.0, "b2": "str2"}, 315 | ] 316 | expected_response_body = [ 317 | {"q1": 1, "q2": "default", "b1": 1.0, "b2": "str1"}, 318 | {"q1": 1, "q2": "default", "b1": 2.0, "b2": "str2"}, 319 | ] 320 | 321 | def f(): 322 | query_params = mock_request.query_params 323 | body_params = mock_request.body_params 324 | return [ 325 | ResponseModel( 326 | q1=query_params.q1, 327 | q2=query_params.q2, 328 | b1=obj.b1, 329 | b2=obj.b2, 330 | ) 331 | for obj in body_params 332 | ] 333 | 334 | response = validate( 335 | query=QueryModel, 336 | body=RequestBodyModel, 337 | request_body_many=True, 338 | response_many=True, 339 | )(f)() 340 | 341 | assert response.status_code == 200 342 | assert_matches(expected_response_body, response.json) 343 | 344 | def test_unsupported_media_type(self, request_ctx, mocker): 345 | mock_request = mocker.patch.object(request_ctx, "request") 346 | content_type = "text/plain" 347 | mock_request.headers = {"Content-Type": content_type} 348 | mock_request.get_json = lambda: None 349 | body_model = RequestBodyModel 350 | response = validate(body_model)(lambda x: x)() 351 | assert response.status_code == 415 352 | assert response.json == { 353 | "detail": f"Unsupported media type '{content_type}' in request. " 354 | "'application/json' is required." 355 | } 356 | 357 | def test_invalid_body_model_root(self, request_ctx, mocker): 358 | mock_request = mocker.patch.object(request_ctx, "request") 359 | content_type = "application/json" 360 | mock_request.headers = {"Content-Type": content_type} 361 | mock_request.get_json = lambda: None 362 | body_model = RequestBodyModelRoot 363 | response = validate(body_model)(lambda x: x)() 364 | assert response.status_code == 400 365 | 366 | assert_matches( 367 | { 368 | "validation_error": { 369 | "body_params": [ 370 | { 371 | "loc": ["__root__"], 372 | "msg": "none is not an allowed value", 373 | "type": "type_error.none.not_allowed", 374 | } 375 | ] 376 | } 377 | }, 378 | response.json, 379 | ) 380 | 381 | def test_damaged_request_body_json_with_charset(self, request_ctx, mocker): 382 | mock_request = mocker.patch.object(request_ctx, "request") 383 | content_type = "application/json;charset=utf-8" 384 | mock_request.headers = {"Content-Type": content_type} 385 | mock_request.get_json = lambda: None 386 | body_model = RequestBodyModel 387 | with pytest.raises(JsonBodyParsingError): 388 | validate(body_model)(lambda x: x)() 389 | 390 | def test_damaged_request_body(self, request_ctx, mocker): 391 | mock_request = mocker.patch.object(request_ctx, "request") 392 | content_type = "application/json" 393 | mock_request.headers = {"Content-Type": content_type} 394 | mock_request.get_json = lambda: None 395 | body_model = RequestBodyModel 396 | with pytest.raises(JsonBodyParsingError): 397 | validate(body_model)(lambda x: x)() 398 | 399 | @pytest.mark.parametrize("parameters", validate_test_cases) 400 | def test_validate_func_having_return_type_annotation( 401 | self, mocker, request_ctx, parameters: ValidateParams 402 | ): 403 | mock_request = mocker.patch.object(request_ctx, "request") 404 | mock_request.args = parameters.request_query 405 | mock_request.get_json = lambda: parameters.request_body 406 | mock_request.form = parameters.request_form 407 | 408 | def f() -> Any: 409 | body = {} 410 | query = {} 411 | if mock_request.form_params: 412 | body = mock_request.form_params.dict() 413 | if mock_request.body_params: 414 | body = mock_request.body_params.dict() 415 | if mock_request.query_params: 416 | query = mock_request.query_params.dict() 417 | return parameters.response_model(**body, **query) 418 | 419 | response = validate( 420 | query=parameters.query_model, 421 | body=parameters.body_model, 422 | form=parameters.form_model, 423 | on_success_status=parameters.on_success_status, 424 | exclude_none=parameters.exclude_none, 425 | response_many=parameters.response_many, 426 | request_body_many=parameters.request_body_many, 427 | )(f)() 428 | 429 | assert response.status_code == parameters.expected_status_code 430 | assert_matches(parameters.expected_response_body, response.json) 431 | if 200 <= response.status_code < 300: 432 | assert ( 433 | mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) 434 | == parameters.request_body 435 | ) 436 | assert ( 437 | mock_request.query_params.dict(exclude_none=True, exclude_defaults=True) 438 | == parameters.request_query.to_dict() 439 | ) 440 | 441 | def test_fail_validation_custom_status_code(self, app, request_ctx, mocker): 442 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE"] = 422 443 | mock_request = mocker.patch.object(request_ctx, "request") 444 | content_type = "application/json" 445 | mock_request.headers = {"Content-Type": content_type} 446 | mock_request.get_json = lambda: None 447 | body_model = RequestBodyModelRoot 448 | response = validate(body_model)(lambda x: x)() 449 | assert response.status_code == 422 450 | 451 | assert_matches( 452 | { 453 | "validation_error": { 454 | "body_params": [ 455 | { 456 | "loc": ["__root__"], 457 | "msg": "none is not an allowed value", 458 | "type": "type_error.none.not_allowed", 459 | } 460 | ] 461 | } 462 | }, 463 | response.json, 464 | ) 465 | 466 | def test_body_fail_validation_raise_exception(self, app, request_ctx, mocker): 467 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 468 | mock_request = mocker.patch.object(request_ctx, "request") 469 | content_type = "application/json" 470 | mock_request.headers = {"Content-Type": content_type} 471 | mock_request.get_json = lambda: None 472 | body_model = RequestBodyModelRoot 473 | with pytest.raises(ValidationError) as excinfo: 474 | validate(body_model)(lambda x: x)() 475 | assert_matches( 476 | [ 477 | { 478 | "loc": ("__root__",), 479 | "msg": "none is not an allowed value", 480 | "type": "type_error.none.not_allowed", 481 | } 482 | ], 483 | excinfo.value.body_params, 484 | ) 485 | 486 | def test_query_fail_validation_raise_exception(self, app, request_ctx, mocker): 487 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 488 | mock_request = mocker.patch.object(request_ctx, "request") 489 | content_type = "application/json" 490 | mock_request.headers = {"Content-Type": content_type} 491 | mock_request.get_json = lambda: None 492 | query_model = QueryModel 493 | with pytest.raises(ValidationError) as excinfo: 494 | validate(query=query_model)(lambda x: x)() 495 | assert_matches( 496 | [ 497 | { 498 | "loc": ("q1",), 499 | "msg": "field required", 500 | "type": "value_error.missing", 501 | } 502 | ], 503 | excinfo.value.query_params, 504 | ) 505 | 506 | def test_form_fail_validation_raise_exception(self, app, request_ctx, mocker): 507 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 508 | mock_request = mocker.patch.object(request_ctx, "request") 509 | content_type = "application/json" 510 | mock_request.headers = {"Content-Type": content_type} 511 | mock_request.get_json = lambda: None 512 | form_model = FormModel 513 | with pytest.raises(ValidationError) as excinfo: 514 | validate(form=form_model)(lambda x: x)() 515 | assert_matches( 516 | [ 517 | { 518 | "loc": ("f1",), 519 | "msg": "field required", 520 | "type": "value_error.missing", 521 | } 522 | ], 523 | excinfo.value.form_params, 524 | ) 525 | 526 | 527 | class TestIsIterableOfModels: 528 | def test_simple_true_case(self): 529 | models = [ 530 | QueryModel(q1=1, q2="w"), 531 | QueryModel(q1=2, q2="wsdf"), 532 | RequestBodyModel(b1=3.1), 533 | RequestBodyModel(b1=0.1), 534 | ] 535 | assert is_iterable_of_models(models) 536 | 537 | def test_false_for_non_iterable(self): 538 | assert not is_iterable_of_models(1) 539 | 540 | def test_false_for_single_model(self): 541 | assert not is_iterable_of_models(RequestBodyModel(b1=12)) 542 | 543 | 544 | convert_query_params_test_cases = [ 545 | pytest.param( 546 | ImmutableMultiDict({"a": 1, "b": "b"}), {"a": 1, "b": "b"}, id="primitive types" 547 | ), 548 | pytest.param( 549 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"]}), 550 | {"a": 1, "b": "b", "c": ["one"]}, 551 | id="one element in array", 552 | ), 553 | pytest.param( 554 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"], "d": [1]}), 555 | {"a": 1, "b": "b", "c": ["one"], "d": [1]}, 556 | id="one element in arrays", 557 | ), 558 | pytest.param( 559 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"], "d": [1, 2, 3]}), 560 | {"a": 1, "b": "b", "c": ["one"], "d": [1, 2, 3]}, 561 | id="one element in array, multiple in the other", 562 | ), 563 | pytest.param( 564 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one", "two", "three"]}), 565 | {"a": 1, "b": "b", "c": ["one", "two", "three"]}, 566 | id="multiple elements in array", 567 | ), 568 | pytest.param( 569 | ImmutableMultiDict( 570 | {"a": 1, "b": "b", "c": ["one", "two", "three"], "d": [1, 2, 3]} 571 | ), 572 | {"a": 1, "b": "b", "c": ["one", "two", "three"], "d": [1, 2, 3]}, 573 | id="multiple in both arrays", 574 | ), 575 | ] 576 | 577 | 578 | @pytest.mark.parametrize( 579 | "query_params,expected_result", convert_query_params_test_cases 580 | ) 581 | def test_convert_query_params(query_params: ImmutableMultiDict, expected_result: dict): 582 | class Model(BaseModel): 583 | a: int 584 | b: str 585 | c: Optional[List[str]] 586 | d: Optional[List[int]] 587 | 588 | assert convert_query_params(query_params, Model) == expected_result 589 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pallets-eco/flask-pydantic/ccdd2a2c9816012c440491af735fbaaa54f57c6b/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_core.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union 4 | 5 | import pytest 6 | from flask import jsonify 7 | from flask_pydantic import ValidationError, validate 8 | from flask_pydantic.core import convert_query_params, is_iterable_of_models 9 | from flask_pydantic.exceptions import ( 10 | InvalidIterableOfModelsException, 11 | JsonBodyParsingError, 12 | ) 13 | from pydantic import BaseModel, RootModel 14 | from werkzeug.datastructures import ImmutableMultiDict 15 | 16 | from ..util import assert_matches 17 | 18 | 19 | class EmptyModel(BaseModel): 20 | pass 21 | 22 | 23 | class ValidateParams(NamedTuple): 24 | body_model: Type[BaseModel] = EmptyModel 25 | query_model: Type[BaseModel] = EmptyModel 26 | form_model: Type[BaseModel] = EmptyModel 27 | response_model: Type[BaseModel] = EmptyModel 28 | on_success_status: int = 200 29 | request_query: ImmutableMultiDict = ImmutableMultiDict({}) 30 | flat_request_query: bool = True 31 | request_body: Union[dict, List[dict]] = {} 32 | request_form: ImmutableMultiDict = ImmutableMultiDict({}) 33 | expected_response_body: Optional[dict] = None 34 | expected_status_code: int = 200 35 | exclude_none: bool = False 36 | response_many: bool = False 37 | request_body_many: bool = False 38 | 39 | 40 | class ResponseModel(BaseModel): 41 | q1: int 42 | q2: str 43 | b1: float 44 | b2: Optional[str] = None 45 | 46 | 47 | class QueryModel(BaseModel): 48 | q1: int 49 | q2: str = "default" 50 | 51 | 52 | class RequestBodyModel(BaseModel): 53 | b1: float 54 | b2: Optional[str] = None 55 | 56 | 57 | class FormModel(BaseModel): 58 | f1: int 59 | f2: Optional[str] = None 60 | 61 | 62 | class RequestWithIterableModel(BaseModel): 63 | b1: List 64 | b2: List[str] 65 | b3: Tuple[str, int] 66 | b4: Optional[List[int]] = None 67 | b5: Union[Tuple[str, int], None] = None 68 | 69 | 70 | if sys.version_info >= (3, 10): 71 | # New Python(>=3.10) syntax tests 72 | class RequestWithIterableModelPy310(BaseModel): 73 | b1: list 74 | b2: list[str] 75 | b3: tuple[str, int] 76 | b4: list[int] | None = None 77 | b5: tuple[str, int] | None = None 78 | 79 | 80 | class RequestBodyModelRoot(RootModel): 81 | root: Union[str, RequestBodyModel] 82 | 83 | 84 | validate_test_cases = [ 85 | pytest.param( 86 | ValidateParams( 87 | request_body={"b1": 1.4}, 88 | request_query=ImmutableMultiDict({"q1": 1}), 89 | request_form=ImmutableMultiDict({"f1": 1}), 90 | form_model=FormModel, 91 | expected_response_body={"q1": 1, "q2": "default", "b1": 1.4, "b2": None}, 92 | response_model=ResponseModel, 93 | query_model=QueryModel, 94 | body_model=RequestBodyModel, 95 | ), 96 | id="simple valid example with default values", 97 | ), 98 | pytest.param( 99 | ValidateParams( 100 | request_body={"b1": 1.4}, 101 | request_query=ImmutableMultiDict({"q1": 1}), 102 | request_form=ImmutableMultiDict({"f1": 1}), 103 | form_model=FormModel, 104 | expected_response_body={"q1": 1, "q2": "default", "b1": 1.4}, 105 | response_model=ResponseModel, 106 | query_model=QueryModel, 107 | body_model=RequestBodyModel, 108 | exclude_none=True, 109 | ), 110 | id="simple valid example with default values, exclude none", 111 | ), 112 | pytest.param( 113 | ValidateParams( 114 | query_model=QueryModel, 115 | expected_response_body={ 116 | "validation_error": { 117 | "query_params": [ 118 | { 119 | "input": {}, 120 | "loc": ["q1"], 121 | "msg": "Field required", 122 | "type": "missing", 123 | "url": re.compile( 124 | r"https://errors\.pydantic\.dev/.*/v/missing" 125 | ), 126 | } 127 | ] 128 | } 129 | }, 130 | expected_status_code=400, 131 | ), 132 | id="invalid query param", 133 | ), 134 | pytest.param( 135 | ValidateParams( 136 | body_model=RequestBodyModel, 137 | expected_response_body={ 138 | "validation_error": { 139 | "body_params": [ 140 | { 141 | "loc": ["root"], 142 | "msg": "is not an array of objects", 143 | "type": "type_error.array", 144 | } 145 | ] 146 | } 147 | }, 148 | request_body={"b1": 3.14, "b2": "str"}, 149 | expected_status_code=400, 150 | request_body_many=True, 151 | ), 152 | id="`request_body_many=True` but in request body is a single object", 153 | ), 154 | pytest.param( 155 | ValidateParams( 156 | expected_response_body={ 157 | "validation_error": { 158 | "body_params": [ 159 | { 160 | "input": {}, 161 | "loc": ["b1"], 162 | "msg": "Field required", 163 | "type": "missing", 164 | "url": re.compile( 165 | r"https://errors\.pydantic\.dev/.*/v/missing" 166 | ), 167 | } 168 | ] 169 | } 170 | }, 171 | body_model=RequestBodyModel, 172 | expected_status_code=400, 173 | ), 174 | id="invalid body param", 175 | ), 176 | pytest.param( 177 | ValidateParams( 178 | expected_response_body={ 179 | "validation_error": { 180 | "body_params": [ 181 | { 182 | "input": {}, 183 | "loc": ["b1"], 184 | "msg": "Field required", 185 | "type": "missing", 186 | "url": re.compile( 187 | r"https://errors\.pydantic\.dev/.*/v/missing" 188 | ), 189 | } 190 | ] 191 | } 192 | }, 193 | body_model=RequestBodyModel, 194 | expected_status_code=400, 195 | request_body=[{}], 196 | request_body_many=True, 197 | ), 198 | id="invalid body param in many-object request body", 199 | ), 200 | pytest.param( 201 | ValidateParams( 202 | form_model=FormModel, 203 | expected_response_body={ 204 | "validation_error": { 205 | "form_params": [ 206 | { 207 | "input": {}, 208 | "loc": ["f1"], 209 | "msg": "Field required", 210 | "type": "missing", 211 | "url": re.compile( 212 | r"https://errors\.pydantic\.dev/.*/v/missing" 213 | ), 214 | } 215 | ] 216 | } 217 | }, 218 | expected_status_code=400, 219 | ), 220 | id="invalid form param", 221 | ), 222 | pytest.param( 223 | ValidateParams( 224 | request_query=ImmutableMultiDict( 225 | [ 226 | ("b1", "str1"), 227 | ("b1", "str2"), 228 | ("b2", "str1"), 229 | ("b2", "str2"), 230 | ("b3", "str"), 231 | ("b3", 123), 232 | ("b4", 1), 233 | ("b4", 2), 234 | ("b4", 3), 235 | ("b5", "str"), 236 | ("b5", 321), 237 | ] 238 | ), 239 | flat_request_query=False, 240 | expected_response_body={ 241 | "b1": ["str1", "str2"], 242 | "b2": ["str1", "str2"], 243 | "b3": ("str", 123), 244 | "b4": [1, 2, 3], 245 | "b5": ("str", 321), 246 | }, 247 | query_model=RequestWithIterableModel, 248 | response_model=RequestWithIterableModel, 249 | expected_status_code=200, 250 | ), 251 | id="iterable and Optional[Iterable] fields in pydantic model in query", 252 | ), 253 | ] 254 | 255 | if sys.version_info >= (3, 10): 256 | validate_test_cases.extend( 257 | [ 258 | pytest.param( 259 | ValidateParams( 260 | request_query=ImmutableMultiDict( 261 | [ 262 | ("b1", "str1"), 263 | ("b1", "str2"), 264 | ("b2", "str1"), 265 | ("b2", "str2"), 266 | ("b3", "str"), 267 | ("b3", 123), 268 | ("b4", 1), 269 | ("b4", 2), 270 | ("b4", 3), 271 | ("b5", "str"), 272 | ("b5", 321), 273 | ] 274 | ), 275 | flat_request_query=False, 276 | expected_response_body={ 277 | "b1": ["str1", "str2"], 278 | "b2": ["str1", "str2"], 279 | "b3": ("str", 123), 280 | "b4": [1, 2, 3], 281 | "b5": ("str", 321), 282 | }, 283 | query_model=RequestWithIterableModelPy310, 284 | response_model=RequestWithIterableModelPy310, 285 | expected_status_code=200, 286 | ), 287 | id="iterable and Iterable | None fields in pydantic model in query (Python 3.10+)", 288 | ), 289 | ] 290 | ) 291 | 292 | 293 | class TestValidate: 294 | @pytest.mark.parametrize("parameters", validate_test_cases) 295 | def test_validate(self, mocker, request_ctx, parameters: ValidateParams): 296 | mock_request = mocker.patch.object(request_ctx, "request") 297 | mock_request.args = parameters.request_query 298 | mock_request.get_json = lambda: parameters.request_body 299 | mock_request.form = parameters.request_form 300 | 301 | def f(): 302 | body = {} 303 | query = {} 304 | if mock_request.form_params: 305 | body = mock_request.form_params.model_dump() 306 | if mock_request.body_params: 307 | body = mock_request.body_params.model_dump() 308 | if mock_request.query_params: 309 | query = mock_request.query_params.model_dump() 310 | return parameters.response_model(**body, **query) 311 | 312 | response = validate( 313 | query=parameters.query_model, 314 | body=parameters.body_model, 315 | on_success_status=parameters.on_success_status, 316 | exclude_none=parameters.exclude_none, 317 | response_many=parameters.response_many, 318 | request_body_many=parameters.request_body_many, 319 | form=parameters.form_model, 320 | )(f)() 321 | 322 | assert response.status_code == parameters.expected_status_code 323 | assert_matches(parameters.expected_response_body, response.json) 324 | if 200 <= response.status_code < 300: 325 | assert_matches( 326 | parameters.request_body, 327 | mock_request.body_params.model_dump( 328 | exclude_none=True, exclude_defaults=True 329 | ), 330 | ) 331 | assert_matches( 332 | parameters.request_query.to_dict(flat=parameters.flat_request_query), 333 | mock_request.query_params.model_dump( 334 | exclude_none=True, exclude_defaults=True 335 | ), 336 | ) 337 | 338 | @pytest.mark.parametrize("parameters", validate_test_cases) 339 | def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams): 340 | mock_request = mocker.patch.object(request_ctx, "request") 341 | mock_request.args = parameters.request_query 342 | mock_request.get_json = lambda: parameters.request_body 343 | mock_request.form = parameters.request_form 344 | 345 | def f( 346 | body: parameters.body_model, 347 | query: parameters.query_model, 348 | form: parameters.form_model, 349 | ): 350 | return parameters.response_model( 351 | **body.model_dump(), **query.model_dump(), **form.model_dump() 352 | ) 353 | 354 | response = validate( 355 | on_success_status=parameters.on_success_status, 356 | exclude_none=parameters.exclude_none, 357 | response_many=parameters.response_many, 358 | request_body_many=parameters.request_body_many, 359 | )(f)() 360 | 361 | assert_matches(parameters.expected_response_body, response.json) 362 | assert response.status_code == parameters.expected_status_code 363 | if 200 <= response.status_code < 300: 364 | assert_matches( 365 | parameters.request_body, 366 | mock_request.body_params.model_dump( 367 | exclude_none=True, exclude_defaults=True 368 | ), 369 | ) 370 | assert_matches( 371 | parameters.request_query.to_dict(flat=parameters.flat_request_query), 372 | mock_request.query_params.model_dump( 373 | exclude_none=True, exclude_defaults=True 374 | ), 375 | ) 376 | 377 | @pytest.mark.usefixtures("request_ctx") 378 | def test_response_with_status(self): 379 | expected_status_code = 201 380 | expected_response_body = dict(q1=1, q2="2", b1=3.14, b2="b2") 381 | 382 | def f(): 383 | return ResponseModel(q1=1, q2="2", b1=3.14, b2="b2"), expected_status_code 384 | 385 | response = validate()(f)() 386 | assert response.status_code == expected_status_code 387 | assert_matches(expected_response_body, response.json) 388 | 389 | @pytest.mark.usefixtures("request_ctx") 390 | def test_response_already_response(self): 391 | expected_response_body = {"a": 1, "b": 2} 392 | 393 | def f(): 394 | return jsonify(expected_response_body) 395 | 396 | response = validate()(f)() 397 | assert_matches(expected_response_body, response.json) 398 | 399 | @pytest.mark.usefixtures("request_ctx") 400 | def test_response_many_response_objs(self): 401 | response_content = [ 402 | ResponseModel(q1=1, q2="2", b1=3.14, b2="b2"), 403 | ResponseModel(q1=2, q2="3", b1=3.14), 404 | ResponseModel(q1=3, q2="4", b1=6.9, b2="b4"), 405 | ] 406 | expected_response_body = [ 407 | {"q1": 1, "q2": "2", "b1": 3.14, "b2": "b2"}, 408 | {"q1": 2, "q2": "3", "b1": 3.14}, 409 | {"q1": 3, "q2": "4", "b1": 6.9, "b2": "b4"}, 410 | ] 411 | 412 | def f(): 413 | return response_content 414 | 415 | response = validate(exclude_none=True, response_many=True)(f)() 416 | assert_matches(expected_response_body, response.json) 417 | 418 | @pytest.mark.usefixtures("request_ctx") 419 | def test_invalid_many_raises(self): 420 | def f(): 421 | return ResponseModel(q1=1, q2="2", b1=3.14, b2="b2") 422 | 423 | with pytest.raises(InvalidIterableOfModelsException): 424 | validate(response_many=True)(f)() 425 | 426 | def test_valid_array_object_request_body(self, mocker, request_ctx): 427 | mock_request = mocker.patch.object(request_ctx, "request") 428 | mock_request.args = ImmutableMultiDict({"q1": 1}) 429 | mock_request.get_json = lambda: [ 430 | {"b1": 1.0, "b2": "str1"}, 431 | {"b1": 2.0, "b2": "str2"}, 432 | ] 433 | expected_response_body = [ 434 | {"q1": 1, "q2": "default", "b1": 1.0, "b2": "str1"}, 435 | {"q1": 1, "q2": "default", "b1": 2.0, "b2": "str2"}, 436 | ] 437 | 438 | def f(): 439 | query_params = mock_request.query_params 440 | body_params = mock_request.body_params 441 | return [ 442 | ResponseModel( 443 | q1=query_params.q1, 444 | q2=query_params.q2, 445 | b1=obj.b1, 446 | b2=obj.b2, 447 | ) 448 | for obj in body_params 449 | ] 450 | 451 | response = validate( 452 | query=QueryModel, 453 | body=RequestBodyModel, 454 | request_body_many=True, 455 | response_many=True, 456 | )(f)() 457 | 458 | assert response.status_code == 200 459 | assert_matches(expected_response_body, response.json) 460 | 461 | def test_unsupported_media_type(self, request_ctx, mocker): 462 | mock_request = mocker.patch.object(request_ctx, "request") 463 | content_type = "text/plain" 464 | mock_request.headers = {"Content-Type": content_type} 465 | mock_request.get_json = lambda: None 466 | body_model = RequestBodyModel 467 | response = validate(body_model)(lambda x: x)() 468 | assert response.status_code == 415 469 | assert response.json == { 470 | "detail": f"Unsupported media type '{content_type}' in request. " 471 | "'application/json' is required." 472 | } 473 | 474 | def test_invalid_body_model_root(self, request_ctx, mocker): 475 | mock_request = mocker.patch.object(request_ctx, "request") 476 | content_type = "application/json" 477 | mock_request.headers = {"Content-Type": content_type} 478 | mock_request.get_json = lambda: None 479 | body_model = RequestBodyModelRoot 480 | response = validate(body_model)(lambda x: x)() 481 | assert response.status_code == 400 482 | 483 | assert_matches( 484 | { 485 | "validation_error": { 486 | "body_params": [ 487 | { 488 | "input": None, 489 | "loc": ["str"], 490 | "msg": "Input should be a valid string", 491 | "type": "string_type", 492 | "url": re.compile( 493 | r"https://errors\.pydantic\.dev/.*/v/string_type" 494 | ), 495 | }, 496 | { 497 | "ctx": {"class_name": "RequestBodyModel"}, 498 | "input": None, 499 | "loc": ["RequestBodyModel"], 500 | "msg": "Input should be a valid dictionary or instance of RequestBodyModel", 501 | "type": "model_type", 502 | "url": re.compile( 503 | r"https://errors\.pydantic\.dev/.*/v/model_type" 504 | ), 505 | }, 506 | ] 507 | } 508 | }, 509 | response.json, 510 | ) 511 | 512 | def test_damaged_request_body_json_with_charset(self, request_ctx, mocker): 513 | mock_request = mocker.patch.object(request_ctx, "request") 514 | content_type = "application/json;charset=utf-8" 515 | mock_request.headers = {"Content-Type": content_type} 516 | mock_request.get_json = lambda: None 517 | body_model = RequestBodyModel 518 | with pytest.raises(JsonBodyParsingError): 519 | validate(body_model)(lambda x: x)() 520 | 521 | def test_damaged_request_body(self, request_ctx, mocker): 522 | mock_request = mocker.patch.object(request_ctx, "request") 523 | content_type = "application/json" 524 | mock_request.headers = {"Content-Type": content_type} 525 | mock_request.get_json = lambda: None 526 | body_model = RequestBodyModel 527 | with pytest.raises(JsonBodyParsingError): 528 | validate(body_model)(lambda x: x)() 529 | 530 | @pytest.mark.parametrize("parameters", validate_test_cases) 531 | def test_validate_func_having_return_type_annotation( 532 | self, mocker, request_ctx, parameters: ValidateParams 533 | ): 534 | mock_request = mocker.patch.object(request_ctx, "request") 535 | mock_request.args = parameters.request_query 536 | mock_request.get_json = lambda: parameters.request_body 537 | mock_request.form = parameters.request_form 538 | 539 | def f() -> Any: 540 | body = {} 541 | query = {} 542 | if mock_request.form_params: 543 | body = mock_request.form_params.model_dump() 544 | if mock_request.body_params: 545 | body = mock_request.body_params.model_dump() 546 | if mock_request.query_params: 547 | query = mock_request.query_params.model_dump() 548 | return parameters.response_model(**body, **query) 549 | 550 | response = validate( 551 | query=parameters.query_model, 552 | body=parameters.body_model, 553 | form=parameters.form_model, 554 | on_success_status=parameters.on_success_status, 555 | exclude_none=parameters.exclude_none, 556 | response_many=parameters.response_many, 557 | request_body_many=parameters.request_body_many, 558 | )(f)() 559 | 560 | assert response.status_code == parameters.expected_status_code 561 | assert_matches(parameters.expected_response_body, response.json) 562 | if 200 <= response.status_code < 300: 563 | assert_matches( 564 | parameters.request_body, 565 | mock_request.body_params.model_dump( 566 | exclude_none=True, exclude_defaults=True 567 | ), 568 | ) 569 | assert_matches( 570 | parameters.request_query.to_dict(flat=parameters.flat_request_query), 571 | mock_request.query_params.model_dump( 572 | exclude_none=True, exclude_defaults=True 573 | ), 574 | ) 575 | 576 | def test_fail_validation_custom_status_code(self, app, request_ctx, mocker): 577 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE"] = 422 578 | mock_request = mocker.patch.object(request_ctx, "request") 579 | content_type = "application/json" 580 | mock_request.headers = {"Content-Type": content_type} 581 | mock_request.get_json = lambda: None 582 | body_model = RequestBodyModelRoot 583 | response = validate(body_model)(lambda x: x)() 584 | assert response.status_code == 422 585 | 586 | assert_matches( 587 | { 588 | "validation_error": { 589 | "body_params": [ 590 | { 591 | "input": None, 592 | "loc": ["str"], 593 | "msg": "Input should be a valid string", 594 | "type": "string_type", 595 | "url": re.compile( 596 | r"https://errors\.pydantic\.dev/.*/v/string_type" 597 | ), 598 | }, 599 | { 600 | "ctx": {"class_name": "RequestBodyModel"}, 601 | "input": None, 602 | "loc": ["RequestBodyModel"], 603 | "msg": "Input should be a valid dictionary or instance of RequestBodyModel", 604 | "type": "model_type", 605 | "url": re.compile( 606 | r"https://errors\.pydantic\.dev/.*/v/model_type" 607 | ), 608 | }, 609 | ] 610 | } 611 | }, 612 | response.json, 613 | ) 614 | 615 | def test_body_fail_validation_raise_exception(self, app, request_ctx, mocker): 616 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 617 | mock_request = mocker.patch.object(request_ctx, "request") 618 | content_type = "application/json" 619 | mock_request.headers = {"Content-Type": content_type} 620 | mock_request.get_json = lambda: None 621 | body_model = RequestBodyModelRoot 622 | with pytest.raises(ValidationError) as excinfo: 623 | validate(body_model)(lambda x: x)() 624 | assert_matches( 625 | [ 626 | { 627 | "input": None, 628 | "loc": ("str",), 629 | "msg": "Input should be a valid string", 630 | "type": "string_type", 631 | "url": re.compile( 632 | r"https://errors\.pydantic\.dev/.*/v/string_type" 633 | ), 634 | }, 635 | { 636 | "ctx": {"class_name": "RequestBodyModel"}, 637 | "input": None, 638 | "loc": ("RequestBodyModel",), 639 | "msg": "Input should be a valid dictionary or instance of RequestBodyModel", 640 | "type": "model_type", 641 | "url": re.compile(r"https://errors\.pydantic\.dev/.*/v/model_type"), 642 | }, 643 | ], 644 | excinfo.value.body_params, 645 | ) 646 | 647 | def test_query_fail_validation_raise_exception(self, app, request_ctx, mocker): 648 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 649 | mock_request = mocker.patch.object(request_ctx, "request") 650 | content_type = "application/json" 651 | mock_request.headers = {"Content-Type": content_type} 652 | mock_request.get_json = lambda: None 653 | query_model = QueryModel 654 | with pytest.raises(ValidationError) as excinfo: 655 | validate(query=query_model)(lambda x: x)() 656 | assert_matches( 657 | [ 658 | { 659 | "input": {}, 660 | "loc": ("q1",), 661 | "msg": "Field required", 662 | "type": "missing", 663 | "url": re.compile(r"https://errors\.pydantic\.dev/.*/v/missing"), 664 | } 665 | ], 666 | excinfo.value.query_params, 667 | ) 668 | 669 | def test_form_fail_validation_raise_exception(self, app, request_ctx, mocker): 670 | app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True 671 | mock_request = mocker.patch.object(request_ctx, "request") 672 | content_type = "application/json" 673 | mock_request.headers = {"Content-Type": content_type} 674 | mock_request.get_json = lambda: None 675 | form_model = FormModel 676 | with pytest.raises(ValidationError) as excinfo: 677 | validate(form=form_model)(lambda x: x)() 678 | assert_matches( 679 | [ 680 | { 681 | "input": {}, 682 | "loc": ("f1",), 683 | "msg": "Field required", 684 | "type": "missing", 685 | "url": re.compile(r"https://errors\.pydantic\.dev/.*/v/missing"), 686 | } 687 | ], 688 | excinfo.value.form_params, 689 | ) 690 | 691 | 692 | class TestIsIterableOfModels: 693 | def test_simple_true_case(self): 694 | models = [ 695 | QueryModel(q1=1, q2="w"), 696 | QueryModel(q1=2, q2="wsdf"), 697 | RequestBodyModel(b1=3.1), 698 | RequestBodyModel(b1=0.1), 699 | ] 700 | assert is_iterable_of_models(models) 701 | 702 | def test_false_for_non_iterable(self): 703 | assert not is_iterable_of_models(1) 704 | 705 | def test_false_for_single_model(self): 706 | assert not is_iterable_of_models(RequestBodyModel(b1=12)) 707 | 708 | 709 | convert_query_params_test_cases = [ 710 | pytest.param( 711 | ImmutableMultiDict({"a": 1, "b": "b"}), {"a": 1, "b": "b"}, id="primitive types" 712 | ), 713 | pytest.param( 714 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"]}), 715 | {"a": 1, "b": "b", "c": ["one"]}, 716 | id="one element in array", 717 | ), 718 | pytest.param( 719 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"], "d": [1]}), 720 | {"a": 1, "b": "b", "c": ["one"], "d": [1]}, 721 | id="one element in arrays", 722 | ), 723 | pytest.param( 724 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"], "d": [1, 2, 3]}), 725 | {"a": 1, "b": "b", "c": ["one"], "d": [1, 2, 3]}, 726 | id="one element in array, multiple in the other", 727 | ), 728 | pytest.param( 729 | ImmutableMultiDict({"a": 1, "b": "b", "c": ["one", "two", "three"]}), 730 | {"a": 1, "b": "b", "c": ["one", "two", "three"]}, 731 | id="multiple elements in array", 732 | ), 733 | pytest.param( 734 | ImmutableMultiDict( 735 | {"a": 1, "b": "b", "c": ["one", "two", "three"], "d": [1, 2, 3]} 736 | ), 737 | {"a": 1, "b": "b", "c": ["one", "two", "three"], "d": [1, 2, 3]}, 738 | id="multiple in both arrays", 739 | ), 740 | ] 741 | 742 | 743 | @pytest.mark.parametrize( 744 | "query_params,expected_result", convert_query_params_test_cases 745 | ) 746 | def test_convert_query_params(query_params: ImmutableMultiDict, expected_result: dict): 747 | class Model(BaseModel): 748 | a: int 749 | b: str 750 | c: Optional[List[str]] 751 | d: Optional[List[int]] 752 | 753 | assert convert_query_params(query_params, Model) == expected_result 754 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List, Union 3 | 4 | ExpectedType = Union[re.Pattern, str, List["ExpectedType"], Dict[str, "ExpectedType"]] 5 | ActualType = Union[str, List["ActualType"], Dict[str, "ActualType"]] 6 | 7 | 8 | def assert_matches(expected: ExpectedType, actual: ActualType): 9 | """ 10 | Recursively compare the expected and actual values. 11 | 12 | Args: 13 | expected: The expected value. If this is a compiled regex, 14 | it will be matched against the actual value. 15 | actual: The actual value. 16 | 17 | Raises: 18 | AssertionError: If the expected and actual values do not match. 19 | """ 20 | if isinstance(expected, dict): 21 | assert set(expected.keys()) == set(actual.keys()) 22 | for key, value in expected.items(): 23 | assert_matches(value, actual[key]) 24 | elif isinstance(expected, (list, tuple)): 25 | assert len(expected) == len(actual) 26 | for a, b in zip(expected, actual): 27 | assert_matches(a, b) 28 | elif isinstance(expected, re.Pattern): 29 | assert expected.match(actual) 30 | else: 31 | assert expected == actual 32 | --------------------------------------------------------------------------------