├── spectree ├── py.typed ├── __init__.py ├── plugins │ ├── __init__.py │ ├── base.py │ ├── flask_plugin.py │ ├── quart_plugin.py │ ├── starlette_plugin.py │ └── werkzeug_utils.py ├── _types.py ├── _pydantic.py ├── config.py ├── models.py ├── page.py └── response.py ├── tests ├── __init__.py ├── import_module │ ├── __init__.py │ ├── test_flask_plugin.py │ ├── test_quart_plugin.py │ ├── test_starlette_plugin.py │ └── test_falcon_plugin.py ├── conftest.py ├── quart_imports │ ├── __init__.py │ └── dry_plugin_quart.py ├── flask_imports │ ├── __init__.py │ └── dry_plugin_flask.py ├── test_config.py ├── test_plugin.py ├── test_response.py ├── test_pydantic.py ├── test_base_plugin.py ├── test_spec.py ├── common.py ├── test_utils.py └── test_plugin_flask_blueprint.py ├── examples ├── __init__.py ├── common.py ├── security_demo.py ├── starlette_demo.py ├── quart_demo.py ├── flask_demo.py ├── falcon_asgi_demo.py └── falcon_demo.py ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yaml │ └── bug_report.yaml ├── workflows │ ├── python-publish.yml │ ├── codeql.yml │ ├── pythonpackage.yml │ └── pythondoc.yml ├── CONTRIBUTING.md ├── FUNDING.yml └── dependabot.yml ├── docs ├── source │ ├── models.rst │ ├── robots.txt │ ├── utils.rst │ ├── config.rst │ ├── spectree.rst │ ├── response.rst │ ├── _static │ │ └── custom.css │ ├── plugins.rst │ ├── index.md │ └── conf.py ├── Makefile └── make.bat ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── Makefile ├── pyproject.toml ├── .gitignore └── LICENSE /spectree/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @kemingy 2 | -------------------------------------------------------------------------------- /tests/import_module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ========== 3 | 4 | .. automodule:: spectree.models 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | 3 | Sitemap: https://0b01001001.github.io/spectree/sitemap.xml 4 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ==================== 3 | 4 | .. automodule:: spectree.utils 5 | :members: -------------------------------------------------------------------------------- /docs/source/config.rst: -------------------------------------------------------------------------------- 1 | Config 2 | ==================== 3 | 4 | .. automodule:: spectree.config 5 | :members: -------------------------------------------------------------------------------- /docs/source/spectree.rst: -------------------------------------------------------------------------------- 1 | SpecTree 2 | ==================== 3 | 4 | .. automodule:: spectree.spec 5 | :members: -------------------------------------------------------------------------------- /docs/source/response.rst: -------------------------------------------------------------------------------- 1 | Response 2 | ==================== 3 | 4 | .. automodule:: spectree.response 5 | :members: -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | p > a > img { 2 | display: inline; 3 | } 4 | 5 | #spectree img { 6 | margin: 0; 7 | } 8 | -------------------------------------------------------------------------------- /tests/import_module/test_flask_plugin.py: -------------------------------------------------------------------------------- 1 | from spectree import SpecTree 2 | 3 | SpecTree("flask") 4 | print("=> passed flask plugin import test") 5 | -------------------------------------------------------------------------------- /tests/import_module/test_quart_plugin.py: -------------------------------------------------------------------------------- 1 | from spectree import SpecTree 2 | 3 | SpecTree("quart") 4 | print("=> passed quart plugin import test") 5 | -------------------------------------------------------------------------------- /tests/import_module/test_starlette_plugin.py: -------------------------------------------------------------------------------- 1 | from spectree import SpecTree 2 | 3 | SpecTree("starlette") 4 | print("=> passed starlette plugin import test") 5 | -------------------------------------------------------------------------------- /tests/import_module/test_falcon_plugin.py: -------------------------------------------------------------------------------- 1 | from spectree import SpecTree 2 | 3 | SpecTree("falcon") 4 | SpecTree("falcon-asgi") 5 | print("=> passed falcon plugin import test") 6 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from syrupy.extensions.json import JSONSnapshotExtension 3 | 4 | 5 | @pytest.fixture 6 | def anyio_backend(): 7 | return "asyncio" 8 | 9 | 10 | @pytest.fixture 11 | def snapshot_json(snapshot): 12 | return snapshot.use_extension(JSONSnapshotExtension).with_defaults() 13 | -------------------------------------------------------------------------------- /docs/source/plugins.rst: -------------------------------------------------------------------------------- 1 | Plugins 2 | ==================== 3 | 4 | .. automodule:: spectree.plugins.base 5 | :members: 6 | 7 | .. automodule:: spectree.plugins.flask_plugin 8 | :members: 9 | 10 | .. automodule:: spectree.plugins.falcon_plugin 11 | :members: 12 | 13 | .. automodule:: spectree.plugins.starlette_plugin 14 | :members: 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | 3 | contact_links: 4 | - name: Have you read the docs? 5 | url: https://0b01001001.github.io/spectree/ 6 | about: Much help can be found in the docs 7 | - name: Ask a question 8 | url: https://github.com/0b01001001/spectree/discussions 9 | about: Ask a question or start a discussion 10 | -------------------------------------------------------------------------------- /examples/common.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from spectree import BaseFile 4 | 5 | 6 | class File(BaseModel): 7 | uid: str 8 | file: BaseFile 9 | 10 | 11 | class FileResp(BaseModel): 12 | filename: str 13 | type: str 14 | 15 | 16 | class Query(BaseModel): 17 | text: str = Field( 18 | ..., 19 | max_length=100, 20 | ) 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - repo: local 10 | hooks: 11 | - id: make-lint 12 | name: Lint 13 | entry: make lint 14 | language: system 15 | types: [python] 16 | pass_filenames: false 17 | always_run: true 18 | -------------------------------------------------------------------------------- /spectree/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .models import BaseFile, ExternalDocs, SecurityScheme, SecuritySchemeData, Tag 4 | from .response import Response 5 | from .spec import SpecTree 6 | 7 | __all__ = [ 8 | "BaseFile", 9 | "ExternalDocs", 10 | "Response", 11 | "SecurityScheme", 12 | "SecuritySchemeData", 13 | "SpecTree", 14 | "Tag", 15 | ] 16 | 17 | # setup library logging 18 | logging.getLogger(__name__).addHandler(logging.NullHandler()) 19 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | # Allows you to run this workflow manually from the Actions tab 7 | workflow_dispatch: 8 | 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | permissions: 13 | id-token: write 14 | steps: 15 | - uses: actions/checkout@v6 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v7 18 | with: 19 | enable-cache: true 20 | ignore-nothing-to-cache: true 21 | - name: Publish to PyPI 22 | run: make publish 23 | -------------------------------------------------------------------------------- /spectree/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from .base import BasePlugin 4 | 5 | Plugin = namedtuple("Plugin", ("name", "package", "class_name")) 6 | 7 | PLUGINS = { 8 | "base": Plugin(".base", __name__, "BasePlugin"), 9 | "flask": Plugin(".flask_plugin", __name__, "FlaskPlugin"), 10 | "quart": Plugin(".quart_plugin", __name__, "QuartPlugin"), 11 | "falcon": Plugin(".falcon_plugin", __name__, "FalconPlugin"), 12 | "falcon-asgi": Plugin(".falcon_plugin", __name__, "FalconAsgiPlugin"), 13 | "starlette": Plugin(".starlette_plugin", __name__, "StarlettePlugin"), 14 | } 15 | 16 | __all__ = ["PLUGINS", "BasePlugin", "Plugin"] 17 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | ```{eval-rst} 2 | .. meta:: 3 | :description lang=en: 4 | SpecTree is an API spec validator and OpenAPI document generator for Python web frameworks. 5 | ``` 6 | 7 | ```{include} ../../README.md 8 | :relative-docs: examples/ 9 | :relative-images: 10 | ``` 11 | 12 | ```{toctree} 13 | --- 14 | maxdepth: 2 15 | hidden: 16 | caption: API reference 17 | --- 18 | spectree 19 | config 20 | response 21 | models 22 | utils 23 | plugins 24 | ``` 25 | 26 | ```{toctree} 27 | --- 28 | hidden: 29 | caption: Project Links 30 | --- 31 | 32 | GitHub 33 | ``` 34 | 35 | ## Indices and tables 36 | 37 | - {ref}`genindex` 38 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SpecTree 2 | 3 | :tada: First off, thanks for taking the time to contribute! :tada: 4 | 5 | > [!TIP] 6 | > To get an isolated development environment, you can try 7 | > [envd](https://github.com/tensorchord/envd) with command `envd up`. 8 | 9 | ## Pull Requests 10 | 11 | * fork this repo and clone it 12 | * install `spectree` and related libraries with `make install` 13 | * create a new branch: `git checkout -b fix-` 14 | * make your changes (code, doc, test) 15 | * check the coding style `make lint` and test cases `make test` 16 | * open a pull request, follow the [semantic commit message format](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716) 17 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/source/conf.py 17 | 18 | # We recommend specifying your dependencies to enable reproducible builds: 19 | # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - method: pip 23 | path: . 24 | extra_requirements: 25 | - docs 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= uv run -- sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [kemingy] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Feature request for spectree 3 | labels: ["enhancement"] 4 | title: "feat: " 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this feature request! 10 | - type: textarea 11 | attributes: 12 | label: Describe the feature 13 | description: A clear and concise description of what the feature is. 14 | validations: 15 | required: true 16 | - type: textarea 17 | attributes: 18 | label: Additional context 19 | description: Add any other context about the problem here. 20 | validations: 21 | required: false 22 | - type: markdown 23 | attributes: 24 | value: | 25 | Love this enhancement proposal? Give it a 👍. We prioritize the proposals with the most 👍. 26 | -------------------------------------------------------------------------------- /tests/quart_imports/__init__.py: -------------------------------------------------------------------------------- 1 | from .dry_plugin_quart import ( 2 | test_quart_custom_error, 3 | test_quart_doc, 4 | test_quart_forced_serializer, 5 | test_quart_list_json_request, 6 | test_quart_no_response, 7 | test_quart_return_list_request, 8 | test_quart_return_model, 9 | test_quart_return_string_status, 10 | test_quart_skip_validation, 11 | test_quart_validate, 12 | test_quart_validation_error_response_status_code, 13 | ) 14 | 15 | __all__ = [ 16 | "test_quart_custom_error", 17 | "test_quart_doc", 18 | "test_quart_forced_serializer", 19 | "test_quart_list_json_request", 20 | "test_quart_no_response", 21 | "test_quart_return_list_request", 22 | "test_quart_return_model", 23 | "test_quart_return_string_status", 24 | "test_quart_skip_validation", 25 | "test_quart_validate", 26 | "test_quart_validation_error_response_status_code", 27 | ] 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | schedule: 9 | - cron: "43 19 * * 5" 10 | 11 | jobs: 12 | analyze: 13 | name: Analyze 14 | runs-on: ubuntu-latest 15 | permissions: 16 | actions: read 17 | contents: read 18 | security-events: write 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | language: [ python ] 24 | 25 | steps: 26 | - name: Checkout 27 | uses: actions/checkout@v6 28 | 29 | - name: Initialize CodeQL 30 | uses: github/codeql-action/init@v4 31 | with: 32 | languages: ${{ matrix.language }} 33 | queries: +security-and-quality 34 | 35 | - name: Autobuild 36 | uses: github/codeql-action/autobuild@v4 37 | 38 | - name: Perform CodeQL Analysis 39 | uses: github/codeql-action/analyze@v4 40 | with: 41 | category: "/language:${{ matrix.language }}" 42 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | cooldown: 13 | default-days: 5 14 | semver-major-days: 30 15 | semver-minor-days: 7 16 | semver-patch-days: 3 17 | commit-message: 18 | prefix: "chore(pip)" 19 | groups: 20 | all-pips: 21 | patterns: 22 | - "*" 23 | 24 | - package-ecosystem: "github-actions" 25 | directory: "/" 26 | schedule: 27 | interval: "monthly" 28 | cooldown: 29 | default-days: 5 30 | commit-message: 31 | prefix: "chore(actions)" 32 | groups: 33 | all-actions: 34 | patterns: 35 | - "*" 36 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL:=install 2 | 3 | SOURCE_FILES=spectree tests examples 4 | MYPY_SOURCE_FILES=spectree tests # temporary 5 | 6 | install: 7 | uv sync --all-extras --all-groups 8 | uv run -- prek install 9 | 10 | import_test: 11 | for module in flask quart falcon starlette; do \ 12 | uv sync --extra $$module; \ 13 | bash -c "uv run tests/import_module/test_$${module}_plugin.py" || exit 1; \ 14 | done 15 | 16 | test: import_test 17 | uv sync --all-extras --group dev 18 | uv run -- pytest tests -vv -rs --disable-warnings 19 | 20 | update_snapshot: 21 | @uv run -- pytest --snapshot-update 22 | 23 | doc: 24 | @cd docs && make html 25 | 26 | opendoc: 27 | @cd docs/build/html && uv run -m http.server 8765 -b 127.0.0.1 28 | 29 | clean: 30 | @-rm -rf build/ dist/ *.egg-info .pytest_cache 31 | @find . -name '*.pyc' -type f -exec rm -rf {} + 32 | @find . -name '__pycache__' -exec rm -rf {} + 33 | 34 | package: clean 35 | @uv build 36 | 37 | publish: package 38 | @uv publish dist/* 39 | 40 | format: 41 | @uv run -- ruff format ${SOURCE_FILES} 42 | @uv run -- ruff check --fix ${SOURCE_FILES} 43 | 44 | lint: 45 | @uv run -- ruff format --check ${SOURCE_FILES} 46 | @uv run -- ruff check ${SOURCE_FILES} 47 | @uv run -- mypy --install-types --non-interactive ${MYPY_SOURCE_FILES} 48 | 49 | .PHONY: test doc 50 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python Check 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | - main 9 | merge_group: 10 | # Allows you to run this workflow manually from the Actions tab 11 | workflow_dispatch: 12 | 13 | concurrency: 14 | group: ${{ github.ref }}-${{ github.workflow }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | lint: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v6 22 | - name: Install uv 23 | uses: astral-sh/setup-uv@v7 24 | with: 25 | enable-cache: true 26 | ignore-nothing-to-cache: true 27 | - name: Install dependencies 28 | run: make install 29 | - name: Lint 30 | run: make lint 31 | 32 | test: 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | fail-fast: false 36 | matrix: 37 | python-version: ["3.10", "3.11", "3.12", "3.13", "3.14", "3.14t", "pypy3.11"] 38 | os: [ubuntu-latest] 39 | 40 | steps: 41 | - uses: actions/checkout@v6 42 | - name: Install uv 43 | uses: astral-sh/setup-uv@v7 44 | with: 45 | enable-cache: true 46 | ignore-nothing-to-cache: true 47 | python-version: ${{ matrix.python-version }} 48 | - name: Install dependencies 49 | run: make install 50 | - name: Test 51 | run: make test 52 | -------------------------------------------------------------------------------- /spectree/_types.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Callable, 4 | Dict, 5 | Iterator, 6 | List, 7 | Mapping, 8 | Optional, 9 | Protocol, 10 | Sequence, 11 | Type, 12 | TypeVar, 13 | Union, 14 | ) 15 | 16 | from spectree._pydantic import BaseModel 17 | 18 | BaseModelSubclassType = TypeVar("BaseModelSubclassType", bound=BaseModel) 19 | ModelType = Type[BaseModelSubclassType] 20 | OptionalModelType = Optional[ModelType] 21 | NamingStrategy = Callable[[ModelType], str] 22 | NestedNamingStrategy = Callable[[str, str], str] 23 | 24 | 25 | class MultiDict(Protocol): 26 | def get(self, key: str) -> Optional[str]: 27 | pass 28 | 29 | def getlist(self, key: str) -> List[str]: 30 | pass 31 | 32 | def __iter__(self) -> Iterator[str]: 33 | pass 34 | 35 | 36 | class MultiDictStarlette(Protocol): 37 | def __iter__(self) -> Iterator[str]: 38 | pass 39 | 40 | def getlist(self, key: Any) -> List[Any]: 41 | pass 42 | 43 | def __getitem__(self, key: Any) -> Any: 44 | pass 45 | 46 | 47 | class FunctionDecorator(Protocol): 48 | resp: Any 49 | tags: Sequence[Any] 50 | security: Union[None, Dict, List[Any]] 51 | deprecated: bool 52 | path_parameter_descriptions: Optional[Mapping[str, str]] 53 | _decorator: Any 54 | 55 | 56 | JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]] 57 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Bug report for spectree 3 | labels: ["bug"] 4 | title: "bug: <title>" 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! 10 | - type: textarea 11 | attributes: 12 | label: Describe the bug 13 | description: A clear and concise description of what the bug is. 14 | validations: 15 | required: true 16 | - type: textarea 17 | attributes: 18 | label: To Reproduce 19 | description: Steps to reproduce the behavior. 20 | validations: 21 | required: true 22 | - type: textarea 23 | attributes: 24 | label: Expected behavior 25 | description: A clear and concise description of what you expected to happen. 26 | validations: 27 | required: false 28 | - type: textarea 29 | attributes: 30 | label: The spectree version 31 | description: The output of `pip show spectree`, `uname -a`, `python --version` commands. 32 | validations: 33 | required: true 34 | - type: textarea 35 | attributes: 36 | label: Additional context 37 | description: Add any other context about the problem here. 38 | validations: 39 | required: false 40 | - type: markdown 41 | attributes: 42 | value: | 43 | Impacted by this bug? Give it a 👍. We prioritize the issues with the most 👍. 44 | -------------------------------------------------------------------------------- /tests/flask_imports/__init__.py: -------------------------------------------------------------------------------- 1 | from .dry_plugin_flask import ( 2 | test_flask_custom_error, 3 | test_flask_doc, 4 | test_flask_forced_serializer, 5 | test_flask_list_json_request, 6 | test_flask_make_response_get, 7 | test_flask_make_response_post, 8 | test_flask_no_response, 9 | test_flask_optional_alias_response, 10 | test_flask_query_list, 11 | test_flask_return_list_request, 12 | test_flask_return_model, 13 | test_flask_return_model_request, 14 | test_flask_return_root_request, 15 | test_flask_return_string_status, 16 | test_flask_skip_validation, 17 | test_flask_upload_file, 18 | test_flask_validate_basic, 19 | test_flask_validate_post_data, 20 | test_flask_validation_error_response_status_code, 21 | ) 22 | 23 | __all__ = [ 24 | "test_flask_custom_error", 25 | "test_flask_doc", 26 | "test_flask_forced_serializer", 27 | "test_flask_list_json_request", 28 | "test_flask_make_response_get", 29 | "test_flask_make_response_post", 30 | "test_flask_no_response", 31 | "test_flask_optional_alias_response", 32 | "test_flask_query_list", 33 | "test_flask_return_list_request", 34 | "test_flask_return_model", 35 | "test_flask_return_model_request", 36 | "test_flask_return_root_request", 37 | "test_flask_return_string_status", 38 | "test_flask_skip_validation", 39 | "test_flask_upload_file", 40 | "test_flask_validate_basic", 41 | "test_flask_validate_post_data", 42 | "test_flask_validation_error_response_status_code", 43 | ] 44 | -------------------------------------------------------------------------------- /.github/workflows/pythondoc.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python document 5 | 6 | on: 7 | pull_request: 8 | paths: 9 | - 'spectree/**' 10 | - 'docs/**' 11 | - '.github/workflows/pythondoc.yml' 12 | - 'examples/**' 13 | push: 14 | branches: 15 | - master 16 | - main 17 | paths: 18 | - 'spectree/**' 19 | - 'docs/**' 20 | - '.github/workflows/pythondoc.yml' 21 | - 'examples/**' 22 | # Allows you to run this workflow manually from the Actions tab 23 | workflow_dispatch: 24 | 25 | concurrency: 26 | group: ${{ github.ref }}-${{ github.workflow }} 27 | cancel-in-progress: true 28 | 29 | jobs: 30 | build: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - uses: actions/checkout@v6 34 | - name: Setup Pages 35 | uses: actions/configure-pages@v5 36 | - name: Install uv 37 | uses: astral-sh/setup-uv@v7 38 | with: 39 | enable-cache: true 40 | python-version: "3.13" 41 | - name: Install dependencies 42 | run: make install 43 | - name: Generate docs 44 | run: make doc 45 | - name: Upload artifact 46 | uses: actions/upload-pages-artifact@v4 47 | with: 48 | # Upload entire repository 49 | path: 'docs/build/html' 50 | 51 | deploy: 52 | runs-on: ubuntu-latest 53 | needs: build 54 | if: ${{ github.event_name == 'push' }} 55 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 56 | permissions: 57 | pages: write 58 | id-token: write 59 | environment: 60 | name: github-pages 61 | url: ${{ steps.deployment.outputs.page_url }} 62 | steps: 63 | - name: Deploy to GitHub Pages 64 | id: deployment 65 | uses: actions/deploy-pages@v4 66 | -------------------------------------------------------------------------------- /examples/security_demo.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from pydantic import BaseModel 3 | 4 | from spectree import SecurityScheme, SecuritySchemeData, SpecTree 5 | 6 | 7 | class Req(BaseModel): 8 | name: str 9 | 10 | 11 | security_schemes = [ 12 | SecurityScheme( 13 | name="PartnerID", 14 | data=SecuritySchemeData.model_validate( 15 | {"type": "apiKey", "name": "partner-id", "in": "header"} 16 | ), 17 | ), 18 | SecurityScheme( 19 | name="PartnerToken", 20 | data=SecuritySchemeData.model_validate( 21 | {"type": "apiKey", "name": "partner-access-token", "in": "header"} 22 | ), 23 | ), 24 | SecurityScheme( 25 | name="test_secure", 26 | data=SecuritySchemeData.model_validate( 27 | { 28 | "type": "http", 29 | "scheme": "bearer", 30 | } 31 | ), 32 | ), 33 | SecurityScheme( 34 | name="auth_oauth2", 35 | data=SecuritySchemeData.model_validate( 36 | { 37 | "type": "oauth2", 38 | "flows": { 39 | "authorizationCode": { 40 | "authorizationUrl": ( 41 | "https://accounts.google.com/o/oauth2/v2/auth" 42 | ), 43 | "tokenUrl": "https://sts.googleapis.com", 44 | "scopes": { 45 | "https://www.googleapis.com/auth/tasks.readonly": "tasks", 46 | }, 47 | }, 48 | }, 49 | } 50 | ), 51 | ), 52 | ] 53 | 54 | app = Flask(__name__) 55 | spec = SpecTree( 56 | "flask", 57 | security_schemes=security_schemes, 58 | SECURITY=[ 59 | {"test_secure": []}, 60 | {"PartnerID": [], "PartnerToken": []}, 61 | ], 62 | client_id="client_id", 63 | ) 64 | 65 | 66 | @app.route("/ping", methods=["POST"]) 67 | @spec.validate() 68 | def ping(json: Req): 69 | return "pong" 70 | 71 | 72 | @app.route("/ping/oauth", methods=["POST"]) 73 | @spec.validate(security=[{"auth_oauth2": ["read"]}]) 74 | def oauth_only(json: Req): 75 | return "pong" 76 | 77 | 78 | @app.route("/") 79 | def index(): 80 | return "hello" 81 | 82 | 83 | if __name__ == "__main__": 84 | spec.register(app) 85 | app.run(port=8000) 86 | -------------------------------------------------------------------------------- /examples/starlette_demo.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from pydantic import BaseModel, Field 3 | from starlette.applications import Starlette 4 | from starlette.datastructures import UploadFile 5 | from starlette.endpoints import HTTPEndpoint 6 | from starlette.responses import JSONResponse 7 | from starlette.routing import Mount, Route 8 | 9 | from examples.common import File, FileResp, Query 10 | from spectree import Response, SpecTree 11 | 12 | spec = SpecTree("starlette", annotations=True) 13 | 14 | 15 | class Resp(BaseModel): 16 | label: int = Field( 17 | ..., 18 | ge=0, 19 | le=9, 20 | ) 21 | score: float = Field( 22 | ..., 23 | gt=0, 24 | lt=1, 25 | ) 26 | 27 | 28 | class Data(BaseModel): 29 | uid: str 30 | limit: int 31 | vip: bool 32 | 33 | 34 | @spec.validate(resp=Response(HTTP_200=Resp), tags=["api"]) 35 | async def predict(request, query: Query, json: Data): 36 | """ 37 | async api 38 | 39 | descriptions about this function 40 | """ 41 | print(request.path_params) 42 | print(query, json) 43 | return JSONResponse({"label": 5, "score": 0.5}) 44 | # return PydanticResponse(Resp(label=5, score=0.5)) 45 | 46 | 47 | @spec.validate(resp=Response(HTTP_200=FileResp), tags=["file-upload"]) 48 | async def file_upload(request, form: File): 49 | """ 50 | post multipart/form-data demo 51 | 52 | demo for 'form' 53 | """ 54 | file: UploadFile = form.file 55 | return JSONResponse({"filename": file.filename, "type": file.content_type}) 56 | 57 | 58 | class Ping(HTTPEndpoint): 59 | @spec.validate(tags=["health check", "api"]) 60 | def get(self, request): 61 | """ 62 | health check 63 | """ 64 | return JSONResponse({"msg": "pong"}) 65 | 66 | 67 | if __name__ == "__main__": 68 | """ 69 | cmd: 70 | http :8000/ping 71 | http ':8000/api/predict/233?text=hello' vip=true uid=admin limit=1 72 | """ 73 | app = Starlette( 74 | routes=[ 75 | Route("/ping", Ping), 76 | Mount( 77 | "/api", 78 | routes=[ 79 | Route("/predict/{luck:int}", predict, methods=["POST"]), 80 | Route("/file-upload", file_upload, methods=["POST"]), 81 | ], 82 | ), 83 | ] 84 | ) 85 | spec.register(app) 86 | 87 | uvicorn.run(app, log_level="info") 88 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "spectree" 3 | version = "2.0.1" 4 | dynamic = [] 5 | description = "Generate OpenAPI document and validate request & response with Python annotations." 6 | readme = "README.md" 7 | license = "Apache-2.0" 8 | requires-python = ">=3.10" 9 | authors = [{ name = "Keming Yang", email = "kemingy94@gmail.com" }] 10 | classifiers = [ 11 | "Intended Audience :: Developers", 12 | "Operating System :: OS Independent", 13 | "Programming Language :: Python :: 3 :: Only", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: 3.11", 16 | "Programming Language :: Python :: 3.12", 17 | "Programming Language :: Python :: 3.13", 18 | "Programming Language :: Python :: 3.14", 19 | "Programming Language :: Python :: Free Threading", 20 | "Programming Language :: Python :: Implementation :: CPython", 21 | "Topic :: Software Development :: Libraries :: Python Modules", 22 | ] 23 | dependencies = [ 24 | "pydantic>=2.11,<3", 25 | ] 26 | 27 | [project.optional-dependencies] 28 | offline = ["offapi>=0.1.1"] 29 | falcon = ["falcon>=3"] 30 | starlette = ["starlette[full]>=0.16"] 31 | flask = ["flask>=2"] 32 | quart = ["quart>=0.16"] 33 | 34 | [project.urls] 35 | Homepage = "https://github.com/0b01001001/spectree" 36 | documentation = "https://0b01001001.github.io/spectree/" 37 | repository = "https://github.com/0b01001001/spectree" 38 | changelog = "https://github.com/0b01001001/spectree/releases" 39 | 40 | [tool.uv] 41 | package = true 42 | 43 | [tool.ruff] 44 | target-version = "py310" 45 | line-length = 88 46 | [tool.ruff.lint] 47 | select = ["E", "F", "B", "G", "I", "SIM", "TID", "PL", "RUF"] 48 | ignore = ["E501", "PLR2004", "RUF012", "B009"] 49 | [tool.ruff.lint.pylint] 50 | max-args = 12 51 | max-branches = 15 52 | 53 | [tool.mypy] 54 | plugins = ["pydantic.mypy"] 55 | follow_imports = "silent" 56 | ignore_missing_imports = true 57 | show_error_codes = true 58 | warn_unused_ignores = false 59 | warn_redundant_casts = true 60 | no_implicit_reexport = true 61 | disable_error_code = ["attr-defined"] 62 | 63 | [tool.pydantic-mypy] 64 | init_typed = true 65 | init_forbid_extra = true 66 | warn_required_dynamic_aliases = true 67 | warn_untyped_fields = true 68 | 69 | [dependency-groups] 70 | dev = [ 71 | "anyio>=4.10.0", 72 | "mypy>=1.16.0", 73 | "prek>=0.1.2", 74 | "pytest>=8.3.5", 75 | "ruff>=0.11.12", 76 | "syrupy>=4.9.1", 77 | "uvicorn>=0.35.0", 78 | ] 79 | docs = [ 80 | "myst-parser>=3.0.1", 81 | "shibuya>=2025.3.24", 82 | "sphinx>=7.4.7", 83 | "sphinx-sitemap>=2.6.0", 84 | ] 85 | -------------------------------------------------------------------------------- /examples/quart_demo.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from random import random 3 | 4 | from pydantic import BaseModel, ConfigDict, Field 5 | from quart import Quart, abort, jsonify 6 | from quart.views import MethodView 7 | 8 | from spectree import Response, SpecTree 9 | 10 | app = Quart(__name__) 11 | spec = SpecTree("quart", annotations=True) 12 | 13 | 14 | class Query(BaseModel): 15 | text: str = "default query strings" 16 | 17 | 18 | class Resp(BaseModel): 19 | label: int 20 | score: float = Field( 21 | ..., 22 | gt=0, 23 | lt=1, 24 | ) 25 | 26 | 27 | class Data(BaseModel): 28 | uid: str 29 | limit: int = 5 30 | vip: bool 31 | 32 | model_config = ConfigDict( 33 | schema_extra={ 34 | "example": { 35 | "uid": "very_important_user", 36 | "limit": 10, 37 | "vip": True, 38 | } 39 | } 40 | ) 41 | 42 | 43 | class Language(str, Enum): 44 | en = "en-US" 45 | zh = "zh-CN" 46 | 47 | 48 | class Header(BaseModel): 49 | Lang: Language 50 | 51 | 52 | class Cookie(BaseModel): 53 | key: str 54 | 55 | 56 | @app.route( 57 | "/api/predict/<string(length=2):source>/<string(length=2):target>", methods=["POST"] 58 | ) 59 | @spec.validate(resp=Response("HTTP_403", HTTP_200=Resp), tags=["model"]) 60 | def predict(source, target, json: Data, query: Query): 61 | """ 62 | predict demo 63 | 64 | demo for `query`, `data`, `resp` 65 | """ 66 | print(f"=> from {source} to {target}") # path 67 | print(f"JSON: {json}") # Data 68 | print(f"Query: {query}") # Query 69 | if random() < 0.5: 70 | abort(403) 71 | 72 | return jsonify(label=int(10 * random()), score=random()) 73 | 74 | 75 | @app.route("/api/header", methods=["POST"]) 76 | @spec.validate(resp=Response("HTTP_203"), tags=["test", "demo"]) 77 | async def with_code_header(headers: Header, cookies: Cookie): 78 | """ 79 | demo for JSON with status code and header 80 | """ 81 | return jsonify(language=headers.Lang), 203, {"X": cookies.key} 82 | 83 | 84 | class UserAPI(MethodView): 85 | @spec.validate(resp=Response(HTTP_200=Resp), tags=["test"]) 86 | async def post(self, json: Data): 87 | return jsonify(label=int(10 * random()), score=random()) 88 | # return Resp(label=int(10 * random()), score=random()) 89 | 90 | 91 | if __name__ == "__main__": 92 | """ 93 | cmd: 94 | http :8000/api/user uid=12 limit=1 vip=false 95 | http ':8000/api/predict/zh/en?text=hello' vip=true uid=aa limit=1 96 | http POST :8000/api/header Lang:zh-CN Cookie:key=hello 97 | """ 98 | app.add_url_rule("/api/user", view_func=UserAPI.as_view("user_id")) 99 | spec.register(app) 100 | app.run(port=8000) 101 | -------------------------------------------------------------------------------- /spectree/_pydantic.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Union, cast 3 | 4 | from pydantic import BaseModel, RootModel 5 | 6 | __all__ = [ 7 | "BaseModel", 8 | "generate_root_model", 9 | "is_base_model", 10 | "is_base_model_instance", 11 | "is_pydantic_model", 12 | "is_root_model", 13 | "is_root_model_instance", 14 | "serialize_model_instance", 15 | ] 16 | 17 | 18 | def generate_root_model(root_type, name="GeneratedRootModel") -> type[BaseModel]: 19 | return type(name, (RootModel[root_type],), {}) 20 | 21 | 22 | def is_pydantic_model(t: Any) -> bool: 23 | return issubclass(t, BaseModel) 24 | 25 | 26 | def is_base_model(t: Any) -> bool: 27 | """Check whether a type is a Pydantic BaseModel""" 28 | try: 29 | return is_pydantic_model(t) 30 | except TypeError: 31 | return False 32 | 33 | 34 | def is_base_model_instance(value: Any) -> bool: 35 | """Check whether a value is a Pydantic BaseModel instance.""" 36 | return is_base_model(type(value)) 37 | 38 | 39 | def is_partial_base_model_instance(instance: Any) -> bool: 40 | """Check if it's a Pydantic BaseModel instance or [BaseModel] 41 | or {key: BaseModel} instance. 42 | """ 43 | if not instance: 44 | return False 45 | if is_base_model_instance(instance): 46 | return True 47 | if isinstance(instance, dict): 48 | return any( 49 | is_partial_base_model_instance(key) or is_partial_base_model_instance(value) 50 | for key, value in instance.items() 51 | ) 52 | if isinstance(instance, (list, tuple)): 53 | return any(is_partial_base_model_instance(value) for value in instance) 54 | return False 55 | 56 | 57 | def is_root_model(t: Any) -> bool: 58 | """Check whether a type is a Pydantic RootModel.""" 59 | pydantic_v2_root = is_base_model(t) and any( 60 | f"{m.__module__}.{m.__name__}" == "pydantic.root_model.RootModel" 61 | for m in t.mro() 62 | ) 63 | return pydantic_v2_root 64 | 65 | 66 | def is_root_model_instance(value: Any): 67 | """Check whether a value is a Pydantic RootModel instance.""" 68 | return is_root_model(type(value)) 69 | 70 | 71 | @dataclass(frozen=True) 72 | class SerializedPydanticResponse: 73 | data: bytes 74 | 75 | 76 | _PydanticResponseModel = generate_root_model(Any, name="_PydanticResponseModel") 77 | 78 | 79 | def serialize_model_instance( 80 | value: Union[BaseModel, list[BaseModel], dict[Any, BaseModel]], 81 | ) -> SerializedPydanticResponse: 82 | """Serialize a (partial) Pydantic BaseModel to json string.""" 83 | if not is_base_model_instance(value): 84 | value = _PydanticResponseModel.model_validate(value) 85 | else: 86 | value = cast(BaseModel, value) 87 | serialized = value.model_dump_json() 88 | return SerializedPydanticResponse(serialized.encode("utf-8")) 89 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../../")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "spectree" 22 | copyright = "2020, Keming Yang" 23 | author = "Keming Yang" 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | autodoc_class_signature = "separated" 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | "sphinx.ext.viewcode", 34 | "sphinx.ext.autodoc", 35 | "sphinx.ext.githubpages", 36 | "myst_parser", 37 | "sphinx_sitemap", 38 | ] 39 | 40 | # myst 41 | myst_enable_extensions = [ 42 | "tasklist", 43 | "fieldlist", 44 | "colon_fence", 45 | "replacements", 46 | "substitution", 47 | "smartquotes", 48 | "html_admonition", 49 | "deflist", 50 | ] 51 | 52 | # Add any paths that contain templates here, relative to this directory. 53 | templates_path = ["_templates"] 54 | 55 | # List of patterns, relative to source directory, that match files and 56 | # directories to ignore when looking for source files. 57 | # This pattern also affects html_static_path and html_extra_path. 58 | exclude_patterns = [] 59 | source_suffix = [".rst", ".md"] 60 | language = "en" 61 | html_baseurl = "https://0b01001001.github.io/spectree/" 62 | html_extra_path = ["robots.txt"] 63 | 64 | # -- Options for HTML output ------------------------------------------------- 65 | 66 | # The theme to use for HTML and HTML Help pages. See the documentation for 67 | # a list of builtin themes. 68 | # 69 | html_theme = "shibuya" 70 | html_theme_options = { 71 | "og_image_url": "https://repository-images.githubusercontent.com/225120376/c3469400-c16d-11ea-9498-093594983a5a", 72 | "nav_links": [ 73 | { 74 | "title": "Sponsor me", 75 | "url": "https://github.com/sponsors/kemingy", 76 | }, 77 | ], 78 | } 79 | html_context = { 80 | "source_type": "github", 81 | "source_user": "0b01001001", 82 | "source_repo": "spectree", 83 | } 84 | 85 | # Add any paths that contain custom static files (such as style sheets) here, 86 | # relative to this directory. They are copied after the builtin static files, 87 | # so a file named "default.css" will overwrite the builtin "default.css". 88 | html_static_path = ["_static"] 89 | html_css_files = ["custom.css"] 90 | 91 | # read the doc 92 | master_doc = "index" 93 | -------------------------------------------------------------------------------- /examples/flask_demo.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from random import random 3 | 4 | from flask import Flask, abort, jsonify 5 | from flask.views import MethodView 6 | from pydantic import BaseModel, ConfigDict, Field 7 | from werkzeug.datastructures import FileStorage 8 | 9 | from examples.common import File, FileResp, Query 10 | from spectree import Response, SpecTree 11 | 12 | app = Flask(__name__) 13 | spec = SpecTree("flask") 14 | 15 | 16 | class Resp(BaseModel): 17 | label: int 18 | score: float = Field( 19 | ..., 20 | gt=0, 21 | lt=1, 22 | ) 23 | 24 | 25 | class Data(BaseModel): 26 | uid: str 27 | limit: int = 5 28 | vip: bool 29 | 30 | model_config = ConfigDict( 31 | schema_extra={ 32 | "example": { 33 | "uid": "very_important_user", 34 | "limit": 10, 35 | "vip": True, 36 | } 37 | } 38 | ) 39 | 40 | 41 | class Language(str, Enum): 42 | en = "en-US" 43 | zh = "zh-CN" 44 | 45 | 46 | class Header(BaseModel): 47 | Lang: Language 48 | 49 | 50 | class Cookie(BaseModel): 51 | key: str 52 | 53 | 54 | @app.route( 55 | "/api/predict/<string(length=2):source>/<string(length=2):target>", methods=["POST"] 56 | ) 57 | @spec.validate(resp=Response("HTTP_403", HTTP_200=Resp), tags=["model"]) 58 | def predict(source, target, query: Query, json: Data): 59 | """ 60 | predict demo 61 | 62 | demo for `query`, `data`, `resp` 63 | """ 64 | print(f"=> from {source} to {target}") # path 65 | print(f"JSON: {json}") # Data 66 | print(f"Query: {query}") # Query 67 | if random() < 0.5: 68 | abort(403) 69 | 70 | return jsonify(label=int(10 * random()), score=random()) 71 | 72 | 73 | @app.route("/api/header", methods=["POST"]) 74 | @spec.validate(resp=Response("HTTP_203"), tags=["test", "demo"]) 75 | def with_code_header(headers: Header, cookies: Cookie): 76 | """ 77 | demo for JSON with status code and header 78 | """ 79 | return jsonify(language=headers.Lang), 203, {"X": cookies.key} 80 | 81 | 82 | @app.route("/api/file_upload", methods=["POST"]) 83 | @spec.validate(resp=Response(HTTP_200=FileResp), tags=["file-upload"]) 84 | def with_file(form: File): 85 | """ 86 | post multipart/form-data demo 87 | 88 | demo for 'form' 89 | """ 90 | file: FileStorage = form.file 91 | return {"filename": file.filename, "type": file.content_type} 92 | 93 | 94 | class UserAPI(MethodView): 95 | @spec.validate(resp=Response(HTTP_200=Resp), tags=["test"]) 96 | def post(self, json: Data): 97 | return jsonify(label=int(10 * random()), score=random()) 98 | 99 | 100 | if __name__ == "__main__": 101 | """ 102 | cmd: 103 | http :8000/api/user uid=12 limit=1 vip=false 104 | http ':8000/api/predict/zh/en?text=hello' vip=true uid=aa limit=1 105 | http POST :8000/api/header Lang:zh-CN Cookie:key=hello 106 | """ 107 | app.add_url_rule("/api/user", view_func=UserAPI.as_view("user_id")) 108 | spec.register(app) 109 | app.run(port=8000) 110 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import string 2 | from typing import Type 3 | 4 | import pytest 5 | from pydantic import ValidationError 6 | 7 | from spectree import SecurityScheme 8 | from spectree.config import Configuration 9 | 10 | from .common import SECURITY_SCHEMAS, WRONG_SECURITY_SCHEMAS_DATA 11 | 12 | 13 | def test_config_license(): 14 | config = Configuration(license={"name": "MIT"}) 15 | assert config.license.name == "MIT" 16 | 17 | config = Configuration( 18 | license={"name": "MIT", "url": "https://opensource.org/licenses/MIT"} 19 | ) 20 | assert config.license.name == "MIT" 21 | assert str(config.license.url) == "https://opensource.org/licenses/MIT" 22 | 23 | with pytest.raises(ValidationError): 24 | Configuration(license={"name": "MIT", "url": "url"}) 25 | 26 | 27 | def test_config_contact(): 28 | config = Configuration(contact={"name": "John"}) 29 | assert config.contact.name == "John" 30 | 31 | config = Configuration(contact={"name": "John", "url": "https://example.com"}) 32 | assert config.contact.name == "John" 33 | assert str(config.contact.url).rstrip("/") == "https://example.com" 34 | 35 | config = Configuration(contact={"name": "John", "email": "hello@github.com"}) 36 | assert config.contact.name == "John" 37 | assert config.contact.email == "hello@github.com" 38 | 39 | with pytest.raises(ValidationError): 40 | Configuration(contact={"name": "John", "url": "url"}) 41 | 42 | 43 | @pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS) 44 | def test_update_security_scheme(secure_item: Type[SecurityScheme]): 45 | # update and validate each schema type 46 | config = Configuration( 47 | security_schemes=[SecurityScheme(name=secure_item.name, data=secure_item.data)] 48 | ) 49 | assert config.security_schemes 50 | assert config.security_schemes[0].name == secure_item.name 51 | assert config.security_schemes[0].data == secure_item.data 52 | 53 | 54 | def test_update_security_schemes(): 55 | # update and validate ALL schemas types 56 | config = Configuration(security_schemes=SECURITY_SCHEMAS) 57 | assert config.security_schemes == SECURITY_SCHEMAS 58 | 59 | 60 | @pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS) 61 | def test_update_security_scheme_wrong_type(secure_item: SecurityScheme): 62 | # update and validate each schema type 63 | with pytest.raises(ValidationError): 64 | secure_item.data.type += "_wrong" # type: ignore 65 | 66 | 67 | @pytest.mark.parametrize( 68 | "symbol", [symb for symb in string.punctuation if symb not in "-._"] 69 | ) 70 | @pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS) 71 | def test_update_security_scheme_wrong_name(secure_item: SecurityScheme, symbol: str): 72 | # update and validate each schema name 73 | with pytest.raises(ValidationError): 74 | secure_item.name += symbol 75 | 76 | with pytest.raises(ValidationError): 77 | secure_item.name = symbol + secure_item.name 78 | 79 | 80 | @pytest.mark.parametrize(("secure_item"), WRONG_SECURITY_SCHEMAS_DATA) 81 | def test_update_security_scheme_wrong_data(secure_item: dict): 82 | # update and validate each schema type 83 | with pytest.raises(ValidationError): 84 | SecurityScheme(**secure_item) 85 | -------------------------------------------------------------------------------- /examples/falcon_asgi_demo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from random import random 3 | 4 | import falcon.asgi 5 | import uvicorn 6 | from falcon.asgi.multipart import BodyPart 7 | from pydantic import BaseModel, Field 8 | 9 | from examples.common import File, FileResp, Query 10 | from spectree import ExternalDocs, Response, SpecTree, Tag 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | logger = logging.getLogger() 14 | 15 | spec = SpecTree( 16 | "falcon-asgi", 17 | title="Demo Service", 18 | version="0.1.2", 19 | annotations=True, 20 | ) 21 | 22 | demo = Tag( 23 | name="demo", description="😊", externalDocs=ExternalDocs(url="https://github.com") 24 | ) 25 | 26 | 27 | class Resp(BaseModel): 28 | label: int = Field( 29 | ..., 30 | ge=0, 31 | le=9, 32 | ) 33 | score: float = Field( 34 | ..., 35 | gt=0, 36 | lt=1, 37 | ) 38 | 39 | 40 | class BadLuck(BaseModel): 41 | loc: str 42 | msg: str 43 | typ: str 44 | 45 | 46 | class Data(BaseModel): 47 | uid: str 48 | limit: int 49 | vip: bool 50 | 51 | 52 | class Ping: 53 | def check(self): 54 | pass 55 | 56 | @spec.validate(tags=[demo]) 57 | async def on_get(self, req, resp): 58 | """ 59 | health check 60 | """ 61 | self.check() 62 | logger.debug("ping <> pong") 63 | resp.media = {"msg": "pong"} 64 | 65 | 66 | class Classification: 67 | """ 68 | classification demo 69 | """ 70 | 71 | @spec.validate(tags=[demo]) 72 | async def on_get(self, req, resp, source, target): 73 | """ 74 | API summary 75 | 76 | description here: test information with `source` and `target` 77 | """ 78 | resp.media = {"msg": f"hello from {source} to {target}"} 79 | 80 | @spec.validate(resp=Response(HTTP_200=Resp, HTTP_403=BadLuck)) 81 | async def on_post(self, req, resp, source, target, query: Query, json: Data): 82 | """ 83 | post demo 84 | 85 | demo for `query`, `data`, `resp` 86 | """ 87 | logger.debug("%s => %s", source, target) 88 | logger.info(query) 89 | logger.info(json) 90 | if random() < 0.5: 91 | resp.status = falcon.HTTP_403 92 | resp.media = {"loc": "unknown", "msg": "bad luck", "typ": "random"} 93 | return 94 | resp.media = {"label": int(10 * random()), "score": random()} 95 | 96 | 97 | class FileUpload: 98 | """ 99 | file-handling demo 100 | """ 101 | 102 | @spec.validate(resp=Response(HTTP_200=FileResp), tags=["file-upload"]) 103 | async def on_post(self, req, resp, form: File): 104 | """ 105 | post multipart/form-data demo 106 | 107 | demo for 'form' 108 | """ 109 | file: BodyPart = req.context.form.file 110 | resp.media = {"filename": file.secure_filename, "type": file.content_type} 111 | 112 | 113 | if __name__ == "__main__": 114 | app = falcon.asgi.App() 115 | app.add_route("/ping", Ping()) 116 | app.add_route("/api/{source}/{target}", Classification()) 117 | app.add_route("/api/file_upload", FileUpload()) 118 | spec.register(app) 119 | 120 | uvicorn.run(app, log_level="info") 121 | -------------------------------------------------------------------------------- /examples/falcon_demo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from random import random 3 | from wsgiref import simple_server 4 | 5 | import falcon 6 | from falcon.media.multipart import BodyPart 7 | from pydantic import BaseModel, Field 8 | 9 | from examples.common import File, FileResp, Query 10 | from spectree import ExternalDocs, Response, SpecTree, Tag 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | logger = logging.getLogger() 14 | 15 | spec = SpecTree( 16 | "falcon", 17 | annotations=True, 18 | title="Demo Service", 19 | version="0.1.2", 20 | description="This is a demo service.", 21 | terms_of_service="https://github.io", 22 | contact={"name": "John", "email": "hello@github.com", "url": "https://github.com"}, 23 | license={"name": "MIT", "url": "https://opensource.org/licenses/MIT"}, 24 | ) 25 | 26 | demo = Tag( 27 | name="demo", description="😊", externalDocs=ExternalDocs(url="https://github.com") 28 | ) 29 | 30 | 31 | class Resp(BaseModel): 32 | label: int = Field( 33 | ..., 34 | ge=0, 35 | le=9, 36 | ) 37 | score: float = Field( 38 | ..., 39 | gt=0, 40 | lt=1, 41 | ) 42 | 43 | 44 | class BadLuck(BaseModel): 45 | loc: str 46 | msg: str 47 | typ: str 48 | 49 | 50 | class Data(BaseModel): 51 | uid: str 52 | limit: int 53 | vip: bool 54 | 55 | 56 | class Ping: 57 | def check(self): 58 | pass 59 | 60 | @spec.validate(tags=[demo]) 61 | def on_get(self, req, resp): 62 | """ 63 | health check 64 | """ 65 | self.check() 66 | logger.debug("ping <> pong") 67 | resp.media = {"msg": "pong"} 68 | 69 | 70 | class Classification: 71 | """ 72 | classification demo 73 | """ 74 | 75 | @spec.validate(tags=[demo]) 76 | def on_get(self, req, resp, source, target): 77 | """ 78 | API summary 79 | 80 | description here: test information with `source` and `target` 81 | """ 82 | resp.media = {"msg": f"hello from {source} to {target}"} 83 | 84 | @spec.validate(resp=Response(HTTP_200=Resp, HTTP_403=BadLuck)) 85 | def on_post(self, req, resp, source, target, query: Query, json: Data): 86 | """ 87 | post demo 88 | 89 | demo for `query`, `data`, `resp` 90 | """ 91 | logger.debug("%s => %s", source, target) 92 | logger.info(query) 93 | logger.info(json) 94 | if random() < 0.5: 95 | resp.status = falcon.HTTP_403 96 | resp.media = {"loc": "unknown", "msg": "bad luck", "typ": "random"} 97 | return 98 | resp.media = {"label": int(10 * random()), "score": random()} 99 | 100 | 101 | class FileUpload: 102 | """ 103 | file-handling demo 104 | """ 105 | 106 | @spec.validate(resp=Response(HTTP_200=FileResp), tags=["file-upload"]) 107 | def on_post(self, req, resp, form: File): 108 | """ 109 | post multipart/form-data demo 110 | 111 | demo for 'form' 112 | """ 113 | file: BodyPart = form.file 114 | resp.media = {"filename": file.secure_filename, "type": file.content_type} 115 | 116 | 117 | def create_app(): 118 | app = falcon.App() 119 | app.add_route("/ping", Ping()) 120 | app.add_route("/api/{source}/{target}", Classification()) 121 | app.add_route("/api/file_upload", FileUpload()) 122 | spec.register(app) 123 | return app 124 | 125 | 126 | if __name__ == "__main__": 127 | """ 128 | cmd: 129 | http :8000/ping 130 | http ':8000/api/zh/en?text=hi' uid=neo limit=1 vip=true 131 | """ 132 | app = create_app() 133 | httpd = simple_server.make_server("localhost", 8000, app) 134 | httpd.serve_forever() 135 | -------------------------------------------------------------------------------- /tests/test_plugin.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from spectree.utils import get_model_key, get_model_schema 4 | 5 | from .common import JSON, SECURITY_SCHEMAS, Cookies, Headers, Query, Resp 6 | from .test_plugin_falcon import api as falcon_api 7 | from .test_plugin_flask import api as flask_api 8 | from .test_plugin_flask import api_global_secure as flask_api_global_secure 9 | from .test_plugin_flask import api_secure as flask_api_secure 10 | from .test_plugin_flask_blueprint import api as flask_bp_api 11 | from .test_plugin_flask_view import api as flask_view_api 12 | from .test_plugin_starlette import api as starlette_api 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "api", 17 | [ 18 | pytest.param(flask_api, id="flask"), 19 | pytest.param(flask_bp_api, id="flask_blueprint"), 20 | pytest.param(flask_view_api, id="flask_view"), 21 | pytest.param(starlette_api, id="starlette"), 22 | pytest.param(falcon_api, id="falcon"), 23 | ], 24 | ) 25 | def test_plugin_spec(api, snapshot_json): 26 | models = { 27 | get_model_key(model=m): get_model_schema(model=m) 28 | for m in (Query, JSON, Resp, Cookies, Headers) 29 | } 30 | for name, schema in models.items(): 31 | schema.pop("definitions", None) 32 | schema.pop("$defs", None) 33 | assert api.spec["components"]["schemas"][name] == schema 34 | 35 | assert api.spec == snapshot_json(name="full_spec") 36 | 37 | 38 | def test_secure_spec(): 39 | assert [*flask_api_secure.spec["components"]["securitySchemes"].keys()] == [ 40 | scheme.name for scheme in SECURITY_SCHEMAS 41 | ] 42 | 43 | paths = flask_api_secure.spec["paths"] 44 | # iter paths 45 | for path, path_data in paths.items(): 46 | security = path_data["get"].get("security") 47 | # check empty-secure path 48 | if path == "/no-secure-ping": 49 | assert security is None 50 | else: 51 | # iter secure names and params 52 | for secure_key, secure_value in security[0].items(): 53 | # check secure names valid 54 | assert secure_key in [scheme.name for scheme in SECURITY_SCHEMAS] 55 | 56 | # check if flow exist 57 | if secure_value: 58 | scopes = [ 59 | scheme.data.flows["authorizationCode"]["scopes"] 60 | for scheme in SECURITY_SCHEMAS 61 | if scheme.name == secure_key 62 | ] 63 | 64 | assert set(secure_value).issubset(*scopes) 65 | 66 | 67 | def test_secure_global_spec(): 68 | assert [*flask_api_global_secure.spec["components"]["securitySchemes"].keys()] == [ 69 | scheme.name for scheme in SECURITY_SCHEMAS 70 | ] 71 | 72 | paths = flask_api_global_secure.spec["paths"] 73 | global_security = flask_api_global_secure.spec["security"] 74 | 75 | assert global_security == [{"auth_apiKey": []}] 76 | 77 | # iter paths 78 | for path, path_data in paths.items(): 79 | security = path_data["get"].get("security") 80 | # check empty-secure path 81 | if path == "/no-secure-override-ping": 82 | # check if it is defined overridden no auth specification 83 | assert security == [] 84 | elif path == "/oauth2-flows-override-ping": 85 | # check if it is defined overridden security specification 86 | assert security == [{"auth_oauth2": ["admin", "read"]}] 87 | elif path == "/global-secure-ping": 88 | # check if local security specification is missing, 89 | # when was not specified explicitly 90 | assert security is None 91 | elif path == "/security_and": 92 | # check if AND operation is supported 93 | assert security == [{"auth_apiKey": [], "auth_apiKey_backup": []}] 94 | elif path == "/security_or": 95 | # check if OR operation is supported 96 | assert security == [{"auth_apiKey": []}, {"auth_apiKey_backup": []}] 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | .idea/ 169 | .vscode/ 170 | 171 | # Ruff stuff: 172 | .ruff_cache/ 173 | 174 | # PyPI configuration file 175 | .pypirc 176 | -------------------------------------------------------------------------------- /tests/test_response.py: -------------------------------------------------------------------------------- 1 | from typing import List, get_type_hints 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from spectree.models import ValidationError 7 | from spectree.response import DEFAULT_CODE_DESC, Response 8 | from spectree.utils import gen_list_model 9 | 10 | from .common import JSON, DemoModel, get_model_path_key 11 | 12 | 13 | class NormalClass: 14 | pass 15 | 16 | 17 | def test_init_response(): 18 | for args, kwargs in [ 19 | ([200], {}), 20 | (["HTTP_110"], {}), 21 | ([], {"HTTP_200": NormalClass}), 22 | ([], {"HTTP_200": (NormalClass, "custom code description")}), 23 | ([], {"HTTP_200": (DemoModel, 1)}), 24 | ([], {"HTTP_200": (DemoModel,)}), 25 | ]: 26 | with pytest.raises(AssertionError): 27 | Response(*args, **kwargs) 28 | 29 | resp = Response("HTTP_200", HTTP_201=DemoModel) 30 | assert resp.has_model() 31 | assert resp.find_model(201) == DemoModel 32 | assert resp.code_descriptions.get("HTTP_200") is None 33 | assert resp.code_descriptions.get("HTTP_201") is None 34 | assert DemoModel in resp.models 35 | 36 | resp = Response( 37 | HTTP_200=None, 38 | HTTP_400=List[JSON], 39 | HTTP_401=DemoModel, 40 | HTTP_402=(None, "custom code description"), 41 | HTTP_403=(DemoModel, "custom code description"), 42 | ) 43 | expect_400_model = gen_list_model(JSON) 44 | assert resp.has_model() 45 | assert resp.find_model(200) is None 46 | assert type(resp.find_model(400)) is type(expect_400_model) and get_type_hints( 47 | resp.find_model(400) 48 | ) == get_type_hints(expect_400_model) 49 | assert resp.find_model(401) == DemoModel 50 | assert resp.find_model(402) is None 51 | assert resp.find_model(403) == DemoModel 52 | assert resp.code_descriptions.get("HTTP_200") is None 53 | assert resp.code_descriptions.get("HTTP_401") is None 54 | assert resp.code_descriptions.get("HTTP_402") == "custom code description" 55 | assert resp.code_descriptions.get("HTTP_403") == "custom code description" 56 | assert DemoModel in resp.models 57 | 58 | assert not Response().has_model() 59 | 60 | 61 | def test_response_add_model(): 62 | resp = Response() 63 | 64 | resp.add_model(201, DemoModel) 65 | 66 | assert resp.find_model(201) == DemoModel 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "replace, expected_model", 71 | [ 72 | pytest.param(True, JSON, id="replace-existing-model"), 73 | pytest.param(False, DemoModel, id="keep-existing-model"), 74 | ], 75 | ) 76 | def test_response_add_model_when_model_already_exists(replace, expected_model): 77 | resp = Response() 78 | 79 | resp.add_model(201, DemoModel) 80 | resp.add_model(201, JSON, replace=replace) 81 | 82 | assert resp.find_model(201) is expected_model 83 | 84 | 85 | def test_response_spec(): 86 | resp = Response( 87 | "HTTP_200", 88 | HTTP_201=DemoModel, 89 | HTTP_401=(DemoModel, "custom code description"), 90 | HTTP_402=(None, "custom code description"), 91 | ) 92 | resp.add_model(422, ValidationError) 93 | spec = resp.generate_spec() 94 | assert spec["200"]["description"] == DEFAULT_CODE_DESC["HTTP_200"] 95 | assert spec["201"]["description"] == DEFAULT_CODE_DESC["HTTP_201"] 96 | assert spec["422"]["description"] == DEFAULT_CODE_DESC["HTTP_422"] 97 | assert spec["401"]["description"] == "custom code description" 98 | assert spec["402"]["description"] == "custom code description" 99 | assert spec["201"]["content"]["application/json"]["schema"]["$ref"].split("/")[ 100 | -1 101 | ] == get_model_path_key(f"{DemoModel.__module__}.{DemoModel.__name__}") 102 | assert spec["422"]["content"]["application/json"]["schema"]["$ref"].split("/")[ 103 | -1 104 | ] == get_model_path_key(f"{ValidationError.__module__}.{ValidationError.__name__}") 105 | 106 | assert spec.get(200) is None 107 | assert spec.get(404) is None 108 | 109 | 110 | def test_list_model(): 111 | resp = Response(HTTP_200=List[JSON]) 112 | model = resp.find_model(200) 113 | expect_model = gen_list_model(JSON) 114 | assert resp.expect_list_result(200) 115 | assert not resp.expect_list_result(500) 116 | assert get_type_hints(model) == get_type_hints(expect_model) 117 | assert type(model) is type(expect_model) 118 | assert issubclass(model, BaseModel) 119 | data = [ 120 | {"name": "a", "limit": 1}, 121 | {"name": "b", "limit": 2}, 122 | ] 123 | instance = model.model_validate(data) 124 | items = instance.model_dump() 125 | if isinstance(items, dict): 126 | items = items["__root__"] 127 | for i, item in enumerate(items): 128 | obj = JSON.model_validate(item) 129 | assert obj.name == data[i]["name"] 130 | assert obj.limit == data[i]["limit"] 131 | -------------------------------------------------------------------------------- /tests/test_pydantic.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from typing import Any, List 4 | 5 | import pytest 6 | from pydantic import BaseModel 7 | 8 | from spectree._pydantic import ( 9 | generate_root_model, 10 | is_base_model, 11 | is_base_model_instance, 12 | is_partial_base_model_instance, 13 | is_root_model, 14 | is_root_model_instance, 15 | serialize_model_instance, 16 | ) 17 | 18 | DummyRootModel = generate_root_model(List[int], name="DummyRootModel") 19 | 20 | NestedRootModel = generate_root_model(DummyRootModel, name="NestedRootModel") 21 | 22 | 23 | class SimpleModel(BaseModel): 24 | user_id: int 25 | 26 | 27 | Users = generate_root_model(List[SimpleModel], name="Users") 28 | 29 | 30 | @dataclass 31 | class RootModelLookalike: 32 | __root__: List[str] 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "value, expected", 37 | [ 38 | (DummyRootModel, True), 39 | (DummyRootModel.model_validate([1, 2, 3]), False), 40 | (NestedRootModel, True), 41 | ( 42 | NestedRootModel.model_validate(DummyRootModel.model_validate([1, 2, 3])), 43 | False, 44 | ), 45 | (SimpleModel, False), 46 | (SimpleModel(user_id=1), False), 47 | (RootModelLookalike, False), 48 | (RootModelLookalike(__root__=["False"]), False), 49 | (list, False), 50 | ([1, 2, 3], False), 51 | (str, False), 52 | ("str", False), 53 | (int, False), 54 | (1, False), 55 | ], 56 | ) 57 | def test_is_root_model(value: Any, expected: bool): 58 | assert is_root_model(value) is expected 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "value, expected", 63 | [ 64 | (DummyRootModel, False), 65 | (DummyRootModel.model_validate([1, 2, 3]), True), 66 | (NestedRootModel, False), 67 | ( 68 | NestedRootModel.model_validate(DummyRootModel.model_validate([1, 2, 3])), 69 | True, 70 | ), 71 | (SimpleModel, False), 72 | (SimpleModel(user_id=1), False), 73 | (RootModelLookalike, False), 74 | (RootModelLookalike(__root__=["False"]), False), 75 | (list, False), 76 | ([1, 2, 3], False), 77 | (str, False), 78 | ("str", False), 79 | (int, False), 80 | (1, False), 81 | ], 82 | ) 83 | def test_is_root_model_instance(value, expected): 84 | assert is_root_model_instance(value) is expected 85 | 86 | 87 | @pytest.mark.parametrize( 88 | "value, expected", 89 | [ 90 | (DummyRootModel, True), 91 | (DummyRootModel.model_validate([1, 2, 3]), False), 92 | (NestedRootModel, True), 93 | ( 94 | NestedRootModel.model_validate(DummyRootModel.model_validate([1, 2, 3])), 95 | False, 96 | ), 97 | (SimpleModel, True), 98 | (SimpleModel(user_id=1), False), 99 | (RootModelLookalike, False), 100 | (RootModelLookalike(__root__=["False"]), False), 101 | (list, False), 102 | ([1, 2, 3], False), 103 | (str, False), 104 | ("str", False), 105 | (int, False), 106 | (1, False), 107 | ], 108 | ) 109 | def test_is_base_model(value, expected): 110 | assert is_base_model(value) is expected 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "value, expected", 115 | [ 116 | (DummyRootModel, False), 117 | (DummyRootModel.model_validate([1, 2, 3]), True), 118 | (NestedRootModel, False), 119 | ( 120 | NestedRootModel.model_validate(DummyRootModel.model_validate([1, 2, 3])), 121 | True, 122 | ), 123 | (SimpleModel, False), 124 | (SimpleModel(user_id=1), True), 125 | (RootModelLookalike, False), 126 | (RootModelLookalike(__root__=["False"]), False), 127 | (list, False), 128 | ([1, 2, 3], False), 129 | (str, False), 130 | ("str", False), 131 | (int, False), 132 | (1, False), 133 | ], 134 | ) 135 | def test_is_base_model_instance(value, expected): 136 | assert is_base_model_instance(value) is expected 137 | 138 | 139 | @pytest.mark.parametrize( 140 | "value, expected", 141 | [ 142 | (SimpleModel(user_id=1), True), 143 | ([0, SimpleModel(user_id=1)], True), 144 | ([1, 2, 3], False), 145 | ((0, SimpleModel(user_id=1)), True), 146 | ((0, 1), False), 147 | ({"test": SimpleModel(user_id=1)}, True), 148 | ({"test": [SimpleModel(user_id=1)]}, True), 149 | ([0, [1, SimpleModel(user_id=1)]], True), 150 | ], 151 | ) 152 | def test_is_partial_base_model_instance(value, expected): 153 | assert is_partial_base_model_instance(value) is expected, value 154 | 155 | 156 | @pytest.mark.parametrize( 157 | "value, expected", 158 | [ 159 | (SimpleModel(user_id=1), {"user_id": 1}), 160 | (DummyRootModel.model_validate([1, 2, 3]), [1, 2, 3]), 161 | ( 162 | NestedRootModel.model_validate(DummyRootModel.model_validate([1, 2, 3])), 163 | [1, 2, 3], 164 | ), 165 | ( 166 | Users.model_validate( 167 | [ 168 | SimpleModel(user_id=1), 169 | SimpleModel(user_id=2), 170 | ] 171 | ), 172 | [{"user_id": 1}, {"user_id": 2}], 173 | ), 174 | ], 175 | ) 176 | def test_serialize_model_instance(value, expected): 177 | assert json.loads(serialize_model_instance(value).data) == expected 178 | -------------------------------------------------------------------------------- /tests/test_base_plugin.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from contextlib import nullcontext as does_not_raise 4 | from dataclasses import dataclass 5 | from datetime import datetime 6 | from typing import Any, Union 7 | 8 | import pytest 9 | from pydantic import ValidationError 10 | 11 | from spectree._pydantic import SerializedPydanticResponse 12 | from spectree._types import OptionalModelType 13 | from spectree.plugins.base import ( 14 | RawResponsePayload, 15 | ResponseValidationResult, 16 | validate_response, 17 | ) 18 | from spectree.utils import gen_list_model 19 | from tests.common import JSON, ComplexResp, Resp, RootResp, StrDict 20 | 21 | RespList = gen_list_model(Resp) 22 | 23 | 24 | @dataclass(frozen=True) 25 | class DummyResponse: 26 | payload: bytes 27 | content_type: str 28 | 29 | 30 | @pytest.mark.parametrize( 31 | [ 32 | "validation_model", 33 | "response_payload", 34 | "expected_result", 35 | ], 36 | [ 37 | ( 38 | Resp, 39 | Resp(name="user1", score=[1, 2]), 40 | ResponseValidationResult({"name": "user1", "score": [1, 2]}), 41 | ), 42 | ( 43 | Resp, 44 | {"name": "user1", "score": [1, 2]}, 45 | ResponseValidationResult({"name": "user1", "score": [1, 2]}), 46 | ), 47 | ( 48 | Resp, 49 | RawResponsePayload({"name": "user1", "score": [1, 2]}), 50 | ResponseValidationResult({"name": "user1", "score": [1, 2]}), 51 | ), 52 | ( 53 | Resp, 54 | {}, 55 | ValidationError, 56 | ), 57 | ( 58 | Resp, 59 | {"name": "user1"}, 60 | ValidationError, 61 | ), 62 | ( 63 | RootResp, 64 | [1, 2, 3], 65 | ResponseValidationResult([1, 2, 3]), 66 | ), 67 | ( 68 | RootResp, 69 | RawResponsePayload([1, 2, 3]), 70 | ResponseValidationResult([1, 2, 3]), 71 | ), 72 | ( 73 | StrDict, 74 | StrDict.model_validate({"key1": "value1", "key2": "value2"}), 75 | ResponseValidationResult({"key1": "value1", "key2": "value2"}), 76 | ), 77 | ( 78 | RootResp, 79 | {"name": "user2", "limit": 1}, 80 | ResponseValidationResult({"name": "user2", "limit": 1}), 81 | ), 82 | ( 83 | RootResp, 84 | RawResponsePayload({"name": "user2", "limit": 1}), 85 | ResponseValidationResult({"name": "user2", "limit": 1}), 86 | ), 87 | ( 88 | RootResp, 89 | JSON(name="user3", limit=5), 90 | ResponseValidationResult({"name": "user3", "limit": 5}), 91 | ), 92 | ( 93 | RootResp, 94 | RootResp.model_validate(JSON(name="user4", limit=23)), 95 | ResponseValidationResult({"name": "user4", "limit": 23}), 96 | ), 97 | ( 98 | RootResp, 99 | {}, 100 | ValidationError, 101 | ), 102 | ( 103 | RespList, 104 | [], 105 | ResponseValidationResult([]), 106 | ), 107 | ( 108 | RespList, 109 | [{"name": "user5", "score": [5, 10]}], 110 | ResponseValidationResult([{"name": "user5", "score": [5, 10]}]), 111 | ), 112 | ( 113 | RespList, 114 | [Resp(name="user6", score=[10, 20]), Resp(name="user7", score=[30, 40])], 115 | ResponseValidationResult( 116 | [ 117 | {"name": "user6", "score": [10, 20]}, 118 | {"name": "user7", "score": [30, 40]}, 119 | ] 120 | ), 121 | ), 122 | ( 123 | None, 124 | {"user_id": "user1", "locale": "en-gb"}, 125 | ResponseValidationResult({"user_id": "user1", "locale": "en-gb"}), 126 | ), 127 | ( 128 | None, 129 | DummyResponse(payload="<html></html>".encode(), content_type="text/html"), 130 | ResponseValidationResult( 131 | DummyResponse( 132 | payload="<html></html>".encode(), content_type="text/html" 133 | ) 134 | ), 135 | ), 136 | ( 137 | ComplexResp, 138 | ComplexResp( 139 | date=datetime(2025, 1, 1), 140 | uuid=uuid.UUID("48b417cd-a884-4e54-9f5b-85c584e5ce77"), 141 | ), 142 | ResponseValidationResult( 143 | { 144 | "date": "2025-01-01T00:00:00", 145 | "uuid": "48b417cd-a884-4e54-9f5b-85c584e5ce77", 146 | } 147 | ), 148 | ), 149 | ], 150 | ) 151 | def test_validate_response( 152 | validation_model: OptionalModelType, 153 | response_payload: Any, 154 | expected_result: Union[ResponseValidationResult, ValidationError], 155 | ): 156 | runtime_expectation = ( 157 | pytest.raises(ValidationError) 158 | if expected_result == ValidationError 159 | else does_not_raise() 160 | ) 161 | with runtime_expectation: 162 | result = validate_response( 163 | validation_model=validation_model, 164 | response_payload=response_payload, 165 | ) 166 | assert isinstance(result, ResponseValidationResult) 167 | payload = ( 168 | ResponseValidationResult(json.loads(result.payload.data)) 169 | if isinstance(result.payload, SerializedPydanticResponse) 170 | else result 171 | ) 172 | assert payload == expected_result 173 | -------------------------------------------------------------------------------- /spectree/config.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from enum import Enum 3 | from typing import Any, Dict, List, Mapping, Optional, Union 4 | 5 | from pydantic import AnyUrl, BaseModel, ConfigDict, model_validator 6 | 7 | from spectree.models import SecurityScheme, Server 8 | from spectree.page import PAGE_TEMPLATES 9 | 10 | 11 | class ModeEnum(str, Enum): 12 | """the mode of the SpecTree validator""" 13 | 14 | #: includes undecorated routes and routes decorated by this instance 15 | normal = "normal" 16 | #: only includes routes decorated by this instance 17 | strict = "strict" 18 | #: includes all the routes 19 | greedy = "greedy" 20 | 21 | 22 | class Contact(BaseModel): 23 | """contact information""" 24 | 25 | #: name of the contact 26 | name: str 27 | #: contact url 28 | url: Optional[AnyUrl] = None 29 | #: contact email address 30 | email: Optional[str] = None 31 | 32 | 33 | class License(BaseModel): 34 | """license information""" 35 | 36 | #: name of the license 37 | name: str 38 | #: license url 39 | url: Optional[AnyUrl] = None 40 | 41 | 42 | class Configuration(BaseModel): 43 | """Global configuration.""" 44 | 45 | # OpenAPI configurations 46 | #: title of the service 47 | title: str = "Service API Document" 48 | #: service OpenAPI document description 49 | description: Optional[str] = None 50 | #: service version 51 | version: str = "0.1.0" 52 | #: terms of service url 53 | terms_of_service: Optional[AnyUrl] = None 54 | #: author contact information 55 | contact: Optional[Contact] = None 56 | #: license information 57 | license: Optional[License] = None 58 | 59 | # SpecTree configurations 60 | #: OpenAPI doc route path prefix (i.e. /apidoc/) 61 | path: str = "apidoc" 62 | #: OpenAPI file route path suffix (i.e. /apidoc/openapi.json) 63 | filename: str = "openapi.json" 64 | #: OpenAPI version (doesn't affect anything) 65 | openapi_version: str = "3.1.0" 66 | #: the mode of the SpecTree validator :class:`ModeEnum` 67 | mode: ModeEnum = ModeEnum.normal 68 | #: A dictionary of documentation page templates. The key is the 69 | #: name of the template, that is also used in the URL path, while the value is used 70 | #: to render the documentation page content. (Each page template should contain a 71 | #: `{spec_url}` placeholder, that'll be replaced by the actual OpenAPI spec URL in 72 | #: the rendered documentation page 73 | page_templates: Dict[str, str] = PAGE_TEMPLATES 74 | #: opt-in type annotation feature, see the README examples 75 | annotations: bool = True 76 | #: servers section of OAS :py:class:`spectree.models.Server` 77 | servers: Optional[List[Server]] = [] 78 | #: OpenAPI `securitySchemes` :py:class:`spectree.models.SecurityScheme` 79 | security_schemes: Optional[List[SecurityScheme]] = None 80 | #: OpenAPI `security` JSON at the global level 81 | security: Union[Dict[str, List[str]], List[Dict[str, List[str]]]] = {} 82 | # Swagger OAuth2 configs 83 | #: OAuth2 client id 84 | client_id: str = "" 85 | #: OAuth2 client secret 86 | client_secret: str = "" 87 | #: OAuth2 realm 88 | realm: str = "" 89 | #: OAuth2 app name 90 | app_name: str = "spectree_app" 91 | #: OAuth2 scope separator 92 | scope_separator: str = " " 93 | #: OAuth2 scopes 94 | scopes: List[str] = [] 95 | #: OAuth2 additional query string params 96 | additional_query_string_params: Dict[str, str] = {} 97 | #: OAuth2 use basic authentication with access code grant 98 | use_basic_authentication_with_access_code_grant: bool = False 99 | #: OAuth2 use PKCE with authorization code grant 100 | use_pkce_with_authorization_code_grant: bool = False 101 | 102 | model_config = ConfigDict(validate_assignment=True) 103 | 104 | @model_validator(mode="before") 105 | def convert_to_lower_case(cls, values: Mapping[str, Any]) -> Dict[str, Any]: 106 | return {k.lower(): v for k, v in values.items()} 107 | 108 | @property 109 | def spec_url(self) -> str: 110 | return f"/{self.path}/{self.filename}" 111 | 112 | def swagger_oauth2_config(self) -> Dict[str, str]: 113 | """ 114 | return the swagger UI OAuth2 configs 115 | 116 | ref: https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/ 117 | """ 118 | if self.client_secret: 119 | warnings.warn( 120 | "Do not use client_secret in production", UserWarning, stacklevel=1 121 | ) 122 | 123 | config = self.model_dump( 124 | include={ 125 | "client_id", 126 | "client_secret", 127 | "realm", 128 | "app_name", 129 | "scope_separator", 130 | "scopes", 131 | "additional_query_string_params", 132 | "use_basic_authentication_with_access_code_grant", 133 | "use_pkce_with_authorization_code_grant", 134 | } 135 | ) 136 | config["use_basic_authentication_with_access_code_grant"] = ( 137 | "true" 138 | if config["use_basic_authentication_with_access_code_grant"] 139 | else "false" 140 | ) 141 | config["use_pkce_with_authorization_code_grant"] = ( 142 | "true" if config["use_pkce_with_authorization_code_grant"] else "false" 143 | ) 144 | return config 145 | 146 | def openapi_info(self) -> Dict[str, str]: 147 | info = self.model_dump( 148 | include={ 149 | "title", 150 | "description", 151 | "version", 152 | "terms_of_service", 153 | "contact", 154 | "license", 155 | }, 156 | exclude_none=True, 157 | mode="json", 158 | ) 159 | if info.get("terms_of_service") is not None: 160 | info["termsOfService"] = info.pop("terms_of_service") 161 | return info 162 | -------------------------------------------------------------------------------- /spectree/models.py: -------------------------------------------------------------------------------- 1 | import re 2 | from enum import Enum 3 | from typing import Any, Dict, Optional, Sequence, Set 4 | 5 | from pydantic import ( 6 | BaseModel, 7 | ConfigDict, 8 | Field, 9 | RootModel, 10 | field_validator, 11 | model_validator, 12 | ) 13 | from pydantic_core import core_schema 14 | 15 | # OpenAPI names validation regexp 16 | OpenAPI_NAME_RE = re.compile(r"^[A-Za-z0-9-._]+") 17 | 18 | _EMPTY_SET: set[None] = set() 19 | 20 | 21 | class ExternalDocs(BaseModel): 22 | description: str = "" 23 | url: str 24 | 25 | 26 | class Tag(BaseModel): 27 | """OpenAPI tag object""" 28 | 29 | name: str 30 | description: str = "" 31 | externalDocs: Optional[ExternalDocs] = None 32 | 33 | def __str__(self): 34 | return self.name 35 | 36 | 37 | class ValidationErrorElement(BaseModel): 38 | """Model of a validation error response element.""" 39 | 40 | loc: Sequence[str] = Field( 41 | ..., 42 | title="Missing field name", 43 | ) 44 | msg: str = Field( 45 | ..., 46 | title="Error message", 47 | ) 48 | type: str = Field( 49 | ..., 50 | title="Error type", 51 | ) 52 | ctx: Optional[Dict[str, Any]] = Field( 53 | None, 54 | title="Error context", 55 | ) 56 | 57 | 58 | class ValidationError(RootModel[Sequence[ValidationErrorElement]]): 59 | """Model of a validation error response.""" 60 | 61 | 62 | class SecureType(str, Enum): 63 | HTTP = "http" 64 | API_KEY = "apiKey" 65 | OAUTH_TWO = "oauth2" 66 | OPEN_ID_CONNECT = "openIdConnect" 67 | 68 | 69 | class InType(str, Enum): 70 | HEADER = "header" 71 | QUERY = "query" 72 | COOKIE = "cookie" 73 | 74 | 75 | type_req_fields: Dict[SecureType, Set[str]] = { 76 | SecureType.HTTP: {"scheme"}, 77 | SecureType.API_KEY: {"name", "in"}, 78 | SecureType.OAUTH_TWO: {"flows"}, 79 | SecureType.OPEN_ID_CONNECT: {"openIdConnectUrl"}, 80 | } 81 | 82 | 83 | class SecuritySchemeData(BaseModel): 84 | """ 85 | Security scheme data 86 | https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#securitySchemeObject 87 | """ 88 | 89 | type: SecureType = Field(..., description="Secure scheme type") 90 | description: Optional[str] = Field( 91 | None, 92 | description="A short description for security scheme.", 93 | ) 94 | name: Optional[str] = Field( 95 | None, 96 | description="The name of the header, query or cookie parameter to be used.", 97 | ) 98 | field_in: Optional[InType] = Field( 99 | None, alias="in", description="The location of the API key." 100 | ) 101 | scheme: Optional[str] = Field( 102 | None, description="The name of the HTTP Authorization scheme." 103 | ) 104 | bearerFormat: Optional[str] = Field( 105 | None, 106 | description=( 107 | "A hint to the client to identify how the bearer token is formatted." 108 | ), 109 | ) 110 | flows: Optional[dict] = Field( 111 | None, 112 | description=( 113 | "Containing configuration information for the flow types supported." 114 | ), 115 | ) 116 | openIdConnectUrl: Optional[str] = Field( 117 | None, description="OpenId Connect URL to discover OAuth2 configuration values." 118 | ) 119 | 120 | @model_validator(mode="before") 121 | @classmethod 122 | def check_type_required_fields(cls, values: dict): 123 | exist_fields = {key for key in values if values[key]} 124 | if not values.get("type"): 125 | raise ValueError("Type field is required") 126 | 127 | if not type_req_fields.get(values["type"], _EMPTY_SET).issubset(exist_fields): 128 | raise ValueError( 129 | f"For `{values['type']}` type " 130 | f"`{', '.join(type_req_fields[values['type']])}` field(s) is required. " 131 | f"But only found `{', '.join(exist_fields)}`." 132 | ) 133 | return values 134 | 135 | model_config = ConfigDict( 136 | validate_assignment=True, 137 | validate_by_alias=True, 138 | validate_by_name=True, 139 | ) 140 | 141 | 142 | class SecurityScheme(BaseModel): 143 | """ 144 | Named security scheme 145 | """ 146 | 147 | name: str = Field( 148 | ..., 149 | description="Custom security scheme name. Can only contain - [A-Za-z0-9-._]", 150 | ) 151 | data: SecuritySchemeData = Field(..., description="Security scheme data") 152 | 153 | @field_validator("name") 154 | def check_name(cls, value: str): 155 | if not OpenAPI_NAME_RE.fullmatch(value): 156 | raise ValueError("Name does not match OpenAPI rules") 157 | return value 158 | 159 | model_config = ConfigDict(validate_assignment=True) 160 | 161 | 162 | class Server(BaseModel): 163 | """ 164 | Servers section of OAS 165 | """ 166 | 167 | url: str = Field( 168 | ..., 169 | description="""URL or path of API server 170 | 171 | (may be parametrized with using \"variables\" section - for more information, 172 | see: https://swagger.io/docs/specification/api-host-and-base-path/ )""", 173 | ) 174 | description: Optional[str] = Field( 175 | None, 176 | description="Custom server description for server URL", 177 | ) 178 | variables: Optional[dict] = Field( 179 | None, 180 | description="Variables for customizing server URL", 181 | ) 182 | 183 | model_config = ConfigDict(validate_assignment=True) 184 | 185 | 186 | class BaseFile: 187 | """ 188 | An uploaded file, will be assigned as the corresponding web framework's 189 | file object. 190 | """ 191 | 192 | @classmethod 193 | def __get_pydantic_json_schema__(cls, _core_schema: Dict[str, Any], _handler): 194 | return {"format": "binary", "type": "string"} 195 | 196 | @classmethod 197 | def __get_pydantic_core_schema__(cls, _source_type, _handler): 198 | return core_schema.with_info_plain_validator_function(cls.validate) 199 | 200 | @classmethod 201 | def validate(cls, value: Any, *_args, **_kwargs): 202 | return value 203 | -------------------------------------------------------------------------------- /spectree/plugins/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import ( 4 | TYPE_CHECKING, 5 | Any, 6 | Callable, 7 | Generic, 8 | Mapping, 9 | NamedTuple, 10 | Optional, 11 | TypeVar, 12 | Union, 13 | ) 14 | 15 | from spectree._pydantic import ( 16 | is_partial_base_model_instance, 17 | serialize_model_instance, 18 | ) 19 | from spectree._types import JsonType, ModelType, OptionalModelType 20 | from spectree.config import Configuration 21 | from spectree.response import Response 22 | 23 | if TYPE_CHECKING: 24 | # to avoid cyclic import 25 | from spectree.spec import SpecTree 26 | 27 | 28 | class Context(NamedTuple): 29 | query: list 30 | json: list 31 | form: list 32 | headers: dict 33 | cookies: dict 34 | 35 | 36 | BackendRoute = TypeVar("BackendRoute") 37 | 38 | 39 | class BasePlugin(Generic[BackendRoute]): 40 | """ 41 | Base plugin for SpecTree plugin classes. 42 | 43 | :param spectree: :class:`spectree.SpecTree` instance 44 | """ 45 | 46 | # ASYNC: is it an async framework or not 47 | ASYNC = False 48 | FORM_MIMETYPE = ("application/x-www-form-urlencoded", "multipart/form-data") 49 | 50 | def __init__(self, spectree: "SpecTree"): 51 | self.spectree = spectree 52 | self.config: Configuration = spectree.config 53 | self.logger = logging.getLogger(__name__) 54 | 55 | def register_route(self, app: Any): 56 | """ 57 | :param app: backend framework application 58 | 59 | register document API routes to application 60 | """ 61 | raise NotImplementedError 62 | 63 | def validate( 64 | self, 65 | func: Callable, 66 | query: Optional[ModelType], 67 | json: Optional[ModelType], 68 | form: Optional[ModelType], 69 | headers: Optional[ModelType], 70 | cookies: Optional[ModelType], 71 | resp: Optional[Response], 72 | before: Callable, 73 | after: Callable, 74 | validation_error_status: int, 75 | skip_validation: bool, 76 | force_resp_serialize: bool, 77 | *args: Any, 78 | **kwargs: Any, 79 | ): 80 | """ 81 | validate the request and response 82 | """ 83 | raise NotImplementedError 84 | 85 | def find_routes(self) -> BackendRoute: 86 | """ 87 | find the routes from application 88 | """ 89 | raise NotImplementedError 90 | 91 | def bypass(self, func: Callable, method: str) -> bool: 92 | """ 93 | :param func: route function (endpoint) 94 | :param method: HTTP method for this route function 95 | 96 | bypass some routes that shouldn't be shown in document 97 | """ 98 | raise NotImplementedError 99 | 100 | def parse_path( 101 | self, route: Any, path_parameter_descriptions: Optional[Mapping[str, str]] 102 | ): 103 | """ 104 | :param route: API routes 105 | :param path_parameter_descriptions: A dictionary of path parameter names and 106 | their description. 107 | 108 | parse URI path to get the variables in path 109 | """ 110 | raise NotImplementedError 111 | 112 | def parse_func(self, route: BackendRoute): 113 | """ 114 | :param route: API routes 115 | 116 | get the endpoint function from routes 117 | """ 118 | raise NotImplementedError 119 | 120 | def get_func_operation_id(self, func: Callable, path: str, method: str): 121 | """ 122 | :param func: route function (endpoint) 123 | :param method: URI path for this route function 124 | :param method: HTTP method for this route function 125 | 126 | get the operation_id value for the endpoint 127 | """ 128 | operation_id = getattr(func, "operation_id", None) 129 | if not operation_id: 130 | operation_id = f"{method.lower()}_{path.replace('/', '_')}" 131 | return operation_id 132 | 133 | 134 | @dataclass(frozen=True) 135 | class RawResponsePayload: 136 | payload: Union[JsonType, bytes] 137 | 138 | 139 | @dataclass(frozen=True) 140 | class ResponseValidationResult: 141 | payload: Any 142 | 143 | 144 | def validate_response( 145 | validation_model: OptionalModelType, 146 | response_payload: Any, 147 | force_serialize: bool = False, 148 | ) -> ResponseValidationResult: 149 | """Validate a given ``response_payload`` against a ``validation_model``. 150 | This does nothing if ``validation_model is None``. 151 | 152 | :param validation_model: Pydantic model used to validate the provided 153 | ``response_payload``. 154 | :param response_payload: Validated response payload. A :class:`RawResponsePayload` 155 | should be provided when the plugin view function returned an already 156 | JSON-serialized response payload. 157 | :param force_serialize: Always serialize the validation model instance. 158 | """ 159 | if not validation_model: 160 | return ResponseValidationResult(payload=response_payload) 161 | 162 | final_response_payload: Any = None 163 | skip_validation = False 164 | if isinstance(response_payload, RawResponsePayload): 165 | final_response_payload = response_payload.payload 166 | elif isinstance(response_payload, validation_model): 167 | skip_validation = True 168 | final_response_payload = serialize_model_instance(response_payload) 169 | else: 170 | # non-BaseModel response or partial BaseModel response 171 | final_response_payload = response_payload 172 | 173 | if not skip_validation: 174 | validator = ( 175 | validation_model.model_validate_json 176 | if isinstance(final_response_payload, bytes) 177 | else validation_model.model_validate 178 | ) 179 | validated_instance = validator(final_response_payload) 180 | # in case the response model contains (alias, default_none, unset fields) which 181 | # might not be what the users want, we only return the validated dict when 182 | # the response contains BaseModel or the user explicitly sets `force_serialize` 183 | if force_serialize or is_partial_base_model_instance(final_response_payload): 184 | final_response_payload = serialize_model_instance(validated_instance) 185 | 186 | return ResponseValidationResult(payload=final_response_payload) 187 | -------------------------------------------------------------------------------- /spectree/plugins/flask_plugin.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | import flask 4 | from flask import Blueprint, abort, current_app, jsonify, make_response, request 5 | from pydantic import ValidationError 6 | 7 | from spectree._pydantic import ( 8 | SerializedPydanticResponse, 9 | is_partial_base_model_instance, 10 | serialize_model_instance, 11 | ) 12 | from spectree._types import ModelType 13 | from spectree.plugins.base import Context, validate_response 14 | from spectree.plugins.werkzeug_utils import WerkzeugPlugin, flask_response_unpack 15 | from spectree.response import Response 16 | from spectree.utils import cached_type_hints, get_multidict_items 17 | 18 | 19 | class FlaskPlugin(WerkzeugPlugin): 20 | def get_current_app(self): 21 | return current_app 22 | 23 | def is_app_response(self, resp): 24 | return isinstance(resp, flask.Response) 25 | 26 | @staticmethod 27 | def is_blueprint(app: Any) -> bool: 28 | return isinstance(app, Blueprint) 29 | 30 | def request_validation(self, request, query, json, form, headers, cookies): 31 | """ 32 | req_query: werkzeug.datastructures.ImmutableMultiDict 33 | req_json: dict 34 | req_headers: werkzeug.datastructures.EnvironHeaders 35 | req_cookies: werkzeug.datastructures.ImmutableMultiDict 36 | """ 37 | req_query = get_multidict_items(request.args, query) 38 | req_headers = dict(iter(request.headers)) or {} 39 | req_cookies = get_multidict_items(request.cookies) 40 | has_data = request.method not in ("GET", "DELETE") 41 | # flask Request.mimetype is already normalized 42 | use_json = json and has_data and request.mimetype not in self.FORM_MIMETYPE 43 | use_form = form and has_data and request.mimetype in self.FORM_MIMETYPE 44 | 45 | request.context = Context( 46 | query.model_validate(req_query) if query else None, 47 | json.model_validate(request.get_json(silent=True) or {}) 48 | if use_json 49 | else None, 50 | form.model_validate(self.fill_form(request)) if use_form else None, 51 | headers.model_validate(req_headers) if headers else None, 52 | cookies.model_validate(req_cookies) if cookies else None, 53 | ) 54 | 55 | def validate_response( 56 | self, 57 | resp, 58 | resp_model: Optional[Response], 59 | skip_validation: bool, 60 | force_resp_serialize: bool, 61 | ): 62 | resp_validation_error = None 63 | payload, status, additional_headers = flask_response_unpack(resp) 64 | 65 | if self.is_app_response(payload): 66 | resp_status, resp_headers = payload.status_code, payload.headers 67 | payload = payload.get_data() 68 | # the inner flask.Response.status_code only takes effect when there is 69 | # no other status code 70 | if status == 200: 71 | status = resp_status 72 | # use the `Header` object to avoid deduplicated by `make_response` 73 | resp_headers.extend(additional_headers) 74 | additional_headers = resp_headers 75 | 76 | if not skip_validation and resp_model: 77 | try: 78 | response_validation_result = validate_response( 79 | validation_model=resp_model.find_model(status), 80 | response_payload=payload, 81 | force_serialize=force_resp_serialize, 82 | ) 83 | except ValidationError as err: 84 | errors = err.errors(include_context=False) 85 | response = make_response(errors, 500) 86 | resp_validation_error = err 87 | else: 88 | response = make_response( 89 | self.get_current_app().response_class( 90 | response_validation_result.payload.data, 91 | mimetype="application/json", 92 | ) 93 | if isinstance( 94 | response_validation_result.payload, 95 | SerializedPydanticResponse, 96 | ) 97 | else response_validation_result.payload, 98 | status, 99 | additional_headers, 100 | ) 101 | else: 102 | if is_partial_base_model_instance(payload): 103 | payload = self.get_current_app().response_class( 104 | serialize_model_instance(payload).data, 105 | mimetype="application/json", 106 | ) 107 | response = make_response(payload, status, additional_headers) 108 | 109 | return response, resp_validation_error 110 | 111 | def validate( 112 | self, 113 | func: Callable, 114 | query: Optional[ModelType], 115 | json: Optional[ModelType], 116 | form: Optional[ModelType], 117 | headers: Optional[ModelType], 118 | cookies: Optional[ModelType], 119 | resp: Optional[Response], 120 | before: Callable, 121 | after: Callable, 122 | validation_error_status: int, 123 | skip_validation: bool, 124 | force_resp_serialize: bool, 125 | *args: Any, 126 | **kwargs: Any, 127 | ): 128 | response, req_validation_error = None, None 129 | if not skip_validation: 130 | try: 131 | self.request_validation(request, query, json, form, headers, cookies) 132 | except ValidationError as err: 133 | req_validation_error = err 134 | errors = err.errors(include_context=False) 135 | response = make_response(jsonify(errors), validation_error_status) 136 | 137 | before(request, response, req_validation_error, None) 138 | 139 | if req_validation_error is not None: 140 | assert response # make mypy happy 141 | abort(response) 142 | 143 | if self.config.annotations: 144 | annotations = cached_type_hints(func) 145 | for name in ("query", "json", "form", "headers", "cookies"): 146 | if annotations.get(name): 147 | kwargs[name] = getattr( 148 | getattr(request, "context", None), name, None 149 | ) 150 | 151 | result = func(*args, **kwargs) 152 | 153 | response, resp_validation_error = self.validate_response( 154 | result, 155 | resp, 156 | skip_validation, 157 | force_resp_serialize, 158 | ) 159 | after(request, response, resp_validation_error, None) 160 | 161 | return response 162 | -------------------------------------------------------------------------------- /tests/test_spec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from falcon import App as FalconApp 3 | from flask import Flask 4 | from pydantic import BaseModel 5 | from starlette.applications import Starlette 6 | 7 | from spectree import Response 8 | from spectree.config import Configuration 9 | from spectree.models import Server, ValidationError 10 | from spectree.plugins.flask_plugin import FlaskPlugin 11 | from spectree.spec import SpecTree 12 | 13 | from .common import get_paths 14 | 15 | 16 | def backend_app(): 17 | return [ 18 | ("flask", Flask(__name__)), 19 | ("falcon", FalconApp()), 20 | ("starlette", Starlette()), 21 | ] 22 | 23 | 24 | def _get_spec(name, app, **kwargs): 25 | api = SpecTree(name, app=app, title=f"{name}", **kwargs) 26 | if name == "flask": 27 | with app.app_context(): 28 | spec = api.spec 29 | else: 30 | spec = api.spec 31 | 32 | return spec 33 | 34 | 35 | def test_spectree_init(): 36 | spec = SpecTree(path="docs") 37 | conf = Configuration() 38 | 39 | assert spec.config.title == conf.title 40 | assert spec.config.path == "docs" 41 | 42 | with pytest.raises(NotImplementedError): 43 | SpecTree(app=conf) 44 | 45 | 46 | @pytest.mark.parametrize("name, app", backend_app()) 47 | def test_register(name, app): 48 | api = SpecTree(name) 49 | api.register(app) 50 | 51 | 52 | @pytest.mark.parametrize("name, app", backend_app()) 53 | def test_spec_generate(name, app): 54 | spec = _get_spec(name, app) 55 | 56 | assert spec["info"]["title"] == name 57 | assert spec["paths"] == {} 58 | 59 | 60 | @pytest.mark.parametrize("name, app", backend_app()) 61 | def test_spec_servers_empty(name, app): 62 | spec = _get_spec(name, app) 63 | 64 | assert "servers" not in spec 65 | 66 | 67 | @pytest.mark.parametrize("name, app", backend_app()) 68 | def test_spec_servers_only(name, app): 69 | server1_url = "http://foo/bar" 70 | server2_url = "/foo/bar/" 71 | spec = _get_spec( 72 | name, app, servers=[Server(url=server1_url), Server(url=server2_url)] 73 | ) 74 | 75 | assert spec["servers"] == [ 76 | {"url": server1_url}, 77 | {"url": server2_url}, 78 | ] 79 | 80 | 81 | @pytest.mark.parametrize("name, app", backend_app()) 82 | def test_spec_servers_full(name, app): 83 | server1 = {"url": "http://foo/bar", "description": "Foo Bar"} 84 | server2 = {"url": "http://bar/foo/{lang}", "variables": {"lang": "en"}} 85 | spec = _get_spec( 86 | name, 87 | app, 88 | servers=[ 89 | Server(**server1), 90 | Server(**server2), 91 | ], 92 | ) 93 | 94 | expected = [] 95 | for server in [server1, server2]: 96 | expected_item = { 97 | "url": server.get("url"), 98 | } 99 | description = server.get("description", None) 100 | if description: 101 | expected_item["description"] = description 102 | variables = server.get("variables", None) 103 | if variables: 104 | expected_item["variables"] = variables 105 | expected.append(expected_item) 106 | 107 | assert spec["servers"] == expected 108 | 109 | 110 | api = SpecTree("flask") 111 | api_strict = SpecTree("flask", mode="strict") 112 | api_greedy = SpecTree("flask", mode="greedy") 113 | api_customize_backend = SpecTree(backend=FlaskPlugin) 114 | 115 | 116 | def create_app(): 117 | app = Flask(__name__) 118 | 119 | @app.route("/foo") 120 | @api.validate() 121 | def foo(): 122 | pass 123 | 124 | @app.route("/bar") 125 | @api_strict.validate() 126 | def bar(): 127 | pass 128 | 129 | @app.route("/lone", methods=["GET"]) 130 | def lone_get(): 131 | pass 132 | 133 | @app.route("/lone", methods=["POST"]) 134 | def lone_post(): 135 | pass 136 | 137 | return app 138 | 139 | 140 | def test_spec_bypass_mode(): 141 | app = create_app() 142 | api.register(app) 143 | with app.app_context(): 144 | assert get_paths(api.spec) == ["/foo", "/lone"] 145 | 146 | app = create_app() 147 | api_customize_backend.register(app) 148 | with app.app_context(): 149 | assert get_paths(api.spec) == ["/foo", "/lone"] 150 | 151 | app = create_app() 152 | api_greedy.register(app) 153 | with app.app_context(): 154 | assert get_paths(api_greedy.spec) == ["/bar", "/foo", "/lone"] 155 | 156 | app = create_app() 157 | api_strict.register(app) 158 | with app.app_context(): 159 | assert get_paths(api_strict.spec) == ["/bar"] 160 | 161 | 162 | def test_two_endpoints_with_the_same_path(): 163 | app = create_app() 164 | api.register(app) 165 | with app.app_context(): 166 | spec = api.spec 167 | 168 | http_methods = list(spec["paths"]["/lone"].keys()) 169 | http_methods.sort() 170 | assert http_methods == ["get", "post"] 171 | 172 | 173 | def test_model_for_validation_errors_specified(): 174 | api = SpecTree("flask") 175 | app = Flask(__name__) 176 | 177 | class CustomValidationError(BaseModel): 178 | pass 179 | 180 | @app.route("/foo") 181 | @api.validate(resp=Response(HTTP_200=None)) 182 | def foo(): 183 | pass 184 | 185 | @app.route("/bar") 186 | @api.validate(resp=Response(HTTP_200=None, HTTP_422=CustomValidationError)) 187 | def bar(): 188 | pass 189 | 190 | api.register(app) 191 | 192 | assert foo.resp.find_model(422) is ValidationError 193 | assert bar.resp.find_model(422) is CustomValidationError 194 | 195 | 196 | def test_global_model_for_validation_errors_specified(): 197 | class GlobalValidationError(BaseModel): 198 | pass 199 | 200 | class RouteValidationError(BaseModel): 201 | pass 202 | 203 | api = SpecTree("flask", validation_error_model=GlobalValidationError) 204 | app = Flask(__name__) 205 | 206 | @app.route("/foo") 207 | @api.validate(resp=Response(HTTP_200=None)) 208 | def foo(): 209 | pass 210 | 211 | @app.route("/bar") 212 | @api.validate(resp=Response(HTTP_200=None, HTTP_422=RouteValidationError)) 213 | def bar(): 214 | pass 215 | 216 | api.register(app) 217 | 218 | assert foo.resp.find_model(422) is GlobalValidationError 219 | assert bar.resp.find_model(422) is RouteValidationError 220 | 221 | 222 | @pytest.mark.parametrize( 223 | ["override_operation_id", "expected_operation_id"], 224 | [(None, "get__foo"), ("getFoo", "getFoo")], 225 | ) 226 | def test_operation_id_override(override_operation_id, expected_operation_id): 227 | api = SpecTree("flask") 228 | app = Flask(__name__) 229 | 230 | @app.route("/foo") 231 | @api.validate(operation_id=override_operation_id) 232 | def foo(): 233 | pass 234 | 235 | api.register(app) 236 | 237 | with app.app_context(): 238 | operation_id = api.spec["paths"]["/foo"]["get"]["operationId"] 239 | assert operation_id == expected_operation_id 240 | -------------------------------------------------------------------------------- /spectree/plugins/quart_plugin.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Optional 3 | 4 | import quart 5 | from pydantic import ValidationError 6 | from quart import Blueprint, abort, current_app, jsonify, make_response, request 7 | 8 | from spectree._pydantic import ( 9 | SerializedPydanticResponse, 10 | is_partial_base_model_instance, 11 | serialize_model_instance, 12 | ) 13 | from spectree._types import ModelType 14 | from spectree.plugins.base import Context, validate_response 15 | from spectree.plugins.werkzeug_utils import WerkzeugPlugin, flask_response_unpack 16 | from spectree.response import Response 17 | from spectree.utils import cached_type_hints, get_multidict_items 18 | 19 | 20 | class QuartPlugin(WerkzeugPlugin): 21 | FORM_MIMETYPE = ("application/x-www-form-urlencoded", "multipart/form-data") 22 | ASYNC = True 23 | 24 | def get_current_app(self): 25 | return current_app 26 | 27 | def is_app_response(self, resp): 28 | return isinstance(resp, quart.Response) 29 | 30 | @staticmethod 31 | def is_blueprint(app: Any) -> bool: 32 | return isinstance(app, Blueprint) 33 | 34 | async def request_validation(self, request, query, json, form, headers, cookies): 35 | """ 36 | req_query: werkzeug.datastructures.ImmutableMultiDict 37 | req_json: dict 38 | req_headers: werkzeug.datastructures.EnvironHeaders 39 | req_cookies: werkzeug.datastructures.ImmutableMultiDict 40 | """ 41 | req_query = get_multidict_items(request.args) 42 | req_headers = dict(iter(request.headers)) or {} 43 | req_cookies = get_multidict_items(request.cookies) or {} 44 | has_data = request.method not in ("GET", "DELETE") 45 | use_json = json and has_data and request.mimetype == "application/json" 46 | use_form = ( 47 | form 48 | and has_data 49 | and any([x in request.mimetype for x in self.FORM_MIMETYPE]) 50 | ) 51 | 52 | request.context = Context( 53 | query.model_validate(req_query) if query else None, 54 | json.model_validate(await request.get_json(silent=True) or {}) 55 | if use_json 56 | else None, 57 | form.model_validate(self.fill_form(request)) if use_form else None, 58 | headers.model_validate(req_headers) if headers else None, 59 | cookies.model_validate(req_cookies) if cookies else None, 60 | ) 61 | 62 | async def validate_response( 63 | self, 64 | resp, 65 | resp_model: Optional[Response], 66 | skip_validation: bool, 67 | force_resp_serialize: bool, 68 | ): 69 | resp_validation_error = None 70 | payload, status, additional_headers = flask_response_unpack(resp) 71 | 72 | if self.is_app_response(payload): 73 | resp_status, resp_headers = payload.status_code, payload.headers 74 | payload = await payload.get_data() 75 | # the inner flask.Response.status_code only takes effect when there is 76 | # no other status code 77 | if status == 200: 78 | status = resp_status 79 | # use the `Header` object to avoid deduplicated by `make_response` 80 | resp_headers.extend(additional_headers) 81 | additional_headers = resp_headers 82 | 83 | if not skip_validation and resp_model: 84 | try: 85 | response_validation_result = validate_response( 86 | validation_model=resp_model.find_model(status), 87 | response_payload=payload, 88 | force_serialize=force_resp_serialize, 89 | ) 90 | except ValidationError as err: 91 | errors = err.errors(include_context=False) 92 | response = await make_response(errors, 500) 93 | resp_validation_error = err 94 | else: 95 | response = await make_response( 96 | self.get_current_app().response_class( 97 | response_validation_result.payload.data, 98 | mimetype="application/json", 99 | ) 100 | if isinstance( 101 | response_validation_result.payload, 102 | SerializedPydanticResponse, 103 | ) 104 | else response_validation_result.payload, 105 | status, 106 | additional_headers, 107 | ) 108 | else: 109 | if is_partial_base_model_instance(payload): 110 | payload = self.get_current_app().response_class( 111 | serialize_model_instance(payload).data, 112 | mimetype="application/json", 113 | ) 114 | response = await make_response(payload, status, additional_headers) 115 | 116 | return response, resp_validation_error 117 | 118 | async def validate( 119 | self, 120 | func: Callable, 121 | query: Optional[ModelType], 122 | json: Optional[ModelType], 123 | form: Optional[ModelType], 124 | headers: Optional[ModelType], 125 | cookies: Optional[ModelType], 126 | resp: Optional[Response], 127 | before: Callable, 128 | after: Callable, 129 | validation_error_status: int, 130 | skip_validation: bool, 131 | force_resp_serialize: bool, 132 | *args: Any, 133 | **kwargs: Any, 134 | ): 135 | response, req_validation_error, resp_validation_error = None, None, None 136 | if not skip_validation: 137 | try: 138 | await self.request_validation( 139 | request, query, json, form, headers, cookies 140 | ) 141 | except ValidationError as err: 142 | req_validation_error = err 143 | errors = err.errors(include_context=False) 144 | response = await make_response(jsonify(errors), validation_error_status) 145 | 146 | before(request, response, req_validation_error, None) 147 | if req_validation_error: 148 | assert response # make mypy happy 149 | abort(response) # type: ignore 150 | 151 | if self.config.annotations: 152 | annotations = cached_type_hints(func) 153 | for name in ("query", "json", "form", "headers", "cookies"): 154 | if annotations.get(name): 155 | kwargs[name] = getattr( 156 | getattr(request, "context", None), name, None 157 | ) 158 | 159 | result = ( 160 | await func(*args, **kwargs) 161 | if inspect.iscoroutinefunction(func) 162 | else func(*args, **kwargs) 163 | ) 164 | 165 | response, resp_validation_error = await self.validate_response( 166 | result, 167 | resp, 168 | skip_validation, 169 | force_resp_serialize, 170 | ) 171 | after(request, response, resp_validation_error, None) 172 | 173 | return response 174 | -------------------------------------------------------------------------------- /spectree/page.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | ONLINE_PAGE_TEMPLATES: Dict[str, str] = { 4 | # https://github.com/Redocly/redoc 5 | "redoc": """ 6 | <!DOCTYPE html> 7 | <html> 8 | <head> 9 | <title>ReDoc 10 | 11 | 12 | 13 | 14 | 17 | 23 | 24 | 25 | 26 | 27 | 28 | """, 29 | # https://swagger.io 30 | "swagger": """ 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | SwaggerUI 39 | 40 | 41 | 42 | 43 | 44 |
45 | 46 | 47 | 48 | 77 | 78 | """, 79 | "swagger/oauth2-redirect.html": """ 80 | 81 | 82 | 83 | Swagger UI: OAuth2 Redirect 84 | 85 | 86 | 153 | 154 | """, 155 | "scalar": """ 156 | 157 | 158 | 159 | API Reference 160 | 161 | 164 | 169 | 170 | 171 | 172 | 176 | 184 | 185 | 186 | """, 187 | } 188 | 189 | try: 190 | from offapi import OpenAPITemplate 191 | 192 | PAGE_TEMPLATES = { 193 | "redoc": OpenAPITemplate.REDOC.value, 194 | "swagger": OpenAPITemplate.SWAGGER.value, 195 | "scalar": OpenAPITemplate.SCALAR.value, 196 | } 197 | except ImportError: 198 | PAGE_TEMPLATES = ONLINE_PAGE_TEMPLATES 199 | -------------------------------------------------------------------------------- /tests/quart_imports/dry_plugin_quart.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.common import UserXmlData 4 | 5 | pytestmark = pytest.mark.anyio 6 | 7 | 8 | @pytest.mark.parametrize("response_format", ["json", "xml"]) 9 | async def test_quart_skip_validation(client, response_format: str): 10 | client.set_cookie( 11 | "quart", "pub", "abcdefg", secure=True, httponly=True, samesite="Strict" 12 | ) 13 | 14 | resp = await client.post( 15 | f"/api/user_skip/quart?order=1&response_format={response_format}", 16 | json=dict(name="quart", limit=10), 17 | headers={"Content-Type": "application/json"}, 18 | ) 19 | resp_json = await resp.json 20 | assert resp.status_code == 200, resp_json 21 | assert resp.headers.get("X-Validation") is None 22 | assert resp.headers.get("X-API") == "OK" 23 | if response_format == "json": 24 | assert resp.content_type == "application/json" 25 | assert resp_json["name"] == "quart" 26 | assert resp_json["x_score"] == sorted(resp_json["x_score"], reverse=True) 27 | else: 28 | assert resp.content_type == "text/xml" 29 | user_xml_data = UserXmlData.parse_xml(await resp.get_data(as_text=True)) 30 | assert user_xml_data.name == "quart" 31 | assert user_xml_data.score == sorted(user_xml_data.score, reverse=True) 32 | 33 | 34 | async def test_quart_return_model(client): 35 | client.set_cookie( 36 | "quart", "pub", "abcdefg", secure=True, httponly=True, samesite="Strict" 37 | ) 38 | 39 | resp = await client.post( 40 | "/api/user_model/quart?order=1", 41 | json=dict(name="quart", limit=10), 42 | headers={"Content-Type": "application/json"}, 43 | ) 44 | resp_json = await resp.json 45 | assert resp.status_code == 200, resp_json 46 | assert resp.headers.get("X-Validation") is None 47 | assert resp.headers.get("X-API") == "OK" 48 | assert resp_json["name"] == "quart" 49 | assert resp_json["score"] == sorted(resp_json["score"], reverse=True) 50 | 51 | 52 | async def test_quart_return_string_status(client): 53 | resp = await client.get("/api/return_string_status") 54 | assert resp.status_code == 200 55 | text = await resp.get_data(as_text=True) 56 | assert text == "Response text string" 57 | 58 | 59 | @pytest.mark.parametrize( 60 | ["test_client_and_api", "expected_status_code"], 61 | [ 62 | pytest.param( 63 | {"api_kwargs": {}, "endpoint_kwargs": {}}, 64 | 422, 65 | id="default-global-status-without-override", 66 | ), 67 | pytest.param( 68 | {"api_kwargs": {}, "endpoint_kwargs": {"validation_error_status": 400}}, 69 | 400, 70 | id="default-global-status-with-override", 71 | ), 72 | pytest.param( 73 | {"api_kwargs": {"validation_error_status": 418}, "endpoint_kwargs": {}}, 74 | 418, 75 | id="overridden-global-status-without-override", 76 | ), 77 | pytest.param( 78 | { 79 | "api_kwargs": {"validation_error_status": 400}, 80 | "endpoint_kwargs": {"validation_error_status": 418}, 81 | }, 82 | 418, 83 | id="overridden-global-status-with-override", 84 | ), 85 | ], 86 | indirect=["test_client_and_api"], 87 | ) 88 | async def test_quart_validation_error_response_status_code( 89 | test_client_and_api, expected_status_code 90 | ): 91 | app_client, _ = test_client_and_api 92 | resp = await app_client.get("/ping") 93 | assert resp.status_code == expected_status_code 94 | 95 | 96 | @pytest.mark.parametrize( 97 | "test_client_and_api, expected_doc_pages", 98 | [ 99 | pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), 100 | pytest.param( 101 | {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, 102 | ["custom_page"], 103 | id="custom-page-templates", 104 | ), 105 | ], 106 | indirect=["test_client_and_api"], 107 | ) 108 | async def test_quart_doc(test_client_and_api, expected_doc_pages): 109 | client, api = test_client_and_api 110 | 111 | resp = await client.get("/apidoc/openapi.json") 112 | assert (await resp.json) == api.spec 113 | 114 | for doc_page in expected_doc_pages: 115 | resp = await client.get(f"/apidoc/{doc_page}/") 116 | assert resp.status_code == 200 117 | 118 | resp = await client.get(f"/apidoc/{doc_page}") 119 | assert resp.status_code == 308 120 | 121 | 122 | async def test_quart_validate(client): 123 | resp = await client.get("/ping") 124 | assert resp.status_code == 422 125 | assert resp.headers.get("X-Error") == "Validation Error" 126 | 127 | resp = await client.get("/ping", headers={"lang": "en-US"}) 128 | resp_json = await resp.json 129 | assert resp_json == {"msg": "pong"} 130 | assert resp.headers.get("X-Error") is None 131 | assert resp.headers.get("X-Validation") == "Pass" 132 | 133 | resp = await client.post("api/user/quart") 134 | assert resp.status_code == 422 135 | assert resp.headers.get("X-Error") == "Validation Error" 136 | 137 | client.set_cookie( 138 | "quart", "pub", "abcdefg", secure=True, httponly=True, samesite="Strict" 139 | ) 140 | for fragment in ("user", "user_annotated"): 141 | resp = await client.post( 142 | f"/api/{fragment}/quart?order=1", 143 | json=dict(name="quart", limit=10), 144 | headers={"Content-Type": "application/json"}, 145 | ) 146 | resp_json = await resp.json 147 | assert resp.status_code == 200, resp_json 148 | assert resp.headers.get("X-Validation") is None 149 | assert resp.headers.get("X-API") == "OK" 150 | assert resp_json["name"] == "quart" 151 | assert resp_json["score"] == sorted(resp_json["score"], reverse=True) 152 | 153 | resp = await client.post( 154 | f"/api/{fragment}/quart?order=0", 155 | json=dict(name="quart", limit=10), 156 | headers={"Content-Type": "application/json"}, 157 | ) 158 | resp_json = await resp.json 159 | assert resp.status_code == 200, resp_json 160 | assert resp_json["score"] == sorted(resp_json["score"], reverse=False) 161 | 162 | 163 | async def test_quart_no_response(client): 164 | resp = await client.get("/api/no_response") 165 | assert resp.status_code == 200 166 | 167 | resp = await client.post("/api/no_response", json={"name": "foo", "limit": 1}) 168 | assert resp.status_code == 200 169 | 170 | 171 | async def test_quart_list_json_request(client): 172 | resp = await client.post("/api/list_json", json=[{"name": "foo", "limit": 1}]) 173 | assert resp.status_code == 200 174 | 175 | 176 | @pytest.mark.parametrize("pre_serialize", [False, True]) 177 | async def test_quart_return_list_request(client, pre_serialize: bool): 178 | resp = await client.get(f"/api/return_list?pre_serialize={int(pre_serialize)}") 179 | assert resp.status_code == 200 180 | json = await resp.json 181 | assert json == [ 182 | {"name": "user1", "limit": 1}, 183 | {"name": "user2", "limit": 2}, 184 | ] 185 | 186 | 187 | async def test_quart_custom_error(client): 188 | # request error 189 | resp = await client.post("/api/custom_error", json={"foo": "bar"}) 190 | assert resp.status_code == 422 191 | 192 | # response error 193 | resp = await client.post("/api/custom_error", json={"foo": "foo"}) 194 | assert resp.status_code == 500 195 | 196 | 197 | async def test_quart_forced_serializer(client): 198 | resp = await client.get("/api/force_serialize") 199 | assert resp.status_code == 200 200 | json = await resp.json 201 | assert json["name"] == "flask" 202 | assert json["score"] == [1, 2, 3] 203 | assert "comment" not in json 204 | -------------------------------------------------------------------------------- /spectree/response.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from http import HTTPStatus 3 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union 4 | 5 | from spectree._pydantic import is_pydantic_model 6 | from spectree._types import ( 7 | BaseModelSubclassType, 8 | ModelType, 9 | NamingStrategy, 10 | OptionalModelType, 11 | ) 12 | from spectree.utils import gen_list_model, get_model_key, parse_code 13 | 14 | # according to https://tools.ietf.org/html/rfc2616#section-10 15 | # https://tools.ietf.org/html/rfc7231#section-6.1 16 | # https://developer.mozilla.org/sv-SE/docs/Web/HTTP/Status 17 | DEFAULT_CODE_DESC: Dict[str, str] = dict( 18 | (f"HTTP_{status.value}", f"{status.phrase}") for status in HTTPStatus 19 | ) 20 | # additional status codes and fixes 21 | if sys.version_info < (3, 13): 22 | # https://docs.python.org/3/library/http.html 23 | # https://datatracker.ietf.org/doc/html/rfc9110.html 24 | for code, phrase in [ 25 | ("HTTP_418", "I'm a teapot"), 26 | ("HTTP_425", "Too Early"), 27 | ]: 28 | DEFAULT_CODE_DESC[code] = phrase 29 | DEFAULT_CODE_DESC["HTTP_422"] = "Unprocessable Content" 30 | 31 | 32 | class Response: 33 | """ 34 | response object 35 | 36 | :param codes: list of HTTP status code, format('HTTP_[0-9]{3}'), 'HTTP_200' 37 | :param code_models: dict of : <`pydantic.BaseModel`> or None or 38 | a two element tuple of (<`pydantic.BaseModel`> or None) as the first item and 39 | a custom status code description string as the second item. 40 | 41 | examples: 42 | 43 | >>> from typing import List 44 | >>> from spectree.response import Response 45 | >>> from pydantic import BaseModel 46 | ... 47 | >>> class User(BaseModel): 48 | ... id: int 49 | ... 50 | >>> response = Response("HTTP_200") 51 | >>> response = Response(HTTP_200=None) 52 | >>> response = Response(HTTP_200=User) 53 | >>> response = Response(HTTP_200=(User, "status code description")) 54 | >>> response = Response(HTTP_200=List[User]) 55 | >>> response = Response(HTTP_200=(List[User], "status code description")) 56 | """ 57 | 58 | def __init__( 59 | self, 60 | *codes: str, 61 | **code_models: Union[ 62 | OptionalModelType, 63 | Tuple[OptionalModelType, str], 64 | Type[List[BaseModelSubclassType]], 65 | Tuple[Type[List[BaseModelSubclassType]], str], 66 | ], 67 | ) -> None: 68 | self.codes: List[str] = [] 69 | 70 | for code in codes: 71 | assert code in DEFAULT_CODE_DESC, "invalid HTTP status code" 72 | self.codes.append(code) 73 | 74 | self.code_models: Dict[str, ModelType] = {} 75 | self.code_descriptions: Dict[str, Optional[str]] = {} 76 | self.code_list_item_types: Dict[str, ModelType] = {} 77 | for code, model_and_description in code_models.items(): 78 | assert code in DEFAULT_CODE_DESC, "invalid HTTP status code" 79 | description: Optional[str] = None 80 | if isinstance(model_and_description, tuple): 81 | assert len(model_and_description) == 2, ( 82 | "unexpected number of arguments for a tuple of " 83 | "`pydantic.BaseModel` and HTTP status code description" 84 | ) 85 | model = model_and_description[0] 86 | description = model_and_description[1] 87 | else: 88 | model = model_and_description 89 | 90 | if model: 91 | origin_type = getattr(model, "__origin__", None) 92 | if origin_type is list or origin_type is List: 93 | # type is List[BaseModel] 94 | list_item_type = model.__args__[0] # type: ignore 95 | model = gen_list_model(list_item_type) 96 | self.code_list_item_types[code] = list_item_type 97 | assert is_pydantic_model(model), ( 98 | f"invalid `pydantic.BaseModel`: {model}" 99 | ) 100 | assert description is None or isinstance(description, str), ( 101 | "invalid HTTP status code description" 102 | ) 103 | self.code_models[code] = model 104 | else: 105 | self.codes.append(code) 106 | 107 | if description: 108 | self.code_descriptions[code] = description 109 | 110 | def add_model( 111 | self, 112 | code: int, 113 | model: ModelType, 114 | replace: bool = True, 115 | description: Optional[str] = None, 116 | ) -> None: 117 | """Add data *model* for the specified status *code*. 118 | 119 | :param code: An HTTP status code. 120 | :param model: A `pydantic.BaseModel`. 121 | :param replace: If `True` and a data *model* already exists for the given 122 | status *code* it will be replaced, if `False` the existing data *model* 123 | will be retained. 124 | :param description: The description string for the code. 125 | """ 126 | if not replace and self.find_model(code): 127 | return 128 | code_name: str = f"HTTP_{code}" 129 | self.code_models[code_name] = model 130 | if description: 131 | self.code_descriptions[code_name] = description 132 | 133 | def has_model(self) -> bool: 134 | """ 135 | :returns: boolean -- does this response has models or not 136 | """ 137 | return bool(self.code_models) 138 | 139 | def find_model(self, code: int) -> OptionalModelType: 140 | """ 141 | :param code: ``r'\\d{3}'`` 142 | """ 143 | return self.code_models.get(f"HTTP_{code}") 144 | 145 | def expect_list_result(self, code: int) -> bool: 146 | """Check whether a specific HTTP code expects a list result. 147 | 148 | :param code: Status code (example: 200) 149 | """ 150 | return f"HTTP_{code}" in self.code_list_item_types 151 | 152 | def get_expected_list_item_type(self, code: int) -> ModelType: 153 | """Get the expected list result item type. 154 | 155 | :param code: Status code (example: 200) 156 | """ 157 | return self.code_list_item_types[f"HTTP_{code}"] 158 | 159 | def get_code_description(self, code: str) -> str: 160 | """Get the description of the given status code. 161 | 162 | :param code: Status code string, format('HTTP_[0-9]_{3}'), 'HTTP_200'. 163 | :returns: The status code's description. 164 | """ 165 | return self.code_descriptions.get(code) or DEFAULT_CODE_DESC[code] 166 | 167 | @property 168 | def models(self) -> Iterable[ModelType]: 169 | """ 170 | :returns: dict_values -- all the models in this response 171 | """ 172 | return self.code_models.values() 173 | 174 | def generate_spec( 175 | self, naming_strategy: NamingStrategy = get_model_key 176 | ) -> Dict[str, Any]: 177 | """ 178 | generate the spec for responses 179 | 180 | :returns: JSON 181 | """ 182 | responses: Dict[str, Any] = {} 183 | for code in self.codes: 184 | responses[parse_code(code)] = { 185 | "description": self.get_code_description(code) 186 | } 187 | 188 | for code, model in self.code_models.items(): 189 | model_name = naming_strategy(model) 190 | responses[parse_code(code)] = { 191 | "description": self.get_code_description(code), 192 | "content": { 193 | "application/json": { 194 | "schema": {"$ref": f"#/components/schemas/{model_name}"} 195 | } 196 | }, 197 | } 198 | 199 | return responses 200 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import xml.etree.ElementTree as ET 3 | from dataclasses import dataclass 4 | from datetime import datetime 5 | from enum import Enum, IntEnum 6 | from typing import Any, Dict, List, Optional, Union, cast 7 | 8 | from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator 9 | 10 | from spectree import BaseFile, ExternalDocs, SecurityScheme, SecuritySchemeData, Tag 11 | from spectree._pydantic import generate_root_model 12 | from spectree.utils import hash_module_path 13 | 14 | api_tag = Tag( 15 | name="API", description="🐱", externalDocs=ExternalDocs(url="https://pypi.org") 16 | ) 17 | 18 | 19 | class Order(IntEnum): 20 | """Order enum""" 21 | 22 | asce = 0 23 | desc = 1 24 | 25 | 26 | class Query(BaseModel): 27 | order: Order 28 | 29 | 30 | class QueryList(BaseModel): 31 | ids: List[int] 32 | 33 | 34 | class FormFileUpload(BaseModel): 35 | file: Optional[BaseFile] = None 36 | other: str 37 | 38 | 39 | class Form(BaseModel): 40 | name: str 41 | limit: str 42 | 43 | 44 | class JSON(BaseModel): 45 | name: str 46 | limit: int 47 | 48 | 49 | class OptionalJSON(BaseModel): 50 | name: Optional[str] = None 51 | limit: Optional[int] = None 52 | 53 | 54 | ListJSON = generate_root_model(List[JSON], name="ListJSON") 55 | 56 | StrDict = generate_root_model(Dict[str, str], name="StrDict") 57 | 58 | 59 | class OptionalAliasResp(BaseModel): 60 | alias_schema: str = Field(alias="schema") 61 | name: Optional[str] = None 62 | limit: Optional[int] = None 63 | 64 | 65 | class Resp(BaseModel): 66 | name: str 67 | score: List[int] 68 | 69 | 70 | class RespFromAttrs(BaseModel): 71 | model_config = ConfigDict(from_attributes=True) 72 | name: str 73 | score: List[int] 74 | 75 | 76 | @dataclass 77 | class RespObject: 78 | name: str 79 | score: List[int] 80 | comment: str 81 | 82 | 83 | RootResp = generate_root_model(Union[JSON, List[int]], name="RootResp") 84 | 85 | 86 | class ComplexResp(BaseModel): 87 | date: datetime 88 | uuid: uuid.UUID 89 | 90 | 91 | class Language(str, Enum): 92 | """Language enum""" 93 | 94 | en = "en-US" 95 | zh = "zh-CN" 96 | 97 | 98 | class Headers(BaseModel): 99 | lang: Language 100 | 101 | @model_validator(mode="before") 102 | @classmethod 103 | def lower_keys(cls, data: Any): 104 | return {key.lower(): value for key, value in data.items()} 105 | 106 | 107 | class Cookies(BaseModel): 108 | pub: str 109 | 110 | 111 | class DemoModel(BaseModel): 112 | uid: int 113 | limit: int 114 | name: str = Field(..., description="user name") 115 | 116 | 117 | class DemoQuery(BaseModel): 118 | names1: List[str] = Field(...) 119 | names2: List[str] = Field( 120 | ..., json_schema_extra=dict(style="matrix", explode=True, non_keyword="dummy") 121 | ) # type: ignore 122 | 123 | 124 | class CustomError(BaseModel): 125 | foo: str 126 | 127 | # @field_validator("foo") 128 | @field_validator("foo") 129 | def value_must_be_foo(cls, value): 130 | if value != "foo": 131 | # this is not JSON serializable if included in the error context 132 | raise ValueError("value must be foo") 133 | return value 134 | 135 | 136 | class Numeric(BaseModel): 137 | normal: float = 0.0 138 | large: float = Field(default=float("inf")) 139 | small: float = Field(default=float("-inf")) 140 | unknown: float = Field(default=float("nan")) 141 | 142 | 143 | class DefaultEnumValue(BaseModel): 144 | langs: frozenset[Language] = frozenset((Language.en,)) 145 | 146 | 147 | def get_paths(spec): 148 | paths = [] 149 | for path in spec["paths"]: 150 | if spec["paths"][path]: 151 | paths.append(path) 152 | 153 | paths.sort() 154 | return paths 155 | 156 | 157 | # data from example - https://swagger.io/docs/specification/authentication/ 158 | SECURITY_SCHEMAS = [ 159 | SecurityScheme( 160 | name="auth_apiKey", 161 | data=SecuritySchemeData.model_validate( 162 | {"type": "apiKey", "name": "Authorization", "in": "header"} 163 | ), 164 | ), 165 | SecurityScheme( 166 | name="auth_apiKey_backup", 167 | data=SecuritySchemeData.model_validate( 168 | {"type": "apiKey", "name": "Authorization", "in": "header"} 169 | ), 170 | ), 171 | SecurityScheme( 172 | name="auth_BasicAuth", 173 | data=SecuritySchemeData.model_validate({"type": "http", "scheme": "basic"}), 174 | ), 175 | SecurityScheme( 176 | name="auth_BearerAuth", 177 | data=SecuritySchemeData.model_validate({"type": "http", "scheme": "bearer"}), 178 | ), 179 | SecurityScheme( 180 | name="auth_openID", 181 | data=SecuritySchemeData.model_validate( 182 | { 183 | "type": "openIdConnect", 184 | "openIdConnectUrl": "https://example.com/.well-known/openid-cfg", 185 | } 186 | ), 187 | ), 188 | SecurityScheme( 189 | name="auth_oauth2", 190 | data=SecuritySchemeData.model_validate( 191 | { 192 | "type": "oauth2", 193 | "flows": { 194 | "authorizationCode": { 195 | "authorizationUrl": "https://example.com/oauth/authorize", 196 | "tokenUrl": "https://example.com/oauth/token", 197 | "scopes": { 198 | "read": "Grants read access", 199 | "write": "Grants write access", 200 | "admin": "Grants access to admin operations", 201 | }, 202 | }, 203 | }, 204 | } 205 | ), 206 | ), 207 | ] 208 | WRONG_SECURITY_SCHEMAS_DATA = [ 209 | { 210 | "name": "auth_apiKey_name", 211 | "data": {"type": "apiKey", "name": "Authorization"}, 212 | }, 213 | { 214 | "name": "auth_apiKey_in", 215 | "data": {"type": "apiKey", "in": "header"}, 216 | }, 217 | { 218 | "name": "auth_BasicAuth_scheme", 219 | "data": {"type": "http"}, 220 | }, 221 | { 222 | "name": "auth_openID_openIdConnectUrl", 223 | "data": {"type": "openIdConnect"}, 224 | }, 225 | {"name": "auth_oauth2_flows", "data": {"type": "oauth2"}}, 226 | {"name": "empty_Data", "data": {}}, 227 | {"name": "wrong_Data", "data": {"x": "y"}}, 228 | ] 229 | 230 | 231 | def get_model_path_key(model_path: str) -> str: 232 | """ 233 | generate short hashed prefix for module path (instead of its path to avoid 234 | code-structure leaking) 235 | 236 | :param model_path: `str` model path in string 237 | """ 238 | 239 | model_path, _, model_name = model_path.rpartition(".") 240 | if not model_path: 241 | return model_name 242 | 243 | return f"{model_name}.{hash_module_path(module_path=model_path)}" 244 | 245 | 246 | def get_root_resp_data(pre_serialize: bool, return_what: str): 247 | assert return_what in ( 248 | "RootResp_JSON", 249 | "RootResp_List", 250 | "JSON", 251 | "List", 252 | "ModelList", 253 | ) 254 | data: Any 255 | if return_what == "RootResp_JSON": 256 | data = RootResp.model_validate(JSON(name="user1", limit=1)) 257 | elif return_what == "RootResp_List": 258 | data = RootResp.model_validate([1, 2, 3, 4]) 259 | elif return_what == "JSON": 260 | data = JSON(name="user1", limit=1) 261 | elif return_what == "List": 262 | data = [1, 2, 3, 4] 263 | pre_serialize = False 264 | elif return_what == "ModelList": 265 | data = [JSON(name="user1", limit=1)] 266 | pre_serialize = False 267 | else: 268 | raise AssertionError() 269 | if pre_serialize: 270 | data = data.model_dump() 271 | if "__root__" in data: 272 | data = data["__root__"] 273 | return data 274 | 275 | 276 | @dataclass(frozen=True) 277 | class UserXmlData: 278 | name: str 279 | score: List[int] 280 | 281 | @staticmethod 282 | def parse_xml(data: str) -> "UserXmlData": 283 | root = ET.fromstring(data) 284 | assert root.tag == "user" 285 | children = [node for node in root] 286 | assert len(children) == 2 287 | assert children[0].tag == "name" 288 | assert children[1].tag == "x_score" 289 | return UserXmlData( 290 | name=cast(str, children[0].text), 291 | score=[int(entry) for entry in cast(str, children[1].text).split(",")], 292 | ) 293 | 294 | def dump_xml(self) -> str: 295 | return f""" 296 | 297 | {self.name} 298 | {",".join(str(entry) for entry in self.score)} 299 | 300 | """ 301 | -------------------------------------------------------------------------------- /spectree/plugins/starlette_plugin.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections import namedtuple 3 | from functools import partial 4 | from json import JSONDecodeError 5 | from typing import Any, Callable, Optional 6 | 7 | from pydantic import ValidationError 8 | from starlette.convertors import CONVERTOR_TYPES 9 | from starlette.requests import Request 10 | from starlette.responses import HTMLResponse, JSONResponse 11 | from starlette.routing import compile_path 12 | 13 | from spectree._pydantic import ( 14 | SerializedPydanticResponse, 15 | generate_root_model, 16 | serialize_model_instance, 17 | ) 18 | from spectree._types import ModelType 19 | from spectree.plugins.base import ( 20 | BasePlugin, 21 | Context, 22 | RawResponsePayload, 23 | validate_response, 24 | ) 25 | from spectree.response import Response 26 | from spectree.utils import cached_type_hints, get_multidict_items_starlette 27 | 28 | METHODS = {"get", "post", "put", "patch", "delete"} 29 | Route = namedtuple("Route", ["path", "methods", "func"]) 30 | 31 | 32 | _PydanticResponseModel = generate_root_model(Any, name="_PydanticResponseModel") 33 | 34 | 35 | def PydanticResponse(content): 36 | class _PydanticResponse(JSONResponse): 37 | def render(self, content) -> bytes: 38 | self._model_class = content.__class__ 39 | return serialize_model_instance( 40 | _PydanticResponseModel.model_validate(content) 41 | ).data 42 | 43 | return _PydanticResponse(content) 44 | 45 | 46 | class StarlettePlugin(BasePlugin): 47 | ASYNC = True 48 | 49 | def __init__(self, spectree): 50 | super().__init__(spectree) 51 | 52 | self.conv2type = {conv: typ for typ, conv in CONVERTOR_TYPES.items()} 53 | 54 | def register_route(self, app): 55 | app.add_route( 56 | self.config.spec_url, 57 | lambda request: JSONResponse(self.spectree.spec), 58 | ) 59 | 60 | for ui in self.config.page_templates: 61 | app.add_route( 62 | f"/{self.config.path}/{ui}", 63 | lambda request, ui=ui: HTMLResponse( 64 | self.config.page_templates[ui].format( 65 | spec_url=self.config.filename, 66 | spec_path=self.config.path, 67 | **self.config.swagger_oauth2_config(), 68 | ) 69 | ), 70 | ) 71 | 72 | async def request_validation(self, request, query, json, form, headers, cookies): 73 | has_data = request.method not in ("GET", "DELETE") 74 | content_type = request.headers.get("content-type", "").lower() 75 | use_json = json and has_data and content_type == "application/json" 76 | use_form = ( 77 | form and has_data and any([x in content_type for x in self.FORM_MIMETYPE]) 78 | ) 79 | request.context = Context( 80 | query.model_validate(get_multidict_items_starlette(request.query_params)) 81 | if query 82 | else None, 83 | json.model_validate(await request.json() or {}) if use_json else None, 84 | form.model_validate(await request.form() or {}) if use_form else None, 85 | headers.model_validate(request.headers) if headers else None, 86 | cookies.model_validate(request.cookies) if cookies else None, 87 | ) 88 | 89 | async def validate( 90 | self, 91 | func: Callable, 92 | query: Optional[ModelType], 93 | json: Optional[ModelType], 94 | form: Optional[ModelType], 95 | headers: Optional[ModelType], 96 | cookies: Optional[ModelType], 97 | resp: Optional[Response], 98 | before: Callable, 99 | after: Callable, 100 | validation_error_status: int, 101 | skip_validation: bool, 102 | force_resp_serialize: bool, 103 | *args: Any, 104 | **kwargs: Any, 105 | ): 106 | if isinstance(args[0], Request): 107 | instance, request = None, args[0] 108 | else: 109 | instance, request = args[:2] 110 | 111 | response = None 112 | req_validation_error = resp_validation_error = json_decode_error = None 113 | 114 | if not skip_validation: 115 | try: 116 | await self.request_validation( 117 | request, query, json, form, headers, cookies 118 | ) 119 | except ValidationError as err: 120 | req_validation_error = err 121 | response = JSONResponse( 122 | err.errors(include_context=False), validation_error_status 123 | ) 124 | except JSONDecodeError as err: 125 | json_decode_error = err 126 | self.logger.info( 127 | "%s Validation Error", 128 | validation_error_status, 129 | extra={"spectree_json_decode_error": str(err)}, 130 | ) 131 | response = JSONResponse( 132 | {"error_msg": str(err)}, validation_error_status 133 | ) 134 | 135 | before(request, response, req_validation_error, instance) 136 | if req_validation_error or json_decode_error: 137 | return response 138 | 139 | if self.config.annotations: 140 | annotations = cached_type_hints(func) 141 | for name in ("query", "json", "form", "headers", "cookies"): 142 | if annotations.get(name): 143 | kwargs[name] = getattr( 144 | getattr(request, "context", None), name, None 145 | ) 146 | 147 | if inspect.iscoroutinefunction(func): 148 | response = await func(*args, **kwargs) 149 | else: 150 | response = func(*args, **kwargs) 151 | 152 | if ( 153 | not skip_validation 154 | and resp 155 | and response 156 | and not ( 157 | isinstance(response, JSONResponse) 158 | and hasattr(response, "_model_class") 159 | and response._model_class == resp.find_model(response.status_code) 160 | ) 161 | ): 162 | try: 163 | response_validation_result = validate_response( 164 | validation_model=resp.find_model(response.status_code), 165 | response_payload=RawResponsePayload(payload=response.body), 166 | force_serialize=force_resp_serialize, 167 | ) 168 | except ValidationError as err: 169 | response = JSONResponse( 170 | err.errors(include_context=False), 171 | 500, 172 | ) 173 | resp_validation_error = err 174 | else: 175 | # replace the body of the response if it was serialized during validation 176 | if isinstance( 177 | response_validation_result.payload, SerializedPydanticResponse 178 | ): 179 | response.body = response_validation_result.payload.data 180 | 181 | after(request, response, resp_validation_error, instance) 182 | 183 | return response 184 | 185 | def find_routes(self): 186 | routes = [] 187 | 188 | def parse_route(app, prefix=""): 189 | # :class:`starlette.staticfiles.StaticFiles` doesn't have routes 190 | if not app.routes: 191 | return 192 | for route in app.routes: 193 | if route.path.startswith(f"/{self.config.path}"): 194 | continue 195 | 196 | func = route.app 197 | if isinstance(func, partial): 198 | try: 199 | func = func.__wrapped__ 200 | except AttributeError as err: 201 | self.logger.warning( 202 | "failed to get the wrapped func %s: %s", func, err 203 | ) 204 | 205 | if inspect.isclass(func): 206 | for method in METHODS: 207 | if getattr(func, method, None): 208 | routes.append( 209 | Route( 210 | f"{prefix}{route.path}", 211 | {method.upper()}, 212 | getattr(func, method), 213 | ) 214 | ) 215 | elif inspect.isfunction(func): 216 | routes.append( 217 | Route(f"{prefix}{route.path}", route.methods, route.endpoint) 218 | ) 219 | else: 220 | parse_route(route, prefix=f"{prefix}{route.path}") 221 | 222 | parse_route(self.spectree.app) 223 | return routes 224 | 225 | def bypass(self, func, method): 226 | return method in ["HEAD", "OPTIONS"] 227 | 228 | def parse_func(self, route): 229 | for method in route.methods or ["GET"]: 230 | yield method, route.func 231 | 232 | def parse_path(self, route, path_parameter_descriptions): 233 | _, path, variables = compile_path(route.path) 234 | parameters = [] 235 | 236 | for name, conv in variables.items(): 237 | schema = None 238 | typ = self.conv2type[conv] 239 | if typ == "int": 240 | schema = {"type": "integer", "format": "int32"} 241 | elif typ == "float": 242 | schema = { 243 | "type": "number", 244 | "format": "float", 245 | } 246 | elif typ == "path": 247 | schema = { 248 | "type": "string", 249 | "format": "path", 250 | } 251 | elif typ == "str": 252 | schema = {"type": "string"} 253 | 254 | description = ( 255 | path_parameter_descriptions.get(name, "") 256 | if path_parameter_descriptions 257 | else "" 258 | ) 259 | parameters.append( 260 | { 261 | "name": name, 262 | "in": "path", 263 | "required": True, 264 | "schema": schema, 265 | "description": description, 266 | } 267 | ) 268 | 269 | return path, parameters 270 | -------------------------------------------------------------------------------- /spectree/plugins/werkzeug_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Iterator, List, Mapping, Optional, Tuple, Union 3 | 4 | from werkzeug.datastructures import Headers 5 | from werkzeug.routing import parse_converter_args 6 | 7 | from spectree.plugins.base import BasePlugin 8 | from spectree.utils import get_multidict_items 9 | 10 | RE_FLASK_RULE = re.compile( 11 | r""" 12 | (?P[^<]*) # static rule data 13 | < 14 | (?: 15 | (?P[a-zA-Z_][a-zA-Z0-9_]*) # converter name 16 | (?:\((?P.*?)\))? # converter arguments 17 | \: # variable delimiter 18 | )? 19 | (?P[a-zA-Z_][a-zA-Z0-9_]*) # variable name 20 | > 21 | """, 22 | re.VERBOSE, 23 | ) 24 | 25 | 26 | def werkzeug_parse_rule( 27 | rule: str, 28 | ) -> Iterator[Tuple[Optional[str], Optional[str], str]]: 29 | """A copy of werkzeug.parse_rule which is now removed. 30 | 31 | Parse a rule and return it as generator. Each iteration yields tuples 32 | in the form ``(converter, arguments, variable)``. If the converter is 33 | `None` it's a static url part, otherwise it's a dynamic one. 34 | """ 35 | pos = 0 36 | end = len(rule) 37 | do_match = RE_FLASK_RULE.match 38 | used_names = set() 39 | while pos < end: 40 | m = do_match(rule, pos) 41 | if m is None: 42 | break 43 | data = m.groupdict() 44 | if data["static"]: 45 | yield None, None, data["static"] 46 | variable = data["variable"] 47 | converter = data["converter"] or "default" 48 | if variable in used_names: 49 | raise ValueError(f"variable name {variable!r} used twice.") 50 | used_names.add(variable) 51 | yield converter, data["args"] or None, variable 52 | pos = m.end() 53 | if pos < end: 54 | remaining = rule[pos:] 55 | if ">" in remaining or "<" in remaining: 56 | raise ValueError(f"malformed url rule: {rule!r}") 57 | yield None, None, remaining 58 | 59 | 60 | def flask_response_unpack( 61 | resp: Any, 62 | ) -> Tuple[Any, int, Union[List[Tuple[str, str]], Headers]]: 63 | """Parse Flask response object into a tuple of (payload, status_code, headers).""" 64 | status = 200 65 | headers: List[Tuple[str, str]] = [] 66 | payload = None 67 | if not isinstance(resp, tuple): 68 | return resp, status, headers 69 | if len(resp) == 1: 70 | payload = resp[0] 71 | elif len(resp) == 2: 72 | payload = resp[0] 73 | if isinstance(resp[1], int): 74 | status = resp[1] 75 | else: 76 | headers = resp[1] 77 | elif len(resp) == 3: 78 | payload, status, headers = resp 79 | else: 80 | raise ValueError( 81 | f"Invalid return tuple: {resp}, expect (body,), (body, status), " 82 | "(body, headers), or (body, status, headers)." 83 | ) 84 | return payload, status, headers 85 | 86 | 87 | class WerkzeugPlugin(BasePlugin): 88 | blueprint_state = None 89 | 90 | def get_current_app(self): 91 | raise NotImplementedError() 92 | 93 | def is_app_response(self, resp) -> bool: 94 | raise NotImplementedError() 95 | 96 | def make_response_with_addition(self, *args): 97 | """This method is derived from Flask's `make_response` method.""" 98 | current_app = self.get_current_app() 99 | if len(args) == 1: 100 | args = args[0] 101 | return current_app.make_response(args) 102 | 103 | @staticmethod 104 | def is_blueprint(app) -> bool: 105 | raise NotImplementedError() 106 | 107 | def find_routes(self): 108 | # https://werkzeug.palletsprojects.com/en/stable/routing/#werkzeug.routing.Rule 109 | for rule in self.get_current_app().url_map.iter_rules(): 110 | if any( 111 | str(rule).startswith(path) 112 | for path in (f"/{self.config.path}", "/static") 113 | ): 114 | continue 115 | if rule.endpoint.startswith("openapi"): 116 | continue 117 | if getattr(rule, "websocket", False): 118 | continue 119 | if ( 120 | self.blueprint_state 121 | and self.blueprint_state.url_prefix 122 | and ( 123 | not str(rule).startswith(self.blueprint_state.url_prefix) 124 | or str(rule).startswith( 125 | "/".join([self.blueprint_state.url_prefix, self.config.path]) 126 | ) 127 | ) 128 | ): 129 | continue 130 | yield rule 131 | 132 | def bypass(self, func, method): 133 | return method in ["HEAD", "OPTIONS"] 134 | 135 | def parse_func(self, route: Any): 136 | if self.blueprint_state: 137 | func = self.blueprint_state.app.view_functions[route.endpoint] 138 | else: 139 | func = self.get_current_app().view_functions[route.endpoint] 140 | 141 | # view class: https://flask.palletsprojects.com/en/1.1.x/views/ 142 | view_cls = getattr(func, "view_class", None) 143 | if view_cls: 144 | for method in route.methods: 145 | view = getattr(view_cls, method.lower(), None) 146 | if view: 147 | yield method, view 148 | else: 149 | for method in route.methods: 150 | yield method, func 151 | 152 | def parse_path( 153 | self, 154 | route: Optional[Mapping[str, str]], 155 | path_parameter_descriptions: Optional[Mapping[str, str]], 156 | ) -> Tuple[str, list]: 157 | subs = [] 158 | parameters = [] 159 | 160 | for converter, arguments, variable in werkzeug_parse_rule(str(route)): 161 | if converter is None: 162 | subs.append(variable) 163 | continue 164 | subs.append(f"{{{variable}}}") 165 | 166 | args: tuple = () 167 | kwargs: dict = {} 168 | 169 | if arguments: 170 | args, kwargs = parse_converter_args(arguments) 171 | 172 | schema = None 173 | if converter == "any": 174 | schema = { 175 | "type": "string", 176 | "enum": args, 177 | } 178 | elif converter == "int": 179 | schema = { 180 | "type": "integer", 181 | "format": "int32", 182 | } 183 | if "max" in kwargs: 184 | schema["maximum"] = kwargs["max"] 185 | if "min" in kwargs: 186 | schema["minimum"] = kwargs["min"] 187 | elif converter == "float": 188 | schema = { 189 | "type": "number", 190 | "format": "float", 191 | } 192 | elif converter == "uuid": 193 | schema = { 194 | "type": "string", 195 | "format": "uuid", 196 | } 197 | elif converter == "path": 198 | schema = { 199 | "type": "string", 200 | "format": "path", 201 | } 202 | elif converter == "string": 203 | schema = { 204 | "type": "string", 205 | } 206 | for prop in ["length", "maxLength", "minLength"]: 207 | if prop in kwargs: 208 | schema[prop] = kwargs[prop] 209 | elif converter == "default": 210 | schema = {"type": "string"} 211 | 212 | description = ( 213 | path_parameter_descriptions.get(variable, "") 214 | if path_parameter_descriptions 215 | else "" 216 | ) 217 | parameters.append( 218 | { 219 | "name": variable, 220 | "in": "path", 221 | "required": True, 222 | "schema": schema, 223 | "description": description, 224 | } 225 | ) 226 | 227 | return "".join(subs), parameters 228 | 229 | def fill_form(self, request) -> dict: 230 | req_data = get_multidict_items(request.form) 231 | req_data.update(get_multidict_items(request.files) if request.files else {}) 232 | return req_data 233 | 234 | def register_route(self, app): 235 | app.add_url_rule( 236 | rule=self.config.spec_url, 237 | endpoint=f"openapi_{self.config.path}", 238 | view_func=lambda: self.get_current_app().json.response(self.spectree.spec), 239 | ) 240 | 241 | if self.is_blueprint(app): 242 | 243 | def gen_doc_page(ui): 244 | spec_url = self.config.spec_url 245 | if self.blueprint_state.url_prefix is not None: 246 | spec_url = "/".join( 247 | ( 248 | self.blueprint_state.url_prefix.rstrip("/"), 249 | self.config.spec_url.lstrip("/"), 250 | ) 251 | ) 252 | 253 | return self.config.page_templates[ui].format( 254 | spec_url=spec_url, 255 | spec_path=self.config.path, 256 | **self.config.swagger_oauth2_config(), 257 | ) 258 | 259 | for ui in self.config.page_templates: 260 | app.add_url_rule( 261 | rule=f"/{self.config.path}/{ui}/", 262 | endpoint=f"openapi_{self.config.path}_{ui.replace('.', '_')}", 263 | view_func=lambda ui=ui: gen_doc_page(ui), 264 | ) 265 | 266 | app.record(lambda state: setattr(self, "blueprint_state", state)) 267 | else: 268 | for ui in self.config.page_templates: 269 | app.add_url_rule( 270 | rule=f"/{self.config.path}/{ui}/", 271 | endpoint=f"openapi_{self.config.path}_{ui}", 272 | view_func=lambda ui=ui: self.config.page_templates[ui].format( 273 | spec_url=self.config.spec_url, 274 | spec_path=self.config.path, 275 | **self.config.swagger_oauth2_config(), 276 | ), 277 | ) 278 | -------------------------------------------------------------------------------- /tests/flask_imports/dry_plugin_flask.py: -------------------------------------------------------------------------------- 1 | import io 2 | import random 3 | import re 4 | 5 | import pytest 6 | 7 | from tests.common import JSON, UserXmlData 8 | 9 | 10 | @pytest.mark.parametrize("response_format", ["json", "xml"]) 11 | def test_flask_skip_validation(client, response_format: str): 12 | client.set_cookie( 13 | key="pub", value="abcdefg", secure=True, httponly=True, samesite="Strict" 14 | ) 15 | assert response_format in ("json", "xml") 16 | resp = client.post( 17 | f"/api/user_skip/flask?order=1&response_format={response_format}", 18 | json=dict(name="flask", limit=10), 19 | content_type="application/json", 20 | ) 21 | assert resp.status_code == 200 22 | assert resp.headers.get("X-Validation") is None 23 | assert resp.headers.get("X-API") == "OK" 24 | if response_format == "json": 25 | assert resp.content_type == "application/json" 26 | assert resp.json["name"] == "flask" 27 | assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True) 28 | else: 29 | assert resp.content_type == "text/xml" 30 | user_xml_data = UserXmlData.parse_xml(resp.text) 31 | assert user_xml_data.name == "flask" 32 | assert user_xml_data.score == sorted(user_xml_data.score, reverse=True) 33 | 34 | 35 | def test_flask_return_model(client): 36 | client.set_cookie( 37 | key="pub", value="abcdefg", secure=True, httponly=True, samesite="Strict" 38 | ) 39 | 40 | resp = client.post( 41 | "/api/user_model/flask?order=1", 42 | json=dict(name="flask", limit=10), 43 | content_type="application/json", 44 | ) 45 | assert resp.status_code == 200, resp.text 46 | assert resp.headers.get("X-Validation") is None 47 | assert resp.headers.get("X-API") == "OK" 48 | assert resp.json["name"] == "flask" 49 | assert resp.json["score"] == sorted(resp.json["score"], reverse=True) 50 | 51 | 52 | @pytest.mark.parametrize( 53 | ["test_client_and_api", "expected_status_code"], 54 | [ 55 | pytest.param( 56 | {"api_kwargs": {}, "endpoint_kwargs": {}}, 57 | 422, 58 | id="default-global-status-without-override", 59 | ), 60 | pytest.param( 61 | {"api_kwargs": {}, "endpoint_kwargs": {"validation_error_status": 400}}, 62 | 400, 63 | id="default-global-status-with-override", 64 | ), 65 | pytest.param( 66 | {"api_kwargs": {"validation_error_status": 418}, "endpoint_kwargs": {}}, 67 | 418, 68 | id="overridden-global-status-without-override", 69 | ), 70 | pytest.param( 71 | { 72 | "api_kwargs": {"validation_error_status": 400}, 73 | "endpoint_kwargs": {"validation_error_status": 418}, 74 | }, 75 | 418, 76 | id="overridden-global-status-with-override", 77 | ), 78 | ], 79 | indirect=["test_client_and_api"], 80 | ) 81 | def test_flask_validation_error_response_status_code( 82 | test_client_and_api, expected_status_code 83 | ): 84 | app_client, _ = test_client_and_api 85 | 86 | resp = app_client.get("/ping") 87 | 88 | assert resp.status_code == expected_status_code 89 | 90 | 91 | @pytest.mark.parametrize( 92 | "test_client_and_api, expected_doc_pages", 93 | [ 94 | pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), 95 | pytest.param( 96 | {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, 97 | ["custom_page"], 98 | id="custom-page-templates", 99 | ), 100 | ], 101 | indirect=["test_client_and_api"], 102 | ) 103 | def test_flask_doc(test_client_and_api, expected_doc_pages): 104 | client, api = test_client_and_api 105 | 106 | resp = client.get("/apidoc/openapi.json") 107 | assert resp.json == api.spec 108 | 109 | for doc_page in expected_doc_pages: 110 | resp = client.get(f"/apidoc/{doc_page}/") 111 | assert resp.status_code == 200 112 | 113 | resp = client.get(f"/apidoc/{doc_page}") 114 | assert resp.status_code == 308 115 | 116 | 117 | def test_flask_validate_basic(client): 118 | resp = client.get("/ping") 119 | assert resp.status_code == 422 120 | assert resp.headers.get("X-Error") == "Validation Error" 121 | 122 | resp = client.get("/ping", headers={"lang": "en-US"}) 123 | assert resp.json == {"msg": "pong"} 124 | assert resp.headers.get("X-Error") is None 125 | assert resp.headers.get("X-Validation") == "Pass" 126 | assert resp.headers.get("lang") == "en-US", resp.headers 127 | 128 | resp = client.post("api/user/flask") 129 | assert resp.status_code == 422 130 | assert resp.headers.get("X-Error") == "Validation Error" 131 | 132 | 133 | @pytest.mark.parametrize( 134 | ["fragment"], 135 | [ 136 | ("user",), 137 | ("user_annotated",), 138 | ], 139 | ) 140 | def test_flask_validate_post_data(client, fragment): 141 | client.set_cookie( 142 | key="pub", value="abcdefg", secure=True, httponly=True, samesite="Strict" 143 | ) 144 | resp = client.post( 145 | f"/api/{fragment}/flask?order=1", 146 | json=dict(name="flask", limit=10), 147 | ) 148 | assert resp.status_code == 200, resp.json 149 | assert resp.headers.get("X-Validation") is None 150 | assert resp.headers.get("X-API") == "OK" 151 | assert resp.json["name"] == "flask" 152 | assert resp.json["score"] == sorted(resp.json["score"], reverse=True) 153 | 154 | resp = client.post( 155 | f"/api/{fragment}/flask?order=0", 156 | json=dict(name="flask", limit=10), 157 | ) 158 | assert resp.status_code == 200, resp.json 159 | assert resp.json["score"] == sorted(resp.json["score"], reverse=False) 160 | 161 | resp = client.post( 162 | f"/api/{fragment}/flask?order=0", 163 | data=dict(name="flask", limit=10), 164 | content_type="application/x-www-form-urlencoded", 165 | ) 166 | assert resp.status_code == 200, resp.json 167 | assert resp.json["score"] == sorted(resp.json["score"], reverse=False) 168 | 169 | # POST without body 170 | resp = client.post( 171 | f"/api/{fragment}/flask?order=0", 172 | ) 173 | assert resp.status_code == 422, resp.content 174 | 175 | 176 | def test_flask_no_response(client): 177 | resp = client.get("/api/no_response") 178 | assert resp.status_code == 200, resp.data 179 | 180 | resp = client.post("/api/no_response", data={"name": "foo", "limit": 1}) 181 | assert resp.status_code == 200, resp.data 182 | 183 | 184 | def test_flask_list_json_request(client): 185 | resp = client.post("/api/list_json", json=[{"name": "foo", "limit": 1}]) 186 | assert resp.status_code == 200, resp.data 187 | 188 | 189 | @pytest.mark.parametrize("pre_serialize", [False, True]) 190 | def test_flask_return_list_request(client, pre_serialize: bool): 191 | resp = client.get(f"/api/return_list?pre_serialize={int(pre_serialize)}") 192 | assert resp.status_code == 200 193 | assert resp.json == [ 194 | {"name": "user1", "limit": 1}, 195 | {"name": "user2", "limit": 2}, 196 | ] 197 | 198 | 199 | def test_flask_make_response_post(client): 200 | payload = JSON( 201 | limit=random.randint(1, 10), 202 | name="user make_response name", 203 | ) 204 | resp = client.post( 205 | "/api/return_make_response", 206 | json=payload.model_dump(), 207 | headers={"lang": "en-US"}, 208 | ) 209 | assert resp.status_code == 201 210 | assert resp.json == {"name": payload.name, "score": [payload.limit]} 211 | assert resp.headers.get("lang") == "en-US" 212 | cookie_result = re.match( 213 | r"^test_cookie=\"((\w+\s?){3})\"; Secure; HttpOnly; Path=/; SameSite=Strict$", 214 | resp.headers.get("Set-Cookie"), 215 | ) 216 | assert cookie_result.group(1) == payload.name 217 | 218 | 219 | def test_flask_make_response_get(client): 220 | payload = JSON( 221 | limit=random.randint(1, 10), 222 | name="user make_response name", 223 | ) 224 | resp = client.get( 225 | "/api/return_make_response", 226 | query_string=payload.model_dump(), 227 | headers={"lang": "en-US"}, 228 | ) 229 | assert resp.status_code == 201, resp 230 | assert resp.json == {"name": payload.name, "score": [payload.limit]} 231 | assert resp.headers.get("lang") == "en-US" 232 | cookie_result = re.match( 233 | r"^test_cookie=\"((\w+\s?){3})\"; Secure; HttpOnly; Path=/; SameSite=Strict$", 234 | resp.headers.get("Set-Cookie"), 235 | ) 236 | assert cookie_result.group(1) == payload.name 237 | 238 | 239 | @pytest.mark.parametrize("pre_serialize", [False, True]) 240 | @pytest.mark.parametrize( 241 | "return_what", ["RootResp_JSON", "RootResp_List", "JSON", "List"] 242 | ) 243 | def test_flask_return_root_request(client, pre_serialize: bool, return_what: str): 244 | resp = client.get( 245 | f"/api/return_root?pre_serialize={int(pre_serialize)}&return_what={return_what}" 246 | ) 247 | assert resp.status_code == 200 248 | if return_what in ("RootResp_JSON", "JSON"): 249 | assert resp.json == {"name": "user1", "limit": 1} 250 | elif return_what in ("RootResp_List", "List"): 251 | assert resp.json == [1, 2, 3, 4] 252 | 253 | 254 | @pytest.mark.parametrize( 255 | "return_what", ["RootResp_JSON", "RootResp_List", "JSON", "ModelList"] 256 | ) 257 | def test_flask_return_model_request(client, return_what: str): 258 | resp = client.get(f"/api/return_model?return_what={return_what}") 259 | assert resp.status_code == 200 260 | if return_what in ("RootResp_JSON", "JSON"): 261 | assert resp.json == {"name": "user1", "limit": 1} 262 | elif return_what in ("RootResp_List"): 263 | assert resp.json == [1, 2, 3, 4] 264 | elif return_what in ("ModelList"): 265 | assert resp.json == [{"name": "user1", "limit": 1}] 266 | 267 | 268 | def test_flask_return_string_status(client): 269 | resp = client.get("/api/return_string_status") 270 | assert resp.status_code == 200 271 | assert resp.text == "Response text string" 272 | 273 | 274 | def test_flask_upload_file(client): 275 | file_content = "abcdef" 276 | data = { 277 | "file": (io.BytesIO(file_content.encode("utf-8")), "test.txt"), 278 | "other": "test", 279 | } 280 | resp = client.post( 281 | "/api/file_upload", 282 | data=data, 283 | content_type="multipart/form-data", 284 | ) 285 | assert resp.status_code == 200, resp.data 286 | assert resp.json["content"] == file_content 287 | assert resp.json["other"] == "test" 288 | 289 | 290 | def test_flask_optional_alias_response(client): 291 | resp = client.get("/api/return_optional_alias") 292 | assert resp.status_code == 200 293 | assert resp.json == {"schema": "test"}, resp.json 294 | 295 | 296 | def test_flask_query_list(client): 297 | resp = client.get("/api/query_list?ids=1&ids=2&ids=3") 298 | assert resp.status_code == 200 299 | 300 | 301 | def test_flask_custom_error(client): 302 | # request error 303 | resp = client.post("/api/custom_error", json={"foo": "bar"}) 304 | assert resp.status_code == 422 305 | 306 | # response error 307 | resp = client.post("/api/custom_error", json={"foo": "foo"}) 308 | assert resp.status_code == 500 309 | 310 | 311 | def test_flask_set_cookies(client): 312 | resp = client.get("/api/set_cookies") 313 | assert resp.status_code == 200 314 | set_cookies = resp.headers.getlist("Set-Cookie") 315 | assert len(set_cookies) == 2 316 | assert "foo=hello" in set_cookies 317 | assert "bar=world" in set_cookies 318 | 319 | 320 | def test_flask_forced_serializer(client): 321 | resp = client.get("/api/force_serialize") 322 | assert resp.status_code == 200 323 | assert resp.json["name"] == "flask" 324 | assert resp.json["score"] == [1, 2, 3] 325 | assert "comment" not in resp.json 326 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2019] [Yang Keming] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | from pydantic import BaseModel, computed_field 5 | 6 | from spectree.models import ValidationError 7 | from spectree.response import DEFAULT_CODE_DESC, Response 8 | from spectree.spec import SpecTree 9 | from spectree.utils import ( 10 | get_model_schema, 11 | has_model, 12 | json_compatible_deepcopy, 13 | parse_code, 14 | parse_comments, 15 | parse_name, 16 | parse_params, 17 | parse_request, 18 | parse_resp, 19 | ) 20 | 21 | from .common import DefaultEnumValue, DemoModel, DemoQuery, Numeric, get_model_path_key 22 | 23 | api = SpecTree() 24 | 25 | 26 | def undecorated_func(): 27 | """summary 28 | 29 | description""" 30 | 31 | 32 | @api.validate(json=DemoModel, resp=Response(HTTP_200=DemoModel)) 33 | def demo_func(): 34 | """ 35 | summary 36 | 37 | description""" 38 | 39 | 40 | @api.validate(query=DemoQuery) 41 | def demo_func_with_query(): 42 | """ 43 | a summary 44 | 45 | a description 46 | """ 47 | 48 | 49 | class DemoClass: 50 | @api.validate(query=DemoModel) 51 | def demo_method(self): 52 | """summary 53 | 54 | description 55 | """ 56 | 57 | 58 | demo_class = DemoClass() 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "docstring, expected_summary, expected_description", 63 | [ 64 | pytest.param(None, None, None, id="no-docstring"), 65 | pytest.param("", "", None, id="empty-docstring"), 66 | pytest.param(" ", "", None, id="all-whitespace-docstring"), 67 | pytest.param("summary", "summary", None, id="single-line-docstring"), 68 | pytest.param( 69 | " summary ", "summary", None, id="single-line-docstring-with-whitespace" 70 | ), 71 | pytest.param( 72 | "summary first line\nsummary second line", 73 | "summary first line summary second line", 74 | None, 75 | id="multi-line-docstring-without-empty-line", 76 | ), 77 | pytest.param( 78 | " summary first line \n summary second line ", 79 | "summary first line summary second line", 80 | None, 81 | id="multi-line-docstring-without-empty-line-whitespace", 82 | ), 83 | pytest.param( 84 | "summary\n\ndescription", 85 | "summary", 86 | "description", 87 | id="multi-line-docstring-with-empty-line", 88 | ), 89 | pytest.param( 90 | " summary \n\n description ", 91 | "summary", 92 | "description", 93 | id="multi-line-docstring-with-empty-line-whitespace", 94 | ), 95 | pytest.param( 96 | "summary\n\t \ndescription", 97 | "summary", 98 | "description", 99 | id="multi-line-docstring-with-whitespace-line", 100 | ), 101 | pytest.param( 102 | "summary\n \n \n \n \n \ndescription", 103 | "summary", 104 | "description", 105 | id="multi-line-docstring-with-multiple-whitespace-lines", 106 | ), 107 | pytest.param( 108 | "summary first line\nsummary second line\nsummary third line" 109 | "\n\t \n" 110 | "description first line\ndescription second line\ndescription third line", 111 | "summary first line summary second line summary third line", 112 | "description first line description second line description third line", 113 | id="large-multi-line-docstring-with-whitespace-line", 114 | ), 115 | pytest.param( 116 | "summary first line\nsummary second line\ftruncated part", 117 | "summary first line summary second line", 118 | None, 119 | id="multi-line-docstring-without-empty-line-and-truncation-char", 120 | ), 121 | pytest.param( 122 | "summary first line\nsummary second line\nsummary third line" 123 | "\n\t \n" 124 | "description first line\ndescription second line\ndescription third line" 125 | "\ftruncated part", 126 | "summary first line summary second line summary third line", 127 | "description first line description second line description third line", 128 | id="large-multi-line-docstring-with-whitespace-line-and-truncation-char", 129 | ), 130 | pytest.param( 131 | "summary first line\nsummary second line\n" 132 | "\t \n" 133 | "description first line \ndescription second line\n" 134 | "\t \n" 135 | "description second paragraph \n" 136 | "\n \n \n" 137 | "description third paragraph\ndescription third paragraph second line", 138 | "summary first line summary second line", 139 | "description first line description second line" 140 | "\n\n" 141 | "description second paragraph" 142 | "\n\n" 143 | "description third paragraph description third paragraph second line", 144 | id="large-multi-line-docstring-with-multiple-paragraphs", 145 | ), 146 | pytest.param( 147 | "\tcode block while indented\n" 148 | "\t\n" 149 | "\tdescription first paragraph\n" 150 | "\t\n" 151 | "\t\tcode block\n" 152 | "\t\n" 153 | "\tdescription third paragraph\n", 154 | "code block while indented", 155 | "description first paragraph" 156 | "\n\n" 157 | " code block" 158 | "\n\n" 159 | "description third paragraph", 160 | id="multi-line-docstring-with-code-block", 161 | ), 162 | ], 163 | ) 164 | def test_parse_comments(docstring, expected_summary, expected_description): 165 | def func(): 166 | pass 167 | 168 | func.__doc__ = docstring 169 | 170 | assert parse_comments(func) == (expected_summary, expected_description) 171 | 172 | 173 | @pytest.mark.parametrize( 174 | "func, expected_summary, expected_description", 175 | [ 176 | pytest.param(lambda x: x, None, None, id="lambda"), 177 | pytest.param( 178 | undecorated_func, "summary", "description", id="undecorated-function" 179 | ), 180 | pytest.param(demo_func, "summary", "description", id="decorated-function"), 181 | pytest.param( 182 | demo_class.demo_method, "summary", "description", id="class-method" 183 | ), 184 | ], 185 | ) 186 | def test_parse_comments_with_different_callable_types( 187 | func, expected_summary, expected_description 188 | ): 189 | assert parse_comments(func) == (expected_summary, expected_description) 190 | 191 | 192 | def test_parse_code(): 193 | with pytest.raises(TypeError): 194 | assert parse_code(200) == 200 195 | 196 | assert parse_code("200") == "" 197 | assert parse_code("HTTP_404") == "404" 198 | 199 | 200 | def test_parse_name(): 201 | assert parse_name(lambda x: x) == "" 202 | assert parse_name(undecorated_func) == "undecorated_func" 203 | assert parse_name(demo_func) == "demo_func" 204 | assert parse_name(demo_class.demo_method) == "demo_method" 205 | 206 | 207 | def test_has_model(): 208 | assert not has_model(undecorated_func) 209 | assert has_model(demo_func) 210 | assert has_model(demo_class.demo_method) 211 | 212 | 213 | def test_parse_resp(): 214 | assert parse_resp(undecorated_func) == {} 215 | resp_spec = parse_resp(demo_func) 216 | 217 | assert resp_spec["422"]["description"] == DEFAULT_CODE_DESC["HTTP_422"] 218 | model_path_key = get_model_path_key( 219 | f"{ValidationError.__module__}.{ValidationError.__name__}" 220 | ) 221 | assert ( 222 | resp_spec["422"]["content"]["application/json"]["schema"]["$ref"] 223 | == f"#/components/schemas/{model_path_key}" 224 | ) 225 | model_path_key = get_model_path_key(f"{DemoModel.__module__}.{DemoModel.__name__}") 226 | assert ( 227 | resp_spec["200"]["content"]["application/json"]["schema"]["$ref"] 228 | == f"#/components/schemas/{model_path_key}" 229 | ) 230 | 231 | 232 | def test_parse_request(): 233 | model_path_key = get_model_path_key(f"{DemoModel.__module__}.{DemoModel.__name__}") 234 | assert ( 235 | parse_request(demo_func)["content"]["application/json"]["schema"]["$ref"] 236 | == f"#/components/schemas/{model_path_key}" 237 | ) 238 | assert parse_request(demo_class.demo_method) == {} 239 | 240 | 241 | def test_parse_params(): 242 | models = { 243 | get_model_path_key( 244 | f"{DemoModel.__module__}.{DemoModel.__name__}" 245 | ): DemoModel.model_json_schema(ref_template="#/components/schemas/{model}") 246 | } 247 | assert parse_params(demo_func, [], models) == [] 248 | params = parse_params(demo_class.demo_method, [], models) 249 | assert len(params) == 3 250 | assert params[0] == { 251 | "name": "uid", 252 | "in": "query", 253 | "required": True, 254 | "description": "", 255 | "schema": {"title": "Uid", "type": "integer"}, 256 | } 257 | assert params[2]["description"] == "user name" 258 | 259 | 260 | def test_parse_params_with_route_param_keywords(): 261 | models = { 262 | get_model_path_key("tests.common.DemoQuery"): DemoQuery.model_json_schema( 263 | ref_template="#/components/schemas/{model}" 264 | ) 265 | } 266 | params = parse_params(demo_func_with_query, [], models) 267 | assert params == [ 268 | { 269 | "name": "names1", 270 | "in": "query", 271 | "required": True, 272 | "description": "", 273 | "schema": {"title": "Names1", "type": "array", "items": {"type": "string"}}, 274 | }, 275 | { 276 | "name": "names2", 277 | "in": "query", 278 | "required": True, 279 | "description": "", 280 | "schema": { 281 | "title": "Names2", 282 | "type": "array", 283 | "items": {"type": "string"}, 284 | "non_keyword": "dummy", 285 | }, 286 | "style": "matrix", 287 | "explode": True, 288 | }, 289 | ] 290 | 291 | 292 | def test_json_compatible_schema(): 293 | schema = get_model_schema(Numeric) 294 | 295 | with pytest.raises(ValueError): 296 | json.dumps(schema, allow_nan=False) 297 | 298 | json_schema = json_compatible_deepcopy(schema) 299 | assert json.dumps(json_schema, allow_nan=False) 300 | 301 | schema = get_model_schema(DefaultEnumValue) 302 | json_schema = json_compatible_deepcopy(schema) 303 | 304 | 305 | def test_get_model_schema_mode_parameter(): 306 | """Test get_model_schema mode parameter for Pydantic v2""" 307 | 308 | class TestModel(BaseModel): 309 | """Model with computed field""" 310 | 311 | name: str 312 | value: int 313 | 314 | @computed_field 315 | @property 316 | def computed_name(self) -> str: 317 | """Computed field - only in serialization""" 318 | return f"computed_{self.name}" 319 | 320 | # Test validation mode - computed fields excluded 321 | validation_schema = get_model_schema(TestModel, mode="validation") 322 | assert "name" in validation_schema["properties"] 323 | assert "value" in validation_schema["properties"] 324 | assert "computed_name" not in validation_schema["properties"], ( 325 | "Computed field should NOT be in validation mode" 326 | ) 327 | 328 | # Test serialization mode - computed fields included 329 | serialization_schema = get_model_schema(TestModel, mode="serialization") 330 | assert "name" in serialization_schema["properties"] 331 | assert "value" in serialization_schema["properties"] 332 | assert "computed_name" in serialization_schema["properties"], ( 333 | "Computed field SHOULD be in serialization mode" 334 | ) 335 | 336 | # Verify computed field is marked as readOnly and required 337 | assert serialization_schema["properties"]["computed_name"].get("readOnly") is True 338 | assert "computed_name" in serialization_schema["required"] 339 | -------------------------------------------------------------------------------- /tests/test_plugin_flask_blueprint.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | from typing import List 3 | 4 | import flask 5 | import pytest 6 | from flask import Blueprint, Flask, jsonify, make_response, request 7 | 8 | from spectree import Response, SpecTree 9 | 10 | from .common import ( 11 | JSON, 12 | Cookies, 13 | CustomError, 14 | Form, 15 | FormFileUpload, 16 | Headers, 17 | ListJSON, 18 | OptionalAliasResp, 19 | Order, 20 | Query, 21 | QueryList, 22 | Resp, 23 | RespFromAttrs, 24 | RespObject, 25 | RootResp, 26 | StrDict, 27 | UserXmlData, 28 | api_tag, 29 | get_paths, 30 | get_root_resp_data, 31 | ) 32 | 33 | # import tests to execute 34 | from .flask_imports import * # NOQA 35 | 36 | 37 | def before_handler(req, resp, err, _): 38 | if err: 39 | resp.headers["X-Error"] = "Validation Error" 40 | 41 | 42 | def after_handler(req, resp, err, _): 43 | resp.headers["X-Validation"] = "Pass" 44 | 45 | 46 | def api_after_handler(req, resp, err, _): 47 | resp.headers["X-API"] = "OK" 48 | 49 | 50 | api = SpecTree("flask", before=before_handler, after=after_handler, annotations=True) 51 | app = Blueprint("test_blueprint", __name__) 52 | 53 | 54 | @app.route("/ping") 55 | @api.validate(headers=Headers, resp=Response(HTTP_202=StrDict), tags=["test", "health"]) 56 | def ping(): 57 | """summary 58 | 59 | description""" 60 | return jsonify(msg="pong"), 202, request.context.headers.model_dump() 61 | 62 | 63 | @app.route("/api/file_upload", methods=["POST"]) 64 | @api.validate( 65 | form=FormFileUpload, 66 | ) 67 | def file_upload(): 68 | upload = request.context.form.file 69 | assert upload 70 | return { 71 | "content": upload.stream.read().decode("utf-8"), 72 | "other": request.context.form.other, 73 | } 74 | 75 | 76 | @app.route("/api/user/", methods=["POST"]) 77 | @api.validate( 78 | query=Query, 79 | json=JSON, 80 | cookies=Cookies, 81 | form=Form, 82 | resp=Response(HTTP_200=Resp, HTTP_401=None), 83 | tags=[api_tag, "test"], 84 | after=api_after_handler, 85 | ) 86 | def user_score(name): 87 | data_src = request.context.json or request.context.form 88 | score = [randint(0, int(data_src.limit)) for _ in range(5)] 89 | score.sort(reverse=request.context.query.order) 90 | assert request.context.cookies.pub == "abcdefg" 91 | assert request.cookies["pub"] == "abcdefg" 92 | return jsonify(name=data_src.name, score=score) 93 | 94 | 95 | @app.route("/api/user_annotated/", methods=["POST"]) 96 | @api.validate( 97 | resp=Response(HTTP_200=Resp, HTTP_401=None), 98 | tags=[api_tag, "test"], 99 | after=api_after_handler, 100 | ) 101 | def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies, form: Form): 102 | data_src = json or form 103 | score = [randint(0, int(data_src.limit)) for _ in range(5)] 104 | score.sort(reverse=(query.order == Order.desc)) 105 | assert cookies.pub == "abcdefg" 106 | assert request.cookies["pub"] == "abcdefg" 107 | return jsonify(name=data_src.name, score=score) 108 | 109 | 110 | @app.route("/api/user_skip/", methods=["POST"]) 111 | @api.validate( 112 | query=Query, 113 | json=JSON, 114 | cookies=Cookies, 115 | resp=Response(HTTP_200=Resp, HTTP_401=None), 116 | tags=[api_tag, "test"], 117 | after=api_after_handler, 118 | skip_validation=True, 119 | ) 120 | def user_score_skip_validation(name): 121 | response_format = request.args.get("response_format") 122 | assert response_format in ("json", "xml") 123 | json = request.get_json() 124 | score = [randint(0, json.get("limit")) for _ in range(5)] 125 | score.sort(reverse=int(request.args.get("order")) == Order.desc) 126 | assert request.cookies["pub"] == "abcdefg" 127 | if response_format == "json": 128 | return jsonify(name=name, x_score=score) 129 | else: 130 | return flask.Response( 131 | UserXmlData(name=name, score=score).dump_xml(), 132 | content_type="text/xml", 133 | ) 134 | 135 | 136 | @app.route("/api/user_model/", methods=["POST"]) 137 | @api.validate( 138 | query=Query, 139 | json=JSON, 140 | cookies=Cookies, 141 | resp=Response(HTTP_200=Resp, HTTP_401=None), 142 | tags=[api_tag, "test"], 143 | after=api_after_handler, 144 | ) 145 | def user_score_model(name): 146 | score = [randint(0, request.context.json.limit) for _ in range(5)] 147 | score.sort(reverse=request.context.query.order == Order.desc) 148 | assert request.context.cookies.pub == "abcdefg" 149 | assert request.cookies["pub"] == "abcdefg" 150 | return Resp(name=request.context.json.name, score=score) 151 | 152 | 153 | @app.route("/api/user//address/", methods=["GET"]) 154 | @api.validate( 155 | query=Query, 156 | path_parameter_descriptions={ 157 | "name": "The name that uniquely identifies the user.", 158 | "non-existent-param": "description", 159 | }, 160 | ) 161 | def user_address(name, address_id): 162 | return None 163 | 164 | 165 | @app.route("/api/no_response", methods=["GET", "POST"]) 166 | @api.validate( 167 | json=StrDict, 168 | ) 169 | def no_response(): 170 | return {} 171 | 172 | 173 | @app.route("/api/set_cookies", methods=["GET"]) 174 | @api.validate(resp=Response(HTTP_200=StrDict)) 175 | def set_cookies(): 176 | # related to GitHub issue #415 177 | resp = make_response(jsonify(msg="ping")) 178 | resp.set_cookie("foo", "hello") 179 | resp.set_cookie("bar", "world") 180 | return resp 181 | 182 | 183 | @app.route("/api/list_json", methods=["POST"]) 184 | @api.validate( 185 | json=ListJSON, 186 | ) 187 | def list_json(): 188 | return {} 189 | 190 | 191 | @app.route("/api/query_list") 192 | @api.validate(query=QueryList) 193 | def query_list(): 194 | assert request.context.query.ids == [1, 2, 3] 195 | return {} 196 | 197 | 198 | @app.route("/api/return_list", methods=["GET"]) 199 | @api.validate(resp=Response(HTTP_200=List[JSON])) 200 | def return_list(): 201 | pre_serialize = bool(int(request.args.get("pre_serialize", default=0))) 202 | data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] 203 | return [entry.model_dump() if pre_serialize else entry for entry in data] 204 | 205 | 206 | @app.route("/api/return_make_response", methods=["POST"]) 207 | @api.validate(json=JSON, headers=Headers, resp=Response(HTTP_201=Resp)) 208 | def return_make_response_post(): 209 | model_data = request.context.json 210 | headers = request.context.headers 211 | response = make_response( 212 | Resp(name=model_data.name, score=[model_data.limit]).model_dump(), 201, headers 213 | ) 214 | response.set_cookie( 215 | key="test_cookie", 216 | value=model_data.name, 217 | secure=True, 218 | httponly=True, 219 | samesite="Strict", 220 | ) 221 | return response 222 | 223 | 224 | @app.route("/api/return_make_response", methods=["GET"]) 225 | @api.validate(query=JSON, headers=Headers, resp=Response(HTTP_201=Resp)) 226 | def return_make_response_get(): 227 | model_data = request.context.query 228 | headers = request.context.headers 229 | response = make_response( 230 | Resp(name=model_data.name, score=[model_data.limit]).model_dump(), 201, headers 231 | ) 232 | response.set_cookie( 233 | key="test_cookie", 234 | value=model_data.name, 235 | secure=True, 236 | httponly=True, 237 | samesite="Strict", 238 | ) 239 | return response 240 | 241 | 242 | @app.route("/api/return_root", methods=["GET"]) 243 | @api.validate(resp=Response(HTTP_200=RootResp)) 244 | def return_root(): 245 | return get_root_resp_data( 246 | pre_serialize=bool(int(request.args.get("pre_serialize", default=0))), 247 | return_what=request.args.get("return_what", default="RootResp"), 248 | ) 249 | 250 | 251 | @app.route("/api/return_model", methods=["GET"]) 252 | @api.validate() 253 | def return_model(): 254 | return get_root_resp_data( 255 | pre_serialize=False, 256 | return_what=request.args.get("return_what", default="RootResp"), 257 | ) 258 | 259 | 260 | @app.route("/api/return_string_status", methods=["GET"]) 261 | @api.validate() 262 | def return_string_status(): 263 | return "Response text string", 200 264 | 265 | 266 | @app.route("/api/return_optional_alias", methods=["GET"]) 267 | @api.validate(resp=Response(HTTP_200=OptionalAliasResp)) 268 | def return_optional_alias(): 269 | return {"schema": "test"} 270 | 271 | 272 | @app.route("/api/custom_error", methods=["POST"]) 273 | @api.validate(resp=Response(HTTP_200=CustomError)) 274 | def custom_error(json: CustomError): 275 | return {"foo": "bar"} 276 | 277 | 278 | @app.route("/api/force_serialize", methods=["GET"]) 279 | @api.validate(resp=Response(HTTP_200=RespFromAttrs), force_resp_serialize=True) 280 | def force_serialize(): 281 | return RespObject(name="flask", score=[1, 2, 3], comment="hello") 282 | 283 | 284 | api.register(app) 285 | 286 | flask_app = Flask(__name__) 287 | flask_app.config["DEBUG"] = True 288 | flask_app.config["TESTING"] = True 289 | flask_app.register_blueprint(app) 290 | with flask_app.app_context(): 291 | _ = api.spec 292 | 293 | 294 | @pytest.fixture 295 | def client(request): 296 | parent_app = Flask(__name__) 297 | url_prefix = getattr(request, "param", None) 298 | parent_app.register_blueprint(app, url_prefix=url_prefix) 299 | with parent_app.test_client() as client: 300 | yield client 301 | 302 | 303 | @pytest.mark.parametrize( 304 | ("client", "prefix"), [(None, ""), ("/prefix", "/prefix")], indirect=["client"] 305 | ) 306 | def test_blueprint_prefix(client, prefix): 307 | resp = client.get(prefix + "/ping") 308 | assert resp.status_code == 422 309 | assert resp.headers.get("X-Error") == "Validation Error" 310 | 311 | resp = client.get(prefix + "/ping", headers={"lang": "en-US"}) 312 | assert resp.status_code == 202, resp.text 313 | assert resp.json == {"msg": "pong"} 314 | assert resp.headers.get("X-Error") is None 315 | assert resp.headers.get("X-Validation") == "Pass" 316 | assert resp.headers.get("lang") == "en-US" 317 | 318 | 319 | @pytest.fixture 320 | def test_client_and_api(request): 321 | api_args = ["flask"] 322 | api_kwargs = {} 323 | endpoint_kwargs = { 324 | "headers": Headers, 325 | "resp": Response(HTTP_200=StrDict), 326 | "tags": ["test", "health"], 327 | } 328 | register_blueprint_kwargs = {} 329 | if hasattr(request, "param"): 330 | api_args.extend(request.param.get("api_args", ())) 331 | api_kwargs.update(request.param.get("api_kwargs", {})) 332 | endpoint_kwargs.update(request.param.get("endpoint_kwargs", {})) 333 | register_blueprint_kwargs.update( 334 | request.param.get("register_blueprint_kwargs", {}) 335 | ) 336 | 337 | api = SpecTree(*api_args, **api_kwargs) 338 | app = Blueprint("test_blueprint", __name__) 339 | 340 | @app.route("/ping") 341 | @api.validate(**endpoint_kwargs) 342 | def ping(): 343 | """summary 344 | 345 | description""" 346 | return jsonify(msg="pong") 347 | 348 | api.register(app) 349 | 350 | flask_app = Flask(__name__) 351 | flask_app.register_blueprint(app, **register_blueprint_kwargs) 352 | 353 | with flask_app.app_context(): 354 | _ = api.spec 355 | 356 | with flask_app.test_client() as test_client: 357 | yield test_client, api 358 | 359 | 360 | @pytest.mark.parametrize( 361 | ("test_client_and_api", "prefix"), 362 | [ 363 | ({"register_blueprint_kwargs": {}}, ""), 364 | ({"register_blueprint_kwargs": {"url_prefix": "/prefix"}}, "/prefix"), 365 | ], 366 | indirect=["test_client_and_api"], 367 | ) 368 | def test_flask_doc_prefix(test_client_and_api, prefix): 369 | client, api = test_client_and_api 370 | 371 | resp = client.get(prefix + "/apidoc/openapi.json") 372 | assert resp.json == api.spec 373 | 374 | resp = client.get(prefix + "/apidoc/redoc/") 375 | assert resp.status_code == 200 376 | 377 | resp = client.get(prefix + "/apidoc/swagger/") 378 | assert resp.status_code == 200 379 | 380 | assert get_paths(api.spec) == [ 381 | prefix + "/ping", 382 | ] 383 | --------------------------------------------------------------------------------