├── .coveragerc ├── .github └── workflows │ ├── build.yml │ └── update.yml ├── .gitignore ├── .readthedocs.yml ├── .vscode └── launch.json ├── LICENSE ├── Makefile ├── dev ├── install ├── lint-code └── test-version ├── docs ├── conf.py ├── env.rst ├── glossary.rst ├── index.rst ├── queries.rst ├── reference.rst ├── requirements.txt ├── tutorial.rst ├── validation.rst └── websocket.rst ├── openapi ├── __init__.py ├── cli.py ├── data │ ├── __init__.py │ ├── db.py │ ├── dump.py │ ├── exc.py │ ├── fields.py │ ├── validate.py │ └── view.py ├── db │ ├── __init__.py │ ├── columns.py │ ├── commands.py │ ├── container.py │ ├── dbmodel.py │ ├── migrations.py │ ├── openapi │ │ ├── alembic.ini.mako │ │ ├── env.py │ │ └── script.py.mako │ └── path.py ├── exc.py ├── json.py ├── logger.py ├── middleware.py ├── pagination │ ├── __init__.py │ ├── create.py │ ├── cursor.py │ ├── offset.py │ ├── pagination.py │ └── search.py ├── py.typed ├── rest.py ├── sentry.py ├── spec │ ├── __init__.py │ ├── hdrs.py │ ├── operation.py │ ├── path.py │ ├── redoc.py │ ├── server.py │ ├── spec.py │ └── utils.py ├── testing.py ├── types.py ├── tz.py ├── utils.py └── ws │ ├── __init__.py │ ├── channel.py │ ├── channels.py │ ├── errors.py │ ├── manager.py │ ├── path.py │ ├── pubsub.py │ ├── rpc.py │ └── utils.py ├── poetry.lock ├── pyproject.toml ├── readme.md └── tests ├── __init__.py ├── conftest.py ├── core ├── test_cli.py ├── test_columns.py ├── test_cruddb.py ├── test_db.py ├── test_db_cli.py ├── test_db_model.py ├── test_db_path_extra.py ├── test_db_single.py ├── test_dc_db.py ├── test_errors.py ├── test_filters.py ├── test_json.py ├── test_logger.py ├── test_paths.py ├── test_union.py └── test_utils.py ├── data ├── test_fields.py ├── test_json_field.py ├── test_validate_nested.py ├── test_validator.py └── test_view.py ├── example ├── __init__.py ├── db │ ├── __init__.py │ ├── tables1.py │ └── tables2.py ├── endpoints.py ├── endpoints_additional.py ├── endpoints_base.py ├── endpoints_form.py ├── endpoints_pagination.py ├── main.py ├── models.py └── ws.py ├── pagination ├── __init__.py ├── conftest.py ├── test_base_classes.py ├── test_cursor_pagination.py ├── test_offset_pagination.py └── utils.py ├── spec ├── test_docstrings.py ├── test_schema_parser.py ├── test_spec.py ├── test_spec_utils.py └── test_validate_spec.py ├── test.env ├── utils.py └── ws ├── test_channels.py └── test_ws.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = openapi 3 | concurrency=greenlet 4 | data_file = build/.coverage 5 | omit = 6 | openapi/db/openapi 7 | 8 | [html] 9 | directory = build/coverage/html 10 | 11 | [xml] 12 | output = build/coverage.xml 13 | 14 | [report] 15 | exclude_lines = 16 | pragma: no cover 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches-ignore: 6 | - deploy 7 | tags-ignore: 8 | - v* 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | env: 14 | PYTHON_ENV: ci 15 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 16 | strategy: 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: run postgres 23 | run: make postgresql 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: make install 30 | - name: check version 31 | run: make test-version 32 | - name: run lint 33 | run: make test-lint 34 | - name: test docs 35 | run: make test-docs 36 | - name: run tests 37 | run: make test 38 | - name: upload coverage reports to codecov 39 | if: matrix.python-version == '3.11' 40 | uses: codecov/codecov-action@v3 41 | with: 42 | token: ${{ secrets.CODECOV_TOKEN }} 43 | files: ./build/coverage.xml 44 | - name: publish 45 | if: ${{ matrix.python-version == '3.11' && github.event.head_commit.message == 'release' }} 46 | run: make publish 47 | - name: create github release 48 | if: ${{ matrix.python-version == '3.12' && github.event.head_commit.message == 'release' }} 49 | uses: ncipollo/release-action@v1 50 | with: 51 | artifacts: "dist/*" 52 | token: ${{ secrets.GITHUB_TOKEN }} 53 | draft: false 54 | prerelease: steps.check-version.outputs.prerelease == 'true' 55 | -------------------------------------------------------------------------------- /.github/workflows/update.yml: -------------------------------------------------------------------------------- 1 | name: update 2 | 3 | on: 4 | schedule: 5 | - cron: "0 6 * * 5" 6 | 7 | jobs: 8 | update: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: checkout repo 13 | uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: "3.11" 18 | - name: install poetry 19 | run: pip install -U pip poetry 20 | - name: update dependencies 21 | run: poetry update 22 | - name: Create Pull Request 23 | uses: peter-evans/create-pull-request@v3 24 | with: 25 | token: ${{ secrets.QMBOT_GITHUB_TOKEN }} 26 | author: qmbot 27 | commit-message: update dependencies 28 | title: Automated Dependency Updates 29 | body: This is an auto-generated PR with dependency updates. 30 | branch: ci-poetry-update 31 | labels: ci, automated pr, automerge 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | __pycache__ 3 | *.pyc 4 | *.pyd 5 | *.so 6 | *.o 7 | *.def 8 | *.egg-info 9 | .python-version 10 | .pytest_cache 11 | .mypy_cache 12 | venv 13 | build 14 | dist 15 | 16 | # ides 17 | .vscode/.ropeproject 18 | .idea 19 | *.swp 20 | 21 | # Mac 22 | .DS_Store 23 | 24 | # files 25 | .env 26 | *.log 27 | migrations 28 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Required 2 | version: 2 3 | 4 | # Build documentation in the docs/ directory with Sphinx 5 | sphinx: 6 | configuration: docs/conf.py 7 | 8 | # Optionally build your docs in additional formats such as PDF and ePub 9 | formats: all 10 | 11 | # Optionally set the version of Python and requirements required to build your docs 12 | python: 13 | version: 3.8 14 | install: 15 | - requirements: docs/requirements.txt 16 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "PyTest", 6 | "type": "python", 7 | "request": "launch", 8 | "cwd": "${workspaceFolder}", 9 | "console": "integratedTerminal", 10 | "module": "pytest", 11 | "env": {}, 12 | "debugOptions": [ 13 | "RedirectOutput" 14 | ], 15 | "args": [ 16 | "tests/core/test_cruddb.py" 17 | ] 18 | } 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Quantmind 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, 7 | this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright notice, 9 | this list of conditions and the following disclaimer in the documentation 10 | and/or other materials provided with the distribution. 11 | * Neither the name of the author nor the names of its contributors 12 | may be used to endorse or promote products derived from this software without 13 | specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 19 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 20 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 23 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 24 | OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for development & CI 2 | 3 | .PHONY: help clean docs 4 | 5 | help: 6 | @fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//' 7 | 8 | clean: ## remove python cache files 9 | find . -name '__pycache__' | xargs rm -rf 10 | find . -name '*.pyc' -delete 11 | rm -rf build 12 | rm -rf dist 13 | rm -rf *.egg-info 14 | rm -rf .pytest_cache 15 | rm -rf .mypy_cache 16 | rm -rf .coverage 17 | 18 | docs: ## build sphinx docs 19 | @poetry run sphinx-build ./docs ./build/docs 20 | 21 | 22 | docs-requirements: ## requrement file for docs 23 | @poetry export -f requirements.txt -E docs --output docs/requirements.txt 24 | 25 | 26 | version: ## display software version 27 | @python setup.py --version 28 | 29 | 30 | install: ## install packages via poetry 31 | @./dev/install 32 | 33 | 34 | lint: ## run linters 35 | poetry run ./dev/lint-code 36 | 37 | 38 | outdated: ## Show outdated packages 39 | poetry show -o 40 | 41 | 42 | postgresql: ## run postgresql for testing 43 | docker run -e POSTGRES_PASSWORD=postgres --rm --network=host --name=openapi-db -d postgres:13 44 | 45 | 46 | postgresql-nd: ## run postgresql for testing - non daemon 47 | docker run -e POSTGRES_PASSWORD=postgres --rm --network=host --name=openapi-db postgres:13 48 | 49 | 50 | test: ## test with coverage 51 | @poetry run pytest -v -x --cov --cov-report xml --cov-report html 52 | 53 | 54 | test-version: ## check version compatibility 55 | ./dev/test-version 56 | 57 | 58 | test-lint: ## run linters checks 59 | @poetry run ./dev/lint-code --check 60 | 61 | 62 | test-docs: ## run docs in CI 63 | make docs 64 | 65 | 66 | publish: ## release to pypi and github tag 67 | @poetry publish --build -u lsbardel -p $(PYPI_PASSWORD) 68 | -------------------------------------------------------------------------------- /dev/install: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | pip install -U pip poetry 4 | poetry install --with=docs --with=extras 5 | -------------------------------------------------------------------------------- /dev/lint-code: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | BLACK_ARG="--check" 5 | RUFF_ARG="" 6 | 7 | if [ "$1" = "fix" ] ; then 8 | BLACK_ARG="" 9 | RUFF_ARG="--fix" 10 | fi 11 | 12 | echo "run black" 13 | black openapi tests ${BLACK_ARG} 14 | echo "run ruff" 15 | ruff openapi tests ${RUFF_ARG} 16 | #echo "run mypy" 17 | #mypy openapi 18 | -------------------------------------------------------------------------------- /dev/test-version: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | poetry=`poetry version` 4 | code=`poetry run python -c 'import openapi; print(f"aio-openapi {openapi.__version__}")'` 5 | 6 | echo ${poetry} 7 | 8 | if [ "${poetry}" != "${code}" ]; then 9 | echo "ERROR: poetry version ${poetry} different from code version ${code}" 10 | exit 1 11 | fi 12 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | from datetime import date 16 | 17 | sys.path.insert(0, os.path.abspath("..")) 18 | 19 | from recommonmark.parser import CommonMarkParser 20 | 21 | import openapi 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | year = date.today().year 26 | project = "aio-openapi" 27 | author = "Quantmind" 28 | copyright = f"{year}, {author}" 29 | 30 | release = openapi.__version__ 31 | source_suffix = [".rst", ".md"] 32 | source_parsers = { 33 | ".md": CommonMarkParser, 34 | } 35 | # The master toctree document. 36 | master_doc = "index" 37 | 38 | # -- General configuration --------------------------------------------------- 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | extensions = [ 43 | "sphinx.ext.viewcode", 44 | "sphinx.ext.intersphinx", 45 | "sphinx.ext.autodoc", 46 | "sphinx_copybutton", 47 | "sphinx_autodoc_typehints", 48 | ] 49 | 50 | try: 51 | import sphinxcontrib.spelling 52 | 53 | extensions.append("sphinxcontrib.spelling") 54 | except ImportError: 55 | pass 56 | 57 | templates_path = ["_templates"] 58 | 59 | # List of patterns, relative to source directory, that match files and 60 | # directories to ignore when looking for source files. 61 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 62 | 63 | 64 | # -- Options for HTML output ------------------------------------------------- 65 | 66 | # The theme to use for HTML and HTML Help pages. See the documentation for 67 | # a list of builtin themes. 68 | # html_theme = "alabaster" 69 | html_theme = "aiohttp_theme" 70 | 71 | # Add any paths that contain custom static files (such as style sheets) here, 72 | # relative to this directory. They are copied after the builtin static files, 73 | html_static_path = ["_static"] 74 | 75 | 76 | intersphinx_mapping = { 77 | "python": ("http://docs.python.org/3", None), 78 | "asyncpg": ("https://magicstack.github.io/asyncpg/current/", None), 79 | "sqlalchemy": ("https://docs.sqlalchemy.org/", None), 80 | } 81 | 82 | highlight_language = "python3" 83 | 84 | html_theme_options = { 85 | "description": "Async web middleware for aiohttp, asyncpg and OpenAPI", 86 | "canonical_url": "https://aio-openapi.readthedocs.io/en/latest/", 87 | "github_user": "quantmind", 88 | "github_repo": "aio-openapi", 89 | "github_button": True, 90 | "github_type": "star", 91 | "github_banner": True, 92 | "badges": [ 93 | { 94 | "image": "https://badge.fury.io/py/aio-openapi.svg", 95 | "target": "https://pypi.org/project/aio-openapi", 96 | "height": "20", 97 | "alt": "Latest PyPI package version", 98 | }, 99 | { 100 | "image": "https://img.shields.io/pypi/pyversions/aio-openapi.svg", 101 | "target": "https://pypi.org/project/aio-openapi", 102 | "height": "20", 103 | "alt": "Supported python versions", 104 | }, 105 | { 106 | "image": "https://github.com/quantmind/aio-openapi/workflows/build/badge.svg", 107 | "target": "https://github.com/quantmind/aio-openapi/actions?query=workflow%3Abuild", 108 | "height": "20", 109 | "alt": "Build status", 110 | }, 111 | { 112 | "image": "https://coveralls.io/repos/github/quantmind/aio-openapi/badge.svg?branch=HEAD", 113 | "target": "https://coveralls.io/github/quantmind/aio-openapi?branch=HEAD", 114 | "height": "20", 115 | "alt": "Coverage status", 116 | }, 117 | ], 118 | } 119 | 120 | html_sidebars = {"**": ["about.html", "navigation.html", "searchbox.html",]} 121 | -------------------------------------------------------------------------------- /docs/env.rst: -------------------------------------------------------------------------------- 1 | .. _aio-openapi-env: 2 | 3 | 4 | ====================== 5 | Environment Variables 6 | ====================== 7 | 8 | Several environment variables can be configured at application level 9 | 10 | * **DATASTORE** Connection string for postgresql database 11 | * **BAD_DATA_MESSAGE** (Invalid data format), message displayed when data is not in valid format (not JSON for example) 12 | * **ERROR_500_MESSSAGE** (Internal Server Error), message displayed when things go wrong 13 | * **DBPOOL_MAX_SIZE** (10), maximum number of connections in postgres connection pool 14 | * **DBECHO**, if set to `true` or `yes` it will use `echo=True` when setting up sqlalchemy engine 15 | * **MICRO_SERVICE_PORT** (8080), default port when running the `serve` command 16 | * **MICRO_SERVICE_HOST** (0.0.0.0), default host when running the `serve` command 17 | * **MAX_PAGINATION_LIMIT** (100), maximum number of objects displayed at once 18 | * **DEF_PAGINATION_LIMIT** (50), default value of pagination 19 | * **SPEC_ROUTE** (/spec), path of OpenAPI spec doc (JSON) 20 | -------------------------------------------------------------------------------- /docs/glossary.rst: -------------------------------------------------------------------------------- 1 | .. _aio-openapi-glossary: 2 | 3 | 4 | ========== 5 | Glossary 6 | ========== 7 | 8 | .. if you add new entries, keep the alphabetical sorting! 9 | 10 | .. glossary:: 11 | :sorted: 12 | 13 | asyncio 14 | 15 | Python :mod:`asyncio` module for asynchronous IO programming 16 | 17 | dataclasses 18 | 19 | Python :mod:`dataclasses` module 20 | 21 | openapi 22 | 23 | The OpenAPI_ specification: a broadly adopted industry standard for describing modern APIs. 24 | The most up-to-date versions are available on github OpenAPI-Specification_ 25 | 26 | schema 27 | 28 | Types and syntax for a valid :ref:`aio-openapi-schema` 29 | 30 | 31 | .. _OpenAPI: https://www.openapis.org/ 32 | .. _OpenAPI-Specification: https://github.com/OAI/OpenAPI-Specification/tree/master/versions 33 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. aio-openapi 2 | 3 | ====================== 4 | Welcome to aio-openapi 5 | ====================== 6 | 7 | Asynchronous web middleware for aiohttp_ and serving Rest APIs with OpenAPI_ v 3 8 | specification and with optional PostgreSql database bindings. 9 | 10 | Current version is |release|. 11 | 12 | Installation 13 | ============ 14 | 15 | It requires python 3.8 or above. 16 | 17 | .. code-block:: bash 18 | 19 | pip install aio-openapi 20 | 21 | Development 22 | =========== 23 | 24 | Clone the repository and install dependencies (via poetry): 25 | 26 | .. code-block:: bash 27 | 28 | make install 29 | 30 | To run tests 31 | 32 | .. code-block:: bash 33 | 34 | poetry run pytest --cov 35 | 36 | 37 | Features 38 | ======== 39 | 40 | - Asynchronous web routes with aiohttp_ 41 | - Data validation, serialization and unserialization with python :term:`dataclasses` 42 | - OpenApi_ v 3 auto documentation 43 | - SqlAlchemy_ expression language 44 | - Asynchronous DB interaction with asyncpg_ 45 | - Migrations with alembic_ 46 | - SqlAlchemy tables as python dataclasses 47 | - Support click_ command line interface 48 | - Redoc_ document rendering (like https://api.metablock.io/v1/docs) 49 | - Optional sentry_ middleware 50 | 51 | 52 | Contents 53 | ======== 54 | 55 | .. toctree:: 56 | :maxdepth: 2 57 | 58 | tutorial 59 | reference 60 | validation 61 | queries 62 | websocket 63 | env 64 | glossary 65 | 66 | 67 | Indices and tables 68 | ================== 69 | 70 | * :ref:`genindex` 71 | * :ref:`modindex` 72 | * :ref:`search` 73 | 74 | 75 | .. _aiohttp: https://github.com/aio-libs/aiohttp 76 | .. _OpenApi: https://www.openapis.org/ 77 | .. _sentry: https://sentry.io 78 | .. _click: https://github.com/pallets/click 79 | .. _SqlAlchemy: https://www.sqlalchemy.org/ 80 | .. _alembic: http://alembic.zzzcomputing.com/en/latest/ 81 | .. _asyncpg: https://github.com/MagicStack/asyncpg 82 | .. _Redoc: https://github.com/Redocly/redoc 83 | -------------------------------------------------------------------------------- /docs/queries.rst: -------------------------------------------------------------------------------- 1 | .. _aio-openapi-queries: 2 | 3 | 4 | ======== 5 | Queries 6 | ======== 7 | 8 | .. module:: openapi.pagination 9 | 10 | The library provide some useful tooling for creating dataclasses for validating schema when querying paginated endpoints. 11 | 12 | Pagination 13 | =========== 14 | 15 | Base class 16 | ------------ 17 | 18 | .. autoclass:: Pagination 19 | :members: 20 | 21 | 22 | Paginated Data 23 | --------------- 24 | 25 | .. autoclass:: PaginatedData 26 | :members: 27 | 28 | 29 | Visitor 30 | ------------ 31 | .. autoclass:: PaginationVisitor 32 | :members: 33 | 34 | 35 | Limit/Offset Pagination 36 | ========================= 37 | 38 | .. autofunction:: offsetPagination 39 | 40 | 41 | Cursor Pagination 42 | ========================= 43 | 44 | .. autofunction:: cursorPagination 45 | 46 | 47 | 48 | searchable 49 | ========== 50 | 51 | .. autofunction:: searchable 52 | -------------------------------------------------------------------------------- /docs/reference.rst: -------------------------------------------------------------------------------- 1 | .. _aio-openapi-reference: 2 | 3 | ========= 4 | Reference 5 | ========= 6 | 7 | 8 | Data 9 | ==== 10 | 11 | DataView 12 | -------- 13 | 14 | .. module:: openapi.data.view 15 | 16 | .. autoclass:: DataView 17 | :members: 18 | 19 | TypeInfo 20 | -------- 21 | 22 | .. module:: openapi.utils 23 | 24 | .. autoclass:: TypingInfo 25 | :members: 26 | 27 | 28 | .. _aio-openapi-data-fields: 29 | 30 | Data Fields 31 | =========== 32 | 33 | .. module:: openapi.data.fields 34 | 35 | .. autofunction:: data_field 36 | 37 | 38 | String field 39 | ------------ 40 | .. autofunction:: str_field 41 | 42 | Bool field 43 | ---------- 44 | .. autofunction:: bool_field 45 | 46 | UUID field 47 | ------------- 48 | .. autofunction:: uuid_field 49 | 50 | 51 | Numeric field 52 | ------------- 53 | .. autofunction:: number_field 54 | 55 | 56 | Integer field 57 | ------------- 58 | .. autofunction:: integer_field 59 | 60 | 61 | Email field 62 | ------------- 63 | .. autofunction:: email_field 64 | 65 | Enum field 66 | ---------- 67 | 68 | .. autofunction:: enum_field 69 | 70 | Date field 71 | ---------- 72 | 73 | .. autofunction:: date_field 74 | 75 | Datetime field 76 | -------------- 77 | 78 | .. autofunction:: date_time_field 79 | 80 | 81 | JSON field 82 | ---------- 83 | 84 | .. autofunction:: json_field 85 | 86 | Data Validation 87 | =============== 88 | 89 | .. module:: openapi.data.validate 90 | 91 | Validate 92 | ----------------------- 93 | 94 | The entry function to validate input data and return a python representation. 95 | The function accept as input a valid type annotation or a :class:`.TypingInfo` object. 96 | 97 | .. autofunction:: validate 98 | 99 | 100 | Validate Schema 101 | ----------------------- 102 | 103 | Same as the :func:`.validate` but returns the validation schema object rather than 104 | simple data types (this is mainly different for dataclasses) 105 | 106 | .. autofunction:: validated_schema 107 | 108 | 109 | Dataclass from db table 110 | ----------------------- 111 | .. module:: openapi.data.db 112 | 113 | .. autofunction:: dataclass_from_table 114 | 115 | 116 | Dump data 117 | --------- 118 | .. module:: openapi.data.dump 119 | 120 | .. autofunction:: dump 121 | 122 | 123 | Openapi Specification 124 | ====================== 125 | 126 | .. module:: openapi.spec 127 | 128 | OpenApiInfo 129 | ------------- 130 | 131 | .. autoclass:: OpenApiInfo 132 | :members: 133 | 134 | 135 | OpenApiSpec 136 | ------------- 137 | 138 | .. autoclass:: OpenApiSpec 139 | :members: 140 | 141 | 142 | op decorator 143 | ------------ 144 | 145 | Decorator for specifying schemas at route/method level. It is used by both 146 | the business logic as well the auto-documentation. 147 | 148 | .. autoclass:: op 149 | 150 | Redoc 151 | ------------ 152 | 153 | Allow to add redoc_ redering to your api. 154 | 155 | .. autoclass:: Redoc 156 | 157 | DB 158 | == 159 | 160 | This module provides integration with SqlAlchemy_ asynchronous engine for postgresql. 161 | The connection string supported is of this type only:: 162 | 163 | postgresql+asyncpg://:@:/ 164 | 165 | 166 | .. module:: openapi.db.container 167 | 168 | 169 | Database 170 | -------- 171 | 172 | .. autoclass:: Database 173 | :members: 174 | :member-order: bysource 175 | :special-members: __getattr__ 176 | 177 | 178 | .. module:: openapi.db.dbmodel 179 | 180 | CrudDB 181 | ------ 182 | 183 | Database container with CRUD operations. Used extensively by the :class:`.SqlApiPath` routing class. 184 | 185 | 186 | .. autoclass:: CrudDB 187 | :members: 188 | 189 | 190 | get_db 191 | ------- 192 | 193 | .. module:: openapi.db 194 | 195 | .. autofunction:: get_db 196 | 197 | 198 | .. module:: openapi.testing 199 | 200 | SingleConnDatabase 201 | ------------------ 202 | 203 | A :class:`.CrudDB` container for testing database driven Rest APIs. 204 | 205 | .. autoclass:: SingleConnDatabase 206 | :members: 207 | 208 | 209 | Routes 210 | ====== 211 | 212 | 213 | ApiPath 214 | ---------- 215 | 216 | .. module:: openapi.spec.path 217 | 218 | .. autoclass:: ApiPath 219 | :members: 220 | :member-order: bysource 221 | 222 | 223 | SqlApiPath 224 | ---------- 225 | 226 | .. module:: openapi.db.path 227 | 228 | .. autoclass:: SqlApiPath 229 | :members: 230 | :member-order: bysource 231 | 232 | 233 | Websocket 234 | ========= 235 | 236 | .. module:: openapi.ws.manager 237 | 238 | 239 | Websocket RPC 240 | ------------- 241 | 242 | .. autoclass:: Websocket 243 | :members: 244 | :member-order: bysource 245 | 246 | 247 | SocketsManager 248 | -------------- 249 | 250 | .. autoclass:: SocketsManager 251 | :members: 252 | :member-order: bysource 253 | 254 | 255 | Channels 256 | ----------- 257 | 258 | .. module:: openapi.ws.channels 259 | 260 | .. autoclass:: Channels 261 | :members: 262 | :member-order: bysource 263 | 264 | 265 | Channel 266 | ----------- 267 | 268 | .. module:: openapi.ws.channel 269 | 270 | .. autoclass:: Channel 271 | :members: 272 | :member-order: bysource 273 | 274 | 275 | WsPathMixin 276 | ----------- 277 | 278 | .. module:: openapi.ws.path 279 | 280 | 281 | .. autoclass:: WsPathMixin 282 | :members: 283 | :member-order: bysource 284 | 285 | 286 | .. module:: openapi.ws.pubsub 287 | 288 | Subscribe 289 | ----------- 290 | 291 | .. autoclass:: Subscribe 292 | :members: 293 | :member-order: bysource 294 | 295 | Publish 296 | ----------- 297 | 298 | .. autoclass:: Publish 299 | :members: 300 | :member-order: bysource 301 | 302 | 303 | .. _redoc: https://gith 304 | .. _SqlAlchemy: https://www.sqlalchemy.org/ 305 | -------------------------------------------------------------------------------- /docs/validation.rst: -------------------------------------------------------------------------------- 1 | .. _aio-openapi-validation: 2 | 3 | 4 | =========== 5 | Validation 6 | =========== 7 | 8 | Validation is an important component of the library and it is designed to validate 9 | data to and from JSON serializable objects. 10 | 11 | To validate a simple list of integers 12 | 13 | .. code-block:: python 14 | 15 | from typing import List 16 | 17 | from openapi.data.validate import validate 18 | 19 | validate(List[int], [5,2,4,8]) 20 | # ValidatedData(data=[5, 2, 4, 8], errors={}) 21 | 22 | validate(List[int], [5,2,"5",8]) 23 | # ValidatedData(data=None, errors='not valid type') 24 | 25 | The main object for validation are python dataclasses: 26 | 27 | 28 | .. code-block:: python 29 | 30 | from dataclasses import dataclass 31 | from typing import Union 32 | 33 | @dataclass 34 | class Foo: 35 | text: str 36 | param: Union[str, int] 37 | done: bool = False 38 | 39 | 40 | validate(Foo, {}) 41 | # ValidatedData(data=None, errors={'text': 'required', 'param': 'required'}) 42 | 43 | validate(Foo, dict(text=1)) 44 | # ValidatedData(data=None, errors={'text': 'not valid type', 'param': 'required'}) 45 | 46 | validate(Foo, dict(text="ciao", param=3)) 47 | # ValidatedData(data={'text': 'ciao', 'param': 3, 'done': False}, errors={}) 48 | 49 | 50 | 51 | Validated Schema 52 | ================ 53 | 54 | Use the :func:`.validated_schema` to validate input data and return an instance of the 55 | validation schema. This differs from :func:`.validate` only when dataclasses are involved 56 | 57 | .. code-block:: python 58 | 59 | from openapi.data.validate import validated_schema 60 | 61 | validated_schema(Foo, dict(text="ciao", param=3)) 62 | # Foo(text='ciao', param=3, done=False) 63 | 64 | 65 | .. _aio-openapi-schema: 66 | 67 | Supported Schema 68 | ================ 69 | 70 | The library support the following schemas 71 | 72 | * Primitive types: ``str``, ``bytes``, ``int``, ``float``, ``bool``, ``date``, ``datetime`` and ``Decimal`` 73 | * Python :mod:`dataclasses` with fields from this supported schema 74 | * ``List`` from ``typing`` annotation with items from this supported schema 75 | * ``Dict`` from ``typing`` with keys as string and items from this supported schema 76 | * ``Union`` from ``typing`` with items from this supported schema 77 | * ``Any`` to skip validation and allow for any value 78 | 79 | Additional, and more powerful, validation can be achieved via the use of custom :func:`dataclasses.field` 80 | constructors (see :ref:`aio-openapi-data-fields` reference). 81 | 82 | .. code-block:: python 83 | 84 | from dataclasses import dataclass 85 | from typing import Union 86 | from openapi.data import fields 87 | 88 | @dataclass 89 | class Foo: 90 | text: str = fields.str_field(min_length=3, description="Just some text") 91 | param: Union[str, int] = fields.integer_field(description="String accepted but convert to int") 92 | done: bool = False = fields.bool_field(description="Is Foo done?") 93 | 94 | validated_schema(Foo, dict(text="ciao", param="2", done="no")) 95 | # Foo(text='ciao', param=2, done=False) 96 | 97 | 98 | Dump 99 | ==== 100 | 101 | Validated schema can be dump into valid JSON via the :func:`.dump` function 102 | -------------------------------------------------------------------------------- /docs/websocket.rst: -------------------------------------------------------------------------------- 1 | .. _aio-openapi-websocket: 2 | 3 | 4 | ============== 5 | Websocket RPC 6 | ============== 7 | 8 | 9 | The library includes a minimal API for Websocket JSON-RPC (remote procedure calls). 10 | 11 | To add websockets RPC you need to create a websocket route: 12 | 13 | .. code-block:: python 14 | 15 | from aiohttp import web 16 | from openapi.spec import ApiPath 17 | from openapi.ws import WsPathMixin 18 | from openapi.ws.pubsub import Publish, Subscribe 19 | 20 | ws_routes = web.RouteTableDef() 21 | 22 | @ws_routes.view("/stream") 23 | class Websocket(ApiPath, WsPathMixin, Subscribe, Publish): 24 | 25 | async def ws_rpc_info(self, payload): 26 | """Server information""" 27 | return self.sockets.server_info() 28 | 29 | the :class:`.WsPathMixin` adds the get method for accepting websocket requests with the RPC protocol. 30 | :class:`.Subscribe` and :class:`.Publish` are optional mixins for adding 31 | Pub/Sub RPC methods to the endpoint. 32 | 33 | The endpoint can be added to an application in the setup function: 34 | 35 | .. code-block:: python 36 | 37 | from aiohttp.web import Application 38 | 39 | from openapi.ws import SocketsManager 40 | 41 | def setup_app(app: Application) -> None: 42 | app['web_sockets'] = SocketsManager() 43 | app.router.add_routes(ws_routes) 44 | 45 | RPC protocol 46 | =============== 47 | 48 | The RPC protocol has the following structure for incoming messages 49 | 50 | .. code-block:: javascript 51 | 52 | { 53 | "id": "abc", 54 | "method": "rpc_method_name", 55 | "payload": { 56 | ... 57 | } 58 | } 59 | 60 | The `id` is used by clients to link the request with the corresponding response. 61 | The response for an RPC call is either a success 62 | 63 | .. code-block:: javascript 64 | 65 | { 66 | "id": "abc", 67 | "method": "rpc_method_name", 68 | "response": { 69 | ... 70 | } 71 | } 72 | 73 | 74 | or an error 75 | 76 | .. code-block:: javascript 77 | 78 | { 79 | "id": "abc", 80 | "method": "rpc_method_name": 81 | "error": { 82 | ... 83 | } 84 | } 85 | 86 | 87 | Publish/Subscribe 88 | ================= 89 | 90 | To subscribe to messages, one need to use the :class:`.Subscribe` mixin with the websocket route (like we have done in this example). 91 | Messages take the form: 92 | 93 | .. code-block:: javascript 94 | 95 | { 96 | "channel": "channel_name", 97 | "event": "event_name", 98 | "data": { 99 | ... 100 | } 101 | } 102 | 103 | 104 | Backend 105 | ======== 106 | 107 | The websocket backend is implemented by subclassing the :class:`.SocketsManager` and implement the methods required by your application. 108 | This example implements a very simple backend for testing the websocket module in unittests. 109 | 110 | 111 | .. code-block:: python 112 | 113 | import asyncio 114 | 115 | from aiohttp import web 116 | from openapi.ws.manager import SocketsManager 117 | 118 | class LocalBroker(SocketsManager): 119 | """A local broker for testing""" 120 | 121 | def __init__(self): 122 | self.binds = set() 123 | self.messages: asyncio.Queue = asyncio.Queue() 124 | self.worker = None 125 | self._stop = False 126 | 127 | @classmethod 128 | def for_app(cls, app: web.Application) -> "LocalBroker": 129 | broker = cls() 130 | app.on_startup.append(broker.start) 131 | app.on_shutdown.append(broker.close) 132 | return broker 133 | 134 | async def start(self, *arg): 135 | if not self.worker: 136 | self.worker = asyncio.ensure_future(self._work()) 137 | 138 | async def publish(self, channel: str, event: str, body: Any): 139 | """simulate network latency""" 140 | if channel.lower() != channel: 141 | raise CannotPublish 142 | payload = dict(event=event, data=self.get_data(body)) 143 | asyncio.get_event_loop().call_later( 144 | 0.01, self.messages.put_nowait, (channel, payload) 145 | ) 146 | 147 | async def subscribe(self, channel: str) -> None: 148 | """ force channel names to be lowercase""" 149 | if channel.lower() != channel: 150 | raise CannotSubscribe 151 | 152 | async def close(self, *arg): 153 | self._stop = True 154 | await self.close_sockets() 155 | if self.worker: 156 | self.messages.put_nowait((None, None)) 157 | await self.worker 158 | self.worker = None 159 | 160 | async def _work(self): 161 | while True: 162 | channel, body = await self.messages.get() 163 | if self._stop: 164 | break 165 | await self.channels(channel, body) 166 | 167 | def get_data(self, data: Any) -> Any: 168 | if data == "error": 169 | return self.raise_error 170 | elif data == "runtime_error": 171 | return self.raise_runtime 172 | return data 173 | 174 | def raise_error(self): 175 | raise ValueError 176 | 177 | def raise_runtime(self): 178 | raise RuntimeError 179 | -------------------------------------------------------------------------------- /openapi/__init__.py: -------------------------------------------------------------------------------- 1 | """Minimal OpenAPI asynchronous server application""" 2 | __version__ = "3.2.1" 3 | -------------------------------------------------------------------------------- /openapi/cli.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import sys 4 | from functools import lru_cache 5 | from typing import Callable, Iterable, List, Optional 6 | 7 | import click 8 | from aiohttp import web 9 | from aiohttp.web import Application 10 | 11 | from .logger import logger, setup_logging 12 | from .spec import OpenApiSpec 13 | from .utils import get_debug_flag 14 | 15 | HOST = os.environ.get("MICRO_SERVICE_HOST", "0.0.0.0") 16 | PORT = os.environ.get("MICRO_SERVICE_PORT", 8080) 17 | 18 | 19 | class OpenApiClient(click.Group): 20 | def __init__( 21 | self, 22 | spec: Optional[OpenApiSpec] = None, 23 | setup_app: Optional[Callable[[Application], None]] = None, 24 | base_path: str = "", 25 | commands: Optional[List] = None, 26 | index: int = -1, 27 | loop: Optional[asyncio.AbstractEventLoop] = None, 28 | **extra, 29 | ) -> None: 30 | params = list(extra.pop("params", None) or ()) 31 | self.spec = spec 32 | self.debug = get_debug_flag() 33 | self.setup_app = setup_app 34 | self.base_path: str = base_path or "" 35 | self.index = index 36 | self.loop = loop 37 | params.extend( 38 | ( 39 | click.Option( 40 | ["--version"], 41 | help="Show the server version", 42 | expose_value=False, 43 | callback=self.get_server_version, 44 | is_flag=True, 45 | is_eager=True, 46 | ), 47 | click.Option( 48 | ["-v", "--verbose"], 49 | help="Increase logging verbosity", 50 | is_flag=True, 51 | is_eager=True, 52 | ), 53 | click.Option( 54 | ["-q", "--quiet"], 55 | help="Decrease logging verbosity", 56 | is_flag=True, 57 | is_eager=True, 58 | ), 59 | ) 60 | ) 61 | extra.setdefault("callback", setup_logging) 62 | super().__init__(params=params, **extra) 63 | self.add_command(serve) 64 | for command in commands or (): 65 | self.add_command(command) 66 | 67 | @lru_cache(None) 68 | def web(self, server: bool = False) -> Application: 69 | """Return the web application""" 70 | app = Application() 71 | app["cli"] = self 72 | app["cwd"] = os.getcwd() 73 | app["index"] = self.index 74 | app["server"] = server 75 | if self.spec: 76 | self.spec.setup_app(app) 77 | if self.setup_app: 78 | self.setup_app(app) 79 | return app 80 | 81 | def get_serve_app(self) -> Application: 82 | """Create the application which runs the server""" 83 | app = self.web(server=True) 84 | if self.base_path: 85 | base = Application() 86 | base.add_subapp(self.base_path, app) 87 | base["cli"] = self 88 | app = base 89 | return app 90 | 91 | def get_command(self, ctx: click.Context, name: str) -> Optional[click.Command]: 92 | ctx.obj = dict(cli=self) 93 | return super().get_command(ctx, name) 94 | 95 | def list_commands(self, ctx: click.Context) -> Iterable[str]: 96 | ctx.obj = dict(cli=self) 97 | return super().list_commands(ctx) 98 | 99 | def get_server_version(self, ctx, param, value) -> None: 100 | if not value or ctx.resilient_parsing: 101 | return 102 | spec = self.spec 103 | message = "%(title)s %(version)s\nPython %(python_version)s" 104 | click.echo( 105 | message 106 | % { 107 | "title": spec.info.title if spec else self.name or "Open API", 108 | "version": spec.info.version if spec else "", 109 | "python_version": sys.version, 110 | }, 111 | color=ctx.color, 112 | ) 113 | ctx.exit() 114 | 115 | 116 | def open_api_cli(ctx: click.Context) -> OpenApiClient: 117 | return ctx.obj["cli"] 118 | 119 | 120 | @click.command("serve", short_help="Start aiohttp server.") 121 | @click.option( 122 | "--host", "-h", default=HOST, help=f"The interface to bind to (default to {HOST})" 123 | ) 124 | @click.option( 125 | "--port", "-p", default=PORT, help=f"The port to bind to (default to {PORT}." 126 | ) 127 | @click.option( 128 | "--index", 129 | default=0, 130 | type=int, 131 | help="Optional index for stateful set deployment", 132 | ) 133 | @click.option( 134 | "--reload/--no-reload", 135 | default=None, 136 | help="Enable or disable the reloader. By default the reloader " 137 | "is active if debug is enabled.", 138 | ) 139 | @click.pass_context 140 | def serve(ctx, host, port, index, reload): 141 | """Run the aiohttp server.""" 142 | cli = open_api_cli(ctx) 143 | cli.index = index 144 | app = cli.get_serve_app() 145 | access_log = logger if ctx.obj["log_level"] else None 146 | web.run_app( 147 | app, 148 | host=host, 149 | port=port, 150 | access_log=access_log, 151 | loop=cli.loop, 152 | print=access_log.info if access_log else None, 153 | ) 154 | -------------------------------------------------------------------------------- /openapi/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantmind/aio-openapi/afe56f7b36cadf32643569b8ffce63da29802801/openapi/data/__init__.py -------------------------------------------------------------------------------- /openapi/data/db.py: -------------------------------------------------------------------------------- 1 | from dataclasses import Field, make_dataclass 2 | from datetime import date, datetime 3 | from decimal import Decimal 4 | from functools import partial 5 | from typing import ( 6 | Callable, 7 | Dict, 8 | List, 9 | Optional, 10 | Sequence, 11 | Set, 12 | Tuple, 13 | Type, 14 | Union, 15 | cast, 16 | ) 17 | 18 | import sqlalchemy as sa 19 | from sqlalchemy_utils import UUIDType 20 | 21 | from . import fields 22 | 23 | ConverterType = Callable[[sa.Column, bool, bool, Sequence[str]], Tuple[Type, Field]] 24 | CONVERTERS: Dict[str, ConverterType] = {} 25 | 26 | 27 | def dataclass_from_table( 28 | name: str, 29 | table: sa.Table, 30 | *, 31 | exclude: Optional[Sequence[str]] = None, 32 | include: Optional[Sequence[str]] = None, 33 | default: Union[bool, Sequence[str]] = False, 34 | required: Union[bool, Sequence[str]] = False, 35 | ops: Optional[Dict[str, Sequence[str]]] = None, 36 | ) -> Type: 37 | """Create a dataclass from an :class:`sqlalchemy.schema.Table` 38 | 39 | :param name: dataclass name 40 | :param table: sqlalchemy table 41 | :param exclude: fields to exclude from the dataclass 42 | :param include: fields to include in the dataclass 43 | :param default: use columns defaults in the dataclass 44 | :param required: set non nullable columns without a default as 45 | required fields in the dataclass 46 | :param ops: additional operation for fields 47 | """ 48 | columns = [] 49 | includes = set(include or table.columns.keys()) - set(exclude or ()) 50 | defaults = column_info(includes, default) 51 | requireds = column_info(includes, required) 52 | column_ops = cast(Dict[str, Sequence[str]], ops or {}) 53 | for col in table.columns: 54 | if col.name not in includes: 55 | continue 56 | ctype = type(col.type) 57 | converter = CONVERTERS.get(ctype) 58 | if not converter: # pragma: no cover 59 | raise NotImplementedError(f"Cannot convert column {col.name}: {ctype}") 60 | required = col.name in requireds 61 | use_default = col.name in defaults 62 | field = ( 63 | col.name, 64 | *converter(col, required, use_default, column_ops.get(col.name, ())), 65 | ) 66 | columns.append(field) 67 | return make_dataclass(name, columns) 68 | 69 | 70 | def column_info(columns: Set[str], value: Union[bool, Sequence[str]]) -> Set[str]: 71 | if value is False: 72 | return set() 73 | elif value is True: 74 | return columns.copy() 75 | else: 76 | return set(value if value is not None else columns) 77 | 78 | 79 | def converter(*types): 80 | def _(f): 81 | for type_ in types: 82 | CONVERTERS[type_] = f 83 | return f 84 | 85 | return _ 86 | 87 | 88 | @converter(sa.Boolean) 89 | def bl( 90 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 91 | ) -> Tuple[Type, Field]: 92 | data_field = col.info.get("data_field", fields.bool_field) 93 | return (bool, data_field(**info(col, required, use_default, ops))) 94 | 95 | 96 | @converter(sa.Integer) 97 | def integer( 98 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 99 | ) -> Tuple[Type, Field]: 100 | data_field = col.info.get("data_field", fields.number_field) 101 | return (int, data_field(precision=0, **info(col, required, use_default, ops))) 102 | 103 | 104 | @converter(sa.Numeric) 105 | def number( 106 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 107 | ) -> Tuple[Type, Field]: 108 | data_field = col.info.get("data_field", fields.decimal_field) 109 | return ( 110 | Decimal, 111 | data_field(precision=col.type.scale, **info(col, required, use_default, ops)), 112 | ) 113 | 114 | 115 | @converter(sa.String, sa.Text, sa.CHAR, sa.VARCHAR) 116 | def string( 117 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 118 | ) -> Tuple[Type, Field]: 119 | data_field = col.info.get("data_field", fields.str_field) 120 | return ( 121 | str, 122 | data_field( 123 | max_length=col.type.length or 0, **info(col, required, use_default, ops) 124 | ), 125 | ) 126 | 127 | 128 | @converter(sa.DateTime) 129 | def dt_ti( 130 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 131 | ) -> Tuple[Type, Field]: 132 | data_field = col.info.get("data_field", fields.date_time_field) 133 | return ( 134 | datetime, 135 | data_field(timezone=col.type.timezone, **info(col, required, use_default, ops)), 136 | ) 137 | 138 | 139 | @converter(sa.Date) 140 | def dt( 141 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 142 | ) -> Tuple[Type, Field]: 143 | data_field = col.info.get("data_field", fields.date_field) 144 | return (date, data_field(**info(col, required, use_default, ops))) 145 | 146 | 147 | @converter(sa.Enum) 148 | def en( 149 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 150 | ) -> Tuple[Type, Field]: 151 | data_field = col.info.get("data_field", fields.enum_field) 152 | return ( 153 | col.type.enum_class, 154 | data_field(col.type.enum_class, **info(col, required, use_default, ops)), 155 | ) 156 | 157 | 158 | @converter(sa.JSON) 159 | def js( 160 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 161 | ) -> Tuple[Type, Field]: 162 | data_field = col.info.get("data_field", fields.json_field) 163 | val = None 164 | if col.default: 165 | arg = col.default.arg 166 | val = arg() if col.default.is_callable else arg 167 | return ( 168 | JsonTypes.get(type(val), Dict), 169 | data_field(**info(col, required, use_default, ops)), 170 | ) 171 | 172 | 173 | @converter(UUIDType) 174 | def uuid( 175 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 176 | ) -> Tuple[Type, Field]: 177 | data_field = col.info.get("data_field", fields.uuid_field) 178 | return (str, data_field(**info(col, required, use_default, ops))) 179 | 180 | 181 | def info( 182 | col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] 183 | ) -> Tuple[Type, Field]: 184 | data = dict(ops=ops) 185 | if use_default: 186 | default = col.default.arg if col.default is not None else None 187 | if callable(default): 188 | data.update(default_factory=partial(default, None)) 189 | required = False 190 | elif isinstance(default, (list, dict, set)): 191 | data.update(default_factory=lambda: default.copy()) 192 | required = False 193 | else: 194 | data.update(default=default) 195 | if required and (col.nullable or default is not None): 196 | required = False 197 | elif required and col.nullable: 198 | required = False 199 | data.update(required=required) 200 | if col.doc: 201 | data.update(description=col.doc) 202 | data.update(col.info) 203 | data.pop("data_field", None) 204 | return data 205 | 206 | 207 | JsonTypes = {list: List, dict: Dict} 208 | -------------------------------------------------------------------------------- /openapi/data/dump.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, fields 2 | from typing import Any, Dict, List, Optional, Union, cast 3 | 4 | from openapi.types import Record 5 | 6 | from ..utils import TypingInfo, iter_items 7 | from .fields import DUMP 8 | 9 | 10 | def is_nothing(value: Any) -> bool: 11 | if value == 0 or value is False: 12 | return False 13 | return not value 14 | 15 | 16 | def dump(schema: Any, data: Any) -> Any: 17 | """Dump data with a given schema. 18 | 19 | :param schema: a valid :ref:`aio-openapi-schema` 20 | :param data: data to dump, if dataclasses are part of the schema, 21 | the `dump` metadata function will be used if available (see :func:`.data_field`) 22 | """ 23 | type_info = cast(TypingInfo, TypingInfo.get(schema)) 24 | if type_info.container is list: 25 | return dump_list(type_info.element, cast(List, data)) 26 | elif type_info.container is dict: 27 | return dump_dict(type_info.element, cast(Dict, data)) 28 | elif type_info.is_dataclass: 29 | return dump_dataclass(type_info.element, data) 30 | else: 31 | return data 32 | 33 | 34 | def dump_dataclass(schema: Any, data: Optional[Union[Dict, Record]] = None) -> Dict: 35 | """Dump a dictionary of data with a given dataclass dump functions 36 | If the data is not given, the schema object is assumed to be 37 | an instance of a dataclass. 38 | """ 39 | if data is None: 40 | data = asdict(schema) 41 | elif isinstance(data, schema): 42 | data = asdict(data) 43 | cleaned = {} 44 | fields_ = {f.name: f for f in fields(schema)} 45 | for name, value in iter_items(data): 46 | if name not in fields_ or is_nothing(value): 47 | continue 48 | field = fields_[name] 49 | dump_value = field.metadata.get(DUMP) 50 | if dump_value: 51 | value = dump_value(value) 52 | cleaned[field.name] = dump(field.type, value) 53 | 54 | return cleaned 55 | 56 | 57 | def dump_list(schema: Any, data: List) -> List[Dict]: 58 | """Validate a dictionary of data with a given dataclass""" 59 | return [dump(schema, d) for d in data] 60 | 61 | 62 | def dump_dict(schema: Any, data: Dict[str, Any]) -> List[Dict]: 63 | """Validate a dictionary of data with a given dataclass""" 64 | return {name: dump(schema, d) for name, d in data.items()} 65 | -------------------------------------------------------------------------------- /openapi/data/exc.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from .fields import data_field, str_field 5 | 6 | 7 | @dataclass 8 | class ErrorMessage: 9 | """Error message and list of errors for data fields""" 10 | 11 | message: str = str_field(description="Error message") 12 | 13 | 14 | @dataclass 15 | class FieldError(ErrorMessage): 16 | """Error message for a data field""" 17 | 18 | field: str = str_field(description="name of the data field with error") 19 | 20 | 21 | @dataclass 22 | class ValidationErrors(ErrorMessage): 23 | """Error message and list of errors for data fields""" 24 | 25 | errors: List[FieldError] = data_field(description="List of field errors") 26 | 27 | 28 | def error_response_schema(status): 29 | return ValidationErrors if status == 422 else ErrorMessage 30 | -------------------------------------------------------------------------------- /openapi/data/view.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, NoReturn, Optional, cast 4 | 5 | from aiohttp import web 6 | 7 | from ..types import DataType, QueryType 8 | from ..utils import TypingInfo, as_list, compact 9 | from .dump import dump 10 | from .validate import ErrorType, ValidationErrors, validate 11 | 12 | BAD_DATA_MESSAGE = os.getenv("BAD_DATA_MESSAGE", "Invalid data format") 13 | 14 | 15 | @dataclass 16 | class Operation: 17 | body_schema: Optional[TypingInfo] = None 18 | query_schema: Optional[TypingInfo] = None 19 | response_schema: Optional[TypingInfo] = None 20 | response: int = 200 21 | 22 | 23 | class DataView: 24 | """Utility class for data with a valid :ref:`aio-openapi-schema`""" 25 | 26 | operation: Operation = Operation() 27 | 28 | def cleaned( 29 | self, 30 | schema: Any, 31 | data: QueryType, 32 | *, 33 | multiple: bool = False, 34 | strict: bool = True, 35 | Error: Optional[type] = None, 36 | ) -> DataType: 37 | """Clean data using a given schema 38 | 39 | :param schema: a valid :ref:`aio-openapi-schema` or an the name of an 40 | attribute in :class:`.Operation` 41 | :param data: data to validate and clean 42 | :param multiple: multiple values for a given key are acceptable 43 | :param strict: all required attributes in schema must be available 44 | :param Error: optional :class:`.Exception` class 45 | """ 46 | type_info = self.get_schema(schema) 47 | validated = validate(type_info, data, strict=strict, multiple=multiple) 48 | if validated.errors: 49 | if Error: 50 | raise Error 51 | elif schema == "path_schema": 52 | raise web.HTTPNotFound 53 | self.raise_validation_error(errors=validated.errors) 54 | return validated.data 55 | 56 | def dump(self, schema: Any, data: DataType) -> DataType: 57 | """Dump data using a given a valid :ref:`aio-openapi-schema`, 58 | if the schema is `None` it returns the same `data` as the input. 59 | 60 | :param schema: a schema or an the name of an attribute in :class:`.Operation` 61 | :param data: data to clean and dump 62 | """ 63 | return data if schema is None else dump(self.get_schema(schema), data) 64 | 65 | def get_schema(self, schema: Any = None) -> TypingInfo: 66 | """Get the :ref:`aio-openapi-schema`. If not found it raises an exception 67 | 68 | :param schema: a schema or an the name of an attribute in :class:`.Operation` 69 | """ 70 | if isinstance(schema, str): 71 | Schema = getattr(self.operation, schema, None) 72 | else: 73 | Schema = schema 74 | if Schema is None: 75 | Schema = getattr(self, str(schema), None) 76 | if Schema is None: 77 | raise web.HTTPNotImplemented 78 | return cast(TypingInfo, TypingInfo.get(Schema)) 79 | 80 | def validation_error( 81 | self, message: str = "", errors: Optional[ErrorType] = None 82 | ) -> Exception: 83 | """Create the validation exception used by :meth:`.raise_validation_error`""" 84 | return ValidationErrors(self.as_errors(message, errors)) 85 | 86 | def raise_validation_error( 87 | self, message: str = "", errors: Optional[ErrorType] = None 88 | ) -> NoReturn: 89 | """Raise an :class:`aiohttp.web.HTTPUnprocessableEntity`""" 90 | raise self.validation_error(message, errors) 91 | 92 | def raise_bad_data( 93 | self, exc: Optional[Exception] = None, message: str = "" 94 | ) -> None: 95 | if not message and exc: 96 | raise exc from exc 97 | raise TypeError(message or BAD_DATA_MESSAGE) 98 | 99 | def as_errors(self, message: str = "", errors: Optional[ErrorType] = None) -> Dict: 100 | if isinstance(errors, str): 101 | message = cast(str, message or errors) 102 | errors = None 103 | return compact(message=message, errors=as_list(errors or ())) 104 | -------------------------------------------------------------------------------- /openapi/db/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from aiohttp.web import Application 5 | 6 | from .container import Database 7 | from .dbmodel import CrudDB 8 | 9 | __all__ = ["compile_query", "Database", "CrudDB", "get_db"] 10 | 11 | 12 | def get_db(app: Application, store_url: Optional[str] = None) -> Optional[CrudDB]: 13 | """Create an Open API db handler and set it for use in an aiohttp application 14 | 15 | :param app: aiohttp Application 16 | :param store_url: datastore connection string, if not provided the env 17 | variable `DATASTORE` is used instead. If the env variable is not available 18 | either the method logs a warning and return `None` 19 | 20 | This function 1) adds the database to the aiohttp application at key "db", 21 | 2) add the db command to the command line client (if command is True) 22 | and 3) add the close handler on application shutdown 23 | """ 24 | store_url = store_url or os.environ.get("DATASTORE") 25 | if not store_url: # pragma: no cover 26 | app.logger.warning("DATASTORE url not available") 27 | return None 28 | else: 29 | app["db"] = CrudDB(store_url) 30 | app.on_shutdown.append(close_db) 31 | return app["db"] 32 | 33 | 34 | async def close_db(app: Application) -> None: 35 | await app["db"].close() 36 | -------------------------------------------------------------------------------- /openapi/db/columns.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import sqlalchemy as sa 4 | from sqlalchemy_utils import UUIDType 5 | 6 | 7 | def UUIDColumn(name, primary_key=True, nullable=False, make_default=False, **kw): 8 | if primary_key and nullable: 9 | raise RuntimeError("Primary key must be NOT NULL") 10 | if make_default: 11 | kw.setdefault("default", uuid.uuid4) 12 | return sa.Column(name, UUIDType(), primary_key=primary_key, nullable=nullable, **kw) 13 | -------------------------------------------------------------------------------- /openapi/db/commands.py: -------------------------------------------------------------------------------- 1 | import click 2 | from sqlalchemy import inspect 3 | from sqlalchemy_utils import create_database, database_exists, drop_database 4 | 5 | from openapi.cli import open_api_cli 6 | 7 | from .dbmodel import CrudDB 8 | from .migrations import Migration 9 | 10 | 11 | def migration(ctx: click.Context) -> Migration: 12 | return Migration(open_api_cli(ctx).web()) 13 | 14 | 15 | def get_db(ctx: click.Context) -> CrudDB: 16 | return open_api_cli(ctx).web()["db"] 17 | 18 | 19 | @click.group() 20 | def db(): 21 | """Perform database migrations and utilities""" 22 | pass 23 | 24 | 25 | @db.command() 26 | @click.pass_context 27 | def init(ctx): 28 | """Creates a new migration repository.""" 29 | migration(ctx).init() 30 | 31 | 32 | @db.command() 33 | @click.option("-m", "--message", help="Revision message", required=True) 34 | @click.option( 35 | "--branch-label", help="Specify a branch label to apply to the new revision" 36 | ) 37 | @click.pass_context 38 | def migrate(ctx, message: str, branch_label: str): 39 | """Autogenerate a new revision file 40 | 41 | alias for 'revision --autogenerate' 42 | """ 43 | return migration(ctx).revision( 44 | message, autogenerate=True, branch_label=branch_label 45 | ) 46 | 47 | 48 | @db.command() 49 | @click.option("-m", "--message", help="Revision message", required=True) 50 | @click.option( 51 | "--branch-label", help="Specify a branch label to apply to the new revision" 52 | ) 53 | @click.option( 54 | "--autogenerate", 55 | default=False, 56 | is_flag=True, 57 | help=( 58 | "Populate revision script with candidate migration " 59 | "operations, based on comparison of database to model" 60 | ), 61 | ) 62 | @click.pass_context 63 | def revision(ctx, message: str, branch_label: str, autogenerate: bool): 64 | """Autogenerate a new revision file""" 65 | return migration(ctx).revision( 66 | message, autogenerate=autogenerate, branch_label=branch_label 67 | ) 68 | 69 | 70 | @db.command() 71 | @click.option("--revision", default="heads") 72 | @click.option( 73 | "--drop-tables", 74 | default=False, 75 | is_flag=True, 76 | help="Drop tables before applying migrations", 77 | ) 78 | @click.pass_context 79 | def upgrade(ctx, revision: str, drop_tables: bool): 80 | """Upgrade to a later version""" 81 | if drop_tables: 82 | _drop_tables(ctx) 83 | migration(ctx).upgrade(revision) 84 | click.echo(f"upgraded successfully to {revision}") 85 | 86 | 87 | @db.command() 88 | @click.option("--revision", help="Revision id", required=True) 89 | @click.pass_context 90 | def downgrade(ctx, revision: str): 91 | """Downgrade to a previous version""" 92 | migration(ctx).downgrade(revision) 93 | click.echo(f"downgraded successfully to {revision}") 94 | 95 | 96 | @db.command() 97 | @click.option("--revision", default="heads") 98 | @click.pass_context 99 | def show(ctx, revision: str): 100 | """Show revision ID and creation date""" 101 | click.echo(migration(ctx).show(revision)) 102 | 103 | 104 | @db.command() 105 | @click.pass_context 106 | def history(ctx): 107 | """List changeset scripts in chronological order""" 108 | click.echo(migration(ctx).history()) 109 | 110 | 111 | @db.command() 112 | @click.option("--verbose/--quiet", default=False) 113 | @click.pass_context 114 | def current(ctx, verbose: bool): 115 | """Show revision ID and creation date""" 116 | click.echo(migration(ctx).current(verbose)) 117 | 118 | 119 | @db.command() 120 | @click.argument("dbname", nargs=1) 121 | @click.option( 122 | "--force", default=False, is_flag=True, help="Force removal of an existing database" 123 | ) 124 | @click.pass_context 125 | def create(ctx, dbname: str, force: str): 126 | """Creates a new database""" 127 | engine = get_db(ctx).sync_engine 128 | url = engine.url.set(database=dbname) 129 | if database_exists(url): 130 | if force: 131 | drop_database(url) 132 | else: 133 | return click.echo(f"database {dbname} already available") 134 | create_database(url) 135 | click.echo(f"database {dbname} created") 136 | 137 | 138 | @db.command() 139 | @click.option( 140 | "--db", 141 | default=False, 142 | is_flag=True, 143 | help="List tables in database rather than in sqlalchemy metadata", 144 | ) 145 | @click.pass_context 146 | def tables(ctx, db): 147 | """List all tables managed by the app""" 148 | d = get_db(ctx) 149 | if db: 150 | tables = inspect(d.sync_engine).get_table_names() 151 | else: 152 | tables = d.metadata.tables 153 | for name in sorted(tables): 154 | click.echo(name) 155 | 156 | 157 | @db.command() 158 | @click.pass_context 159 | def drop(ctx): 160 | """Drop all tables in database""" 161 | _drop_tables(ctx) 162 | 163 | 164 | def _drop_tables(ctx): 165 | get_db(ctx).drop_all_schemas() 166 | click.echo("tables dropped") 167 | -------------------------------------------------------------------------------- /openapi/db/container.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import asynccontextmanager 3 | from typing import Any, Optional 4 | 5 | import sqlalchemy as sa 6 | from sqlalchemy.engine import Engine, create_engine 7 | from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine 8 | 9 | from openapi.types import Connection 10 | from openapi.utils import str2bool 11 | 12 | from ..exc import ImproperlyConfigured 13 | 14 | DBPOOL_MAX_SIZE = int(os.environ.get("DBPOOL_MAX_SIZE") or "10") 15 | DBPOOL_MAX_OVERFLOW = int(os.environ.get("DBPOOL_MAX_OVERFLOW") or "10") 16 | DBECHO = str2bool(os.environ.get("DBECHO") or "no") 17 | 18 | 19 | class Database: 20 | """A container for tables in a database and a manager of asynchronous 21 | connections to a postgresql database 22 | 23 | :param dsn: Data source name used for database connections 24 | :param metadata: :class:`sqlalchemy.schema.MetaData` containing tables 25 | """ 26 | 27 | def __init__(self, dsn: str = "", metadata: sa.MetaData = None) -> None: 28 | self._dsn = dsn 29 | self._metadata = metadata or sa.MetaData() 30 | self._engine = None 31 | 32 | def __repr__(self) -> str: 33 | return self._dsn 34 | 35 | __str__ = __repr__ 36 | 37 | @property 38 | def dsn(self) -> str: 39 | """Data source name used for database connections""" 40 | return self._dsn 41 | 42 | @property 43 | def metadata(self) -> sa.MetaData: 44 | """The :class:`sqlalchemy.schema.MetaData` containing tables""" 45 | return self._metadata 46 | 47 | @property 48 | def engine(self) -> AsyncEngine: 49 | """The :class:`sqlalchemy.ext.asyncio.AsyncEngine` creating connection 50 | and transactions""" 51 | if self._engine is None: 52 | if not self._dsn: 53 | raise ImproperlyConfigured("DSN not available") 54 | self._engine = create_async_engine( 55 | self._dsn, 56 | echo=DBECHO, 57 | pool_size=DBPOOL_MAX_SIZE, 58 | max_overflow=DBPOOL_MAX_OVERFLOW, 59 | ) 60 | return self._engine 61 | 62 | @property 63 | def sync_engine(self) -> Engine: 64 | """The :class:`sqlalchemy.engine.Engine` for synchrouns operations""" 65 | return create_engine(self._dsn.replace("+asyncpg", "")) 66 | 67 | def __getattr__(self, name: str) -> Any: 68 | """Retrive a :class:`sqlalchemy.schema.Table` from metadata tables 69 | 70 | :param name: if this is a valid table name in the tables of :attr:`.metadata` 71 | it returns the table, otherwise it defaults to superclass method 72 | """ 73 | if name in self._metadata.tables: 74 | return self._metadata.tables[name] 75 | return super().__getattribute__(name) 76 | 77 | @asynccontextmanager 78 | async def connection(self) -> Connection: 79 | """Context manager for obtaining an asynchronous connection""" 80 | async with self.engine.connect() as conn: 81 | yield conn 82 | 83 | @asynccontextmanager 84 | async def transaction(self) -> Connection: 85 | """Context manager for initializing an asynchronous database transaction""" 86 | async with self.engine.begin() as conn: 87 | yield conn 88 | 89 | @asynccontextmanager 90 | async def ensure_transaction(self, conn: Optional[Connection] = None) -> Connection: 91 | """Context manager for ensuring we a connection has initialized 92 | a database transaction""" 93 | if conn: 94 | if not conn.in_transaction(): 95 | async with conn.begin(): 96 | yield conn 97 | else: 98 | yield conn 99 | else: 100 | async with self.transaction() as conn: 101 | yield conn 102 | 103 | # backward compatibility 104 | ensure_connection = ensure_transaction 105 | 106 | async def close(self) -> None: 107 | """Close the asynchronous db engine if opened""" 108 | if self._engine: 109 | engine, self._engine = self._engine, None 110 | await engine.dispose() 111 | 112 | # SQL Alchemy Sync Operations 113 | def create_all(self) -> None: 114 | """Create all tables defined in :attr:`metadata`""" 115 | self.metadata.create_all(self.sync_engine) 116 | 117 | def drop_all(self) -> None: 118 | """Drop all tables from :attr:`metadata` in database""" 119 | with self.sync_engine.begin() as conn: 120 | conn.execute(sa.text(f'truncate {", ".join(self.metadata.tables)}')) 121 | try: 122 | conn.execute(sa.text("drop table alembic_version")) 123 | except Exception: # noqa 124 | pass 125 | 126 | def drop_all_schemas(self) -> None: 127 | """Drop all schema in database""" 128 | with self.sync_engine.begin() as conn: 129 | conn.execute(sa.text("DROP SCHEMA IF EXISTS public CASCADE")) 130 | conn.execute(sa.text("CREATE SCHEMA IF NOT EXISTS public")) 131 | -------------------------------------------------------------------------------- /openapi/db/migrations.py: -------------------------------------------------------------------------------- 1 | """Alembic migrations handler 2 | """ 3 | import os 4 | from io import StringIO 5 | 6 | from alembic import command as alembic_cmd 7 | from alembic.config import Config 8 | 9 | 10 | def get_template_directory(): 11 | return os.path.dirname(os.path.realpath(__file__)) 12 | 13 | 14 | class Migration: 15 | def __init__(self, app): 16 | self.app = app 17 | self.cfg = create_config(app) 18 | 19 | def init(self): 20 | dirname = self.cfg.get_main_option("script_location") 21 | alembic_cmd.init(self.cfg, dirname, template="openapi") 22 | return self.message() 23 | 24 | def show(self, revision): 25 | alembic_cmd.show(self.cfg, revision) 26 | return self.message() 27 | 28 | def history(self): 29 | alembic_cmd.history(self.cfg) 30 | return self.message() 31 | 32 | def revision(self, message, autogenerate=False, branch_label=None): 33 | alembic_cmd.revision( 34 | self.cfg, 35 | autogenerate=autogenerate, 36 | message=message, 37 | branch_label=branch_label, 38 | ) 39 | return self.message() 40 | 41 | def upgrade(self, revision): 42 | alembic_cmd.upgrade(self.cfg, revision) 43 | return self.message() 44 | 45 | def downgrade(self, revision): 46 | alembic_cmd.downgrade(self.cfg, revision) 47 | return self.message() 48 | 49 | def current(self, verbose=False): 50 | alembic_cmd.current(self.cfg, verbose=verbose) 51 | return self.message() 52 | 53 | def message(self): 54 | msg = self.cfg.stdout.getvalue() 55 | self.cfg.stdout.seek(0) 56 | self.cfg.stdout.truncate() 57 | return msg 58 | 59 | 60 | def create_config(app): 61 | """Programmatically create Alembic config""" 62 | cfg = Config(stdout=StringIO()) 63 | cfg.get_template_directory = get_template_directory 64 | migrations = app.get("migrations_dir") or os.path.join(app["cwd"], "migrations") 65 | cfg.set_main_option("script_location", migrations) 66 | cfg.config_file_name = os.path.join(migrations, "alembic.ini") 67 | db = app["db"] 68 | cfg.set_section_option( 69 | "default", 70 | "sqlalchemy.url", 71 | db.sync_engine.url.render_as_string(hide_password=False), 72 | ) 73 | # put database in main options 74 | cfg.set_main_option("databases", "default") 75 | # create empty logging section to avoid raising errors in env.py 76 | cfg.set_section_option("logging", "path", "") 77 | cfg.metadata = dict(default=db.metadata) 78 | return cfg 79 | -------------------------------------------------------------------------------- /openapi/db/openapi/alembic.ini.mako: -------------------------------------------------------------------------------- 1 | # Lux configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = ${script_location} 6 | -------------------------------------------------------------------------------- /openapi/db/openapi/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import with_statement 2 | 3 | import logging 4 | import re 5 | from logging.config import fileConfig 6 | 7 | from alembic import context 8 | from sqlalchemy import engine_from_config, pool 9 | 10 | USE_TWOPHASE = False 11 | 12 | # this is the Alembic Config object, which provides 13 | # access to the values within. 14 | config = context.config 15 | 16 | # Interpret the config file for Python logging. 17 | # This line sets up loggers basically. 18 | cfg_file = config.get_section_option("logging", "path", "") 19 | if len(cfg_file) > 0: 20 | fileConfig(open(cfg_file, "r")) 21 | logger = logging.getLogger("alembic.env") 22 | 23 | # gather section names referring to different 24 | # databases. 25 | db_names = config.get_main_option("databases") 26 | 27 | # your model's MetaData will be try to obtain below 28 | # for 'autogenerate' support. You can remove this code 29 | # and import necessary metadata manually. 30 | # from myapp import mymodel 31 | # target_metadata = { 32 | # 'engine1':mymodel.metadata, 33 | # 'engine2':mymodel.metadata, 34 | # } 35 | target_metadata = {name: meta for name, meta in config.metadata.items()} 36 | 37 | # other values from the config, defined by the needs of env.py, 38 | # can be acquired: 39 | # my_important_option = config.get_main_option("my_important_option") 40 | # ... etc. 41 | 42 | 43 | def run_migrations_offline(): 44 | """Run migrations in 'offline' mode. 45 | 46 | This configures the context with just a URL 47 | and not an Engine, though an Engine is acceptable 48 | here as well. By skipping the Engine creation 49 | we don't even need a DBAPI to be available. 50 | 51 | Calls to context.execute() here emit the given string to the 52 | script output. 53 | 54 | """ 55 | # for the --sql use case, run migrations for each URL into 56 | # individual files. 57 | 58 | engines = {} 59 | for name in re.split(r",\s*", db_names): 60 | engines[name] = rec = {} 61 | rec["url"] = context.config.get_section_option(name, "sqlalchemy.url") 62 | 63 | for name, rec in engines.items(): 64 | logger.info("Migrating database %s" % name) 65 | file_ = "%s.sql" % name 66 | logger.info("Writing output to %s" % file_) 67 | with open(file_, "w") as buffer: 68 | context.configure( 69 | url=rec["url"], 70 | output_buffer=buffer, 71 | target_metadata=target_metadata.get(name), 72 | literal_binds=True, 73 | ) 74 | with context.begin_transaction(): 75 | context.run_migrations(engine_name=name) 76 | 77 | 78 | def run_migrations_online(): 79 | """Run migrations in 'online' mode. 80 | 81 | In this scenario we need to create an Engine 82 | and associate a connection with the context. 83 | 84 | """ 85 | 86 | # for the direct-to-DB use case, start a transaction on all 87 | # engines, then run all migrations, then commit all transactions. 88 | 89 | engines = {} 90 | for name in re.split(r",\s*", db_names): 91 | engines[name] = rec = {} 92 | rec["engine"] = engine_from_config( 93 | context.config.get_section(name), 94 | prefix="sqlalchemy.", 95 | poolclass=pool.NullPool, 96 | ) 97 | 98 | for name, rec in engines.items(): 99 | engine = rec["engine"] 100 | rec["connection"] = conn = engine.connect() 101 | 102 | if USE_TWOPHASE: 103 | rec["transaction"] = conn.begin_twophase() 104 | else: 105 | rec["transaction"] = conn.begin() 106 | 107 | try: 108 | for name, rec in engines.items(): 109 | logger.info("Migrating database %s" % name) 110 | context.configure( 111 | connection=rec["connection"], 112 | upgrade_token="%s_upgrades" % name, 113 | downgrade_token="%s_downgrades" % name, 114 | target_metadata=target_metadata.get(name), 115 | compare_type=True, 116 | ) 117 | context.run_migrations(engine_name=name) 118 | 119 | if USE_TWOPHASE: 120 | for rec in engines.values(): 121 | rec["transaction"].prepare() 122 | 123 | for rec in engines.values(): 124 | rec["transaction"].commit() 125 | except Exception: 126 | for rec in engines.values(): 127 | rec["transaction"].rollback() 128 | raise 129 | finally: 130 | for rec in engines.values(): 131 | rec["connection"].close() 132 | 133 | 134 | if context.is_offline_mode(): 135 | run_migrations_offline() 136 | else: 137 | run_migrations_online() 138 | -------------------------------------------------------------------------------- /openapi/db/openapi/script.py.mako: -------------------------------------------------------------------------------- 1 | <%! 2 | import re 3 | 4 | %>"""${message} 5 | 6 | Revision ID: ${up_revision} 7 | Revises: ${down_revision | comma,n} 8 | Create Date: ${create_date} 9 | 10 | """ 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | from alembic import op 19 | import sqlalchemy as sa 20 | import sqlalchemy_utils 21 | ${imports if imports else ""} 22 | 23 | 24 | def upgrade(engine_name): 25 | globals()["upgrade_%s" % engine_name]() 26 | 27 | 28 | def downgrade(engine_name): 29 | globals()["downgrade_%s" % engine_name]() 30 | 31 | <% 32 | db_names = config.get_main_option("databases") 33 | %> 34 | 35 | ## generate an "upgrade_() / downgrade_()" function 36 | ## for each database name in the ini file. 37 | 38 | % for db_name in re.split(r',\s*', db_names): 39 | 40 | def upgrade_${db_name}(): 41 | ${context.get("%s_upgrades" % db_name, "pass")} 42 | 43 | 44 | def downgrade_${db_name}(): 45 | ${context.get("%s_downgrades" % db_name, "pass")} 46 | 47 | % endfor 48 | -------------------------------------------------------------------------------- /openapi/exc.py: -------------------------------------------------------------------------------- 1 | from aiohttp.web import HTTPException 2 | 3 | from .json import dumps 4 | 5 | 6 | class OpenApiError(RuntimeError): 7 | pass 8 | 9 | 10 | class ImproperlyConfigured(OpenApiError): 11 | pass 12 | 13 | 14 | class InvalidSpecException(OpenApiError): 15 | pass 16 | 17 | 18 | class InvalidTypeException(TypeError): 19 | pass 20 | 21 | 22 | class JsonHttpException(HTTPException): 23 | def __init__(self, status=None, **kw): 24 | self.status_code = status or 500 25 | kw["content_type"] = "application/json" 26 | super().__init__(**kw) 27 | reason = self.reason 28 | if isinstance(reason, str): 29 | reason = {"message": reason} 30 | self.text = dumps(reason) 31 | -------------------------------------------------------------------------------- /openapi/json.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import Enum 3 | from functools import partial 4 | from uuid import UUID 5 | 6 | import simplejson 7 | from simplejson.errors import JSONDecodeError 8 | 9 | 10 | def encoder(obj): 11 | if isinstance(obj, UUID): 12 | return obj.hex 13 | if isinstance(obj, datetime): 14 | return obj.isoformat() 15 | if isinstance(obj, Enum): 16 | return obj.name 17 | raise TypeError 18 | 19 | 20 | loads = partial(simplejson.loads, use_decimal=True) 21 | dumps = partial( 22 | simplejson.dumps, use_decimal=True, default=encoder, iterable_as_array=True 23 | ) 24 | 25 | __all__ = ["loads", "dumps", "JSONDecodeError"] 26 | -------------------------------------------------------------------------------- /openapi/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import click 5 | 6 | try: 7 | import colorlog 8 | except ImportError: # pragma: no cover 9 | colorlog = None 10 | 11 | 12 | LEVEL = (os.environ.get("LOG_LEVEL") or "info").upper() 13 | LOGGER_NAME = os.environ.get("APP_NAME") or "" 14 | LOG_FORMAT = "%(levelname)s: %(name)s: %(message)s" 15 | 16 | logger = logging.getLogger(LOGGER_NAME) 17 | 18 | 19 | def get_logger(name: str = "") -> logging.Logger: 20 | return logger.getChild(name) if name else logger 21 | 22 | 23 | getLogger = get_logger 24 | 25 | 26 | @click.pass_context 27 | def setup_logging(ctx, verbose, quiet): 28 | if verbose: 29 | level = "DEBUG" 30 | elif quiet: 31 | level = "ERROR" 32 | else: 33 | level = LEVEL 34 | level = getattr(logging, level) if level != "NONE" else None 35 | ctx.obj["log_level"] = level 36 | if level: 37 | logger.setLevel(level) 38 | if not logger.hasHandlers(): 39 | fmt = LOG_FORMAT 40 | if colorlog: 41 | handler = colorlog.StreamHandler() 42 | fmt = colorlog.ColoredFormatter(f"%(log_color)s{LOG_FORMAT}") 43 | else: # pragma: no cover 44 | handler = logging.StreamHandler() 45 | fmt = logging.Formatter(LOG_FORMAT) 46 | handler.setFormatter(fmt) 47 | logger.addHandler(handler) 48 | -------------------------------------------------------------------------------- /openapi/middleware.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from aiohttp import web 4 | 5 | from .exc import ImproperlyConfigured 6 | 7 | try: 8 | from . import sentry 9 | except ImportError: # pragma: no cover 10 | sentry = None 11 | 12 | ERROR_500 = os.environ.get("ERROR_500_MESSSAGE", "Internal Server Error") 13 | 14 | 15 | def sentry_middleware(app, dsn, env="dev"): 16 | if not sentry: # pragma: no cover 17 | raise ImproperlyConfigured("Sentry middleware requires sentry-sdk") 18 | sentry.setup(app, dsn, env) 19 | 20 | 21 | def json_error(status_codes=None): 22 | status_codes = set(status_codes or (404, 405, 500)) 23 | content_type = "application/json" 24 | 25 | @web.middleware 26 | async def json_middleware(request, handler): 27 | try: 28 | response = await handler(request) 29 | if response.status not in status_codes: 30 | return response 31 | message = response.message 32 | status = response.status 33 | except web.HTTPException as ex: 34 | if ex.status not in status_codes or ex.content_type == content_type: 35 | raise 36 | message = ex.reason 37 | status = ex.status 38 | if isinstance(message, str): 39 | message = {"error": message} 40 | except Exception: 41 | if 500 in status_codes: 42 | status = 500 43 | message = {"error": ERROR_500} 44 | request.app.logger.exception(ERROR_500) 45 | else: 46 | raise 47 | return web.json_response(message, status=status) 48 | 49 | return json_middleware 50 | 51 | 52 | # backward compatibility 53 | json404 = json_error((404,)) 54 | -------------------------------------------------------------------------------- /openapi/pagination/__init__.py: -------------------------------------------------------------------------------- 1 | from .create import create_dataclass 2 | from .cursor import cursorPagination 3 | from .offset import offsetPagination 4 | from .pagination import PaginatedData, Pagination, PaginationVisitor, fields_flip_sign 5 | from .search import Search, SearchVisitor, searchable 6 | 7 | __all__ = [ 8 | "Pagination", 9 | "PaginatedData", 10 | "PaginationVisitor", 11 | "cursorPagination", 12 | "offsetPagination", 13 | "create_dataclass", 14 | "fields_flip_sign", 15 | "Search", 16 | "searchable", 17 | "SearchVisitor", 18 | ] 19 | -------------------------------------------------------------------------------- /openapi/pagination/create.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type, TypeVar 2 | 3 | from ..utils import TypingInfo 4 | from .pagination import Pagination 5 | from .search import Search 6 | 7 | T = TypeVar("T") 8 | 9 | CREATE_MAP = {Pagination: "create_pagination", Search: "create_search"} 10 | 11 | 12 | def create_dataclass( 13 | type_info: Optional[TypingInfo], data: dict, DataClass: Type[T] 14 | ) -> T: 15 | if type_info is None: 16 | return DataClass() 17 | if type_info.is_dataclass and issubclass(type_info.element, DataClass): 18 | method_name = CREATE_MAP.get(DataClass) 19 | if method_name: 20 | return getattr(type_info.element, method_name)(data) 21 | return DataClass() 22 | -------------------------------------------------------------------------------- /openapi/pagination/cursor.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from dataclasses import dataclass 3 | from datetime import date, datetime 4 | from functools import cached_property 5 | from typing import Any, Dict, Optional, Tuple, Type 6 | 7 | from dateutil.parser import parse as parse_date 8 | from yarl import URL 9 | 10 | from openapi import json 11 | from openapi.data.fields import Choice, integer_field, str_field 12 | from openapi.data.validate import ValidationErrors 13 | 14 | from .pagination import ( 15 | DEF_PAGINATION_LIMIT, 16 | MAX_PAGINATION_LIMIT, 17 | Pagination, 18 | PaginationVisitor, 19 | fields_flip_sign, 20 | fields_no_sign, 21 | from_filters_and_dataclass, 22 | ) 23 | 24 | CursorType = Tuple[Tuple[str, str], ...] 25 | 26 | 27 | def encode_cursor(data: Tuple[str, ...], previous: bool = False) -> str: 28 | cursor_bytes = json.dumps((data, previous)).encode("ascii") 29 | base64_bytes = base64.b64encode(cursor_bytes) 30 | return base64_bytes.decode("ascii") 31 | 32 | 33 | def decode_cursor( 34 | cursor: Optional[str], field_names: Tuple[str] 35 | ) -> Tuple[CursorType, bool]: 36 | try: 37 | if cursor: 38 | base64_bytes = cursor.encode("ascii") 39 | cursor_bytes = base64.b64decode(base64_bytes) 40 | values, previous = json.loads(cursor_bytes) 41 | if len(values) == len(field_names): 42 | return tuple(zip(field_names, values)), previous 43 | raise ValueError 44 | return (), False 45 | except Exception as e: 46 | raise ValidationErrors("invalid cursor") from e 47 | 48 | 49 | def cursor_url(url: URL, cursor: str) -> URL: 50 | query = url.query.copy() 51 | query.update(_cursor=cursor) 52 | return url.with_query(query) 53 | 54 | 55 | def start_values(record: dict, field_names: Tuple[str, ...]) -> Tuple[str, ...]: 56 | """start values for pagination""" 57 | return tuple(record[field] for field in field_names) 58 | 59 | 60 | def cursorPagination( 61 | *order_by_fields: str, 62 | default_limit: int = DEF_PAGINATION_LIMIT, 63 | max_limit: int = MAX_PAGINATION_LIMIT, 64 | ) -> Type[Pagination]: 65 | if len(order_by_fields) == 0: 66 | raise ValueError("orderable_fields must be specified") 67 | 68 | field_names = fields_no_sign(order_by_fields) 69 | 70 | @dataclass 71 | class CursorPagination(Pagination): 72 | limit: int = integer_field( 73 | min_value=1, 74 | max_value=max_limit, 75 | default=default_limit, 76 | required=False, 77 | description="Limit the number of objects returned from the endpoint", 78 | ) 79 | direction: str = str_field( 80 | validator=Choice(("asc", "desc")), 81 | required=False, 82 | default="asc", 83 | description=( 84 | f"Sort results via `{', '.join(order_by_fields)}` " 85 | "in descending or ascending order" 86 | ), 87 | ) 88 | _cursor: str = str_field(default="", hidden=True) 89 | 90 | @cached_property 91 | def cursor_info(self) -> Tuple[CursorType, Tuple[str, ...], bool]: 92 | order_by = ( 93 | fields_flip_sign(order_by_fields) 94 | if self.direction == "desc" 95 | else order_by_fields 96 | ) 97 | cursor, previous = decode_cursor(self._cursor, order_by) 98 | return cursor, order_by, previous 99 | 100 | @property 101 | def previous(self) -> bool: 102 | return self.cursor_info[2] 103 | 104 | def apply(self, visitor: PaginationVisitor) -> None: 105 | cursor, order_by, previous = self.cursor_info 106 | visitor.apply_cursor_pagination( 107 | cursor, 108 | self.limit, 109 | order_by, 110 | previous=previous, 111 | ) 112 | 113 | @classmethod 114 | def create_pagination(cls, data: dict) -> "CursorPagination": 115 | return from_filters_and_dataclass(CursorPagination, data) 116 | 117 | def links( 118 | self, url: URL, data: list, total: Optional[int] = None 119 | ) -> Dict[str, str]: 120 | links = {} 121 | if self.previous: 122 | if len(data) > self.limit + 1: 123 | links["prev"] = cursor_url( 124 | url, 125 | encode_cursor( 126 | start_values(data[self.limit], field_names), previous=True 127 | ), 128 | ) 129 | if self._cursor: 130 | links["next"] = cursor_url( 131 | url, 132 | encode_cursor( 133 | start_values(data[0], field_names), 134 | ), 135 | ) 136 | else: 137 | if len(data) > self.limit: 138 | links["next"] = cursor_url( 139 | url, 140 | encode_cursor(start_values(data[self.limit], field_names)), 141 | ) 142 | if self._cursor: 143 | links["prev"] = cursor_url( 144 | url, 145 | encode_cursor( 146 | start_values(data[0], field_names), 147 | previous=True, 148 | ), 149 | ) 150 | return links 151 | 152 | def get_data(self, data: list) -> list: 153 | if self.previous: 154 | data = list(reversed(data[1:])) 155 | return data if len(data) <= self.limit else data[1:] 156 | return data if len(data) <= self.limit else data[: self.limit] 157 | 158 | return CursorPagination 159 | 160 | 161 | def cursor_to_python(py_type: Type, value: Any) -> Any: 162 | try: 163 | if py_type is datetime: 164 | return parse_date(value) 165 | elif py_type is date: 166 | return parse_date(value).date() 167 | elif py_type is int: 168 | return int(value) 169 | else: 170 | return value 171 | except Exception as e: 172 | raise ValidationErrors("invalid cursor") from e 173 | -------------------------------------------------------------------------------- /openapi/pagination/offset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, NamedTuple, Optional, Type 3 | 4 | from multidict import MultiDict 5 | from yarl import URL 6 | 7 | from openapi.data.fields import Choice, integer_field, str_field 8 | from openapi.utils import docjoin 9 | 10 | from .pagination import ( 11 | DEF_PAGINATION_LIMIT, 12 | MAX_PAGINATION_LIMIT, 13 | Pagination, 14 | PaginationVisitor, 15 | from_filters_and_dataclass, 16 | ) 17 | 18 | 19 | def offsetPagination( 20 | *order_by_fields: str, 21 | default_limit: int = DEF_PAGINATION_LIMIT, 22 | max_limit: int = MAX_PAGINATION_LIMIT, 23 | ) -> Type[Pagination]: 24 | """Crate a limit/offset :class:`.Pagination` dataclass""" 25 | if len(order_by_fields) == 0: 26 | raise ValueError("orderable_fields must be specified") 27 | 28 | @dataclass 29 | class OffsetPagination(Pagination): 30 | limit: int = integer_field( 31 | min_value=1, 32 | max_value=max_limit, 33 | default=default_limit, 34 | description="Limit the number of objects returned from the endpoint", 35 | ) 36 | offset: int = integer_field( 37 | min_value=0, 38 | default=0, 39 | description=( 40 | "Number of objects to exclude. " 41 | "Use in conjunction with limit to paginate results" 42 | ), 43 | ) 44 | order_by: str = str_field( 45 | validator=Choice(order_by_fields), 46 | default=order_by_fields[0], 47 | description=( 48 | "Order results by given column (default ascending order). " 49 | f"Possible values are {docjoin(order_by_fields)}" 50 | ), 51 | ) 52 | 53 | def apply(self, visitor: PaginationVisitor) -> None: 54 | visitor.apply_offset_pagination( 55 | limit=self.limit, offset=self.offset, order_by=self.order_by 56 | ) 57 | 58 | @classmethod 59 | def create_pagination(cls, data: dict) -> "OffsetPagination": 60 | return from_filters_and_dataclass(OffsetPagination, data) 61 | 62 | def links( 63 | self, url: URL, data: list, total: Optional[int] = None 64 | ) -> Dict[str, str]: 65 | """Return links for paginated data""" 66 | return Links(url=url, query=MultiDict(url.query)).links( 67 | total, self.limit, self.offset 68 | ) 69 | 70 | return OffsetPagination 71 | 72 | 73 | class Links(NamedTuple): 74 | url: URL 75 | query: MultiDict 76 | 77 | def first_link(self, total, limit, offset): 78 | n = self._count_part(offset, limit, 0) 79 | if n: 80 | offset -= n * limit 81 | if offset > 0: 82 | return self.link(0, min(limit, offset)) 83 | 84 | def prev_link(self, total, limit, offset): 85 | if offset: 86 | olimit = min(limit, offset) 87 | prev_offset = offset - olimit 88 | return self.link(prev_offset, olimit) 89 | 90 | def next_link(self, total, limit, offset): 91 | next_offset = offset + limit 92 | if total > next_offset: 93 | return self.link(next_offset, limit) 94 | 95 | def last_link(self, total, limit, offset): 96 | n = self._count_part(total, limit, offset) 97 | if n > 0: 98 | return self.link(offset + n * limit, limit) 99 | 100 | def link(self, offset, limit): 101 | query = self.query.copy() 102 | query.update({"offset": offset, "limit": limit}) 103 | return self.url.with_query(query) 104 | 105 | def _count_part(self, total, limit, offset): 106 | n = (total - offset) // limit 107 | # make sure we account for perfect matching 108 | if n * limit + offset == total: 109 | n -= 1 110 | return max(0, n) 111 | 112 | def links(self, total, limit, offset): 113 | links = {} 114 | first = self.first_link(total, limit, offset) 115 | if first: 116 | links["first"] = first 117 | links["prev"] = self.prev_link(total, limit, offset) 118 | next_ = self.next_link(total, limit, offset) 119 | if next_: 120 | links["next"] = next_ 121 | links["last"] = self.last_link(total, limit, offset) 122 | return links 123 | -------------------------------------------------------------------------------- /openapi/pagination/pagination.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, fields 3 | from typing import ( 4 | Dict, 5 | List, 6 | NamedTuple, 7 | Optional, 8 | Sequence, 9 | Tuple, 10 | Type, 11 | TypeVar, 12 | Union, 13 | ) 14 | 15 | from aiohttp import web 16 | from yarl import URL 17 | 18 | from openapi.json import dumps 19 | 20 | MAX_PAGINATION_LIMIT: int = int(os.environ.get("MAX_PAGINATION_LIMIT") or 100) 21 | DEF_PAGINATION_LIMIT: int = int(os.environ.get("DEF_PAGINATION_LIMIT") or 50) 22 | 23 | 24 | class PaginationVisitor: 25 | """Visitor for pagination""" 26 | 27 | def apply_offset_pagination( 28 | self, limit: int, offset: int, order_by: Union[str, List[str]] 29 | ): 30 | """Apply limit/offset pagination""" 31 | raise NotImplementedError 32 | 33 | def apply_cursor_pagination( 34 | self, 35 | cursor: Sequence[Tuple[str, str]], 36 | limit: int, 37 | order_by: Sequence[str], 38 | previous: bool, 39 | ): 40 | """Apply cursor pagination""" 41 | raise NotImplementedError 42 | 43 | 44 | T = TypeVar("T") 45 | 46 | 47 | def from_filters_and_dataclass(data_class: Type[T], data: dict) -> T: 48 | params = {} 49 | for field in fields(data_class): 50 | if field.name in data: 51 | params[field.name] = data.pop(field.name) 52 | return data_class(**params) 53 | 54 | 55 | def fields_no_sign(fields: Sequence[str]) -> Tuple[str, ...]: 56 | return tuple(field[1:] if field.startswith("-") else field for field in fields) 57 | 58 | 59 | def fields_flip_sign(fields: Sequence[str]) -> Tuple[str, ...]: 60 | return tuple(flip_field_sign(field) for field in fields) 61 | 62 | 63 | def flip_field_sign(field: str) -> str: 64 | return field[1:] if field.startswith("-") else f"-{field}" 65 | 66 | 67 | @dataclass 68 | class Pagination: 69 | """Base class for Pagination""" 70 | 71 | @classmethod 72 | def create_pagination(cls, data: dict) -> "Pagination": 73 | return cls() 74 | 75 | def apply(self, visitor: PaginationVisitor) -> None: 76 | """Apply pagination to the visitor""" 77 | pass 78 | 79 | def paginated( 80 | self, url: URL, data: list, total: Optional[int] = None 81 | ) -> "PaginatedData": 82 | """Return paginated data""" 83 | return PaginatedData(url=url, data=data, pagination=self, total=total) 84 | 85 | def links( 86 | self, url: URL, data: list, total: Optional[int] = None 87 | ) -> Dict[str, str]: 88 | """Return links for paginated data""" 89 | return {} 90 | 91 | def get_data(self, data: list) -> list: 92 | return data 93 | 94 | 95 | class PaginatedData(NamedTuple): 96 | """Named tuple containing paginated data and methods for retrieving 97 | links to previous or next data in the pagination 98 | """ 99 | 100 | url: URL 101 | """Base url""" 102 | data: list 103 | """Paginated list of data""" 104 | pagination: Pagination 105 | """Pagination dataclass which created the data""" 106 | total: Optional[int] = None 107 | """Total number of records (supported by limit/offset pagination only)""" 108 | 109 | def json_response(self, headers: Optional[Dict[str, str]] = None, **kwargs): 110 | """Create a JSON response with link header""" 111 | headers = headers or {} 112 | links = self.header_links() 113 | if links: 114 | headers["Link"] = links 115 | if self.total is not None: 116 | headers["X-Total-Count"] = str(self.total) 117 | kwargs.setdefault("dumps", dumps) 118 | return web.json_response( 119 | self.pagination.get_data(self.data), headers=headers, **kwargs 120 | ) 121 | 122 | def header_links(self) -> str: 123 | """Header links""" 124 | links = self.pagination.links(self.url, self.data, self.total) 125 | return ", ".join(f'<{value}>; rel="{name}"' for name, value in links.items()) 126 | -------------------------------------------------------------------------------- /openapi/pagination/search.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Sequence 3 | 4 | from openapi.data.fields import str_field 5 | from openapi.utils import docjoin 6 | 7 | from .pagination import from_filters_and_dataclass 8 | 9 | 10 | class SearchVisitor: 11 | def apply_search(self, search: str, search_fields: Sequence[str]) -> None: 12 | raise NotImplementedError 13 | 14 | 15 | @dataclass 16 | class Search: 17 | @classmethod 18 | def create_search(cls, data: dict) -> "Search": 19 | return cls() 20 | 21 | def apply(self, visitor: SearchVisitor) -> None: 22 | pass 23 | 24 | 25 | def searchable(*searchable_fields) -> type: 26 | """Create a dataclass with `search_fields` class attribute and `search` field. 27 | The search field is a set of field which can be used for searching and it is used 28 | internally by the library, while the `search` field is the query string passed 29 | in the url. 30 | 31 | :param searchable_fields: fields which can be used for searching 32 | """ 33 | fields = docjoin(searchable_fields) 34 | 35 | @dataclass 36 | class Searchable(Search): 37 | search_fields = frozenset(searchable_fields) 38 | search: str = str_field( 39 | description=( 40 | "Search query string. " f"The search is performed on {fields} fields." 41 | ) 42 | ) 43 | 44 | @classmethod 45 | def create_search(cls, data: dict) -> "Searchable": 46 | return from_filters_and_dataclass(Searchable, data) 47 | 48 | def apply(self, visitor: SearchVisitor) -> None: 49 | visitor.apply_search(self.search, self.search_fields) 50 | 51 | return Searchable 52 | -------------------------------------------------------------------------------- /openapi/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantmind/aio-openapi/afe56f7b36cadf32643569b8ffce63da29802801/openapi/py.typed -------------------------------------------------------------------------------- /openapi/rest.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Sequence 2 | 3 | from aiohttp.web import Application 4 | 5 | from .cli import OpenApiClient 6 | from .spec import OpenApi, OpenApiSpec, Redoc 7 | 8 | 9 | def rest( 10 | openapi: Optional[Dict] = None, 11 | setup_app: Callable[[Application], None] = None, 12 | base_path: str = "", 13 | commands: Optional[List] = None, 14 | allowed_tags: Sequence[str] = (), 15 | validate_docs: bool = False, 16 | servers: Optional[List[str]] = None, 17 | security: Optional[Dict[str, Dict]] = None, 18 | redoc: Optional[Redoc] = None, 19 | OpenApiSpecClass: type = OpenApiSpec, 20 | **kwargs, 21 | ) -> OpenApiClient: 22 | """Create the OpenApi Client""" 23 | if openapi is not None: 24 | openapi = OpenApiSpecClass( 25 | OpenApi(**(openapi or {})), 26 | allowed_tags=allowed_tags, 27 | validate_docs=validate_docs, 28 | servers=servers, 29 | security=security, 30 | redoc=redoc, 31 | ) 32 | return OpenApiClient( 33 | spec=openapi, 34 | base_path=base_path, 35 | commands=commands, 36 | setup_app=setup_app, 37 | **kwargs, 38 | ) 39 | -------------------------------------------------------------------------------- /openapi/sentry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import sentry_sdk 4 | from sentry_sdk.integrations.aiohttp import AioHttpIntegration 5 | from sentry_sdk.integrations.logging import LoggingIntegration 6 | 7 | 8 | def setup(app, dsn, env="dev", level=logging.ERROR, event_level=logging.ERROR): 9 | sentry_sdk.init( 10 | dsn=dsn, 11 | environment=env, 12 | integrations=[ 13 | LoggingIntegration( 14 | level=level, # Capture level and above as breadcrumbs 15 | event_level=event_level, # Send event_level and above as events 16 | ), 17 | AioHttpIntegration(), 18 | ], 19 | ) 20 | -------------------------------------------------------------------------------- /openapi/spec/__init__.py: -------------------------------------------------------------------------------- 1 | from .operation import op 2 | from .path import ApiPath 3 | from .redoc import Redoc 4 | from .spec import OpenApi, OpenApiInfo, OpenApiSpec, SchemaParser, SpecDoc 5 | 6 | __all__ = [ 7 | "op", 8 | "ApiPath", 9 | "OpenApi", 10 | "OpenApiInfo", 11 | "OpenApiSpec", 12 | "SchemaParser", 13 | "SpecDoc", 14 | "Redoc", 15 | ] 16 | -------------------------------------------------------------------------------- /openapi/spec/hdrs.py: -------------------------------------------------------------------------------- 1 | from multidict import istr 2 | 3 | X_FORWARDED_PROTO = istr("X-Forwarded-Proto") 4 | X_FORWARDED_HOST = istr("X-Forwarded-Host") 5 | X_FORWARDED_PORT = istr("X-Forwarded-Port") 6 | -------------------------------------------------------------------------------- /openapi/spec/operation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import wraps 3 | from typing import Any, Callable 4 | 5 | from ..data.view import DataView, Operation 6 | from ..utils import TypingInfo 7 | 8 | 9 | @dataclass 10 | class op: 11 | """Decorator for a :class:`.ApiPath` view which specifies an operation object 12 | in an OpenAPI Path. Parameters are dataclasses used for validation and 13 | OpenAPI auto documentation. 14 | """ 15 | 16 | body_schema: Any = None 17 | query_schema: Any = None 18 | response_schema: Any = None 19 | response: int = 200 20 | # responses: List[Any] = [] 21 | 22 | def __call__(self, method) -> Callable: 23 | method.op = Operation( 24 | body_schema=TypingInfo.get(self.body_schema), 25 | query_schema=TypingInfo.get(self.query_schema), 26 | response_schema=TypingInfo.get(self.response_schema), 27 | response=self.response, 28 | ) 29 | 30 | @wraps(method) 31 | async def _(view: DataView) -> Any: 32 | view.operation = method.op 33 | return await method(view) 34 | 35 | return _ 36 | -------------------------------------------------------------------------------- /openapi/spec/path.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from aiohttp import web 4 | from multidict import MultiDict 5 | from yarl import URL 6 | 7 | from openapi.json import dumps, loads 8 | 9 | from ..data.validate import ValidationErrors 10 | from ..data.view import BAD_DATA_MESSAGE, DataView, ErrorType 11 | from ..types import DataType, QueryType, SchemaTypeOrStr 12 | from ..utils import compact 13 | from . import hdrs 14 | 15 | 16 | class ApiPath(web.View, DataView): 17 | """A :class:`.DataView` class for OpenAPI path""" 18 | 19 | path_schema: Optional[type] = None 20 | """Optional dataclass for validating path variables""" 21 | private: bool = False 22 | 23 | # UTILITIES 24 | 25 | def insert_data( 26 | self, 27 | data: DataType, 28 | *, 29 | multiple: bool = False, 30 | strict: bool = True, 31 | body_schema: SchemaTypeOrStr = "body_schema", 32 | ) -> Dict[str, Any]: 33 | """Validate data for insertion 34 | 35 | if a :attr:`.path_schema` is given, it validate the request `match_info` 36 | against it and add it to the validated data. 37 | 38 | :param data: object to be validated against the body_schema, usually obtained 39 | from the request body (JSON) 40 | :param multiple: multiple values for a given key are acceptable 41 | :param strict: all required attributes in schema must be available 42 | :param body_schema: the schema to validate against 43 | """ 44 | data = self.cleaned(body_schema, data, multiple=multiple, strict=strict) 45 | if self.path_schema: 46 | path = self.cleaned("path_schema", self.request.match_info) 47 | data.update(path) 48 | return data 49 | 50 | def get_filters( 51 | self, 52 | *, 53 | query: Optional[QueryType] = None, 54 | query_schema: SchemaTypeOrStr = "query_schema", 55 | ) -> Dict[str, Any]: 56 | """Collect a dictionary of filters from the request query string. 57 | If :attr:`path_schema` is defined, it collects filter data 58 | from path variables as well. 59 | 60 | :param query: optional query dictionary (will be overwritten 61 | by the request.query) 62 | :param query_schema: a dataclass or an the name of an attribute in Operation 63 | for collecting query filters 64 | """ 65 | combined = MultiDict(query or ()) 66 | combined.update(self.request.query) 67 | try: 68 | params = self.cleaned(query_schema, combined, multiple=True) 69 | except web.HTTPNotImplemented: 70 | params = {} 71 | if self.path_schema: 72 | path = self.cleaned("path_schema", self.request.match_info) 73 | params.update(path) 74 | return params 75 | 76 | async def json_data(self) -> DataType: 77 | """Load JSON data from the request. 78 | 79 | :raise HTTPBadRequest: when body data is not valid JSON 80 | """ 81 | try: 82 | return await self.request.json(loads=loads) 83 | except Exception: 84 | self.raise_bad_data() 85 | 86 | def validation_error( 87 | self, message: str = "", errors: Optional[ErrorType] = None 88 | ) -> Exception: 89 | """Create an :class:`aiohttp.web.HTTPUnprocessableEntity`""" 90 | raw = self.as_errors(message, errors) 91 | data = self.dump(ValidationErrors, raw) 92 | return web.HTTPUnprocessableEntity(**self.api_response_data(data)) 93 | 94 | def raise_bad_data( 95 | self, exc: Optional[Exception] = None, message: str = "" 96 | ) -> None: 97 | raw = compact(message=message or BAD_DATA_MESSAGE) 98 | data = self.dump(ValidationErrors, raw) 99 | raise web.HTTPBadRequest(**self.api_response_data(data)) 100 | 101 | def full_url(self) -> URL: 102 | return full_url(self.request) 103 | 104 | @classmethod 105 | def api_response_data(cls, data: DataType) -> Dict[str, Any]: 106 | return dict(text=dumps(data), content_type="application/json") 107 | 108 | @classmethod 109 | def json_response(cls, data, **kwargs): 110 | kwargs.setdefault("dumps", dumps) 111 | return web.json_response(data, **kwargs) 112 | 113 | 114 | def full_url(request) -> URL: 115 | headers = request.headers 116 | proto = headers.get(hdrs.X_FORWARDED_PROTO) 117 | host = headers.get(hdrs.X_FORWARDED_HOST) 118 | port = headers.get(hdrs.X_FORWARDED_PORT) 119 | if proto and host: 120 | url = URL.build(scheme=proto, host=host) 121 | if port: 122 | port = int(port) 123 | if url.port != port: 124 | url = url.with_port(port) 125 | return url.join(request.rel_url) 126 | else: 127 | return request.url 128 | -------------------------------------------------------------------------------- /openapi/spec/redoc.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from aiohttp import web 4 | 5 | 6 | @dataclass 7 | class Redoc: 8 | """A dataclass for redoc rendering""" 9 | 10 | path: str = "/docs" 11 | favicon_url: str = ( 12 | "https://raw.githubusercontent.com/Redocly/redoc/master/demo/favicon.png" 13 | ) 14 | redoc_js_url: str = ( 15 | "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js" 16 | ) 17 | font: str = "family=Montserrat:300,400,700|Roboto:300,400,700" 18 | 19 | async def handle_doc(self, request: web.Request) -> web.Response: 20 | """Render a webpage with redoc and the spec form the app""" 21 | spec = request.app["spec"] 22 | spec_url = request.app.router["openapi_spec"].url_for() 23 | title = spec.info.title 24 | html = f""" 25 | 26 | 27 | 28 | {title} 29 | 30 | 31 | 32 | """ 33 | if self.font: 34 | html += f""" 35 | 36 | """ 37 | html += f""" 38 | 39 | 42 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | """ 55 | return web.Response(text=html, content_type="text/html") 56 | -------------------------------------------------------------------------------- /openapi/spec/server.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from aiohttp import web 4 | 5 | from .path import full_url 6 | 7 | 8 | def default_server(request: web.Request) -> Dict[str, str]: 9 | app = request.app 10 | url = full_url(request) 11 | url = url.with_path(app["cli"].base_path) 12 | return dict(url=str(url), description="Api server") 13 | 14 | 15 | def server_urls(request: web.Request, paths: List[str]) -> List[str]: 16 | base_path = request.app["cli"].base_path 17 | n = len(base_path) 18 | spec = request.app.get("spec") 19 | server = spec.servers[0] if spec and spec.servers else default_server(request) 20 | base_url = server["url"] 21 | return [f"{base_url}{p[n:]}" for p in paths] 22 | -------------------------------------------------------------------------------- /openapi/spec/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, Optional 3 | 4 | import yaml 5 | 6 | from ..exc import InvalidSpecException 7 | 8 | 9 | # from django.contrib.admindocs.utils 10 | def trim_docstring(docstring: str) -> str: 11 | """Uniformly trims leading/trailing whitespace from docstrings. 12 | Based on 13 | http://www.python.org/peps/pep-0257.html#handling-docstring-indentation 14 | """ 15 | if not docstring or not docstring.strip(): 16 | return "" 17 | # Convert tabs to spaces and split into lines 18 | lines = docstring.expandtabs().splitlines() 19 | indent = min(len(line) - len(line.lstrip()) for line in lines if line.lstrip()) 20 | trimmed = [lines[0].lstrip()] + [line[indent:].rstrip() for line in lines[1:]] 21 | return "\n".join(trimmed).strip() 22 | 23 | 24 | # from rest_framework.utils.formatting 25 | def dedent(content: str) -> str: 26 | """ 27 | Remove leading indent from a block of text. 28 | Used when generating descriptions from docstrings. 29 | Note that python's `textwrap.dedent` doesn't quite cut it, 30 | as it fails to dedent multiline docstrings that include 31 | unindented text on the initial line. 32 | """ 33 | whitespace_counts = [ 34 | len(line) - len(line.lstrip(" ")) 35 | for line in content.splitlines()[1:] 36 | if line.lstrip() 37 | ] 38 | 39 | # unindent the content if needed 40 | if whitespace_counts: 41 | whitespace_pattern = "^" + (" " * min(whitespace_counts)) 42 | content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content) 43 | 44 | return content.strip() 45 | 46 | 47 | def load_yaml_from_docstring(docstring: str) -> Optional[Dict]: 48 | """Loads YAML from docstring.""" 49 | split_lines = trim_docstring(docstring).split("\n") 50 | 51 | # Cut YAML from rest of docstring 52 | for index, line in enumerate(split_lines): 53 | line = line.strip() 54 | if line.startswith("---"): 55 | cut_from = index 56 | break 57 | else: 58 | return None 59 | 60 | yaml_string = "\n".join(split_lines[cut_from:]) 61 | yaml_string = dedent(yaml_string) 62 | try: 63 | return yaml.load(yaml_string, Loader=yaml.FullLoader) 64 | except Exception as e: 65 | raise InvalidSpecException("Invalid yaml %s" % e) from None 66 | -------------------------------------------------------------------------------- /openapi/testing.py: -------------------------------------------------------------------------------- 1 | """Testing utilities 2 | """ 3 | import asyncio 4 | from contextlib import asynccontextmanager, contextmanager 5 | from typing import Any 6 | 7 | from aiohttp.client import ClientResponse 8 | from aiohttp.test_utils import TestClient, TestServer 9 | from aiohttp.web import Application 10 | 11 | from .db import CrudDB, Database 12 | from .json import dumps, loads 13 | from .types import Connection 14 | 15 | 16 | async def json_body(response: ClientResponse, status: int = 200) -> Any: 17 | assert response.content_type == "application/json" 18 | data = await response.json(loads=loads) 19 | if response.status != status: # pragma: no cover 20 | print(dumps({"status": response.status, "data": data}, indent=4)) 21 | 22 | assert response.status == status 23 | return data 24 | 25 | 26 | @contextmanager 27 | def with_test_db(db: CrudDB) -> CrudDB: 28 | db.create_all() 29 | try: 30 | yield db 31 | finally: 32 | db.drop_all_schemas() 33 | 34 | 35 | class SingleConnDatabase(CrudDB): # noqa 36 | """Useful for speedup testing""" 37 | 38 | def __init__(self, *args, **kwargs) -> None: 39 | super().__init__(*args, **kwargs) 40 | self._lock = asyncio.Lock() 41 | self._connection = None 42 | 43 | @classmethod 44 | def from_db(cls, db: Database) -> "SingleConnDatabase": 45 | return cls(dsn=db.dsn, metadata=db.metadata) 46 | 47 | async def __aenter__(self) -> "SingleConnDatabase": 48 | self._connection = await self.engine.begin() 49 | return self 50 | 51 | async def __aexit__(self, exc_type, exc, tb): 52 | transaction = self._connection.get_transaction() 53 | await transaction.rollback() 54 | self._connection = None 55 | 56 | @asynccontextmanager 57 | async def connection(self) -> Connection: 58 | async with self._lock: 59 | yield self._connection 60 | 61 | @asynccontextmanager 62 | async def transaction(self) -> Connection: 63 | async with self._lock: 64 | yield self._connection 65 | 66 | 67 | @asynccontextmanager 68 | async def app_cli(app: Application) -> TestClient: 69 | server = TestServer(app) 70 | client = TestClient(server, json_serialize=dumps) 71 | await client.start_server() 72 | try: 73 | yield client 74 | finally: 75 | await client.close() 76 | -------------------------------------------------------------------------------- /openapi/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Type, Union 2 | 3 | from multidict import MultiDict 4 | from sqlalchemy.engine import CursorResult, Row 5 | from sqlalchemy.ext.asyncio import AsyncConnection 6 | 7 | PrimitiveType = Union[int, float, bool, str] 8 | JSONType = Union[PrimitiveType, List, Dict[str, Any]] 9 | DataType = Any 10 | 11 | SchemaType = Union[List[Type], Type] 12 | SchemaTypeOrStr = Union[str, SchemaType] 13 | StrDict = Dict[str, Any] 14 | QueryType = Union[StrDict, MultiDict] 15 | Record = Row 16 | Records = CursorResult 17 | Connection = AsyncConnection 18 | -------------------------------------------------------------------------------- /openapi/tz.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | UTC = timezone.utc 4 | 5 | 6 | def utcnow() -> datetime: 7 | return datetime.now(tz=UTC) 8 | 9 | 10 | def as_utc(dt: datetime) -> datetime: 11 | return dt.replace(tzinfo=UTC) 12 | -------------------------------------------------------------------------------- /openapi/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import is_dataclass 3 | from inspect import isclass 4 | from typing import ( 5 | Any, 6 | Dict, 7 | Hashable, 8 | Iterable, 9 | Iterator, 10 | List, 11 | Mapping, 12 | NamedTuple, 13 | Optional, 14 | TypeVar, 15 | Union, 16 | cast, 17 | ) 18 | 19 | from .exc import InvalidTypeException 20 | from .types import Record 21 | 22 | 23 | def get_origin(value: Any) -> Any: 24 | return getattr(value, "__origin__", None) 25 | 26 | 27 | LOCAL = "local" 28 | DEV = "dev" 29 | PRODUCTION = "production" 30 | NO_DEBUG = {"0", "false", "no"} 31 | Null = object() 32 | # 33 | # this should be Union[type, "TypingInfo"] but recursive types are not supported in mypy 34 | ElementType = Any 35 | 36 | 37 | def get_args(value, defaults): 38 | return getattr(value, "__args__", None) or defaults 39 | 40 | 41 | KT, VT = get_args(Dict, (TypeVar("KT"), TypeVar("VT"))) 42 | (T,) = get_args(List, (TypeVar("T"),)) 43 | 44 | 45 | class TypingInfo(NamedTuple): 46 | """Information about a type annotation""" 47 | 48 | element: ElementType 49 | container: Optional[type] = None 50 | 51 | @property 52 | def is_dataclass(self) -> bool: 53 | """True if :attr:`.element` is a dataclass""" 54 | return not self.container and is_dataclass(self.element) 55 | 56 | @property 57 | def is_union(self) -> bool: 58 | """True if :attr:`.element` is a union of typing info""" 59 | return isinstance(self.element, tuple) 60 | 61 | @property 62 | def is_complex(self) -> bool: 63 | """True if :attr:`.element` is either a dataclass or a union""" 64 | return self.container is not None or self.is_union 65 | 66 | @property 67 | def is_none(self) -> bool: 68 | """True if :attr:`.element` is either a dataclass or a union""" 69 | return self.element is type(None) # noqa: E721 70 | 71 | @classmethod 72 | def get(cls, value: Any) -> Optional["TypingInfo"]: 73 | """Create a :class:`.TypingInfo` from a typing annotation or 74 | another typing info 75 | 76 | :param value: typing annotation 77 | """ 78 | if value is None or isinstance(value, cls): 79 | return value 80 | origin = get_origin(value) 81 | if not origin: 82 | if value is Any or isclass(value): 83 | return cls(value) 84 | else: 85 | raise InvalidTypeException( 86 | f"a class or typing annotation is required, got {value}" 87 | ) 88 | elif origin is list: 89 | (val,) = get_args(value, (T,)) 90 | if val is T: 91 | val = Any 92 | elem_info = cast(TypingInfo, cls.get(val)) 93 | elem = elem_info if elem_info.is_complex else elem_info.element 94 | return cls(elem, list) 95 | elif origin is dict: 96 | key, val = get_args(value, (KT, VT)) 97 | if key is KT: 98 | key = str 99 | if val is VT: 100 | val = Any 101 | if key is not str: 102 | raise InvalidTypeException( 103 | f"Dict key annotation must be a string, got {key}" 104 | ) 105 | 106 | elem_info = cast(TypingInfo, cls.get(val)) 107 | elem = elem_info if elem_info.is_complex else elem_info.element 108 | return cls(elem, dict) 109 | elif origin is Union: 110 | elem = tuple(cls.get(val) for val in value.__args__) 111 | return cls(elem) 112 | else: 113 | raise InvalidTypeException( 114 | f"Types or List and Dict typing is required, got {value}" 115 | ) 116 | 117 | 118 | def get_env() -> str: 119 | return os.environ.get("PYTHON_ENV") or PRODUCTION 120 | 121 | 122 | def get_debug_flag() -> bool: 123 | val = os.environ.get("DEBUG") 124 | if not val: 125 | return get_env() == LOCAL 126 | return val.lower() not in NO_DEBUG 127 | 128 | 129 | def compact(**kwargs) -> Dict: 130 | return {k: v for k, v in kwargs.items() if v} 131 | 132 | 133 | def compact_dict(kwargs: Dict) -> Dict: 134 | return {k: v for k, v in kwargs.items() if v is not None} 135 | 136 | 137 | def replace_key(kwargs: Dict, from_key: Hashable, to_key: Hashable) -> Dict: 138 | value = kwargs.pop(from_key, Null) 139 | if value is not Null: 140 | kwargs[to_key] = value 141 | return kwargs 142 | 143 | 144 | def iter_items(data: Iterable) -> Iterator: 145 | if isinstance(data, Record): 146 | data = data._asdict() 147 | if isinstance(data, Mapping): 148 | return iter(data.items()) 149 | return iter(data) 150 | 151 | 152 | def is_subclass(value: Any, Type: type) -> bool: 153 | origin = getattr(value, "__origin__", None) or value 154 | return isclass(origin) and issubclass(origin, Type) 155 | 156 | 157 | def as_list(errors: Iterable) -> List[Dict[str, Any]]: 158 | return [ 159 | {"field": field, "message": message} for field, message in iter_items(errors) 160 | ] 161 | 162 | 163 | def error_dict(errors: List) -> Dict: 164 | return dict(((d["field"], d["message"]) for d in errors)) 165 | 166 | 167 | TRUE_VALUES = frozenset(("yes", "true", "t", "1")) 168 | 169 | 170 | def str2bool(v: Union[str, bool, int]): 171 | return str(v).lower() in TRUE_VALUES 172 | 173 | 174 | def docjoin(iterable: Iterable) -> str: 175 | return ", ".join(f"``{v}``" for v in iterable) 176 | -------------------------------------------------------------------------------- /openapi/ws/__init__.py: -------------------------------------------------------------------------------- 1 | """Web socket handler with Publish/Subscribe capabilities 2 | 3 | Pub/Sub requires a message broker object in the "broker" app key 4 | """ 5 | from .channel import Channel, Event 6 | from .channels import Channels 7 | from .errors import CannotPublish, CannotSubscribe, ChannelCallbackError 8 | from .manager import SocketsManager, Websocket, WsHandlerType 9 | from .path import WsPathMixin 10 | from .rpc import ws_rpc 11 | 12 | __all__ = [ 13 | "WsPathMixin", 14 | "WsHandlerType", 15 | "SocketsManager", 16 | "Websocket", 17 | "Channels", 18 | "Channel", 19 | "Event", 20 | "CannotPublish", 21 | "CannotSubscribe", 22 | "ChannelCallbackError", 23 | "ws_rpc", 24 | ] 25 | -------------------------------------------------------------------------------- /openapi/ws/channel.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import logging 5 | import re 6 | from dataclasses import dataclass, field 7 | from typing import Any, Awaitable, Callable, Dict, Sequence, Set 8 | 9 | from .errors import ChannelCallbackError 10 | from .utils import redis_to_py_pattern 11 | 12 | logger = logging.getLogger("trading.websocket") 13 | 14 | 15 | @dataclass 16 | class Event: 17 | name: str 18 | pattern: str 19 | regex: Any 20 | callbacks: Set[CallbackType] = field(default_factory=set) 21 | 22 | 23 | CallbackType = Callable[[str, str, Any], Awaitable[Any]] 24 | 25 | 26 | @dataclass 27 | class Channel: 28 | """A websocket channel""" 29 | 30 | name: str 31 | _events: Dict[str, Event] = field(default_factory=dict) 32 | 33 | @property 34 | def events(self): 35 | """List of event names this channel is registered with""" 36 | return tuple((e.name for e in self._events.values())) 37 | 38 | def __len__(self) -> int: 39 | return len(self._events) 40 | 41 | def __contains__(self, pattern: str) -> bool: 42 | return pattern in self._events 43 | 44 | def __iter__(self): 45 | return iter(self._events) 46 | 47 | def info(self) -> Dict: 48 | return {e.name: len(e.callbacks) for e in self._events.values()} 49 | 50 | async def __call__(self, message: Dict[str, Any]) -> Sequence[CallbackType]: 51 | """Execute callbacks from a new message 52 | 53 | Return callbacks which have raise WebsocketClosed or have raise an exception 54 | """ 55 | event_name = message.get("event") or "" 56 | data = message.get("data") 57 | for event in tuple(self._events.values()): 58 | match = event.regex.match(event_name) 59 | if match: 60 | match = match.group() 61 | results = await asyncio.gather( 62 | *[ 63 | self._execute_callback(callback, event, match, data) 64 | for callback in event.callbacks 65 | ] 66 | ) 67 | return tuple(c for c in results if c) 68 | return () 69 | 70 | def register(self, event_name: str, callback: CallbackType): 71 | """Register a ``callback`` for ``event_name``""" 72 | event_name = event_name or "*" 73 | pattern = self.event_pattern(event_name) 74 | event = self._events.get(pattern) 75 | if not event: 76 | event = Event(name=event_name, pattern=pattern, regex=re.compile(pattern)) 77 | self._events[event.pattern] = event 78 | event.callbacks.add(callback) 79 | return event 80 | 81 | def get_subscribed(self, callback: CallbackType): 82 | events = [] 83 | for event in self._events.values(): 84 | if callback in event.callbacks: 85 | events.append(event.name) 86 | return events 87 | 88 | def unregister(self, event_name: str, callback: CallbackType): 89 | pattern = self.event_pattern(event_name) 90 | event = self._events.get(pattern) 91 | if event: 92 | return self.remove_event_callback(event, callback) 93 | 94 | def event_pattern(self, event): 95 | """Channel pattern for an event name""" 96 | return redis_to_py_pattern(event or "*") 97 | 98 | def remove_callback(self, callback: CallbackType) -> None: 99 | for key, event in tuple(self._events.items()): 100 | self.remove_event_callback(event, callback) 101 | 102 | def remove_event_callback(self, event: Event, callback: CallbackType) -> None: 103 | event.callbacks.discard(callback) 104 | if not event.callbacks: 105 | self._events.pop(event.pattern) 106 | 107 | async def _execute_callback( 108 | self, callback: CallbackType, event: Event, match: str, data: Any 109 | ) -> Any: 110 | try: 111 | await callback(self.name, match, data) 112 | except ChannelCallbackError: 113 | return callback 114 | except Exception: 115 | logger.exception('callback exception: channel "%s" event "%s"', self, event) 116 | return callback 117 | -------------------------------------------------------------------------------- /openapi/ws/channels.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple 2 | 3 | from openapi.data.validate import ValidationErrors 4 | 5 | from .channel import CallbackType, Channel 6 | from .errors import CannotSubscribe 7 | 8 | if TYPE_CHECKING: # pragma: no cover 9 | from .manager import SocketsManager 10 | 11 | 12 | class Channels: 13 | """Manage channels for publish/subscribe""" 14 | 15 | def __init__(self, sockets: "SocketsManager") -> None: 16 | self.sockets: "SocketsManager" = sockets 17 | self._channels: Dict[str, Channel] = {} 18 | 19 | @property 20 | def registered(self) -> Tuple[str, ...]: 21 | """Registered channels""" 22 | return tuple(self._channels) 23 | 24 | def __len__(self) -> int: 25 | return len(self._channels) 26 | 27 | def __contains__(self, channel_name: str) -> bool: 28 | return channel_name in self._channels 29 | 30 | def __iter__(self) -> Iterator[Channel]: 31 | return iter(self._channels.values()) 32 | 33 | def clear(self) -> None: 34 | self._channels.clear() 35 | 36 | def get(self, channel_name: str) -> Optional[Channel]: 37 | return self._channels.get(channel_name) 38 | 39 | def info(self) -> Dict: 40 | return {channel.name: channel.info() for channel in self} 41 | 42 | async def __call__(self, channel_name: str, message: Dict) -> None: 43 | """Channel callback""" 44 | channel = self.get(channel_name) 45 | if channel: 46 | closed = await channel(message) 47 | for websocket in closed: 48 | for channel_name, channel in tuple(self._channels.items()): 49 | channel.remove_callback(websocket) 50 | await self._maybe_remove_channel(channel) 51 | 52 | async def register( 53 | self, channel_name: str, event_name: str, callback: CallbackType 54 | ) -> Channel: 55 | """Register a callback 56 | 57 | :param channel_name: name of the channel 58 | :param event_name: name of the event in the channel or a pattern 59 | :param callback: the callback to invoke when the `event` on `channel` occurs 60 | """ 61 | channel = self.get(channel_name) 62 | if channel is None: 63 | try: 64 | await self.sockets.subscribe(channel_name) 65 | except CannotSubscribe: 66 | raise ValidationErrors(dict(channel="Invalid channel")) 67 | else: 68 | channel = Channel(channel_name) 69 | self._channels[channel_name] = channel 70 | event = channel.register(event_name, callback) 71 | await self.sockets.subscribe_to_event(channel.name, event.name) 72 | return channel 73 | 74 | async def unregister( 75 | self, channel_name: str, event: str, callback: CallbackType 76 | ) -> Optional[Channel]: 77 | """Safely unregister a callback from the list of event 78 | callbacks for channel_name 79 | """ 80 | channel = self.get(channel_name) 81 | if channel is None: 82 | raise ValidationErrors(dict(channel="Invalid channel")) 83 | channel.unregister(event, callback) 84 | return await self._maybe_remove_channel(channel) 85 | 86 | async def _maybe_remove_channel(self, channel: Channel) -> Channel: 87 | if not channel: 88 | await self.sockets.unsubscribe(channel.name) 89 | self._channels.pop(channel.name) 90 | return channel 91 | 92 | def get_subscribed(self, callback: CallbackType) -> Dict[str, List[str]]: 93 | subscribed = {} 94 | for channel in self: 95 | events = channel.get_subscribed(callback) 96 | if events: 97 | subscribed[channel.name] = events 98 | return subscribed 99 | -------------------------------------------------------------------------------- /openapi/ws/errors.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | class CannotSubscribe(RuntimeError): 5 | """Raised by a :class:`.ServiceManager`. 6 | 7 | When a :class:`.ServiceManager` is not able to subscribe to a channel 8 | it should raise this exception 9 | """ 10 | 11 | 12 | class ChannelCallbackError(RuntimeError): 13 | """Exception which allow for a clean callback removal""" 14 | 15 | 16 | class CannotPublish(RuntimeError): 17 | """Raised when not possible to publish event into channels""" 18 | 19 | 20 | CONNECTION_ERRORS = ( 21 | asyncio.CancelledError, 22 | asyncio.TimeoutError, 23 | RuntimeError, 24 | ConnectionResetError, 25 | ) 26 | -------------------------------------------------------------------------------- /openapi/ws/manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from functools import cached_property 3 | from typing import Any, Callable, Dict, Set 4 | 5 | from .channels import CannotSubscribe, Channels 6 | from .errors import CannotPublish 7 | 8 | WsHandlerType = Callable[[str, Any], None] 9 | 10 | 11 | class Websocket: 12 | """A websocket connection""" 13 | 14 | socket_id: str = "" 15 | """websocket ID""" 16 | 17 | def __str__(self) -> str: 18 | return self.socket_id 19 | 20 | 21 | class SocketsManager: 22 | """A base class for websocket managers""" 23 | 24 | @cached_property 25 | def sockets(self) -> Set[Websocket]: 26 | """Set of connected :class:`.Websocket`""" 27 | return set() 28 | 29 | @cached_property 30 | def channels(self) -> Channels: 31 | """Pub/sub :class:`.Channels` currently active on the running pod""" 32 | return Channels(self) 33 | 34 | def add(self, ws: Websocket) -> None: 35 | """Add a new websocket to the connected set""" 36 | self.sockets.add(ws) 37 | 38 | def remove(self, ws: Websocket) -> None: 39 | """Remove a websocket from the connected set""" 40 | self.sockets.discard(ws) 41 | 42 | def server_info(self) -> Dict: 43 | """Server information""" 44 | return dict(connections=len(self.sockets), channels=self.channels.info()) 45 | 46 | async def close_sockets(self) -> None: 47 | """Close and remove all websockets from the connected set""" 48 | await asyncio.gather(*[view.response.close() for view in self.sockets]) 49 | self.sockets.clear() 50 | self.channels.clear() 51 | 52 | async def publish( 53 | self, channel: str, event: str, body: Dict 54 | ) -> None: # pragma: no cover 55 | """Publish an event to a channel 56 | 57 | :property channel: the channel to publish to 58 | :property event: the event in the channel 59 | :property body: the body of the event to broadcast in the channel 60 | 61 | This method should raise :class:`.CannotPublish` if not possible to publish 62 | """ 63 | raise CannotPublish 64 | 65 | async def subscribe(self, channel: str) -> None: # pragma: no cover 66 | """Subscribe to a channel 67 | 68 | This method should raise :class:`.CannotSubscribe` if not possible to publish 69 | """ 70 | raise CannotSubscribe 71 | 72 | async def subscribe_to_event(self, channel: str, event: str) -> None: 73 | """Callback when a subscription to an event is done 74 | 75 | :property channel: the channel to publish to 76 | :property event: the event in the channel 77 | 78 | You can use this callback to perform any backend subscriptions to 79 | third-party streaming services if required. 80 | 81 | By default it does nothing. 82 | """ 83 | 84 | async def unsubscribe(self, channel: str) -> None: 85 | """Unsubscribe from a channel""" 86 | -------------------------------------------------------------------------------- /openapi/ws/path.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | import time 4 | from dataclasses import dataclass, field 5 | from typing import Any, Dict 6 | 7 | from aiohttp import web 8 | 9 | from openapi.ws.channels import Channels 10 | 11 | from .. import json 12 | from ..data.validate import ValidationErrors, validated_schema 13 | from ..utils import compact 14 | from .errors import CONNECTION_ERRORS 15 | from .manager import SocketsManager, Websocket 16 | 17 | logger = logging.getLogger("openapi.ws") 18 | 19 | 20 | @dataclass 21 | class RpcProtocol: 22 | id: str 23 | method: str 24 | payload: Dict = field(default_factory=dict) 25 | 26 | 27 | class ProtocolError(RuntimeError): 28 | pass 29 | 30 | 31 | class WsPathMixin(Websocket): 32 | """Api Path mixin for Websocket RPC protocol""" 33 | 34 | SOCKETS_KEY = "web_sockets" 35 | """Key in the app where the Web Sockets manager is located""" 36 | 37 | @property 38 | def sockets(self) -> SocketsManager: 39 | """Connected websockets""" 40 | return self.request.app[self.SOCKETS_KEY] 41 | 42 | @property 43 | def channels(self) -> Channels: 44 | """Channels for pub/sub""" 45 | return self.sockets.channels 46 | 47 | async def get(self): 48 | response = web.WebSocketResponse() 49 | available = response.can_prepare(self.request) 50 | if not available: 51 | raise web.HTTPBadRequest( 52 | **self.api_response_data( 53 | {"message": "Unable to open websocket connection"} 54 | ) 55 | ) 56 | await response.prepare(self.request) 57 | self.response = response 58 | self.started = time.time() 59 | key = "%s - %s" % (self.request.remote, self.started) 60 | self.socket_id = hashlib.sha224(key.encode("utf-8")).hexdigest() 61 | # 62 | # Add to set of sockets if available 63 | self.sockets.add(self) 64 | # 65 | try: 66 | async for msg in response: 67 | if msg.type == web.WSMsgType.TEXT: 68 | await self.on_message(msg) 69 | except CONNECTION_ERRORS: 70 | logger.info("lost connection with websocket %s", self) 71 | finally: 72 | self.sockets.remove(self) 73 | return response 74 | 75 | def decode_message(self, msg: str) -> Any: 76 | """Decode JSON string message, override for different protocol""" 77 | try: 78 | return json.loads(msg) 79 | except json.JSONDecodeError: 80 | raise ProtocolError("JSON string expected") from None 81 | 82 | def encode_message(self, msg: Any) -> str: 83 | """Encode as JSON string message, override for different protocol""" 84 | try: 85 | return json.dumps(msg) 86 | except TypeError: 87 | raise ProtocolError("JSON object expected") from None 88 | 89 | async def on_message(self, msg): 90 | id_ = None 91 | rpc = None 92 | try: 93 | data = self.decode_message(msg.data) 94 | if not isinstance(data, dict): 95 | raise ProtocolError( 96 | "Malformed message; expected dictionary, " 97 | f"got {type(data).__name__}" 98 | ) 99 | id_ = data.get("id") 100 | rpc = validated_schema(RpcProtocol, data) 101 | method = getattr(self, f"ws_rpc_{rpc.method}", None) 102 | if not method: 103 | raise ValidationErrors( 104 | dict(method=f"{rpc.method} method not available") 105 | ) 106 | response = await method(rpc.payload or {}) 107 | await self.write(dict(id=rpc.id, method=rpc.method, response=response)) 108 | except ProtocolError as exc: 109 | logger.error("Protocol error: %s", exc) 110 | await self.error_message( 111 | str(exc), id=id_, method=rpc.method if rpc else None 112 | ) 113 | except ValidationErrors as exc: 114 | await self.error_message( 115 | "Invalid RPC parameters", 116 | errors=exc.errors, 117 | id=id_, 118 | method=rpc.method if rpc else None, 119 | ) 120 | 121 | async def error_message(self, message, *, errors=None, **kw): 122 | error = dict(message=message) 123 | if errors: 124 | error["errors"] = errors 125 | await self.write(compact(error=error, **kw)) 126 | 127 | async def write(self, msg: Dict) -> None: 128 | text = self.encode_message(msg) 129 | await self.response.send_str(text) 130 | 131 | async def close(self) -> None: 132 | await self.response.close() 133 | self.sockets.remove(self) 134 | -------------------------------------------------------------------------------- /openapi/ws/pubsub.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import cached_property 3 | from typing import TYPE_CHECKING, Any, Dict, List, Union 4 | 5 | from ..data import fields 6 | from ..data.validate import ValidationErrors 7 | from .channel import logger 8 | from .errors import CONNECTION_ERRORS, CannotPublish, ChannelCallbackError 9 | from .rpc import ws_rpc 10 | 11 | if TYPE_CHECKING: # pragma: no cover 12 | from .path import WsPathMixin 13 | 14 | 15 | @dataclass 16 | class PublishSchema: 17 | data: Union[str, List, Dict] 18 | channel: str = fields.data_field( 19 | required=True, description="Channel to publish message" 20 | ) 21 | event: str = fields.data_field(description="Channel event") 22 | 23 | 24 | @dataclass 25 | class SubscribeSchema: 26 | channel: str = fields.data_field(required=True, description="Channel to subscribe") 27 | event: str = fields.data_field(description="Channel event") 28 | 29 | 30 | class ChannelCallback: 31 | """Callback for channels""" 32 | 33 | def __init__(self, ws: "WsPathMixin"): 34 | self.ws: "WsPathMixin" = ws 35 | 36 | def __repr__(self) -> str: # pragma: no cover 37 | return self.ws.socket_id 38 | 39 | def __str__(self) -> str: 40 | return f"websocket {self.ws.socket_id}" 41 | 42 | async def __call__(self, channel: str, match: str, data: Any) -> None: 43 | try: 44 | if hasattr(data, "__call__"): 45 | data = data() 46 | await self.ws.write(dict(channel=channel, event=match, data=data)) 47 | except CONNECTION_ERRORS: 48 | logger.info("lost connection with %s", self) 49 | await self.ws.close() 50 | raise ChannelCallbackError 51 | except Exception: 52 | logger.exception("Critical exception on connection %s", self) 53 | await self.ws.close() 54 | raise ChannelCallbackError 55 | 56 | 57 | class Publish: 58 | """Mixin which implements the publish RPC method 59 | 60 | Must be used as mixin of :class:`.WsPathMixin` 61 | """ 62 | 63 | def get_publish_message(self, data: Any) -> Any: 64 | """Create the publish message from the data payload""" 65 | return data 66 | 67 | @ws_rpc(body_schema=PublishSchema) 68 | async def ws_rpc_publish(self, payload): 69 | """Publish an event on a channel""" 70 | try: 71 | event = payload.get("event") 72 | data = self.get_publish_message(payload.get("data")) 73 | await self.sockets.publish(payload["channel"], event, data) 74 | return dict(channel=payload["channel"], event=event, data=data) 75 | except CannotPublish: 76 | raise ValidationErrors(dict(channel="Cannot publish to channel")) 77 | 78 | 79 | class Subscribe: 80 | """Mixin which implements the subscribe and unsubscribe RPC methods 81 | 82 | Must be used as mixin of :class:`.WsPathMixin` 83 | """ 84 | 85 | @cached_property 86 | def channel_callback(self) -> ChannelCallback: 87 | """The callback for :class:`.Channels`""" 88 | return ChannelCallback(self) 89 | 90 | @ws_rpc(body_schema=SubscribeSchema) 91 | async def ws_rpc_subscribe(self, payload): 92 | """Subscribe to an event on a channel""" 93 | await self.channels.register( 94 | payload["channel"], payload.get("event"), self.channel_callback 95 | ) 96 | return dict(subscribed=self.channels.get_subscribed(self.channel_callback)) 97 | 98 | @ws_rpc(body_schema=SubscribeSchema) 99 | async def ws_rpc_unsubscribe(self, payload): 100 | """Unsubscribe to an event on a channel""" 101 | await self.channels.unregister( 102 | payload["channel"], payload.get("event"), self.channel_callback 103 | ) 104 | return dict(subscribed=self.channels.get_subscribed(self.channel_callback)) 105 | -------------------------------------------------------------------------------- /openapi/ws/rpc.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import wraps 3 | from typing import Any 4 | 5 | from ..data.validate import ValidationErrors, validate 6 | 7 | 8 | @dataclass 9 | class ws_rpc: 10 | """Defines a Websocket RPC method in an OpenAPI Path""" 11 | 12 | body_schema: Any = None 13 | response_schema: Any = None 14 | 15 | def __call__(self, method): 16 | method.ws_rpc = self 17 | 18 | @wraps(method) 19 | async def _(view, payload): 20 | if self.body_schema: 21 | d = validate(self.body_schema, payload) 22 | if d.errors: 23 | raise ValidationErrors(d.errors) 24 | payload = d.data 25 | data = await method(view, payload) 26 | return view.dump(self.response_schema, data) 27 | 28 | return _ 29 | -------------------------------------------------------------------------------- /openapi/ws/utils.py: -------------------------------------------------------------------------------- 1 | def redis_to_py_pattern(pattern): 2 | return "".join(_redis_to_py_pattern(pattern)) 3 | 4 | 5 | def _redis_to_py_pattern(pattern): 6 | clear, esc = False, False 7 | s, q, op, cp, e = "*", "?", "[", "]", "\\" 8 | 9 | for v in pattern: 10 | if v == s and not esc: 11 | yield "(.*)" 12 | elif v == q and not esc: 13 | yield "." 14 | elif v == op and not esc: 15 | esc = True 16 | yield v 17 | elif v == cp and esc: 18 | esc = False 19 | yield v 20 | elif v == e: 21 | clear, esc = True 22 | yield v 23 | elif clear: 24 | clear, esc = False, False 25 | yield v 26 | else: 27 | yield v 28 | yield "$" 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "aio-openapi" 3 | version = "3.2.1" 4 | description = "Minimal OpenAPI asynchronous server application" 5 | documentation = "https://aio-openapi.readthedocs.io" 6 | repository = "https://github.com/quantmind/aio-openapi" 7 | authors = ["Luca "] 8 | license = "BSD-3-Clause" 9 | readme = "readme.md" 10 | packages = [ 11 | { include = "openapi" } 12 | ] 13 | classifiers = [ 14 | "Development Status :: 5 - Production/Stable", 15 | "Intended Audience :: Developers", 16 | "License :: OSI Approved :: BSD License", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: JavaScript", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Topic :: Internet", 26 | "Topic :: Internet :: WWW/HTTP", 27 | "Topic :: Software Development :: Libraries :: Application Frameworks", 28 | "Topic :: Software Development :: Libraries :: Python Modules", 29 | "Topic :: Software Development :: Libraries", 30 | "Topic :: Software Development", 31 | "Typing :: Typed", 32 | "Framework :: AsyncIO", 33 | "Environment :: Web Environment", 34 | ] 35 | 36 | [tool.poetry.urls] 37 | repository = "https://github.com/quantmind/aio-openapi" 38 | issues = "https://github.com/quantmind/aio-openapi/issues" 39 | 40 | 41 | [tool.poetry.dependencies] 42 | python = ">=3.8.1,<4" 43 | aiohttp = "^3.8.0" 44 | httptools = "^0.5.0" 45 | simplejson = "^3.17.2" 46 | SQLAlchemy = { version="^2.0.8", extras=["asyncio"] } 47 | SQLAlchemy-Utils = "^0.41.1" 48 | psycopg2-binary = "^2.9.2" 49 | click = "^8.0.3" 50 | python-dateutil = "^2.8.2" 51 | PyYAML = "^6.0" 52 | email-validator = "^1.2.1" 53 | alembic = "^1.8.1" 54 | "backports.zoneinfo" = { version = "^0.2.1", python="<3.9" } 55 | asyncpg = "^0.28.0" 56 | 57 | [tool.poetry.group.dev.dependencies] 58 | black = "^23.3.0" 59 | pytest = "^7.1.1" 60 | mypy = "^1.1.1" 61 | sentry-sdk = "^1.4.3" 62 | python-dotenv = "^1.0.0" 63 | openapi-spec-validator = "^0.3.1" 64 | pytest-cov = "^4.0.0" 65 | pytest-mock = "^3.6.1" 66 | isort = "^5.10.1" 67 | types-simplejson = "^3.17.5" 68 | types-python-dateutil = "^2.8.11" 69 | factory-boy = "^3.2.1" 70 | pytest-asyncio = "^0.21.0" 71 | types-pyyaml = "^6.0.12" 72 | ruff = "^0.0.280" 73 | 74 | [tool.poetry.group.extras] 75 | optional = true 76 | 77 | [tool.poetry.group.extras.dependencies] 78 | aiodns = {version = "^3.0.0"} 79 | PyJWT = {version = "^2.3.0"} 80 | colorlog = {version = "^6.6.0"} 81 | phonenumbers = {version = "^8.12.37"} 82 | 83 | 84 | [tool.poetry.group.docs] 85 | optional = true 86 | 87 | [tool.poetry.group.docs.dependencies] 88 | Sphinx = {version = "^6.1.3"} 89 | sphinx-copybutton = {version = "^0.5.0"} 90 | sphinx-autodoc-typehints = {version = "^1.12.0"} 91 | aiohttp-theme = {version = "^0.1.6"} 92 | recommonmark = {version = "^0.7.1"} 93 | 94 | 95 | [tool.poetry.extras] 96 | dev = ["aiodns", "PyJWT", "colorlog", "phonenumbers"] 97 | docs = [ 98 | "Sphinx", 99 | "recommonmark", 100 | "aiohttp-theme", 101 | "sphinx-copybutton", 102 | "sphinx-autodoc-typehints", 103 | ] 104 | 105 | [build-system] 106 | requires = ["poetry-core>=1.0.0"] 107 | build-backend = "poetry.core.masonry.api" 108 | 109 | [tool.pytest.ini_options] 110 | asyncio_mode = "auto" 111 | testpaths = [ 112 | "tests" 113 | ] 114 | 115 | [tool.isort] 116 | profile = "black" 117 | 118 | [tool.ruff] 119 | select = ["E", "F"] 120 | line-length = 88 121 | 122 | [tool.mypy] 123 | # strict = true 124 | disallow_untyped_calls = true 125 | disallow_untyped_defs = true 126 | warn_no_return = true 127 | 128 | [[tool.mypy.overrides]] 129 | module = "tests.*" 130 | disallow_untyped_defs = false 131 | 132 | [[tool.mypy.overrides]] 133 | module = "openapi.db.openapi.*" 134 | ignore_errors = true 135 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # aio-openapi 2 | 3 | [![PyPI version](https://badge.fury.io/py/aio-openapi.svg)](https://badge.fury.io/py/aio-openapi) 4 | [![Python versions](https://img.shields.io/pypi/pyversions/aio-openapi.svg)](https://pypi.org/project/aio-openapi) 5 | [![Build](https://github.com/quantmind/aio-openapi/workflows/build/badge.svg)](https://github.com/quantmind/aio-openapi/actions?query=workflow%3Abuild) 6 | [![codecov](https://codecov.io/github/quantmind/aio-openapi/branch/main/graph/badge.svg?token=XV2GD946QI)](https://codecov.io/github/quantmind/aio-openapi) 7 | [![Documentation Status](https://readthedocs.org/projects/aio-openapi/badge/?version=latest)](https://aio-openapi.readthedocs.io/en/latest/?badge=latest) 8 | [![Downloads](https://img.shields.io/pypi/dd/aio-openapi.svg)](https://pypi.org/project/aio-openapi/) 9 | 10 | Asynchronous web middleware for [aiohttp][] and serving Rest APIs with [OpenAPI][] v 3 11 | specification and with optional [PostgreSql][] database bindings. 12 | 13 | See the [tutorial](https://aio-openapi.readthedocs.io/en/latest/tutorial.html) for a quick introduction. 14 | 15 | 16 | [aiohttp]: https://aiohttp.readthedocs.io/en/stable/ 17 | [openapi]: https://www.openapis.org/ 18 | [postgresql]: https://www.postgresql.org/ 19 | [sqlalchemy]: https://www.sqlalchemy.org/ 20 | [click]: https://github.com/pallets/click 21 | [alembic]: http://alembic.zzzcomputing.com/en/latest/ 22 | [asyncpg]: https://github.com/MagicStack/asyncpg 23 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dotenv 4 | 5 | dotenv.load_dotenv() 6 | dotenv.load_dotenv("tests/test.env") 7 | 8 | if not os.environ.get("PYTHON_ENV"): 9 | os.environ["PYTHON_ENV"] = "test" 10 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import shutil 4 | from unittest import mock 5 | 6 | import pytest 7 | from aiohttp.test_utils import TestClient 8 | from aiohttp.web import Application 9 | from sqlalchemy.engine.url import URL 10 | from sqlalchemy_utils import create_database, database_exists 11 | 12 | from openapi.db.dbmodel import CrudDB 13 | from openapi.testing import app_cli, with_test_db 14 | 15 | from .example.db import DB 16 | from .example.main import create_app 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def sync_url() -> URL: 21 | return DB.sync_engine.url 22 | 23 | 24 | @pytest.fixture(autouse=True) 25 | def clean_migrations(): 26 | if os.path.isdir("migrations"): 27 | shutil.rmtree("migrations") 28 | 29 | 30 | @pytest.fixture(autouse=True) 31 | def sentry_mock(mocker): 32 | mm = mock.MagicMock() 33 | mocker.patch("sentry_sdk.init", mm) 34 | return mm 35 | 36 | 37 | @pytest.fixture(scope="module", autouse=True) 38 | def event_loop(): 39 | """Return an instance of the event loop.""" 40 | loop = asyncio.new_event_loop() 41 | try: 42 | yield loop 43 | finally: 44 | loop.close() 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def clear_db(sync_url: URL) -> CrudDB: 49 | if not database_exists(sync_url): 50 | # drop_database(url) 51 | create_database(sync_url) 52 | else: 53 | DB.drop_all_schemas() 54 | return DB 55 | 56 | 57 | @pytest.fixture 58 | async def cli(clear_db: CrudDB) -> TestClient: 59 | app = create_app().web() 60 | with with_test_db(app["db"]): 61 | async with app_cli(app) as cli: 62 | yield cli 63 | 64 | 65 | @pytest.fixture(scope="module") 66 | async def cli2(clear_db: CrudDB) -> TestClient: 67 | app = create_app().web() 68 | with with_test_db(app["db"]): 69 | async with app_cli(app) as cli: 70 | yield cli 71 | 72 | 73 | @pytest.fixture 74 | def test_app(cli: TestClient) -> Application: 75 | return cli.app 76 | 77 | 78 | @pytest.fixture 79 | def db(test_app: Application) -> CrudDB: 80 | return test_app["db"] 81 | -------------------------------------------------------------------------------- /tests/core/test_cli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import patch 3 | 4 | import click 5 | from click.testing import CliRunner 6 | 7 | from openapi.logger import logger 8 | from openapi.rest import rest 9 | 10 | 11 | def test_usage(): 12 | runner = CliRunner() 13 | result = runner.invoke(rest()) 14 | assert result.exit_code == 0 15 | assert result.output.startswith("Usage:") 16 | 17 | 18 | def test_version(): 19 | runner = CliRunner() 20 | result = runner.invoke(rest(), ["--version"]) 21 | assert result.exit_code == 0 22 | assert result.output.startswith("Open API") 23 | 24 | 25 | def test_version_openapi(): 26 | runner = CliRunner() 27 | result = runner.invoke( 28 | rest(openapi=dict(title="Test Version", version="1.0")), ["--version"] 29 | ) 30 | assert result.exit_code == 0 31 | assert result.output.startswith("Test Version 1.0") 32 | 33 | 34 | def test_serve(): 35 | runner = CliRunner() 36 | cli = rest(base_path="/v1") 37 | with patch("aiohttp.web.run_app") as mock: 38 | result = runner.invoke(cli, ["--quiet", "serve"]) 39 | assert result.exit_code == 0 40 | assert mock.call_count == 1 41 | app = mock.call_args[0][0] 42 | assert app.router is not None 43 | assert logger.level == logging.ERROR 44 | 45 | with patch("aiohttp.web.run_app") as mock: 46 | result = runner.invoke(cli, ["--verbose", "serve"]) 47 | assert result.exit_code == 0 48 | assert mock.call_count == 1 49 | app = mock.call_args[0][0] 50 | assert app.router is not None 51 | assert logger.level == logging.DEBUG 52 | 53 | 54 | def test_serve_index(): 55 | runner = CliRunner() 56 | cli = rest() 57 | with patch("aiohttp.web.run_app") as mock: 58 | result = runner.invoke(cli, ["serve", "--index", "1"]) 59 | assert result.exit_code == 0 60 | assert mock.call_count == 1 61 | app = mock.call_args[0][0] 62 | assert app.router is not None 63 | assert app["index"] == 1 64 | assert logger.level == logging.INFO 65 | 66 | 67 | def test_commands(): 68 | runner = CliRunner() 69 | cli = rest(base_path="/v1", commands=[hello]) 70 | result = runner.invoke(cli, ["hello"]) 71 | assert result.exit_code == 0 72 | assert result.output.startswith("Hello!") 73 | 74 | 75 | @click.command("hello") 76 | @click.pass_context 77 | def hello(ctx): 78 | click.echo("Hello!") 79 | -------------------------------------------------------------------------------- /tests/core/test_columns.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from openapi.db.columns import UUIDColumn 4 | 5 | 6 | def test_runtime_error_with_incorrect_params(): 7 | with pytest.raises(RuntimeError): 8 | UUIDColumn("test", primary_key=True, nullable=True) 9 | -------------------------------------------------------------------------------- /tests/core/test_cruddb.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from openapi.db import CrudDB 4 | 5 | 6 | async def test_upsert(db: CrudDB) -> None: 7 | task = await db.db_upsert(db.tasks, dict(title="Example"), dict(severity=4)) 8 | assert task.id 9 | assert task.severity == 4 10 | assert task.done is None 11 | task2 = await db.db_upsert( 12 | db.tasks, dict(title="Example"), dict(done=datetime.now()) 13 | ) 14 | task2.id == task.id 15 | assert task2.done 16 | assert await db.db_count(db.tasks) == 1 17 | 18 | 19 | async def test_upsert_no_data(db: CrudDB) -> None: 20 | task = await db.db_upsert(db.tasks, dict(title="Example2")) 21 | assert task.id 22 | assert task.title == "Example2" 23 | -------------------------------------------------------------------------------- /tests/core/test_db_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import sqlalchemy as sa 5 | from click.testing import CliRunner 6 | 7 | from tests.example.db.tables2 import extra 8 | 9 | 10 | def _migrate(cli, name="test", runner: Optional[CliRunner] = None) -> CliRunner: 11 | if not runner: 12 | runner = CliRunner() 13 | result = runner.invoke(cli.app["cli"], ["db", "init"]) 14 | assert result.exit_code == 0 15 | assert os.path.isdir("migrations") 16 | _drop(cli, runner) 17 | result = runner.invoke(cli.app["cli"], ["db", "migrate", "-m", name]) 18 | assert result.exit_code == 0 19 | return runner 20 | 21 | 22 | def _current(cli, runner: Optional[CliRunner] = None): 23 | if not runner: 24 | runner = CliRunner() 25 | result = runner.invoke(cli.app["cli"], ["db", "current"]) 26 | assert result.exit_code == 0 27 | return result.output.split()[0] 28 | 29 | 30 | def _drop(cli, runner: Optional[CliRunner] = None): 31 | if not runner: 32 | runner = CliRunner() 33 | result = runner.invoke(cli.app["cli"], ["db", "drop"]) 34 | assert result.exit_code == 0 35 | assert result.output == "tables dropped\n" 36 | result = runner.invoke(cli.app["cli"], ["db", "tables", "--db"]) 37 | assert result.exit_code == 0 38 | assert result.output == "" 39 | 40 | 41 | async def test_db(cli): 42 | runner = CliRunner() 43 | result = runner.invoke(cli.app["cli"], ["db", "--help"]) 44 | assert result.exit_code == 0 45 | assert result.output.startswith("Usage: root db [OPTIONS]") 46 | db = cli.app["db"] 47 | assert repr(db) 48 | 49 | 50 | async def test_createdb(cli): 51 | runner = CliRunner() 52 | result = runner.invoke(cli.app["cli"], ["db", "create", "testing-aio-db"]) 53 | assert result.exit_code == 0 54 | result = runner.invoke( 55 | cli.app["cli"], ["db", "create", "testing-aio-db", "--force"] 56 | ) 57 | assert result.exit_code == 0 58 | assert result.output == "database testing-aio-db created\n" 59 | result = runner.invoke(cli.app["cli"], ["db", "create", "testing-aio-db"]) 60 | assert result.exit_code == 0 61 | assert result.output == "database testing-aio-db already available\n" 62 | 63 | 64 | async def test_migration_upgrade(cli): 65 | runner = _migrate(cli) 66 | result = runner.invoke(cli.app["cli"], ["db", "upgrade"]) 67 | assert result.exit_code == 0 68 | 69 | # delete column to check if tables will be dropped and recreated 70 | db = cli.app["db"] 71 | async with db.transaction() as conn: 72 | await conn.execute(sa.text("ALTER TABLE tasks DROP COLUMN title")) 73 | 74 | result = runner.invoke(cli.app["cli"], ["db", "upgrade", "--drop-tables"]) 75 | assert result.exit_code == 0 76 | 77 | assert "title" in db.metadata.tables["tasks"].c 78 | 79 | 80 | async def test_show_migration(cli): 81 | runner = _migrate(cli) 82 | result = runner.invoke(cli.app["cli"], ["db", "show"]) 83 | assert result.exit_code == 0 84 | assert result.output.split("\n")[4].strip() == "test" 85 | 86 | 87 | async def test_history(cli): 88 | runner = _migrate(cli) 89 | result = runner.invoke(cli.app["cli"], ["db", "history"]) 90 | assert result.exit_code == 0 91 | assert result.output.strip().startswith(" -> ") 92 | 93 | 94 | async def test_upgrade(cli): 95 | runner = _migrate(cli) 96 | result = runner.invoke(cli.app["cli"], ["db", "upgrade"]) 97 | assert result.exit_code == 0 98 | 99 | 100 | async def test_downgrade(cli): 101 | runner = _migrate(cli) 102 | runner.invoke(cli.app["cli"], ["db", "upgrade", "--drop-tables"]) 103 | name = _current(cli, runner) 104 | 105 | extra(cli.app["db"].metadata) 106 | _migrate(cli, name="extra", runner=runner) 107 | 108 | # upgrade to new migration 109 | result = runner.invoke(cli.app["cli"], ["db", "upgrade"]) 110 | assert result.exit_code == 0 111 | name2 = _current(cli, runner) 112 | 113 | assert name != name2 114 | 115 | # downgrade 116 | result = runner.invoke(cli.app["cli"], ["db", "downgrade", "--revision", name]) 117 | assert result.exit_code == 0 118 | assert result.output == f"downgraded successfully to {name}\n" 119 | assert name == _current(cli, runner) 120 | 121 | 122 | async def test_tables(cli): 123 | runner = CliRunner() 124 | result = runner.invoke(cli.app["cli"], ["db", "tables"]) 125 | assert result.exit_code == 0 126 | assert result.output == "\n".join( 127 | ("multi_key", "multi_key_unique", "randoms", "series", "tasks", "") 128 | ) 129 | 130 | 131 | async def test_drop(cli): 132 | _drop(cli) 133 | -------------------------------------------------------------------------------- /tests/core/test_db_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | async def test_get_attr(cli): 5 | db = cli.app["db"] 6 | assert db.tasks is db.metadata.tables["tasks"] 7 | with pytest.raises(AttributeError) as ex_info: 8 | db.fooooo 9 | assert "fooooo" in str(ex_info.value) 10 | 11 | 12 | async def test_db_count(cli): 13 | db = cli.app["db"] 14 | n = await db.db_count(db.tasks, {}) 15 | assert n == 0 16 | await db.db_insert(db.tasks, dict(title="testing rollback")) 17 | n = await db.db_count(db.tasks, {}) 18 | assert n == 1 19 | -------------------------------------------------------------------------------- /tests/core/test_db_path_extra.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from uuid import UUID 3 | 4 | from openapi.testing import json_body 5 | 6 | 7 | async def test_get_update(cli): 8 | response = await cli.post("/tasks", json=dict(title="test task2 2")) 9 | data = await json_body(response, 201) 10 | id_ = UUID(data["id"]) 11 | task_path = f"/tasks2/{id_.hex}" 12 | 13 | assert data["title"] == "test task2 2" 14 | # 15 | # now get it 16 | response = await cli.get(task_path) 17 | data = await json_body(response, 200) 18 | assert data["title"] == "test task2 2" 19 | # 20 | # now update 21 | response = await cli.patch(task_path, json=dict(done=datetime.now().isoformat())) 22 | data = await json_body(response, 200) 23 | assert data["id"] == id_.hex 24 | # 25 | # now delete it 26 | response = await cli.delete(task_path) 27 | assert response.status == 204 28 | response = await cli.delete(task_path) 29 | await json_body(response, 404) 30 | -------------------------------------------------------------------------------- /tests/core/test_db_single.py: -------------------------------------------------------------------------------- 1 | from openapi.testing import CrudDB, SingleConnDatabase 2 | 3 | 4 | async def __test_rollback(db: CrudDB): 5 | async with SingleConnDatabase.from_db(db) as sdb: 6 | rows = await sdb.db_insert(sdb.tasks, dict(title="testing rollback")) 7 | assert rows.rowcount == 1 8 | row = rows.first() 9 | assert row["title"] == "testing rollback" 10 | id_ = row["id"] 11 | assert id_ 12 | rows = await sdb.db_select(sdb.tasks, dict(id=id_)) 13 | assert rows.rowcount == 1 14 | rows = await sdb.db_select(sdb.tasks, dict(id=id_)) 15 | assert rows.rowcount == 0 16 | -------------------------------------------------------------------------------- /tests/core/test_dc_db.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from datetime import date, datetime 3 | 4 | from openapi.data.db import dataclass_from_table 5 | from openapi.data.dump import dump 6 | from openapi.data.fields import REQUIRED, VALIDATOR, UUIDValidator, field_dict 7 | from openapi.data.validate import validate 8 | 9 | 10 | def test_convert_task(db): 11 | Tasks = dataclass_from_table("Tasks", db.tasks, required=True, exclude=("random",)) 12 | fields = field_dict(Tasks) 13 | assert "random" not in fields 14 | props = {} 15 | fields["title"].metadata[VALIDATOR].openapi(props) 16 | assert props["maxLength"] == 64 17 | assert props["minLength"] == 3 18 | 19 | 20 | def test_convert_random(db): 21 | Randoms = dataclass_from_table("Randoms", db.randoms) 22 | assert Randoms 23 | fields = field_dict(Randoms) 24 | assert isinstance(fields["id"].metadata[VALIDATOR], UUIDValidator) 25 | d = validate(Randoms, dict(info="jhgjg")) 26 | assert d.errors["info"] == "expected an object" 27 | d = validate(Randoms, dict(info=dict(a="3", b="test"))) 28 | assert not d.errors 29 | 30 | 31 | def test_validate(db): 32 | Tasks = dataclass_from_table( 33 | "Tasks", db.tasks, required=True, default=("created_by",), exclude=("id",) 34 | ) 35 | d = validate(Tasks, dict(title="test")) 36 | assert len(d.errors) == 1 37 | assert d.errors["subtitle"] == "required" 38 | Tasks = dataclass_from_table( 39 | "Tasks", db.tasks, required=True, default=True, exclude=("id",) 40 | ) 41 | d = validate(Tasks, dict(title="test")) 42 | assert not d.errors 43 | d = validate(Tasks, dict(title="te")) 44 | assert d.errors["title"] == "Too short" 45 | d = validate(Tasks, dict(title="t" * 100)) 46 | assert d.errors["title"] == "Too long" 47 | d = validate(Tasks, dict(title=40)) 48 | assert d.errors["title"] == "Must be a string" 49 | 50 | 51 | def test_date(db): 52 | Randoms = dataclass_from_table("Randoms", db.randoms) 53 | d = validate(Randoms, dict(randomdate="jhgjg")) 54 | assert d.errors["randomdate"] == "jhgjg not valid format" 55 | d = validate(Randoms, dict(randomdate=date.today())) 56 | assert not d.errors 57 | v = dump(Randoms, d.data) 58 | assert v["randomdate"] == date.today().isoformat() 59 | v = dump(Randoms, {"randomdate": datetime.now()}) 60 | assert v["randomdate"] == date.today().isoformat() 61 | v = dump(Randoms, {"randomdate": date.today().isoformat()}) 62 | assert v["randomdate"] == date.today().isoformat() 63 | 64 | 65 | def test_json_list(db): 66 | Randoms = dataclass_from_table("Randoms", db.randoms) 67 | fields = field_dict(Randoms) 68 | assert fields["jsonlist"].type is typing.List 69 | d = validate(Randoms, dict(jsonlist="jhgjg")) 70 | assert d.errors["jsonlist"] == "expected a sequence" 71 | d = validate(Randoms, dict(jsonlist=["bla", "foo"])) 72 | assert not d.errors 73 | 74 | 75 | def test_include(db): 76 | Randoms = dataclass_from_table("Randoms", db.randoms, include=("price",)) 77 | fields = field_dict(Randoms) 78 | assert len(fields) == 1 79 | 80 | 81 | def test_require(db): 82 | Randoms = dataclass_from_table("Randoms", db.randoms, required=False) 83 | fields = field_dict(Randoms) 84 | assert fields 85 | for field in fields.values(): 86 | assert field.metadata[REQUIRED] is False 87 | -------------------------------------------------------------------------------- /tests/core/test_errors.py: -------------------------------------------------------------------------------- 1 | from openapi.data.validate import OBJECT_EXPECTED 2 | from openapi.testing import json_body 3 | 4 | 5 | async def test_bad_data(cli): 6 | for bad in ([1], 3, "ciao"): 7 | response = await cli.post("/tasks", json=bad) 8 | data = await json_body(response, 422) 9 | assert data["message"] == OBJECT_EXPECTED 10 | -------------------------------------------------------------------------------- /tests/core/test_filters.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from multidict import MultiDict 3 | 4 | from openapi.spec import OpenApiSpec 5 | from openapi.testing import json_body 6 | from tests.utils import FakeRequest 7 | 8 | tests = [ 9 | {"title": "test1", "unique_title": "thefirsttest", "severity": 1}, 10 | {"title": "test2", "unique_title": "anothertest1", "severity": 3}, 11 | {"title": "test3"}, 12 | {"title": "test4", "unique_title": "anothertest4", "severity": 5}, 13 | ] 14 | 15 | 16 | @pytest.fixture 17 | async def fixtures(cli): 18 | results = [] 19 | for test in tests: 20 | rs = await cli.post("/tasks", json=test) 21 | test = await json_body(rs, 201) 22 | test.pop("id") 23 | results.append(test) 24 | return results 25 | 26 | 27 | async def assert_query(cli, params, expected): 28 | response = await cli.get("/tasks", params=params) 29 | data = await json_body(response) 30 | for d in data: 31 | d.pop("id") 32 | assert len(data) == len(expected) 33 | assert data == expected 34 | 35 | 36 | async def test_spec(test_app): 37 | spec = OpenApiSpec() 38 | doc = spec.build(FakeRequest.from_app(test_app)) 39 | query = doc["paths"]["/tasks"]["get"]["parameters"] 40 | filters = [q["name"] for q in query] 41 | assert set(filters) == { 42 | "title", 43 | "done", 44 | "type", 45 | "search", 46 | "severity", 47 | "severity:lt", 48 | "severity:le", 49 | "severity:gt", 50 | "severity:ge", 51 | "severity:ne", 52 | "story_points", 53 | "order_by", 54 | "limit", 55 | "offset", 56 | } 57 | 58 | 59 | async def test_filters(cli, fixtures): 60 | test1, test2, test3, test4 = fixtures 61 | await assert_query(cli, {"severity:gt": 1}, [test2, test4]) 62 | await assert_query(cli, {"severity:ge": 1}, [test1, test2, test4]) 63 | await assert_query(cli, {"severity:lt": 3}, [test1]) 64 | await assert_query(cli, {"severity:le": 2}, [test1]) 65 | await assert_query(cli, {"severity:le": 3}, [test1, test2]) 66 | await assert_query(cli, {"severity:ne": 3}, [test1, test4]) 67 | await assert_query(cli, {"severity": 2}, []) 68 | await assert_query(cli, {"severity": 1}, [test1]) 69 | await assert_query(cli, {"severity": "NULL"}, [test3]) 70 | 71 | 72 | async def test_multiple(cli, fixtures): 73 | test1, test2, test3, test4 = fixtures 74 | params = MultiDict((("severity", 1), ("severity", 3))) 75 | await assert_query(cli, params, [test1, test2]) 76 | params = MultiDict((("severity:ne", 1), ("severity:ne", 3))) 77 | await assert_query(cli, params, [test4]) 78 | 79 | 80 | async def test_search(cli, fixtures): 81 | test1, test2, test3, test4 = fixtures 82 | params = {"search": "test"} 83 | await assert_query(cli, params, [test1, test2, test3, test4]) 84 | 85 | 86 | async def test_search_match_one(cli, fixtures): 87 | test2 = fixtures[1] 88 | params = {"search": "est2"} 89 | await assert_query(cli, params, [test2]) 90 | 91 | 92 | async def test_search_match_one_with_title(cli, fixtures): 93 | test2 = fixtures[1] 94 | params = {"title": "test2", "search": "est2"} 95 | await assert_query(cli, params, [test2]) 96 | 97 | 98 | async def test_search_match_none_with_title(cli, fixtures): 99 | params = {"title": "test1", "search": "est2"} 100 | await assert_query(cli, params, []) 101 | 102 | 103 | async def test_search_either_end(cli, fixtures): 104 | params = {"search": "est"} 105 | await assert_query(cli, params, fixtures) 106 | 107 | 108 | async def test_multicolumn_search(cli, fixtures): 109 | test1, test2, test3, _ = fixtures 110 | params = {"search": "est1"} 111 | await assert_query(cli, params, [test1, test2]) 112 | -------------------------------------------------------------------------------- /tests/core/test_json.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from datetime import datetime 3 | from uuid import uuid4 4 | 5 | import pytest 6 | 7 | from openapi.json import encoder 8 | 9 | 10 | class Pippo(enum.Enum): 11 | bla = 1 12 | foo = 2 13 | 14 | 15 | def test_encoder_uuid(): 16 | uuid = uuid4() 17 | encoded = encoder(uuid) 18 | assert encoded == uuid.hex 19 | 20 | 21 | def test_encoder_enum(): 22 | assert encoder(Pippo.bla) == "bla" 23 | 24 | 25 | def test_encoder_datetime(): 26 | now = datetime.now() 27 | encoded = encoder(now) 28 | assert encoded == now.isoformat() 29 | 30 | 31 | def test_encoder_invalid_type(): 32 | with pytest.raises(TypeError): 33 | encoder("string") 34 | with pytest.raises(TypeError): 35 | encoder(123) 36 | with pytest.raises(TypeError): 37 | encoder([1, 2, 3]) 38 | -------------------------------------------------------------------------------- /tests/core/test_logger.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from click.testing import CliRunner 4 | 5 | from openapi.logger import getLogger 6 | from openapi.rest import rest 7 | 8 | 9 | def test_logger(): 10 | logger = getLogger() 11 | assert logger.name == "root" 12 | logger = getLogger("foo") 13 | assert logger.name == "foo" 14 | 15 | 16 | def test_serve(): 17 | runner = CliRunner() 18 | cli = rest(base_path="/v1") 19 | with patch("aiohttp.web.run_app") as mock: 20 | with patch("openapi.logger.logger.hasHandlers") as hasHandlers: 21 | hasHandlers.return_value = False 22 | with patch("openapi.logger.logger.addHandler") as addHandler: 23 | result = runner.invoke(cli, ["serve"]) 24 | assert result.exit_code == 0 25 | assert mock.call_count == 1 26 | assert addHandler.call_count == 1 27 | -------------------------------------------------------------------------------- /tests/core/test_paths.py: -------------------------------------------------------------------------------- 1 | from openapi.testing import json_body 2 | 3 | 4 | async def test_servers(cli): 5 | response = await cli.get("/") 6 | await json_body(response) 7 | -------------------------------------------------------------------------------- /tests/core/test_union.py: -------------------------------------------------------------------------------- 1 | from openapi.testing import json_body 2 | 3 | 4 | async def test_multicolumn_union(cli): 5 | row = {"x": 1, "y": 2} 6 | resp = await cli.post("/multikey", json=row) 7 | assert await json_body(resp, status=201) == row 8 | row2 = {"x": "ciao", "y": 2} 9 | resp = await cli.post("/multikey", json=row2) 10 | assert await json_body(resp, status=201) == row2 11 | -------------------------------------------------------------------------------- /tests/core/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import pytest 4 | 5 | from openapi import utils 6 | from openapi.db.container import Database 7 | from openapi.exc import ImproperlyConfigured, JsonHttpException 8 | from openapi.json import dumps 9 | from openapi.utils import TypingInfo 10 | 11 | TEST_ENVS = frozenset(("test", "ci")) 12 | 13 | 14 | def test_env(): 15 | assert utils.get_env() in TEST_ENVS 16 | 17 | 18 | def test_debug_flag(): 19 | assert utils.get_debug_flag() is False 20 | 21 | 22 | def test_json_http_exception(): 23 | ex = JsonHttpException(status=401) 24 | assert ex.status == 401 25 | assert ex.text == dumps({"message": "Unauthorized"}) 26 | assert ex.headers["content-type"] == "application/json; charset=utf-8" 27 | 28 | 29 | def test_json_http_exception_reason(): 30 | ex = JsonHttpException(status=422, reason="non lo so") 31 | assert ex.status == 422 32 | assert ex.text == dumps({"message": "non lo so"}) 33 | assert ex.headers["content-type"] == "application/json; charset=utf-8" 34 | 35 | 36 | def test_exist_database_not_configured(): 37 | db = Database() 38 | with pytest.raises(ImproperlyConfigured): 39 | db.engine 40 | 41 | 42 | def test_replace_key(): 43 | assert utils.replace_key({}, "foo", "bla") == {} 44 | assert utils.replace_key({"foo": 5}, "foo", "bla") == {"bla": 5} 45 | 46 | 47 | def test_typing_info() -> None: 48 | assert TypingInfo.get(int) == utils.TypingInfo(int) 49 | assert TypingInfo.get(float) == utils.TypingInfo(float) 50 | assert TypingInfo.get(List[int]) == utils.TypingInfo(int, list) 51 | assert TypingInfo.get(Dict[str, int]) == utils.TypingInfo(int, dict) 52 | assert TypingInfo.get(List[Dict[str, int]]) == utils.TypingInfo( 53 | utils.TypingInfo(int, dict), list 54 | ) 55 | assert TypingInfo.get(None) is None 56 | info = TypingInfo.get(List[int]) 57 | assert TypingInfo.get(info) is info 58 | 59 | 60 | def test_typing_info_dict_list() -> None: 61 | assert TypingInfo.get(Dict) == utils.TypingInfo(Any, dict) 62 | assert TypingInfo.get(List) == utils.TypingInfo(Any, list) 63 | 64 | 65 | def test_bad_typing_info() -> None: 66 | with pytest.raises(TypeError): 67 | TypingInfo.get(1) 68 | with pytest.raises(TypeError): 69 | TypingInfo.get(Dict[int, int]) 70 | with pytest.raises(TypeError): 71 | TypingInfo.get(Tuple[int, int]) 72 | -------------------------------------------------------------------------------- /tests/data/test_json_field.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields 2 | from typing import Dict, List 3 | 4 | import pytest 5 | 6 | from openapi.data.fields import VALIDATOR, JSONValidator, ValidationError, json_field 7 | from openapi.data.validate import ValidationErrors, validated_schema 8 | 9 | 10 | @dataclass 11 | class TJson: 12 | a: List = json_field() 13 | b: Dict = json_field() 14 | 15 | 16 | def test_validator(): 17 | dfields = fields(TJson) 18 | assert isinstance(dfields[0].metadata[VALIDATOR], JSONValidator) 19 | with pytest.raises(ValidationErrors): 20 | validated_schema(TJson, dict(a="{]}", b="{}")) 21 | 22 | 23 | def test_validattionb_fail_list(): 24 | with pytest.raises(ValidationErrors): 25 | validated_schema(TJson, dict(a="{}", b="{}")) 26 | s = validated_schema(TJson, dict(a="[]", b="{}")) 27 | assert s.a == [] 28 | assert s.b == {} 29 | 30 | 31 | def test_validattionb_fail_dict(): 32 | with pytest.raises(ValidationErrors): 33 | validated_schema(TJson, dict(a="[]", b="[]")) 34 | 35 | 36 | def test_json_field_error(): 37 | field = json_field() 38 | validator = field.metadata[VALIDATOR] 39 | with pytest.raises(ValidationError): 40 | validator(field, object()) 41 | -------------------------------------------------------------------------------- /tests/data/test_validate_nested.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | 3 | from openapi.data.validate import validated_schema 4 | from tests.example.models import SourcePrice 5 | 6 | 7 | def test_prices() -> None: 8 | d = validated_schema(SourcePrice, dict(id=5432534, prices=dict(foo=45.68564))) 9 | assert d.prices 10 | assert d.prices["foo"] == Decimal("45.6856") 11 | d = validated_schema( 12 | SourcePrice, 13 | dict( 14 | id=5432534, 15 | prices=dict(foo=45.68564), 16 | foos=[dict(text="test1", param=1), dict(text="test2", param="a")], 17 | ), 18 | ) 19 | assert len(d.foos) == 2 20 | assert d.foos[0].text == "test1" 21 | -------------------------------------------------------------------------------- /tests/data/test_validator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from typing import Dict, List, Union 3 | 4 | import pytest 5 | 6 | from openapi import json 7 | from openapi.data import fields 8 | from openapi.data.validate import ValidationErrors, validate, validated_schema 9 | from tests.example.models import Foo, Moon, Permission, Role, TaskAdd 10 | 11 | 12 | @dataclass 13 | class Foo2: 14 | text: str = fields.str_field(min_length=3, description="Just some text") 15 | param: Union[str, int] = fields.integer_field( 16 | description="String accepted but convert to int" 17 | ) 18 | done: bool = fields.bool_field(description="Is Foo done?") 19 | value: float = fields.number_field() 20 | 21 | 22 | def test_validated_schema(): 23 | data = dict(title="test", severity=1, unique_title="test") 24 | v = validated_schema(TaskAdd, data) 25 | assert len(data) == 3 26 | assert asdict(v) == dict( 27 | title="test", 28 | unique_title="test", 29 | random=None, 30 | type=None, 31 | severity=1, 32 | story_points=None, 33 | subtitle="", 34 | created_by="", 35 | ) 36 | 37 | 38 | def test_validated_schema_errors(): 39 | data = dict(severity=1) 40 | with pytest.raises(ValidationErrors) as e: 41 | validated_schema(TaskAdd, data) 42 | assert len(e.value.errors) == 1 43 | assert repr(e.value) 44 | assert repr(e.value) == json.dumps(e.value.errors, indent=4) 45 | 46 | 47 | def test_openapi_listvalidator(): 48 | validator = fields.ListValidator([fields.NumberValidator(-1, 1)]) 49 | props = {} 50 | validator.openapi(props) 51 | assert props["minimum"] == -1 52 | assert props["maximum"] == 1 53 | 54 | 55 | def test_permission(): 56 | data = dict(paths=["bla"], methods=["get"], body=dict(a="test")) 57 | d = validated_schema(Permission, data) 58 | assert d.action == "allow" 59 | assert d.paths == ["bla"] 60 | assert d.body == dict(a="test") 61 | 62 | 63 | def test_role(): 64 | data = dict( 65 | name="test", 66 | permissions=[dict(paths=["bla"], methods=["get"], body=dict(a="test"))], 67 | ) 68 | d = validated_schema(Role, data) 69 | assert isinstance(d.permissions[0], Permission) 70 | 71 | 72 | def test_post_process(): 73 | d = validate(Moon, {}) 74 | assert d.data == {} 75 | d = validate(Moon, {"names": "luca, max"}) 76 | assert d.data == {"names": ["luca", "max"]} 77 | 78 | 79 | def test_validate_list(): 80 | data = [dict(paths=["bla"], methods=["get"], body=dict(a="test"))] 81 | d = validate(List[Permission], data) 82 | assert not d.errors 83 | assert isinstance(d.data, list) 84 | 85 | 86 | def test_validate_union(): 87 | schema = Union[int, str] 88 | d = validate(schema, "3") 89 | assert d.data == "3" 90 | d = validate(schema, 3) 91 | assert d.data == 3 92 | d = validate(schema, 3.3) 93 | assert d.errors 94 | 95 | 96 | def test_validate_union_nested(): 97 | schema = Union[int, str, Dict[str, Union[int, str]]] 98 | d = validate(schema, "3") 99 | assert d.data == "3" 100 | d = validate(schema, 3) 101 | assert d.data == 3 102 | d = validate(schema, dict(foo=3, bla="ciao")) 103 | assert d.data == dict(foo=3, bla="ciao") 104 | 105 | 106 | def test_foo(): 107 | assert validate(Foo, {}).errors 108 | assert validate(Foo, dict(text="ciao")).errors 109 | assert validate(Foo, dict(text="ciao"), strict=False).data == dict(text="ciao") 110 | valid = dict(text="ciao", param=3) 111 | assert validate(Foo, valid).data == dict(text="ciao", param=3, done=False) 112 | d = validated_schema(List[Foo], [valid]) 113 | assert len(d) == 1 114 | assert isinstance(d[0], Foo) 115 | 116 | 117 | def test_foo2(): 118 | s = validated_schema(Foo2, dict(text="ciao", param="3", done="no")) 119 | assert s.done is False 120 | 121 | 122 | def test_float_validation(): 123 | s = validated_schema(Foo2, dict(value=24500)) 124 | assert s.value == 24500 125 | s = validated_schema(Foo2, dict(value=24500.5)) 126 | assert s.value == 24500.5 127 | -------------------------------------------------------------------------------- /tests/data/test_view.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from openapi.data.view import DataView, ValidationErrors 4 | from tests.example.models import Foo 5 | 6 | 7 | def test_error() -> None: 8 | dv = DataView() 9 | with pytest.raises(RuntimeError): 10 | dv.cleaned(Foo, {}, Error=RuntimeError) 11 | 12 | with pytest.raises(TypeError): 13 | dv.raise_bad_data() 14 | 15 | with pytest.raises(RuntimeError): 16 | dv.raise_bad_data(exc=RuntimeError) 17 | 18 | with pytest.raises(ValidationErrors): 19 | dv.raise_validation_error() 20 | -------------------------------------------------------------------------------- /tests/example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantmind/aio-openapi/afe56f7b36cadf32643569b8ffce63da29802801/tests/example/__init__.py -------------------------------------------------------------------------------- /tests/example/db/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from aiohttp.web import Application 4 | 5 | from openapi.db import CrudDB, get_db 6 | 7 | from .tables1 import meta 8 | from .tables2 import additional_meta 9 | 10 | DATASTORE = os.getenv( 11 | "DATASTORE", "postgresql+asyncpg://postgres:postgres@localhost:5432/openapi" 12 | ) 13 | 14 | 15 | def setup(app: Application) -> CrudDB: 16 | return setup_tables(get_db(app, DATASTORE)) 17 | 18 | 19 | def setup_tables(db: CrudDB) -> CrudDB: 20 | additional_meta(meta(db.metadata)) 21 | return db 22 | 23 | 24 | DB = setup_tables(CrudDB(DATASTORE)) 25 | -------------------------------------------------------------------------------- /tests/example/db/tables1.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | import sqlalchemy as sa 4 | from sqlalchemy_utils import UUIDType 5 | 6 | from openapi.data import fields 7 | from openapi.db.columns import UUIDColumn 8 | 9 | original_init = UUIDType.__init__ 10 | 11 | 12 | class TaskType(enum.Enum): 13 | todo = 0 14 | issue = 1 15 | 16 | 17 | def patch_init(self, binary=True, native=True, **kw): 18 | original_init(self, binary=binary, native=native) 19 | 20 | 21 | UUIDType.__init__ = patch_init 22 | 23 | 24 | def title_field(**kwargs): 25 | return fields.str_field(**kwargs) 26 | 27 | 28 | def meta(meta=None): 29 | """Add task related tables""" 30 | if meta is None: 31 | meta = sa.MetaData() 32 | 33 | sa.Table( 34 | "tasks", 35 | meta, 36 | UUIDColumn("id", make_default=True, doc="Unique ID"), 37 | sa.Column( 38 | "title", 39 | sa.String(64), 40 | nullable=False, 41 | info=dict(min_length=3, data_field=title_field), 42 | ), 43 | sa.Column("done", sa.DateTime(timezone=True)), 44 | sa.Column("severity", sa.Integer), 45 | sa.Column("created_by", sa.String, default="", nullable=False), 46 | sa.Column("type", sa.Enum(TaskType)), 47 | sa.Column("unique_title", sa.String, unique=True), 48 | sa.Column("story_points", sa.Numeric), 49 | sa.Column("random", sa.String(64)), 50 | sa.Column( 51 | "subtitle", 52 | sa.String(64), 53 | nullable=False, 54 | default="", 55 | ), 56 | ) 57 | 58 | sa.Table( 59 | "series", 60 | meta, 61 | sa.Column("date", sa.DateTime(timezone=True), nullable=False, index=True), 62 | sa.Column("group", sa.String(32), nullable=False, index=True, default=""), 63 | sa.Column("value", sa.Numeric(precision=20, scale=8)), 64 | sa.UniqueConstraint("date", "group"), 65 | ) 66 | 67 | return meta 68 | -------------------------------------------------------------------------------- /tests/example/db/tables2.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import date 3 | 4 | import sqlalchemy as sa 5 | from sqlalchemy_utils import UUIDType 6 | 7 | from openapi.db.columns import UUIDColumn 8 | from openapi.tz import utcnow 9 | 10 | 11 | def additional_meta(meta=None): 12 | """Add task related tables""" 13 | if meta is None: 14 | meta = sa.MetaData() 15 | 16 | sa.Table( 17 | "randoms", 18 | meta, 19 | sa.Column( 20 | "id", UUIDType(), primary_key=True, nullable=False, default=uuid.uuid4 21 | ), 22 | sa.Column("randomdate", sa.Date, nullable=False, default=date.today), 23 | sa.Column( 24 | "timestamp", sa.DateTime(timezone=True), nullable=False, default=utcnow 25 | ), 26 | sa.Column("price", sa.Numeric(precision=100, scale=4), nullable=False), 27 | sa.Column("tenor", sa.String(3), nullable=False), 28 | sa.Column("tick", sa.Boolean), 29 | sa.Column("info", sa.JSON), 30 | sa.Column("jsonlist", sa.JSON, default=[]), 31 | sa.Column( 32 | "task_id", sa.ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False 33 | ), 34 | ) 35 | 36 | sa.Table( 37 | "multi_key_unique", 38 | meta, 39 | sa.Column("x", sa.Integer, nullable=False), 40 | sa.Column("y", sa.Integer, nullable=False), 41 | sa.UniqueConstraint("x", "y"), 42 | ) 43 | 44 | sa.Table( 45 | "multi_key", 46 | meta, 47 | sa.Column("x", sa.JSON), 48 | sa.Column("y", sa.JSON), 49 | ) 50 | 51 | return meta 52 | 53 | 54 | def extra(meta): 55 | sa.Table( 56 | "extras", 57 | meta, 58 | UUIDColumn("id", make_default=True, doc="Unique ID"), 59 | sa.Column("name", sa.String(64), nullable=False), 60 | ) 61 | -------------------------------------------------------------------------------- /tests/example/endpoints.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from aiohttp import web 4 | from sqlalchemy.sql.expression import null 5 | 6 | from openapi.db.path import SqlApiPath 7 | from openapi.spec import op 8 | 9 | from .models import ( 10 | Task, 11 | TaskAdd, 12 | TaskOrderableQuery, 13 | TaskPathSchema, 14 | TaskQuery, 15 | TaskUpdate, 16 | ) 17 | 18 | routes = web.RouteTableDef() 19 | 20 | 21 | @routes.view("/tasks") 22 | class TasksPath(SqlApiPath): 23 | """ 24 | --- 25 | summary: Create and query Tasks 26 | tags: 27 | - Task 28 | """ 29 | 30 | table = "tasks" 31 | 32 | def filter_done(self, op, value): 33 | done = self.db_table.c.done 34 | return done != null() if value else done == null() 35 | 36 | @op(query_schema=TaskOrderableQuery, response_schema=List[Task]) 37 | async def get(self): 38 | """ 39 | --- 40 | summary: Retrieve Tasks 41 | description: Retrieve a list of Tasks 42 | responses: 43 | 200: 44 | description: Authenticated tasks 45 | """ 46 | paginated = await self.get_list() 47 | return paginated.json_response() 48 | 49 | @op(response_schema=Task, body_schema=TaskAdd) 50 | async def post(self): 51 | """ 52 | --- 53 | summary: Create a Task 54 | description: Create a new Task 55 | responses: 56 | 201: 57 | description: the task was successfully added 58 | 422: 59 | description: Failed validation 60 | """ 61 | data = await self.create_one() 62 | return self.json_response(data, status=201) 63 | 64 | @op(query_schema=TaskQuery) 65 | async def delete(self): 66 | """ 67 | --- 68 | summary: Delete Tasks 69 | description: Delete a group of Tasks 70 | responses: 71 | 204: 72 | description: Tasks successfully deleted 73 | """ 74 | await self.delete_list(query=dict(self.request.query)) 75 | return web.Response(status=204) 76 | 77 | 78 | @routes.view("/tasks/{id}") 79 | class TaskPath(SqlApiPath): 80 | """ 81 | --- 82 | summary: Create and query tasks 83 | tags: 84 | - name: Task 85 | description: Simple description 86 | - name: Random 87 | description: Random description 88 | """ 89 | 90 | table = "tasks" 91 | path_schema = TaskPathSchema 92 | 93 | @op(response_schema=Task) 94 | async def get(self): 95 | """ 96 | --- 97 | summary: Retrieve a Task 98 | description: Retrieve a Task by ID 99 | responses: 100 | 200: 101 | description: the task 102 | """ 103 | data = await self.get_one() 104 | return self.json_response(data) 105 | 106 | @op(response_schema=Task, body_schema=TaskUpdate) 107 | async def patch(self): 108 | """ 109 | --- 110 | summary: Update a Task 111 | description: Update an existing Task by ID 112 | responses: 113 | 200: 114 | description: the updated task 115 | """ 116 | data = await self.update_one() 117 | return self.json_response(data) 118 | 119 | @op() 120 | async def delete(self): 121 | """ 122 | --- 123 | summary: Delete a Task 124 | description: Delete an existing task 125 | responses: 126 | 204: 127 | description: Task successfully deleted 128 | """ 129 | await self.delete_one() 130 | return web.Response(status=204) 131 | -------------------------------------------------------------------------------- /tests/example/endpoints_base.py: -------------------------------------------------------------------------------- 1 | from aiohttp import web 2 | 3 | from openapi.spec.server import server_urls 4 | 5 | base_routes = web.RouteTableDef() 6 | 7 | 8 | @base_routes.get("/") 9 | async def urls(request) -> web.Response: 10 | paths = set() 11 | for route in request.app.router.routes(): 12 | route_info = route.get_info() 13 | path = route_info.get("path", route_info.get("formatter", None)) 14 | paths.add(path) 15 | return web.json_response(server_urls(request, sorted(paths))) 16 | 17 | 18 | @base_routes.get("/status") 19 | async def status(request) -> web.Response: 20 | return web.json_response({}) 21 | 22 | 23 | @base_routes.get("/error") 24 | async def error(request) -> web.Response: 25 | 1 / 0 # noqa 26 | -------------------------------------------------------------------------------- /tests/example/endpoints_form.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from aiohttp import web 4 | 5 | from openapi.spec import op 6 | from openapi.spec.path import ApiPath 7 | 8 | from .models import BundleUpload 9 | 10 | form_routes = web.RouteTableDef() 11 | 12 | 13 | @form_routes.view("/upload") 14 | class UploadPath(ApiPath): 15 | """ 16 | --- 17 | summary: Bulk manage tasks 18 | tags: 19 | - Task 20 | """ 21 | 22 | table = "tasks" 23 | 24 | @op(body_schema=BundleUpload, response_schema=Dict) 25 | async def post(self): 26 | """ 27 | --- 28 | summary: Upload a bundle 29 | description: Upload a bundle 30 | body_content: multipart/form-data 31 | responses: 32 | 201: 33 | description: Created tasks 34 | """ 35 | return self.json_response(dict(ok=True), status=201) 36 | -------------------------------------------------------------------------------- /tests/example/endpoints_pagination.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from aiohttp import web 5 | 6 | from openapi.data.db import dataclass_from_table 7 | from openapi.db.path import SqlApiPath 8 | from openapi.pagination import cursorPagination, offsetPagination 9 | from openapi.spec import op 10 | 11 | from .db import DB 12 | 13 | series_routes = web.RouteTableDef() 14 | 15 | 16 | Serie = dataclass_from_table( 17 | "Serie", DB.series, required=True, default=True, exclude=("group",) 18 | ) 19 | 20 | 21 | BaseQuery = dataclass_from_table( 22 | "BaseQuery", 23 | DB.series, 24 | default=True, 25 | include=("group", "date"), 26 | ops=dict(date=("le", "ge", "gt", "lt")), 27 | ) 28 | 29 | 30 | @dataclass 31 | class SeriesQueryCursor( 32 | BaseQuery, 33 | cursorPagination("-date"), 34 | ): 35 | """Series query with cursor pagination""" 36 | 37 | 38 | @dataclass 39 | class SeriesQueryOffset( 40 | BaseQuery, 41 | offsetPagination("-date", "date"), 42 | ): 43 | """Series query with offset pagination""" 44 | 45 | 46 | @series_routes.view("/series_cursor") 47 | class SeriesPath(SqlApiPath): 48 | """ 49 | --- 50 | summary: Get Series 51 | tags: 52 | - Series 53 | """ 54 | 55 | table = "series" 56 | 57 | @op(query_schema=SeriesQueryCursor, response_schema=List[Serie]) 58 | async def get(self): 59 | """ 60 | --- 61 | summary: Retrieve Series 62 | description: Retrieve a TimeSeries 63 | responses: 64 | 200: 65 | description: timeseries 66 | """ 67 | paginated = await self.get_list() 68 | return paginated.json_response() 69 | 70 | 71 | @series_routes.view("/series_offset") 72 | class SeriesOffsetPath(SqlApiPath): 73 | """ 74 | --- 75 | summary: Get Series 76 | tags: 77 | - Series 78 | """ 79 | 80 | table = "series" 81 | 82 | @op(query_schema=SeriesQueryOffset, response_schema=List[Serie]) 83 | async def get(self): 84 | """ 85 | --- 86 | summary: Retrieve Series 87 | description: Retrieve a TimeSeries 88 | responses: 89 | 200: 90 | description: timeseries 91 | """ 92 | paginated = await self.get_list() 93 | return paginated.json_response() 94 | -------------------------------------------------------------------------------- /tests/example/main.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from aiohttp import web 4 | 5 | from openapi.db.commands import db as db_command 6 | from openapi.middleware import json_error, sentry_middleware 7 | from openapi.rest import rest 8 | from openapi.spec import Redoc 9 | 10 | from . import db 11 | from .endpoints import routes 12 | from .endpoints_additional import additional_routes 13 | from .endpoints_base import base_routes 14 | from .endpoints_form import form_routes 15 | from .endpoints_pagination import series_routes 16 | from .ws import LocalBroker, ws_routes 17 | 18 | 19 | def create_app(): 20 | return rest( 21 | openapi=dict(title="Test API"), 22 | security=dict( 23 | auth_key={ 24 | "type": "apiKey", 25 | "name": "X-Meta-Api-Key", 26 | "description": ( 27 | "The authentication key is required to access most " 28 | "endpoints of the API" 29 | ), 30 | "in": "header", 31 | } 32 | ), 33 | setup_app=setup_app, 34 | commands=[db_command], 35 | redoc=Redoc(), 36 | ) 37 | 38 | 39 | def setup_app(app: web.Application) -> None: 40 | db.setup(app) 41 | app.middlewares.append(json_error()) 42 | sentry_middleware(app, f"https://{uuid.uuid4().hex}@sentry.io/1234567", "test") 43 | app.router.add_routes(base_routes) 44 | app.router.add_routes(routes) 45 | app.router.add_routes(series_routes) 46 | # 47 | # Additional routes for testing 48 | app.router.add_routes(additional_routes) 49 | app.router.add_routes(form_routes) 50 | app["web_sockets"] = LocalBroker.for_app(app) 51 | app.router.add_routes(ws_routes) 52 | 53 | 54 | if __name__ == "__main__": 55 | create_app().main() 56 | -------------------------------------------------------------------------------- /tests/example/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import datetime 3 | from decimal import Decimal 4 | from typing import Dict, List, Union 5 | 6 | from openapi.data import fields 7 | from openapi.data.db import dataclass_from_table 8 | from openapi.pagination import offsetPagination, searchable 9 | 10 | from .db import DB 11 | from .db.tables1 import TaskType 12 | 13 | 14 | @dataclass 15 | class TaskAdd( 16 | dataclass_from_table( 17 | "_TaskAdd", DB.tasks, required=True, default=True, exclude=("id", "done") 18 | ) 19 | ): 20 | @classmethod 21 | def validate(cls, data, errors): 22 | """here just for coverage""" 23 | 24 | 25 | Task = dataclass_from_table("Task", DB.tasks) 26 | 27 | 28 | @dataclass 29 | class TaskQuery(offsetPagination("title", "-title", "severity", "-severity")): 30 | title: str = fields.str_field(description="Task title") 31 | done: bool = fields.bool_field(description="done flag") 32 | type: TaskType = fields.enum_field(TaskType, description="Task type") 33 | severity: int = fields.integer_field( 34 | ops=("lt", "le", "gt", "ge", "ne"), description="Task severity" 35 | ) 36 | story_points: Decimal = fields.decimal_field(description="Story points") 37 | 38 | 39 | @dataclass 40 | class TaskOrderableQuery( 41 | TaskQuery, 42 | searchable("title", "unique_title"), 43 | ): 44 | pass 45 | 46 | 47 | @dataclass 48 | class TaskUpdate(TaskAdd): 49 | done: datetime = fields.date_time_field(description="Done timestamp") 50 | 51 | 52 | @dataclass 53 | class TaskPathSchema: 54 | id: str = fields.uuid_field(required=True, description="Task ID") 55 | 56 | 57 | # Additional models for testing 58 | 59 | 60 | @dataclass 61 | class TaskPathSchema2: 62 | task_id: str = fields.uuid_field(required=True, description="Task ID") 63 | 64 | 65 | MultiKeyUnique = dataclass_from_table("MultiKeyUnique", DB.multi_key_unique) 66 | 67 | 68 | @dataclass 69 | class MultiKey: 70 | x: Union[int, str, datetime] = fields.json_field(required=True, description="x") 71 | y: Union[int, str, datetime] = fields.json_field(required=True, description="y") 72 | 73 | 74 | @dataclass 75 | class Permission: 76 | paths: List[str] = fields.data_field(description="Permition paths") 77 | methods: List[str] = fields.data_field(description="Permition methods") 78 | body: Dict[str, str] = fields.json_field(description="Permission body") 79 | action: str = fields.str_field(default="allow", description="Permition action") 80 | 81 | 82 | @dataclass 83 | class Role: 84 | name: str = fields.str_field(required=True, description="Role name") 85 | permissions: List[Permission] = fields.data_field( 86 | required=True, description="List of permissions" 87 | ) 88 | 89 | 90 | @dataclass 91 | class Moon: 92 | names: str = fields.str_field( 93 | description="Comma separated list of names", 94 | post_process=lambda values: [v.strip() for v in values.split(",")], 95 | ) 96 | 97 | 98 | @dataclass 99 | class Foo: 100 | text: str 101 | param: Union[str, int] 102 | done: bool = False 103 | 104 | 105 | @dataclass 106 | class SourcePrice: 107 | """An object containing prices for a single contract""" 108 | 109 | id: int = fields.integer_field(description="ID", required=True) 110 | extra: Dict = fields.data_field(description="JSON blob") 111 | prices: Dict[str, Decimal] = fields.data_field( 112 | description="source-price mapping", 113 | items=fields.decimal_field( 114 | min_value=0, 115 | max_value=100, 116 | precision=4, 117 | description="price", 118 | ), 119 | default_factory=dict, 120 | ) 121 | foos: List[Foo] = fields.data_field(default_factory=list) 122 | 123 | 124 | @dataclass 125 | class BundleUpload: 126 | files: List[bytes] = fields.data_field(description="list of bundles to upload") 127 | -------------------------------------------------------------------------------- /tests/example/ws.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any 3 | 4 | from aiohttp import web 5 | 6 | from openapi import ws 7 | from openapi.spec.path import ApiPath 8 | from openapi.ws import CannotPublish, CannotSubscribe, pubsub 9 | from openapi.ws.manager import SocketsManager 10 | 11 | ws_routes = web.RouteTableDef() 12 | 13 | 14 | @ws_routes.view("/stream") 15 | class StreamPath(ws.WsPathMixin, pubsub.Publish, pubsub.Subscribe, ApiPath): 16 | """ 17 | --- 18 | summary: Create and query Tasks 19 | tags: 20 | - Task 21 | """ 22 | 23 | async def ws_rpc_echo(self, payload): 24 | """Echo parameters""" 25 | return payload 26 | 27 | async def ws_rpc_server_info(self, payload): 28 | """Websocket server information""" 29 | return self.sockets.server_info() 30 | 31 | async def ws_rpc_cancel(self, payload): 32 | """Echo parameters""" 33 | raise asyncio.CancelledError 34 | 35 | async def ws_rpc_badjson(self, payload): 36 | """Echo parameters""" 37 | return ApiPath 38 | 39 | 40 | class LocalBroker(SocketsManager): 41 | """A local broker for testing""" 42 | 43 | def __init__(self): 44 | self.binds = set() 45 | self.messages: asyncio.Queue = asyncio.Queue() 46 | self.worker = None 47 | self._stop = False 48 | 49 | @classmethod 50 | def for_app(cls, app: web.Application) -> "LocalBroker": 51 | broker = cls() 52 | app.on_startup.append(broker.start) 53 | app.on_shutdown.append(broker.close) 54 | return broker 55 | 56 | async def start(self, *arg): 57 | if not self.worker: 58 | self.worker = asyncio.ensure_future(self._work()) 59 | 60 | async def publish(self, channel: str, event: str, body: Any): 61 | """simulate network latency""" 62 | if channel.lower() != channel: 63 | raise CannotPublish 64 | payload = dict(event=event, data=self.get_data(body)) 65 | asyncio.get_event_loop().call_later( 66 | 0.01, self.messages.put_nowait, (channel, payload) 67 | ) 68 | 69 | async def subscribe(self, channel: str) -> None: 70 | if channel.lower() != channel: 71 | raise CannotSubscribe 72 | 73 | async def close(self, *arg): 74 | self._stop = True 75 | await self.close_sockets() 76 | if self.worker: 77 | self.messages.put_nowait((None, None)) 78 | await self.worker 79 | self.worker = None 80 | 81 | async def _work(self): 82 | while True: 83 | channel, body = await self.messages.get() 84 | if self._stop: 85 | break 86 | await self.channels(channel, body) 87 | 88 | def get_data(self, data: Any) -> Any: 89 | if data == "error": 90 | return self.raise_error 91 | elif data == "runtime_error": 92 | return self.raise_runtime 93 | return data 94 | 95 | def raise_error(self): 96 | raise ValueError 97 | 98 | def raise_runtime(self): 99 | raise RuntimeError 100 | -------------------------------------------------------------------------------- /tests/pagination/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantmind/aio-openapi/afe56f7b36cadf32643569b8ffce63da29802801/tests/pagination/__init__.py -------------------------------------------------------------------------------- /tests/pagination/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from openapi.db.dbmodel import CrudDB 4 | 5 | from .utils import SerieFactory 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | async def series(cli2): 10 | db: CrudDB = cli2.app["db"] 11 | series = SerieFactory.create_batch(200) 12 | await db.db_insert(db.series, series) 13 | return series 14 | -------------------------------------------------------------------------------- /tests/pagination/test_base_classes.py: -------------------------------------------------------------------------------- 1 | from openapi.pagination import Pagination, Search, create_dataclass 2 | 3 | 4 | def test_pagination(): 5 | d = create_dataclass(None, {}, Pagination) 6 | assert isinstance(d, Pagination) 7 | assert d.apply(None) is None 8 | assert d.links(None, [], None) == {} 9 | 10 | 11 | def test_search(): 12 | d = create_dataclass(None, {}, Search) 13 | assert isinstance(d, Search) 14 | assert d.apply(None) is None 15 | -------------------------------------------------------------------------------- /tests/pagination/test_cursor_pagination.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from openapi.pagination import cursorPagination 4 | from openapi.pagination.cursor import encode_cursor 5 | from openapi.testing import json_body 6 | 7 | from .utils import direction_asc, direction_desc 8 | 9 | 10 | async def test_direction_asc(cli2, series): 11 | assert await direction_asc(cli2, series, "/series_cursor") 12 | 13 | 14 | async def test_direction_desc(cli2, series): 15 | assert await direction_desc(cli2, series, "/series_cursor", direction="desc") 16 | 17 | 18 | async def test_bad_cursor(cli2): 19 | response = await cli2.get("/series_cursor", params={"_cursor": "wtf"}) 20 | assert await json_body(response, 422) == dict(message="invalid cursor") 21 | response = await cli2.get( 22 | "/series_cursor", params={"_cursor": encode_cursor([3, 4])} 23 | ) 24 | assert await json_body(response, 422) == dict(message="invalid cursor") 25 | response = await cli2.get("/series_cursor", params={"_cursor": encode_cursor([3])}) 26 | assert await json_body(response, 422) == dict(message="invalid cursor") 27 | 28 | 29 | def test_cursor_pagination_error(): 30 | with pytest.raises(ValueError): 31 | cursorPagination() 32 | -------------------------------------------------------------------------------- /tests/pagination/test_offset_pagination.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import pytest 4 | from yarl import URL 5 | 6 | from openapi.pagination import offsetPagination 7 | from openapi.testing import json_body 8 | 9 | from .utils import direction_asc, direction_desc 10 | 11 | OffsetPagination = offsetPagination("id") 12 | 13 | 14 | def pag_links(total: int, limit: int, offset: int) -> Dict[str, URL]: 15 | return OffsetPagination(limit=limit, offset=offset).links( 16 | URL("http://test.com/path?a=2&b=3"), [], total 17 | ) 18 | 19 | 20 | def test_last_link(): 21 | # 22 | links = pag_links(0, 25, 0) 23 | assert links == {} 24 | # 25 | links = pag_links(120, 25, 0) 26 | assert len(links) == 2 27 | assert links["next"] 28 | assert links["last"] 29 | assert links["next"].query["offset"] == "25" 30 | assert links["next"].query["limit"] == "25" 31 | assert links["last"].query["offset"] == "100" 32 | assert links["last"].query["limit"] == "25" 33 | # 34 | links = pag_links(120, 25, 75) 35 | assert len(links) == 4 36 | assert links["first"].query["offset"] == "0" 37 | assert links["prev"].query["offset"] == "50" 38 | assert links["last"].query["offset"] == "100" 39 | # 40 | links = pag_links(120, 25, 50) 41 | assert len(links) == 4 42 | assert links["first"].query["offset"] == "0" 43 | assert links["prev"].query["offset"] == "25" 44 | assert links["next"].query["offset"] == "75" 45 | assert links["last"].query["offset"] == "100" 46 | 47 | 48 | async def test_pagination_next_link(cli): 49 | response = await cli.post("/tasks", json=dict(title="bla")) 50 | await json_body(response, 201) 51 | response = await cli.post("/tasks", json=dict(title="foo")) 52 | await json_body(response, 201) 53 | response = await cli.get("/tasks") 54 | data = await json_body(response) 55 | assert "Link" not in response.headers 56 | assert len(data) == 2 57 | 58 | 59 | async def test_pagination_first_link(cli): 60 | response = await cli.post("/tasks", json=dict(title="bla")) 61 | await json_body(response, 201) 62 | response = await cli.post("/tasks", json=dict(title="foo")) 63 | await json_body(response, 201) 64 | response = await cli.get("/tasks", params={"limit": 10, "offset": 20}) 65 | url = response.url 66 | data = await json_body(response) 67 | link = response.headers["Link"] 68 | assert link == ( 69 | f'<{url.parent}{url.path}?limit=10&offset=0>; rel="first", ' 70 | f'<{url.parent}{url.path}?limit=10&offset=10>; rel="prev"' 71 | ) 72 | assert "Link" in response.headers 73 | assert len(data) == 0 74 | 75 | 76 | async def test_invalid_limit_offset(cli): 77 | response = await cli.get("/tasks", params={"limit": "wtf"}) 78 | await json_body(response, 422) 79 | response = await cli.get("/tasks", params={"limit": 0}) 80 | await json_body(response, 422) 81 | response = await cli.get("/tasks", params={"offset": "wtf"}) 82 | await json_body(response, 422) 83 | response = await cli.get("/tasks", params={"offset": -10}) 84 | await json_body(response, 422) 85 | 86 | 87 | async def test_pagination_with_forwarded_host(cli): 88 | response = await cli.post("/tasks", json=dict(title="bla")) 89 | await json_body(response, 201) 90 | response = await cli.post("/tasks", json=dict(title="foo")) 91 | await json_body(response, 201) 92 | response = await cli.get( 93 | "/tasks", 94 | headers={ 95 | "X-Forwarded-Proto": "https", 96 | "X-Forwarded-Host": "whenbeer.pub", 97 | "X-Forwarded-Port": "1234", 98 | }, 99 | params={"limit": 10, "offset": 20}, 100 | ) 101 | data = await json_body(response) 102 | assert len(data) == 0 103 | link = response.headers["Link"] 104 | assert link == ( 105 | '; rel="first", ' 106 | '; rel="prev"' 107 | ) 108 | assert response.headers["X-total-count"] == "2" 109 | # 110 | response = await cli.get( 111 | "/tasks", 112 | headers={ 113 | "X-Forwarded-Proto": "https", 114 | "X-Forwarded-Host": "whenbeer.pub", 115 | "X-Forwarded-Port": "443", 116 | }, 117 | params={"limit": 10, "offset": 20}, 118 | ) 119 | data = await json_body(response) 120 | assert len(data) == 0 121 | link = response.headers["Link"] 122 | assert link == ( 123 | '; rel="first", ' 124 | '; rel="prev"' 125 | ) 126 | assert response.headers["X-total-count"] == "2" 127 | 128 | 129 | async def test_direction_asc(cli2, series): 130 | assert await direction_asc(cli2, series, "/series_offset") 131 | 132 | 133 | async def test_direction_desc(cli2, series): 134 | assert await direction_desc(cli2, series, "/series_offset", order_by="date") 135 | 136 | 137 | def test_cursor_pagination_error(): 138 | with pytest.raises(ValueError): 139 | offsetPagination() 140 | -------------------------------------------------------------------------------- /tests/pagination/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import async_timeout 4 | import yarl 5 | from factory import Factory, Faker, fuzzy 6 | from yarl import URL 7 | 8 | from openapi.testing import json_body 9 | 10 | GROUPS = ("group1", "group2", "group3") 11 | 12 | 13 | class SerieFactory(Factory): 14 | class Meta: 15 | model = dict 16 | 17 | value = Faker("pydecimal", positive=True, left_digits=6, right_digits=5) 18 | date = Faker("date_time_between", start_date="-2y") 19 | group = fuzzy.FuzzyChoice(GROUPS) 20 | 21 | 22 | async def direction_asc(cli, series: list, path: str, limit: int = 15): 23 | response = await cli.get(path) 24 | data = await json_body(response) 25 | assert len(data) == 0 26 | all_groups = defaultdict(list) 27 | total = 0 28 | for group in GROUPS: 29 | values = all_groups[group] 30 | async for data in traverse_pagination( 31 | cli, path, dict(limit=limit, group=group) 32 | ): 33 | values.extend(data) 34 | total += len(values) 35 | for d1, d2 in zip(values[:-1], values[1:]): 36 | assert d1["date"] > d2["date"] 37 | 38 | assert total == len(series) 39 | return total 40 | 41 | 42 | async def direction_desc(cli, series, path: str, limit: int = 10, **kwargs): 43 | all_groups = defaultdict(list) 44 | total = 0 45 | for group in GROUPS: 46 | values = all_groups[group] 47 | async for data in traverse_pagination( 48 | cli, path, dict(limit=limit, group=group, **kwargs) 49 | ): 50 | values.extend(data) 51 | total += len(values) 52 | for d1, d2 in zip(values[:-1], values[1:]): 53 | assert d1["date"] < d2["date"] 54 | 55 | assert total == len(series) 56 | return total 57 | 58 | 59 | async def traverse_pagination( 60 | cli, 61 | path: str, 62 | params, 63 | *, 64 | timeout: float = 10000, 65 | test_prev: bool = True, 66 | ): 67 | url = yarl.URL("https://fake.com").with_path(path).with_query(params) 68 | async with async_timeout.timeout(timeout): 69 | batch = [] 70 | while True: 71 | response = await cli.get(url.path_qs) 72 | data = await json_body(response) 73 | yield data 74 | next = response.links.get("next") 75 | if not next: 76 | break 77 | url: URL = next["url"] 78 | batch.append((data, url)) 79 | if test_prev: 80 | prev = response.links.get("prev") 81 | while prev: 82 | url: URL = prev["url"] 83 | response = await cli.get(url.path_qs) 84 | data = await json_body(response) 85 | pdata, next_url = batch.pop() 86 | prev = response.links.get("prev") 87 | assert data == pdata 88 | assert next_url == response.links.get("next")["url"] 89 | assert batch == [] 90 | -------------------------------------------------------------------------------- /tests/spec/test_docstrings.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from tests.example.models import TaskOrderableQuery 4 | 5 | 6 | def test_search_docstring(): 7 | fields = {f.name: f for f in dataclasses.fields(TaskOrderableQuery)} 8 | assert fields["search"].metadata["description"] == ( 9 | "Search query string. " 10 | "The search is performed on ``title``, ``unique_title`` fields." 11 | ) 12 | -------------------------------------------------------------------------------- /tests/spec/test_spec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from openapi.exc import InvalidSpecException 4 | from openapi.rest import rest 5 | from openapi.spec import OpenApi, OpenApiSpec 6 | from openapi.testing import json_body 7 | from tests.example import endpoints, endpoints_additional 8 | from tests.utils import FakeRequest 9 | 10 | 11 | def create_spec_request(routes) -> FakeRequest: 12 | def setup_app(app): 13 | app.router.add_routes(routes) 14 | 15 | cli = rest(setup_app=setup_app) 16 | app = cli.web() 17 | return FakeRequest.from_app(app) 18 | 19 | 20 | def test_init(): 21 | u = OpenApi() 22 | assert u.version == "0.1.0" 23 | 24 | 25 | async def test_spec_validation(test_app): 26 | spec = OpenApiSpec() 27 | spec.build(FakeRequest.from_app(test_app)) 28 | 29 | 30 | async def test_spec_422(test_app): 31 | spec = OpenApiSpec() 32 | doc = spec.build(FakeRequest.from_app(test_app)) 33 | tasks = doc["paths"]["/tasks"] 34 | resp = tasks["post"]["responses"] 35 | assert ( 36 | resp[422]["content"]["application/json"]["schema"]["$ref"] 37 | == "#/components/schemas/ValidationErrors" 38 | ) 39 | 40 | 41 | async def test_invalid_path(): 42 | request = create_spec_request(endpoints_additional.invalid_path_routes) 43 | spec = OpenApiSpec(validate_docs=True) 44 | 45 | with pytest.raises(InvalidSpecException): 46 | spec.build(request) 47 | 48 | 49 | async def test_invalid_method_missing_summary(): 50 | request = create_spec_request(endpoints_additional.invalid_method_summary_routes) 51 | spec = OpenApiSpec(validate_docs=True) 52 | 53 | with pytest.raises(InvalidSpecException): 54 | spec.build(request) 55 | 56 | 57 | async def test_invalid_method_missing_description(): 58 | request = create_spec_request( 59 | endpoints_additional.invalid_method_description_routes 60 | ) 61 | spec = OpenApiSpec(validate_docs=True) 62 | 63 | with pytest.raises(InvalidSpecException): 64 | spec.build(request) 65 | 66 | 67 | async def test_allowed_tags_ok(): 68 | request = create_spec_request(endpoints.routes) 69 | spec = OpenApiSpec(allowed_tags=set(("Task", "Transaction", "Random"))) 70 | spec.build(request) 71 | 72 | 73 | async def test_allowed_tags_invalid(): 74 | request = create_spec_request(endpoints.routes) 75 | spec = OpenApiSpec(validate_docs=True, allowed_tags=set(("Task", "Transaction"))) 76 | with pytest.raises(InvalidSpecException): 77 | spec.build(request) 78 | 79 | 80 | async def test_tags_missing_description(): 81 | request = create_spec_request( 82 | endpoints_additional.invalid_tag_missing_description_routes 83 | ) 84 | spec = OpenApiSpec( 85 | validate_docs=True, allowed_tags=set(("Task", "Transaction", "Random")) 86 | ) 87 | with pytest.raises(InvalidSpecException): 88 | spec.build(request) 89 | 90 | 91 | async def test_spec_root(cli): 92 | response = await cli.get("/spec") 93 | spec = await json_body(response) 94 | assert "paths" in spec 95 | assert "tags" in spec 96 | assert len(spec["tags"]) == 6 97 | assert spec["tags"][4]["name"] == "Task" 98 | assert spec["tags"][4]["description"] == "Simple description" 99 | 100 | 101 | async def test_spec_bytes(cli): 102 | response = await cli.get("/spec") 103 | spec = await json_body(response) 104 | upload = spec["paths"]["/upload"]["post"] 105 | assert list(upload["requestBody"]["content"]) == ["multipart/form-data"] 106 | 107 | 108 | async def test_redoc(cli): 109 | response = await cli.get("/docs") 110 | docs = await response.text() 111 | assert response.status == 200 112 | assert docs 113 | -------------------------------------------------------------------------------- /tests/spec/test_spec_utils.py: -------------------------------------------------------------------------------- 1 | from openapi.spec.utils import dedent, load_yaml_from_docstring, trim_docstring 2 | from openapi.utils import compact, compact_dict 3 | 4 | 5 | def test_compact(): 6 | data = {"key1": "A", "key2": 1, "key3": None, "key4": False, "key5": True} 7 | expected = {"key1": "A", "key2": 1, "key5": True} 8 | new_data = compact(**data) 9 | assert new_data == expected 10 | 11 | 12 | def test_compact_dict(): 13 | data = {"key1": "A", "key2": 1, "key3": None, "key4": False, "key5": True} 14 | expected = {"key1": "A", "key2": 1, "key4": False, "key5": True} 15 | new_data = compact_dict(data) 16 | assert new_data == expected 17 | 18 | 19 | def test_trim_docstring(): 20 | docstring = " test docstring\nline one\nline 2\nline 3 " 21 | expected = "test docstring\nline one\nline 2\nline 3" 22 | trimmed = trim_docstring(docstring) 23 | assert trimmed == expected 24 | 25 | 26 | def test_dedent(): 27 | docstring = " line one\n line two\n line three\n line four" 28 | expected = "line one\nline two\n line three\nline four" 29 | dedented = dedent(docstring) 30 | assert dedented == expected 31 | 32 | 33 | def test_load_yaml_from_docstring(): 34 | docstring = """ 35 | This won't be on yaml: neither this 36 | --- 37 | keyA: something 38 | keyB: 39 | keyB1: nested1 40 | keyB2: nested2 41 | keyC: [item1, item2] 42 | """ 43 | expected = { 44 | "keyA": "something", 45 | "keyB": {"keyB1": "nested1", "keyB2": "nested2"}, 46 | "keyC": ["item1", "item2"], 47 | } 48 | yaml_data = load_yaml_from_docstring(docstring) 49 | assert yaml_data == expected 50 | 51 | 52 | def test_load_yaml_from_docstring_invalid(): 53 | docstring = """ 54 | this is not a valid yaml docstring: docstring 55 | something here: else 56 | """ 57 | yaml_data = load_yaml_from_docstring(docstring) 58 | assert yaml_data is None 59 | -------------------------------------------------------------------------------- /tests/spec/test_validate_spec.py: -------------------------------------------------------------------------------- 1 | from openapi_spec_validator import validate_spec 2 | 3 | from openapi.testing import json_body 4 | 5 | 6 | async def test_validate_spec(cli) -> None: 7 | response = await cli.get("/spec") 8 | spec = await json_body(response) 9 | validate_spec(spec) 10 | -------------------------------------------------------------------------------- /tests/test.env: -------------------------------------------------------------------------------- 1 | DBPOOL_MAX_SIZE=1 2 | DBPOOL_MAX_OVERFLOW=0 3 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, NamedTuple 2 | 3 | import yarl 4 | from aiohttp.web import Application 5 | 6 | 7 | class FakeRequest(NamedTuple): 8 | app: Application 9 | headers: Dict 10 | url: yarl.URL 11 | 12 | @classmethod 13 | def from_app(cls, app: Application) -> "FakeRequest": 14 | return cls(app, {}, yarl.URL("https://fake.com")) 15 | -------------------------------------------------------------------------------- /tests/ws/test_channels.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | 4 | import pytest 5 | from async_timeout import timeout 6 | 7 | from openapi.ws import Channels 8 | from openapi.ws.utils import redis_to_py_pattern 9 | from tests.example.ws import LocalBroker 10 | 11 | 12 | @pytest.fixture 13 | async def channels(): 14 | broker = LocalBroker() 15 | await broker.start() 16 | try: 17 | yield broker.channels 18 | finally: 19 | await broker.close() 20 | 21 | 22 | async def test_channels_properties(channels: Channels): 23 | assert channels.sockets 24 | await channels.register("foo", "*", lambda c, e, d: d) 25 | assert len(channels) == 1 26 | assert "foo" in channels 27 | 28 | 29 | async def test_channels_wildcard(channels: Channels): 30 | future = asyncio.Future() 31 | 32 | def fire(channel, event, data): 33 | future.set_result(event) 34 | 35 | await channels.register("test1", "*", fire) 36 | await channels.sockets.publish("test1", "boom", "ciao!") 37 | async with timeout(1): 38 | result = await future 39 | assert result == "boom" 40 | assert len(channels) == 1 41 | await channels.sockets.close() 42 | assert len(channels) == 0 43 | 44 | 45 | def test_redis_to_py_pattern(): 46 | p = redis_to_py_pattern("h?llo") 47 | c = re.compile(p) 48 | assert match(c, "hello") 49 | assert match(c, "hallo") 50 | assert not_match(c, "haallo") 51 | assert not_match(c, "hallox") 52 | # 53 | p = redis_to_py_pattern("h*llo") 54 | c = re.compile(p) 55 | assert match(c, "hello") 56 | assert match(c, "hallo") 57 | assert match(c, "hasjdbvhckjcvkfcdfllo") 58 | assert not_match(c, "haallox") 59 | assert not_match(c, "halloouih") 60 | # 61 | p = redis_to_py_pattern("h[ae]llo") 62 | c = re.compile(p) 63 | assert match(c, "hello") 64 | assert match(c, "hallo") 65 | assert not_match(c, "hollo") 66 | 67 | 68 | async def test_channel(channels: Channels): 69 | assert channels.sockets 70 | await channels.register("test", "foo", lambda c, e, d: d) 71 | assert channels.registered == ("test",) 72 | channel = channels.get("test") 73 | assert channel.name == "test" 74 | assert channel.events == ("foo",) 75 | assert "foo$" in channel 76 | events = list(channel) 77 | assert len(events) == 1 78 | assert await channel({}) == () 79 | 80 | 81 | def match(c, text): 82 | return c.match(text).group() == text 83 | 84 | 85 | def not_match(c, text): 86 | return c.match(text) is None 87 | --------------------------------------------------------------------------------