├── .envrc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ ├── config.yaml │ └── feature_request.yaml ├── dependabot.yml └── workflows │ ├── codeql-analysis.yml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── benchmarks ├── Makefile ├── README.md ├── backend-nats.py ├── backend-rest.py ├── docker-compose.yml ├── frontend-gateway.py └── input.json ├── doc ├── minimal.png └── readme-example-redoc.png ├── docker-compose.yml ├── examples ├── __init__.py ├── client.py ├── event_decorators.py ├── full_with_reload.py ├── minimal.py └── with_fastapi.py ├── natsapi ├── __init__.py ├── _compat.py ├── applications.py ├── asyncapi │ ├── __init__.py │ ├── constants.py │ ├── models.py │ └── utils.py ├── client │ ├── __init__.py │ ├── client.py │ └── config.py ├── context.py ├── encoders.py ├── enums.py ├── exception_handlers.py ├── exceptions.py ├── logger.py ├── models.py ├── plugin.py ├── routing.py ├── state.py ├── types.py └── utils.py ├── poetry.lock ├── pyproject.toml ├── render-asyncapi ├── go.mod ├── go.sum ├── main.go ├── redoc.asyncapi.js └── render-asyncapi ├── shell.nix └── tests ├── __init__.py ├── asyncapi ├── test_generation.py └── test_models.py ├── conftest.py ├── fixtures.py ├── plugins └── test_mock.py ├── test_applications.py ├── test_config.py ├── test_exceptions.py ├── test_fastapi.py ├── test_jsonable_encoder.py ├── test_method_type_conversion.py ├── test_models.py ├── test_publish.py ├── test_requests.py └── test_routing.py /.envrc: -------------------------------------------------------------------------------- 1 | use nix 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Create a bug report 3 | labels: [bug] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thank you for reporting the bug. 9 | Please provide all the required information to receive faster responses from the maintainers. 10 | - type: textarea 11 | attributes: 12 | label: Describe the bug 13 | description: A concise description of what you're experiencing. 14 | - type: textarea 15 | attributes: 16 | label: To reproduce 17 | description: Steps to reproduce this behavior. 18 | placeholder: | 19 | 1. Go to '...' 20 | validations: 21 | required: false 22 | - type: textarea 23 | attributes: 24 | label: Expected Behavior 25 | description: A concise description of what you expected to happen. 26 | validations: 27 | required: false 28 | 29 | - type: textarea 30 | attributes: 31 | label: Additional context 32 | description: | 33 | Please add screenshots, logs files, links, or details that provide context about the issue. 34 | 35 | Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. 36 | validations: 37 | required: false 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Security Contact 4 | about: Please report security vulnerabilities to security@wegroup.be 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project 3 | labels: New feature 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to request a feature for NatsAPI! 9 | - type: textarea 10 | attributes: 11 | label: Is your feature request related to a problem? 12 | description: 13 | A concise description of the problem you are facing or the motivation 14 | behind this feature request. 15 | placeholder: | 16 | I experienced difficulties with... 17 | validations: 18 | required: false 19 | - type: textarea 20 | attributes: 21 | label: Describe the solution you'd like. 22 | description: A concise description of what you want to happen. 23 | placeholder: | 24 | As a ... I want ... so that ... 25 | validations: 26 | required: false 27 | - type: textarea 28 | attributes: 29 | label: Describe alternatives you've considered. 30 | description: Is there another approach to solve the problem? 31 | validations: 32 | required: false 33 | - type: textarea 34 | attributes: 35 | label: Additional context. 36 | description: | 37 | Add any other context, screenshots, links to provide more context about the feature request. 38 | validations: 39 | required: false 40 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | commit-message: 9 | prefix: ⬆ 10 | # Python 11 | - package-ecosystem: "pip" 12 | directory: "/" 13 | schedule: 14 | interval: "daily" 15 | commit-message: 16 | prefix: ⬆ 17 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | schedule: 21 | - cron: '27 12 * * 4' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: ["python"] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'master' 7 | pull_request: 8 | types: [opened, synchronize] 9 | 10 | jobs: 11 | ci: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 16 | poetry-version: ["1.8.3"] 17 | pydantic-version: ["pydantic<2", "pydantic>=2"] 18 | 19 | 20 | services: 21 | nats: 22 | # Docker Hub image 23 | image: nats 24 | ports: 25 | - 4222:4222 26 | 27 | steps: 28 | - uses: actions/checkout@v3 29 | - uses: actions/setup-python@v4 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Run image 33 | uses: abatilo/actions-poetry@v2 34 | with: 35 | poetry-version: ${{ matrix.poetry-version }} 36 | - name: Install Dependencies 37 | run: | 38 | poetry install 39 | pip install --upgrade "${{ matrix.pydantic-version }}" 40 | - name: Run CI scripts 41 | run: make ci 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .env 3 | .htmlcov 4 | .pytest_cache 5 | .coverage 6 | dist 7 | __pycache__/ 8 | prof/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | node_modules 123 | .env 124 | 125 | manual_test.py 126 | manual_*.py 127 | 128 | test.py 129 | 130 | #vscode 131 | .vscode/ 132 | 133 | test.key 134 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: ^(blib2to3/|profiling/|tests/data/|\.mypy/|\.tox/) 2 | repos: 3 | - repo: git@github.com:pre-commit/pre-commit-hooks 4 | rev: v4.4.0 5 | hooks: 6 | - id: requirements-txt-fixer 7 | - id: check-case-conflict 8 | - id: check-json 9 | - id: debug-statements 10 | - id: check-merge-conflict 11 | - id: check-symlinks 12 | - id: end-of-file-fixer 13 | - id: check-added-large-files 14 | - id: check-toml 15 | - id: trailing-whitespace 16 | - id: pretty-format-json 17 | args: 18 | - --autofix 19 | - id: trailing-whitespace 20 | - repo: local 21 | hooks: 22 | - id: black 23 | name: black 24 | language: system 25 | entry: black 26 | types: [python] 27 | 28 | - repo: https://github.com/charliermarsh/ruff-pre-commit 29 | rev: v0.0.254 30 | hooks: 31 | - id: ruff 32 | args: 33 | - --fix 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 WeGroup 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ##---------- Preliminaries ---------------------------------------------------- 2 | .POSIX: # Get reliable POSIX behaviour 3 | .SUFFIXES: # Clear built-in inference rules 4 | 5 | ##---------- Variables -------------------------------------------------------- 6 | PREFIX = /usr/local # Default installation directory 7 | PYTEST_GENERAL_FLAGS := -vvvx --asyncio-mode=auto 8 | PYTEST_COV_FLAGS := --cov=natsapi --cov-append --cov-report=term-missing --cov-fail-under=85 9 | PYTEST_COV_ENV := COV_CORE_SOURCE=natsapi COV_CORE_CONFIG=.coveragerc COV_CORE_DATAFILE=.coverage.eager 10 | 11 | ##---------- Build targets ---------------------------------------------------- 12 | 13 | help: ## Show this help message (default) 14 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) 15 | 16 | test: ## Run tests 17 | $(PYTEST_COV_ENV) poetry run pytest $(PYTEST_GENERAL_FLAGS) $(PYTEST_COV_FLAGS) 18 | 19 | testr: ## Run tests with entr 20 | find natsapi tests | entr -r poetry run pytest --disable-warnings -vvvx 21 | 22 | lint: ## Lint checks 23 | poetry run ruff check . 24 | 25 | format: ## Format checks 26 | poetry run black . 27 | 28 | style: ## Style checks 29 | poetry run black --check . 30 | 31 | static-check: ## Static code check 32 | semgrep -v --error --config=p/ci /drone/src 33 | 34 | security: ## Run security check 35 | poetry export -f requirements.txt --without-hashes | sed 's/; .*//' > /tmp/req.txt 36 | sed -i '/^typing-extensions/d' /tmp/req.txt 37 | sed -i '/^anyio/d' /tmp/req.txt 38 | poetry run safety check -r /tmp/req.txt 39 | poetry run bandit -lll -r . 40 | poetry run vulture . --min-confidence 95 41 | 42 | ci: lint format security test ## Run all 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NatsAPI 2 | 3 | NatsAPI is a framework to develop Python3 applications that uses [nats](https://nats.io) as communication medium instead of http. With nats you have a smaller footprint, faster req/s, pub/sub and better observability. 4 | NatsAPI is highly inspired by [FastAPI](https://github.com/tiangolo/fastapi) and has the same development style. NatsAPI produces an [AsyncAPI](https://www.asyncapi.com/) schema out of the box, **this schema is not fully compatible with the standard waiting for [version 3.0.0 to support request/reply pattern](https://github.com/asyncapi/spec/issues/94)**. 5 | 6 | ## Table of Contents 7 | 8 | 9 | 10 | * [Python and pydantic support](#python-and-pydantic-support) 11 | * [Quickstart](#quickstart) 12 | * [Usage](#usage) 13 | * [Docs](#docs) 14 | * [Examples](#examples) 15 | * [Basic](#basic) 16 | * [Error handling with sentry](#error-handling-with-sentry) 17 | * [Generating documentation (asyncapi)](#generating-documentation-asyncapi) 18 | * [Plugins](#plugins) 19 | * [Roadmap](#roadmap) 20 | 21 | 22 | 23 | ## Python and pydantic support 24 | 25 | This library has support for python > 3.9 and pydantic v1 and v2. 26 | 27 | ## Quickstart 28 | 29 | 30 | > Optional: Run a nats server in docker so you are able to connect: `docker run --rm --name nats -dp 4222:4222 nats:latest -js -DV` 31 | 32 | ``` 33 | $ git clone git@github.com:wegroupwolves/natsapi.git && cd natsapi 34 | $ poetry install 35 | $ poetry run python examples/minimal.py 36 | $ nats request natsapi.persons.greet '{"params": {"person": {"first_name": "Foo", "last_name": "Bar"}}}' 37 | ``` 38 | 39 | the output should look as follows: 40 | 41 | ``` 42 | 18:19:00 Sending request on "natsapi.persons.greet" 43 | 18:19:00 Received on "_INBOX.dpBgTyG9XC5NhagdqRHTcp.eMkVkru8" rtt 1.052463ms 44 | {"jsonrpc": "2.0", "id": "c2bc2d20-dbd5-4e39-a22d-c22a8631c5a3", "result": {"message": "Greetings Foo Bar!"}, "error": null} 45 | ``` 46 | 47 | ## Usage 48 | 49 | ### Docs 50 | 51 | > UNDER CONSTRUCTION 52 | 53 | ### Examples 54 | 55 | #### Basic 56 | 57 | ```python 58 | from natsapi import NatsAPI, SubjectRouter 59 | from pydantic import BaseModel 60 | 61 | 62 | class Message(BaseModel): 63 | message: str 64 | 65 | 66 | class Person(BaseModel): 67 | first_name: str 68 | last_name: str 69 | 70 | 71 | app = NatsAPI("natsapi") 72 | router = SubjectRouter() 73 | 74 | 75 | @router.request("persons.greet", result=Message) 76 | async def greet_person(app: NatsAPI, person: Person): 77 | return {"message": f"Greetings {person.first_name} {person.last_name}!"} 78 | 79 | 80 | app.include_router(router) 81 | 82 | if __name__ == "__main__": 83 | app.run() 84 | ``` 85 | 86 | Run as follows: 87 | 88 | ```bash 89 | $ python app.py 90 | ``` 91 | 92 | Docs will be rendered as: 93 | 94 | ![Example of redoc](./doc/minimal.png) 95 | 96 | Send a request: 97 | 98 | ```python 99 | from natsapi import NatsAPI 100 | import asyncio 101 | 102 | 103 | async def main(): 104 | app = await NatsAPI("client").startup() 105 | 106 | params = {"person": {"first_name": "Foo", "last_name": "Bar"}} 107 | r = await app.nc.request("natsapi.persons.greet", params=params, timeout=5) 108 | print(r.result) 109 | 110 | asyncio.run(main()) 111 | 112 | #> {'message': 'Greetings Foo Bar!'} 113 | ``` 114 | 115 | or on the command line 116 | 117 | ```shell 118 | $ nats request natsapi.persons.greet '{"params": {"person": {"first_name": "Foo", "last_name": "Bar"}}}' 119 | 120 | 18:19:00 Sending request on "natsapi.persons.greet" 121 | 18:19:00 Received on "_INBOX.dpBgTyG9XC5NhagdqRHTcp.eMkVkru8" rtt 1.052463ms 122 | {"jsonrpc": "2.0", "id": "c2bc2d20-dbd5-4e39-a22d-c22a8631c5a3", "result": {"message": "Greetings Foo Bar!"}, "error": null} 123 | ``` 124 | 125 | #### Error handling with sentry 126 | 127 | 128 | ```python 129 | from natsapi import NatsAPI, SubjectRouter 130 | import logging 131 | from pydantic import ValidationError 132 | from sentry_sdk import configure_scope 133 | from natsapi.models import JsonRPCRequest, JsonRPCError 134 | from pydantic import BaseModel 135 | 136 | 137 | class StatusResult(BaseModel): 138 | status: str 139 | 140 | 141 | app = NatsAPI("natsapi-example") 142 | 143 | router = SubjectRouter() 144 | 145 | 146 | @router.request("healthz", result=StatusResult) 147 | async def handle_user(app: NatsAPI): 148 | return {"status": "OK"} 149 | 150 | 151 | app.include_router(router) 152 | 153 | 154 | def configure_sentry(auth): 155 | with configure_scope() as scope: 156 | scope.user = { 157 | "email": auth.get("email"), 158 | "id": auth.get("uid"), 159 | "ip_address": auth.get("ip_address"), 160 | } 161 | 162 | 163 | @app.exception_handler(ValidationError) 164 | async def handle_validation_exception(exc: ValidationError, request: JsonRPCRequest, subject: str) -> JsonRPCError: 165 | auth = request.params.get("auth") or {} 166 | configure_sentry(auth) 167 | logging.error( 168 | exc, 169 | exc_info=True, 170 | stack_info=True, 171 | extra={"auth": auth, "json": request.dict(), "subject": subject, "NATS": True, "code": -32003}, 172 | ) 173 | 174 | return JsonRPCError(code=-90001, message="VALIDATION_ERROR", data={"error_str": str(exc)}) 175 | 176 | 177 | if __name__ == "__main__": 178 | app.run(reload=False) 179 | ``` 180 | 181 | ### Generating documentation (asyncapi) 182 | 183 | To see the documentation, you can use the binary to run the server. Root path is `natsapi-example` so: 184 | 185 | ```bash 186 | $ ./nats-redoc 4222 master.trinity-testing 187 | 188 | Server running 189 | Docs can be found on localhost:8090 190 | connected to nats on port 4222 191 | ``` 192 | 193 | When surfing to [localhost:8090](http://127.0.0.1:8090), the documentation should look like this: 194 | 195 | ![Example of redoc](./doc/readme-example-redoc.png) 196 | 197 | ### Plugins 198 | 199 | Plugins can be added and are found in `natsapi/plugin.py`. 200 | 201 | - [natsapi_mock](./natsapi/plugin.py): A handy mock fixture to intercept nats requests and to fake nats responses for any subject. 202 | 203 | ### Roadmap 204 | 205 | - [ ] Add Request/Reply AsyncApi support 206 | - [ ] Hot reloading (when saving source code, application should be reloaded) 207 | - [ ] Fancy readme 208 | - [ ] Better benchmark 209 | - [ ] Add support for 'side effect' testing so that you can have more than 1 response for same mocked route, based on ordering. See respx 210 | - [ ] Better CI/CD -> with Python 3.13 211 | - [ ] Make `nkeys` optional 212 | - [x] Pydantic V2 support (nice to have) 213 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | Security is very important for NatsAPI and its community. 🔒 4 | 5 | Learn more about it below. 👇 6 | 7 | 8 | ## Reporting a Vulnerability 9 | 10 | If you think you found a vulnerability, and even if you are not sure about it, please report it right away by sending an email to: security@wegroup.be. Please try to be as explicit as possible, describing all the steps and example code to reproduce the security issue. 11 | -------------------------------------------------------------------------------- /benchmarks/Makefile: -------------------------------------------------------------------------------- 1 | ##---------- Preliminaries ---------------------------------------------------- 2 | .POSIX: # Get reliable POSIX behaviour 3 | .SUFFIXES: # Clear built-in inference rules 4 | 5 | ##---------- Variables -------------------------------------------------------- 6 | PREFIX = /usr/local # Default installation directory 7 | 8 | ##---------- Build targets ---------------------------------------------------- 9 | 10 | help: ## Show this help message (default) 11 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) 12 | 13 | frontend-gateway: ## Run the frontend gateway 14 | poetry run uvicorn frontend-gateway:app --no-access-log --workers 1 --port 5000 15 | 16 | backend-rest: ## Run the backend rest service 17 | poetry run uvicorn backend-rest:app --no-access-log --workers 1 --port 5001 18 | 19 | backend-nats: ## Run the backend nats service 20 | poetry run python backend-nats.py 21 | 22 | bench-rest: ## Run benchmarks for the rest to rest connection 23 | ab -n 1000 -c 100 http://localhost:5000/rest/index 24 | ab -p input.json -T application/json -c 100 -n 1000 http://localhost:5000/rest/sum 25 | 26 | bench-nats: ## Run benchmarks for the nats to nats connection 27 | ab -n 1000 -c 100 http://localhost:5000/nats/index 28 | ab -p input.json -T application/json -c 100 -n 1000 http://localhost:5000/nats/sum 29 | 30 | # cursor: 15 del 31 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | ## Results 4 | 5 | ### Rest 6 | 7 | GET /rest/index 8 | 9 | ``` 10 | Server Software: uvicorn 11 | Server Hostname: localhost 12 | Server Port: 5000 13 | 14 | Document Path: /rest/index 15 | Document Length: 17 bytes 16 | 17 | Concurrency Level: 100 18 | Time taken for tests: 10.171 seconds 19 | Complete requests: 1000 20 | Failed requests: 0 21 | Total transferred: 161000 bytes 22 | HTML transferred: 17000 bytes 23 | Requests per second: 98.32 [#/sec] (mean) 24 | Time per request: 1017.057 [ms] (mean) 25 | Time per request: 10.171 [ms] (mean, across all concurrent requests) 26 | Transfer rate: 15.46 [Kbytes/sec] received 27 | 28 | Connection Times (ms) 29 | min mean[+/-sd] median max 30 | Connect: 0 1 3.0 0 14 31 | Processing: 83 1005 95.8 1018 1213 32 | Waiting: 64 988 88.3 993 1192 33 | Total: 84 1007 95.9 1018 1213 34 | 35 | Percentage of the requests served within a certain time (ms) 36 | 50% 1018 37 | 66% 1030 38 | 75% 1038 39 | 80% 1075 40 | 90% 1103 41 | 95% 1192 42 | 98% 1199 43 | 99% 1201 44 | 100% 1213 (longest request) 45 | ``` 46 | 47 | POST /rest/sum 48 | 49 | ``` 50 | Server Software: uvicorn 51 | Server Hostname: localhost 52 | Server Port: 5000 53 | 54 | Document Path: /rest/sum 55 | Document Length: 3 bytes 56 | 57 | Concurrency Level: 100 58 | Time taken for tests: 10.328 seconds 59 | Complete requests: 1000 60 | Failed requests: 0 61 | Total transferred: 146000 bytes 62 | Total body sent: 184000 63 | HTML transferred: 3000 bytes 64 | Requests per second: 96.82 [#/sec] (mean) 65 | Time per request: 1032.849 [ms] (mean) 66 | Time per request: 10.328 [ms] (mean, across all concurrent requests) 67 | Transfer rate: 13.80 [Kbytes/sec] received 68 | 17.40 kb/s sent 69 | 31.20 kb/s total 70 | 71 | Connection Times (ms) 72 | min mean[+/-sd] median max 73 | Connect: 0 1 0.7 0 3 74 | Processing: 14 1026 85.3 1008 1212 75 | Waiting: 11 1014 80.8 998 1209 76 | Total: 14 1026 85.6 1008 1214 77 | WARNING: The median and mean for the initial connection time are not within a normal deviation 78 | These results are probably not that reliable. 79 | 80 | Percentage of the requests served within a certain time (ms) 81 | 50% 1008 82 | 66% 1054 83 | 75% 1070 84 | 80% 1071 85 | 90% 1161 86 | 95% 1162 87 | 98% 1214 88 | 99% 1214 89 | 100% 1214 (longest request) 90 | ``` 91 | 92 | ### Nats 93 | 94 | GET /nats/index 95 | 96 | ``` 97 | Server Software: uvicorn 98 | Server Hostname: localhost 99 | Server Port: 5000 100 | 101 | Document Path: /nats/index 102 | Document Length: 17 bytes 103 | 104 | Concurrency Level: 100 105 | Time taken for tests: 0.739 seconds 106 | Complete requests: 1000 107 | Failed requests: 0 108 | Total transferred: 161000 bytes 109 | HTML transferred: 17000 bytes 110 | Requests per second: 1354.05 [#/sec] (mean) 111 | Time per request: 73.852 [ms] (mean) 112 | Time per request: 0.739 [ms] (mean, across all concurrent requests) 113 | Transfer rate: 212.89 [Kbytes/sec] received 114 | 115 | Connection Times (ms) 116 | min mean[+/-sd] median max 117 | Connect: 0 1 2.3 1 11 118 | Processing: 22 70 42.7 52 197 119 | Waiting: 6 60 41.9 44 189 120 | Total: 22 72 44.7 53 204 121 | 122 | Percentage of the requests served within a certain time (ms) 123 | 50% 53 124 | 66% 54 125 | 75% 59 126 | 80% 88 127 | 90% 182 128 | 95% 203 129 | 98% 203 130 | 99% 204 131 | 100% 204 (longest request) 132 | ``` 133 | 134 | POST /nats/sum 135 | 136 | ``` 137 | Server Software: uvicorn 138 | Server Hostname: localhost 139 | Server Port: 5000 140 | 141 | Document Path: /nats/sum 142 | Document Length: 1 bytes 143 | 144 | Concurrency Level: 100 145 | Time taken for tests: 0.738 seconds 146 | Complete requests: 1000 147 | Failed requests: 0 148 | Total transferred: 144000 bytes 149 | Total body sent: 184000 150 | HTML transferred: 1000 bytes 151 | Requests per second: 1355.88 [#/sec] (mean) 152 | Time per request: 73.753 [ms] (mean) 153 | Time per request: 0.738 [ms] (mean, across all concurrent requests) 154 | Transfer rate: 190.67 [Kbytes/sec] received 155 | 243.64 kb/s sent 156 | 434.31 kb/s total 157 | 158 | Connection Times (ms) 159 | min mean[+/-sd] median max 160 | Connect: 0 1 0.7 1 4 161 | Processing: 8 72 12.6 68 103 162 | Waiting: 2 62 13.3 58 97 163 | Total: 8 72 12.5 68 103 164 | 165 | Percentage of the requests served within a certain time (ms) 166 | 50% 68 167 | 66% 70 168 | 75% 70 169 | 80% 79 170 | 90% 100 171 | 95% 101 172 | 98% 103 173 | 99% 103 174 | 100% 103 (longest request) 175 | ``` 176 | 177 | ## Conclusion 178 | 179 | Nats is quite a lot faster than using http. 180 | -------------------------------------------------------------------------------- /benchmarks/backend-nats.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | 4 | from nats.aio.client import Client as NATS 5 | 6 | nc = NATS() 7 | 8 | 9 | async def handle_index(msg): 10 | data = {"hello": "world"} 11 | await nc.publish(msg.reply, json.dumps(data).encode()) 12 | 13 | 14 | async def handle_sum(msg): 15 | data = json.loads(msg.data.decode()) 16 | response = str(data["number_1"] + data["number_2"]) 17 | await nc.publish(msg.reply, response.encode()) 18 | 19 | 20 | async def main(): 21 | await nc.connect("nats://127.0.0.1:4222") 22 | await nc.subscribe("index", "", handle_index) 23 | await nc.subscribe("sum", "", handle_sum) 24 | print("Listening for requests") 25 | 26 | 27 | if __name__ == "__main__": 28 | loop = asyncio.get_event_loop() 29 | try: 30 | loop.run_until_complete(main()) 31 | loop.run_forever() 32 | loop.close() 33 | except Exception as e: 34 | print(e) 35 | -------------------------------------------------------------------------------- /benchmarks/backend-rest.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | 3 | from fastapi import FastAPI 4 | from pydantic import BaseModel 5 | 6 | app = FastAPI() 7 | 8 | 9 | @app.get("/index") 10 | async def index(): 11 | return {"Hello": "World"} 12 | 13 | 14 | class Numbers(BaseModel): 15 | number_1: int 16 | number_2: int 17 | 18 | 19 | @app.post("/sum") 20 | async def sum(numbers: Numbers): 21 | return numbers.number_1 + numbers.number_2 22 | -------------------------------------------------------------------------------- /benchmarks/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | services: 3 | nats: 4 | image: nats:latest 5 | container_name: nats 6 | ports: 7 | - 4222:4222 8 | command: -js 9 | -------------------------------------------------------------------------------- /benchmarks/frontend-gateway.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | 3 | import json 4 | 5 | import httpx 6 | import nats 7 | from fastapi import FastAPI, Request 8 | from pydantic import BaseModel 9 | 10 | app = FastAPI() 11 | 12 | 13 | @app.on_event("startup") 14 | async def setup() -> None: 15 | try: 16 | app.nc = await nats.connect("nats://localhost:4222") 17 | except Exception as e: 18 | print(e) 19 | 20 | 21 | @app.get("/nats/index") 22 | async def nats_index(request: Request): 23 | response = await request.app.nc.request("index", b"", timeout=60) 24 | return json.loads(response.data.decode()) 25 | 26 | 27 | @app.get("/rest/index") 28 | async def rest_index(): 29 | async with httpx.AsyncClient() as client: 30 | r = await client.get("http://localhost:5001/index") 31 | return r.json() 32 | 33 | 34 | class Numbers(BaseModel): 35 | number_1: int 36 | number_2: int 37 | 38 | 39 | @app.post("/nats/sum") 40 | async def nats_sum(request: Request, numbers: Numbers): 41 | response = await request.app.nc.request("sum", numbers.json().encode(), timeout=60) 42 | return json.loads(response.data.decode()) 43 | 44 | 45 | @app.post("/rest/sum") 46 | async def rest_sum(numbers: Numbers): 47 | async with httpx.AsyncClient() as client: 48 | r = await client.post("http://localhost:5001/sum", json=numbers.dict()) 49 | return r.text 50 | -------------------------------------------------------------------------------- /benchmarks/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "number_1": 2, 3 | "number_2": 3 4 | } 5 | -------------------------------------------------------------------------------- /doc/minimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wegroupwolves/natsapi/cb31ef79c5cea79f815ed4d11f1cc8f03deaf21c/doc/minimal.png -------------------------------------------------------------------------------- /doc/readme-example-redoc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wegroupwolves/natsapi/cb31ef79c5cea79f815ed4d11f1cc8f03deaf21c/doc/readme-example-redoc.png -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | services: 3 | nats: 4 | image: nats:2-alpine 5 | container_name: nats-server 6 | ports: 7 | - 4222:4222 8 | command: -DV 9 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wegroupwolves/natsapi/cb31ef79c5cea79f815ed4d11f1cc8f03deaf21c/examples/__init__.py -------------------------------------------------------------------------------- /examples/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from natsapi import NatsAPI 4 | 5 | 6 | async def main(): 7 | app = await NatsAPI("client").startup() 8 | 9 | params = {"person": {"first_name": "Foo", "last_name": "Bar"}} 10 | r = await app.nc.request("natsapi.persons.greet", params=params, timeout=5) 11 | print(r.result) 12 | 13 | 14 | asyncio.run(main()) 15 | -------------------------------------------------------------------------------- /examples/event_decorators.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from natsapi import NatsAPI 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | app = NatsAPI("natsapi.dev") 8 | 9 | 10 | @app.on_startup 11 | async def setup(): 12 | logging.info("Connect to db") 13 | 14 | 15 | @app.on_shutdown 16 | async def teardown(): 17 | logging.info("Disconnect from db") 18 | 19 | 20 | if __name__ == "__main__": 21 | app.run() 22 | -------------------------------------------------------------------------------- /examples/full_with_reload.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pydantic import BaseModel 4 | 5 | from natsapi import NatsAPI 6 | from natsapi.models import JsonRPCError, JsonRPCRequest 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | 11 | class StatusResult(BaseModel): 12 | status: str 13 | 14 | 15 | app = NatsAPI("natsapi.development") 16 | 17 | 18 | @app.on_startup 19 | async def setup(): 20 | logging.info("[STARTUP]") 21 | 22 | 23 | @app.on_shutdown 24 | async def teardown(): 25 | logging.info("[TEARDOWN]") 26 | 27 | 28 | @app.publish("foo.do", description="This is the description of the publish request") 29 | async def _(app, bar: int): 30 | bar = bar 31 | logging.info("Publish called!") 32 | 33 | 34 | @app.request("healthz.get", description="Is the server still up?", result=StatusResult) 35 | async def _(app: NatsAPI): 36 | return {"status": "OK"} 37 | 38 | 39 | @app.exception_handler(Exception) 40 | async def handle_exception_custom(exc: Exception, request: JsonRPCRequest, subject: str) -> JsonRPCError: 41 | return JsonRPCError(code=-90001, message="VALIDATION_ERROR", data={"error_str": str(exc)}) 42 | 43 | 44 | if __name__ == "__main__": 45 | app.run(reload=False) 46 | -------------------------------------------------------------------------------- /examples/minimal.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from natsapi import NatsAPI, SubjectRouter 4 | 5 | 6 | class Message(BaseModel): 7 | message: str 8 | 9 | 10 | class Person(BaseModel): 11 | first_name: str 12 | last_name: str 13 | 14 | 15 | app = NatsAPI("natsapi") 16 | router = SubjectRouter() 17 | 18 | 19 | @router.request("persons.greet", result=Message) 20 | async def greet_person(app: NatsAPI, person: Person): 21 | return {"message": f"Greetings {person.first_name} {person.last_name}!"} 22 | 23 | 24 | app.include_router(router) 25 | 26 | if __name__ == "__main__": 27 | app.run() 28 | -------------------------------------------------------------------------------- /examples/with_fastapi.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from fastapi import FastAPI 5 | 6 | from natsapi import NatsAPI 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | fastapi = FastAPI() 11 | natsapi = NatsAPI("natsapi.dev") 12 | 13 | 14 | @fastapi.on_event("startup") 15 | async def setup(): 16 | loop = asyncio.get_running_loop() 17 | await natsapi.startup(loop=loop) 18 | 19 | logging.info("Connect to db") 20 | 21 | 22 | @fastapi.on_event("shutdown") 23 | async def teardown(): 24 | await natsapi.shutdown() 25 | 26 | logging.info("Disconnect from db") 27 | 28 | 29 | # uvicorn example.with_fastapi:fastapi 30 | -------------------------------------------------------------------------------- /natsapi/__init__.py: -------------------------------------------------------------------------------- 1 | from .applications import NatsAPI 2 | from .exceptions import JsonRPCException 3 | from .models import JsonRPCReply, JsonRPCRequest 4 | from .routing import Pub, Sub, SubjectRouter 5 | 6 | __all__ = ["NatsAPI", "JsonRPCException", "Pub", "Sub", "SubjectRouter", "JsonRPCRequest", "JsonRPCReply"] 7 | -------------------------------------------------------------------------------- /natsapi/_compat.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from functools import lru_cache 5 | from typing import Annotated, Any, Literal, Optional, Union 6 | 7 | from pydantic import BaseModel 8 | from pydantic.version import VERSION as PYDANTIC_VERSION 9 | 10 | from natsapi.asyncapi.constants import REF_PREFIX 11 | 12 | PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) 13 | PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 14 | 15 | ModelNameMap = Union[dict[type[BaseModel], type[Enum], str]] 16 | 17 | if PYDANTIC_V2: 18 | from pydantic import ( 19 | RootModel, # noqa 20 | TypeAdapter, 21 | ) 22 | from pydantic import ValidationError as ValidationError 23 | from pydantic._internal._utils import lenient_issubclass as lenient_issubclass 24 | from pydantic.deprecated.json import ENCODERS_BY_TYPE 25 | from pydantic.fields import FieldInfo 26 | from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema 27 | from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue 28 | from pydantic_core import PydanticUndefined, PydanticUndefinedType 29 | from pydantic_settings import BaseSettings 30 | 31 | Undefined = PydanticUndefined 32 | UndefinedType = PydanticUndefinedType 33 | 34 | @dataclass 35 | class ModelField: 36 | field_info: FieldInfo 37 | name: str 38 | mode: Literal["validation", "serialization"] = "validation" 39 | sub_fields: Optional[str] = None 40 | 41 | @property 42 | def alias(self) -> str: 43 | a = self.field_info.alias 44 | return a if a is not None else self.name 45 | 46 | @property 47 | def required(self) -> bool: 48 | return self.field_info.is_required() 49 | 50 | @property 51 | def default(self) -> Any: 52 | return self.get_default() 53 | 54 | @property 55 | def type_(self) -> Any: 56 | return self.field_info.annotation 57 | 58 | def __post_init__(self) -> None: 59 | self._type_adapter: TypeAdapter[Any] = TypeAdapter(Annotated[self.field_info.annotation, self.field_info]) 60 | 61 | def get_default(self) -> Any: 62 | if self.field_info.is_required(): 63 | return Undefined 64 | return self.field_info.get_default(call_default_factory=True) 65 | 66 | def validate( 67 | self, 68 | value: Any, 69 | values: dict[str, Any] = {}, # noqa: B006 70 | *, 71 | loc: tuple[Union[int, str], ...] = (), 72 | ) -> tuple[Any, list[dict[str, Any]]]: 73 | try: 74 | return ( 75 | self._type_adapter.validate_python(value, from_attributes=True), 76 | None, 77 | ) 78 | except ValidationError as exc: 79 | return None, _regenerate_error_with_loc(errors=exc.errors(include_url=False), loc_prefix=loc) 80 | 81 | def serialize( 82 | self, 83 | value: Any, 84 | *, 85 | mode: Literal["json", "python"] = "json", 86 | by_alias: bool = True, 87 | exclude_unset: bool = False, 88 | exclude_defaults: bool = False, 89 | exclude_none: bool = False, 90 | ) -> Any: 91 | # What calls this code passes a value that already called 92 | return self._type_adapter.dump_python( 93 | value, 94 | mode=mode, 95 | by_alias=by_alias, 96 | exclude_unset=exclude_unset, 97 | exclude_defaults=exclude_defaults, 98 | exclude_none=exclude_none, 99 | ) 100 | 101 | def __hash__(self) -> int: 102 | # Each ModelField is unique for our purposes, to allow making a dict from 103 | # ModelField to its JSON Schema. 104 | return id(self) 105 | 106 | def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]: 107 | return errors # type: ignore[return-value] 108 | 109 | def get_model_fields(model: type[BaseModel]) -> list[ModelField]: 110 | return [ModelField(field_info=field_info, name=name) for name, field_info in model.model_fields.items()] 111 | 112 | def get_compat_model_name_map(fields: list[ModelField]): 113 | return {} 114 | 115 | def get_definitions( 116 | *, 117 | fields: list[ModelField], 118 | schema_generator: GenerateJsonSchema, 119 | model_name_map: ModelNameMap, 120 | separate_input_output_schemas: bool = True, 121 | ) -> tuple[ 122 | dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], 123 | dict[str, dict[str, Any]], 124 | ]: 125 | override_mode: Optional[Literal["validation"]] = None if separate_input_output_schemas else "validation" 126 | inputs = [(field, override_mode or field.mode, field._type_adapter.core_schema) for field in fields] 127 | field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) 128 | return field_mapping, definitions # type: ignore[return-value] 129 | 130 | else: 131 | from pydantic import ( 132 | BaseModel, 133 | BaseSettings, # noqa: F401 134 | ) 135 | from pydantic.error_wrappers import ( # type: ignore[no-redef] 136 | ErrorWrapper as ErrorWrapper, # noqa: F401 137 | ) 138 | from pydantic.fields import ( # type: ignore[no-redef,attr-defined] 139 | ModelField as ModelField, # noqa: F401 140 | ) 141 | from pydantic.json import ENCODERS_BY_TYPE # noqa: F401 142 | from pydantic.schema import ( 143 | get_flat_models_from_fields, 144 | get_model_name_map, 145 | model_process_schema, 146 | ) 147 | from pydantic.utils import ( # noqa 148 | lenient_issubclass as lenient_issubclass, # noqa: F401 149 | ) 150 | 151 | class RootModel(BaseModel): 152 | __root__: str 153 | 154 | GetJsonSchemaHandler = Any # type: ignore[assignment,misc] 155 | JsonSchemaValue = dict[str, Any] # type: ignore[misc] 156 | CoreSchema = Any # type: ignore[assignment,misc] 157 | 158 | @dataclass 159 | class GenerateJsonSchema: # type: ignore[no-redef] 160 | ref_template: str 161 | 162 | def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]: 163 | use_errors: list[Any] = [] 164 | for error in errors: 165 | if isinstance(error, ErrorWrapper): 166 | new_errors = ValidationError(errors=[error]).errors() # type: ignore[call-arg] 167 | use_errors.extend(new_errors) 168 | elif isinstance(error, list): 169 | use_errors.extend(_normalize_errors(error)) 170 | else: 171 | use_errors.append(error) 172 | return use_errors 173 | 174 | def get_model_fields(model: type[BaseModel]) -> list[ModelField]: 175 | return list(model.__fields__.values()) # type: ignore[attr-defined] 176 | 177 | def get_compat_model_name_map(fields: list[ModelField]): 178 | models = get_flat_models_from_fields(fields, known_models=set()) 179 | return get_model_name_map(models) # type: ignore[no-any-return] 180 | 181 | def get_model_definitions( 182 | *, 183 | flat_models: Union[set[type[BaseModel], type[Enum]]], 184 | model_name_map: Union[dict[type[BaseModel], type[Enum], str]], 185 | ) -> dict[str, Any]: 186 | definitions: dict[str, dict[str, Any]] = {} 187 | for model in flat_models: 188 | m_schema, m_definitions, m_nested_models = model_process_schema( 189 | model, 190 | model_name_map=model_name_map, 191 | ref_prefix=REF_PREFIX, 192 | ) 193 | definitions.update(m_definitions) 194 | model_name = model_name_map[model] 195 | if "description" in m_schema: 196 | m_schema["description"] = m_schema["description"].split("\f")[0] 197 | definitions[model_name] = m_schema 198 | return definitions 199 | 200 | def get_definitions( 201 | *, 202 | fields: list[ModelField], 203 | schema_generator: GenerateJsonSchema, 204 | model_name_map: ModelNameMap, 205 | separate_input_output_schemas: bool = True, 206 | ) -> tuple[ 207 | dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], 208 | dict[str, dict[str, Any]], 209 | ]: 210 | models = get_flat_models_from_fields(fields, known_models=set()) 211 | return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map) 212 | 213 | 214 | def _regenerate_error_with_loc( 215 | *, 216 | errors: Sequence[Any], 217 | loc_prefix: tuple[Union[str, int], ...], 218 | ) -> list[dict[str, Any]]: 219 | updated_loc_errors: list[Any] = [ 220 | {**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors) 221 | ] 222 | 223 | return updated_loc_errors 224 | 225 | 226 | @lru_cache 227 | def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]: 228 | return get_model_fields(model) 229 | 230 | 231 | class MyGenerateJsonSchema(GenerateJsonSchema): 232 | def sort(self, value: JsonSchemaValue, *args) -> JsonSchemaValue: 233 | """ 234 | No-op, we don't want to sort schema values at all. 235 | https://docs.pydantic.dev/latest/concepts/json_schema/#json-schema-sorting 236 | """ 237 | return value 238 | -------------------------------------------------------------------------------- /natsapi/applications.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import inspect 3 | import signal 4 | from collections.abc import Callable 5 | from typing import Any, Optional, Union 6 | 7 | from pydantic import BaseModel, ValidationError 8 | 9 | from natsapi.asyncapi import Errors, ExternalDocumentation 10 | from natsapi.asyncapi.models import AsyncAPI 11 | from natsapi.asyncapi.utils import get_asyncapi 12 | from natsapi.client import Config, NatsClient 13 | from natsapi.client.config import default_config 14 | from natsapi.exception_handlers import handle_internal_error, handle_jsonrpc_exception, handle_validation_error 15 | from natsapi.exceptions import DuplicateRouteException, JsonRPCException 16 | from natsapi.logger import logger 17 | from natsapi.routing import Pub, Publish, Request, Sub, SubjectRouter 18 | from natsapi.state import State 19 | from natsapi.types import DecoratedCallable 20 | 21 | 22 | class NatsAPI: 23 | def __init__( 24 | self, 25 | root_path: str, 26 | *, 27 | app: Any = None, 28 | client_config: Optional[Config] = None, 29 | rpc_methods: Optional[list[str]] = None, 30 | exception_handlers: Optional[dict[type[Exception], Callable[[type[Exception]], JsonRPCException]]] = None, 31 | title: str = "NatsAPI", 32 | version: str = "0.1.0", 33 | description: str = None, 34 | tags: Optional[list[dict[str, Any]]] = None, 35 | servers: Optional[dict[str, Union[str, Any]]] = None, 36 | domain_errors: Optional[dict[str, Any]] = None, 37 | external_docs: Optional[dict[str, Any]] = None, 38 | ): 39 | """ 40 | Parameters 41 | ---------- 42 | root_path: str The path that every application-specific subject 43 | app: FastAPI Must be a FastAPI instance or None. If none the app is the NatsAPI instance itself 44 | """ 45 | self.routes: dict[str, Request] = {} 46 | self.root_path = root_path 47 | self._root_paths = [root_path] 48 | self.rpc_methods = rpc_methods if rpc_methods else None 49 | self.title = title 50 | self.version = version 51 | self.description = description 52 | self.asyncapi_servers = servers or {} 53 | self.asyncapi_external_docs = external_docs 54 | self.asyncapi_tags = tags or [] 55 | self.asyncapi_version = "2.0.0" 56 | self.domain_errors: Errors = domain_errors 57 | self.asyncapi_schema: Optional[dict[str, Any]] = None 58 | self.nc: NatsClient = None 59 | self.subs: set[Sub] = set() 60 | self.pubs: set[Pub] = set() 61 | self.state = State() 62 | self._on_startup_method = None 63 | self._on_shutdown_method = None 64 | self.client_config = client_config or default_config 65 | self._exception_handlers: dict[type[Exception], Callable[[type[Exception]], JsonRPCException]] = ( 66 | {} if exception_handlers is None else dict(exception_handlers) 67 | ) 68 | self._exception_handlers.setdefault(JsonRPCException, handle_jsonrpc_exception) 69 | self._exception_handlers.setdefault(ValidationError, handle_validation_error) 70 | self._exception_handlers.setdefault(Exception, handle_internal_error) 71 | try: 72 | self.loop = asyncio.get_running_loop() 73 | self._sharing_loop = True 74 | except RuntimeError: 75 | self.loop = asyncio.get_event_loop() 76 | self._sharing_loop = False 77 | 78 | if app is not None: 79 | app_type = str(type(app)) 80 | assert ( 81 | "FastAPI" in app_type or "Sanic" in app_type 82 | ), f"App must be a FastAPI or Sanic instance, got {app_type}" 83 | self.app = app 84 | else: 85 | self.app = self 86 | 87 | async def __aenter__(self): 88 | await self.startup(self.loop) 89 | return self 90 | 91 | async def __aexit__(self, *args): 92 | await self.shutdown() 93 | 94 | def include_router(self, router: type[SubjectRouter], root_path: str = None) -> None: 95 | current_root_path = root_path or self.root_path 96 | if current_root_path not in self._root_paths: 97 | self._root_paths.append(current_root_path) 98 | 99 | for subject in router.routes: 100 | if self.rpc_methods: 101 | method = subject.subject.split(".")[-1] 102 | assert ( 103 | method in self.rpc_methods 104 | ), f"'{method}' is an invalid request method for handler {subject.endpoint.__name__}. Allowed methods: {self.rpc_methods}" 105 | key_name = ".".join([current_root_path, subject.subject]) 106 | if key_name in self.routes: 107 | raise DuplicateRouteException(f"{key_name} is defined twice!") 108 | 109 | self.routes[key_name] = subject 110 | 111 | self.subs = self.subs | router.subs 112 | self.pubs = self.pubs | router.pubs 113 | 114 | def generate_asyncapi(self) -> dict[str, Any]: 115 | if not self.asyncapi_schema: 116 | self.asyncapi_schema = get_asyncapi( 117 | title=self.title, 118 | version=self.version, 119 | asyncapi_version=self.asyncapi_version, 120 | description=self.description, 121 | routes=self.routes, 122 | subs=self.subs, 123 | pubs=self.pubs, 124 | errors=self.domain_errors, 125 | servers=self.asyncapi_servers, 126 | external_docs=self.asyncapi_external_docs, 127 | ) 128 | return self.asyncapi_schema 129 | 130 | def _add_asyncapi_route(self) -> dict[str, Any]: 131 | """ 132 | Adds default route to retrieve the asyncapi schema. 133 | """ 134 | 135 | @self.request("schema.RETRIEVE", result=AsyncAPI, include_schema=False) 136 | def retrieve_asyncapi_schema(app): 137 | return self.generate_asyncapi() 138 | 139 | def on_startup(self, method: Callable) -> Callable[[DecoratedCallable], DecoratedCallable]: 140 | self._on_startup_method = method 141 | 142 | def on_shutdown(self, method: Callable) -> Callable[[DecoratedCallable], DecoratedCallable]: 143 | self._on_shutdown_method = method 144 | 145 | async def startup(self, loop=None): 146 | if loop: 147 | self.loop = loop 148 | self._sharing_loop = True 149 | else: 150 | self._listen_to_signals() 151 | 152 | self.nc = NatsClient( 153 | self.routes, 154 | app=self.app, 155 | config=self.client_config, 156 | exception_handlers=self._exception_handlers, 157 | ) 158 | await self.nc.connect() 159 | logger.info("Connected to NATS server") 160 | 161 | if self._on_startup_method: 162 | if inspect.iscoroutinefunction(self._on_startup_method): 163 | await self._on_startup_method() 164 | else: 165 | self._on_startup_method() 166 | 167 | for path in self._root_paths: 168 | sub_path = ".".join([path, ">"]) 169 | await self.nc.root_path_subscribe( 170 | sub_path, 171 | cb=self.nc.handle_request, 172 | queue=self.client_config.subscribe.queue, 173 | ) 174 | self.include_subs( 175 | [ 176 | Sub( 177 | sub_path, 178 | queue=self.client_config.subscribe.queue, 179 | summary=f"Sub to root path {sub_path}", 180 | tags=["automatic subs"], 181 | ), 182 | ], 183 | ) 184 | logger.info(f"Subscribed to {sub_path}") 185 | self._add_asyncapi_route() 186 | logger.info(f"Asyncapi schema can be found on {self.root_path}.schema.RETRIEVE") 187 | 188 | return self 189 | 190 | def _listen_to_signals(self): 191 | signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) 192 | for s in signals: 193 | self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(self.shutdown(signal=s))) 194 | 195 | async def shutdown(self, signal=None): 196 | if signal: 197 | logger.info("Received kill signal") 198 | 199 | logger.warning("Cleanup coroutines that handle NATS messages.") 200 | cleanup = [x for x in asyncio.all_tasks() if x.get_name().startswith("natsapi_")] 201 | await asyncio.gather(*cleanup, return_exceptions=True) 202 | logger.warning(f"Cleaned up {len(cleanup)} coroutines.") 203 | 204 | await self.nc.shutdown() 205 | 206 | if self._on_shutdown_method: 207 | logger.info("Invoking shutdown of application instance.") 208 | if inspect.iscoroutinefunction(self._on_shutdown_method): 209 | await self._on_shutdown_method() 210 | else: 211 | self._on_shutdown_method() 212 | logger.info("Shutdown of application instance completed.") 213 | 214 | if not self._sharing_loop: 215 | logger.info("Self-managed loop: cancelling remaining asyncio tasks.") 216 | tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] 217 | [task.cancel() for task in tasks] 218 | results = await asyncio.gather(*tasks, return_exceptions=True) 219 | logger.info(f"Finished cancelling tasks, results: {results}") 220 | self.loop.stop() 221 | 222 | def run(self): 223 | self.loop.run_until_complete(self.startup()) 224 | self.loop.run_forever() 225 | 226 | def add_request( 227 | self, 228 | subject: str, 229 | endpoint: Callable[..., Any], 230 | *, 231 | result=type[Any], 232 | skip_validation: Optional[bool] = False, 233 | description: Optional[str] = None, 234 | deprecated: Optional[bool] = None, 235 | tags: Optional[list[str]] = None, 236 | summary: Optional[str] = None, 237 | suggested_timeout: Optional[float] = None, 238 | include_schema: Optional[bool] = True, 239 | ) -> None: 240 | request = Request( 241 | subject=subject, 242 | endpoint=endpoint, 243 | result=result, 244 | skip_validation=skip_validation, 245 | description=description, 246 | deprecated=deprecated, 247 | tags=tags, 248 | summary=summary, 249 | suggested_timeout=suggested_timeout, 250 | include_schema=include_schema, 251 | ) 252 | if self.rpc_methods: 253 | method = subject.split(".")[-1] 254 | assert ( 255 | method in self.rpc_methods 256 | ), f"'{method}' is an invalid request method in handler '{endpoint.__name__}'. Allowed methods: {self.rpc_methods}" 257 | 258 | key_name = ".".join([self.root_path, request.subject]) 259 | if key_name in self.routes: 260 | raise DuplicateRouteException(f"{key_name} is defined twice!") 261 | self.routes[key_name] = request 262 | 263 | def add_publish( 264 | self, 265 | subject: str, 266 | endpoint: Callable[..., Any], 267 | *, 268 | skip_validation: Optional[bool] = False, 269 | description: Optional[str] = None, 270 | deprecated: Optional[bool] = None, 271 | tags: Optional[list[str]] = None, 272 | summary: Optional[str] = None, 273 | include_schema: Optional[bool] = True, 274 | ) -> None: 275 | publish = Publish( 276 | subject=subject, 277 | endpoint=endpoint, 278 | skip_validation=skip_validation, 279 | description=description, 280 | deprecated=deprecated, 281 | tags=tags, 282 | summary=summary, 283 | include_schema=include_schema, 284 | ) 285 | if self.rpc_methods: 286 | method = subject.split(".")[-1] 287 | assert ( 288 | method in self.rpc_methods 289 | ), f"'{method}' is an invalid request method in handler '{endpoint.__name__}'. Allowed methods: {self.rpc_methods}" 290 | 291 | key_name = ".".join([self.root_path, publish.subject]) 292 | if key_name in self.routes: 293 | raise DuplicateRouteException(f"{key_name} is defined twice!") 294 | self.routes[key_name] = publish 295 | 296 | def request( 297 | self, 298 | subject: str, 299 | *, 300 | result=type[Any], 301 | skip_validation: Optional[bool] = False, 302 | description: Optional[str] = None, 303 | deprecated: Optional[bool] = None, 304 | tags: Optional[list[str]] = None, 305 | summary: Optional[str] = None, 306 | suggested_timeout: Optional[float] = None, 307 | include_schema: Optional[bool] = True, 308 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 309 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 310 | self.add_request( 311 | subject=subject, 312 | endpoint=func, 313 | result=result, 314 | skip_validation=skip_validation, 315 | description=description, 316 | deprecated=deprecated, 317 | tags=tags, 318 | summary=summary, 319 | suggested_timeout=suggested_timeout, 320 | include_schema=include_schema, 321 | ) 322 | return func 323 | 324 | return decorator 325 | 326 | def publish( 327 | self, 328 | subject: str, 329 | *, 330 | skip_validation: Optional[bool] = False, 331 | description: Optional[str] = None, 332 | deprecated: Optional[bool] = None, 333 | tags: Optional[list[str]] = None, 334 | summary: Optional[str] = None, 335 | include_schema: Optional[bool] = True, 336 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 337 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 338 | self.add_publish( 339 | subject=subject, 340 | endpoint=func, 341 | skip_validation=skip_validation, 342 | description=description, 343 | deprecated=deprecated, 344 | tags=tags, 345 | summary=summary, 346 | include_schema=include_schema, 347 | ) 348 | return func 349 | 350 | return decorator 351 | 352 | def add_pub( 353 | self, 354 | subject: str, 355 | params: BaseModel, 356 | *, 357 | summary: Optional[str] = None, 358 | description: Optional[str] = None, 359 | tags: Optional[list[str]] = None, 360 | externalDocs: Optional[ExternalDocumentation] = None, 361 | ) -> None: 362 | """ 363 | Include pub in asyncapi schema 364 | """ 365 | pub = Pub( 366 | subject, 367 | params, 368 | summary=summary, 369 | description=description, 370 | tags=tags or None, 371 | externalDocs=externalDocs, 372 | ) 373 | self.pubs.add(pub) 374 | 375 | def pub( 376 | self, 377 | subject: str, 378 | *, 379 | params=type[Any], 380 | description: Optional[str] = None, 381 | tags: Optional[list[str]] = None, 382 | summary: Optional[str] = None, 383 | externalDocs: Optional[ExternalDocumentation] = None, 384 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 385 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 386 | self.add_pub( 387 | subject=subject, 388 | params=params, 389 | summary=summary, 390 | description=description, 391 | tags=tags, 392 | externalDocs=externalDocs, 393 | ) 394 | return func 395 | 396 | return decorator 397 | 398 | def include_subs(self, subs: list[Sub]): 399 | for sub in subs: 400 | self.subs.add(sub) 401 | 402 | def add_sub( 403 | self, 404 | subject: str, 405 | *, 406 | queue: Optional[str] = None, 407 | summary: Optional[str] = None, 408 | description: Optional[str] = None, 409 | tags: Optional[list[str]] = None, 410 | externalDocs: Optional[ExternalDocumentation] = None, 411 | ) -> None: 412 | """ 413 | Include sub in asyncapi schema 414 | """ 415 | sub = Sub( 416 | subject, 417 | queue=queue, 418 | summary=summary, 419 | description=description, 420 | tags=tags or None, 421 | externalDocs=externalDocs, 422 | ) 423 | self.subs.add(sub) 424 | 425 | def sub( 426 | self, 427 | subject: str, 428 | *, 429 | queue: Optional[str] = None, 430 | description: Optional[str] = None, 431 | tags: Optional[list[str]] = None, 432 | summary: Optional[str] = None, 433 | externalDocs: Optional[ExternalDocumentation] = None, 434 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 435 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 436 | self.add_sub( 437 | subject=subject, 438 | queue=queue, 439 | summary=summary, 440 | description=description, 441 | tags=tags, 442 | externalDocs=externalDocs, 443 | ) 444 | return func 445 | 446 | return decorator 447 | 448 | def include_pubs(self, pubs: list[Pub]): 449 | for pub in pubs: 450 | self.pubs.add(pub) 451 | 452 | def add_exception_handler( 453 | self, 454 | exc_class: type[Exception], 455 | handler: Callable[[Exception], JsonRPCException], 456 | ) -> None: 457 | self._exception_handlers[exc_class] = handler 458 | 459 | def exception_handler(self, exc_class: type[Exception]) -> Callable: 460 | def decorator(func: Callable[[Exception], JsonRPCException]) -> Callable: 461 | self.add_exception_handler(exc_class, func) 462 | return func 463 | 464 | return decorator 465 | -------------------------------------------------------------------------------- /natsapi/asyncapi/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel, validator 4 | 5 | from .models import ExternalDocumentation, Server 6 | 7 | 8 | class Errors(BaseModel): 9 | class Config: 10 | arbitrary_types_allowed = True 11 | 12 | upper_bound: int 13 | lower_bound: int 14 | errors: list[Any] 15 | 16 | @validator("lower_bound") 17 | def upper_bigger_than_lower(v, values): 18 | assert v < values["upper_bound"], "upper bound is smaller than lower bound" 19 | return v 20 | 21 | 22 | __all__ = ["ExternalDocumentation", "Server", "Errors"] 23 | -------------------------------------------------------------------------------- /natsapi/asyncapi/constants.py: -------------------------------------------------------------------------------- 1 | REF_PREFIX = "#/components/schemas/" 2 | REF_TEMPLATE = "#/components/schemas/{model}" 3 | -------------------------------------------------------------------------------- /natsapi/asyncapi/models.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | 3 | from pydantic import BaseModel, Field, validator 4 | 5 | # TODO: Extend some classes with specifications object 6 | 7 | 8 | class ExternalDocumentation(BaseModel): 9 | description: Optional[str] 10 | url: str 11 | 12 | 13 | class Discriminator(BaseModel): 14 | propertyName: dict[str, Any] 15 | mapping: Optional[dict[str, str]] = None 16 | 17 | 18 | class SchemaBase(BaseModel): 19 | ref: Optional[str] = Field(None, alias="$ref") 20 | title: Optional[str] = None 21 | multipleOf: Optional[float] = None 22 | maximum: Optional[float] = None 23 | exclusiveMaximum: Optional[float] = None 24 | minimum: Optional[float] = None 25 | exclusiveMinimum: Optional[float] = None 26 | maxLength: Optional[int] = Field(None, gte=0) 27 | minLength: Optional[int] = Field(None, gte=0) 28 | pattern: Optional[str] = None 29 | maxItems: Optional[int] = Field(None, gte=0) 30 | minItems: Optional[int] = Field(None, gte=0) 31 | uniqueItems: Optional[bool] = None 32 | maxProperties: Optional[int] = Field(None, gte=0) 33 | minProperties: Optional[int] = Field(None, gte=0) 34 | required: Optional[list[str]] = None 35 | enum: Optional[list[Any]] = None 36 | type: Optional[str] = None 37 | allOf: Optional[list[Any]] = None 38 | oneOf: Optional[list[Any]] = None 39 | anyOf: Optional[list[Any]] = None 40 | not_: Any = Field(None, alias="not") 41 | items: Any = None 42 | properties: Optional[dict[str, Any]] = None 43 | additionalProperties: Optional[Union[dict[str, Any], bool]] = None 44 | description: Optional[str] = None 45 | format: Optional[str] = None 46 | default: Any = None 47 | nullable: Optional[bool] = None 48 | discriminator: Optional[Discriminator] = None 49 | readOnly: Optional[bool] = None 50 | writeOnly: Optional[bool] = None 51 | externalDocs: Optional[ExternalDocumentation] = None 52 | example: Any = None 53 | deprecated: Optional[bool] = None 54 | 55 | 56 | class Schema(SchemaBase): 57 | allOf: Optional[list[SchemaBase]] = None 58 | oneOf: Optional[list[SchemaBase]] = None 59 | anyOf: Optional[list[SchemaBase]] = None 60 | not_: Optional[SchemaBase] = Field(None, alias="not") 61 | items: Optional[SchemaBase] = None 62 | properties: Optional[dict[str, SchemaBase]] = None 63 | additionalProperties: Optional[Union[dict[str, Any], bool]] = None 64 | 65 | 66 | class Contact(BaseModel): 67 | name: Optional[str] 68 | url: Optional[str] 69 | email: Optional[str] 70 | 71 | 72 | class License(BaseModel): 73 | name: str 74 | url: Optional[str] 75 | 76 | 77 | class Info(BaseModel): 78 | title: str 79 | version: str 80 | description: Optional[str] = None 81 | termsOfService: Optional[str] = None 82 | contact: Optional[Contact] = None 83 | licence: Optional[License] = None 84 | 85 | 86 | class Reference(BaseModel): 87 | ref: str = Field(..., alias="$ref") 88 | 89 | 90 | class Tag(BaseModel): 91 | name: str 92 | description: Optional[str] = None 93 | externalDocs: Optional[ExternalDocumentation] = None 94 | 95 | 96 | # TODO: Add nats bindings 97 | class NatsBindings(BaseModel): 98 | pass 99 | 100 | 101 | class BindingsBase(BaseModel): 102 | nats: Union[NatsBindings, Reference, None] 103 | 104 | 105 | class ServerBindings(BindingsBase): 106 | pass 107 | 108 | 109 | class ChannelBindings(BindingsBase): 110 | pass 111 | 112 | 113 | class OperationBindings(BindingsBase): 114 | pass 115 | 116 | 117 | class MessageBindings(BindingsBase): 118 | pass 119 | 120 | 121 | class CorrelationId(BaseModel): 122 | location: str 123 | description: Optional[str] 124 | 125 | 126 | class MessageTrait(BaseModel): 127 | headers: Union[Schema, Reference, None] 128 | correlationId: Union[CorrelationId, Reference, None] 129 | schemaFormat: Optional[str] 130 | contentType: Optional[str] 131 | name: Optional[str] 132 | title: Optional[str] 133 | summary: Optional[str] 134 | description: Optional[str] 135 | tags: Optional[list[Tag]] 136 | externalDocs: Optional[ExternalDocumentation] 137 | bindings: Union[MessageBindings, Reference, None] 138 | examples: Optional[dict[str, Any]] 139 | 140 | 141 | class Message(BaseModel): 142 | headers: Union[Schema, Reference, None] = None 143 | payload: Any = None 144 | correlationId: Union[CorrelationId, Reference, None] = None 145 | schemaFormat: Optional[str] = None 146 | contentType: Optional[str] = None 147 | name: Optional[str] = None 148 | title: Optional[str] = None 149 | summary: Optional[str] = None 150 | description: Optional[str] = None 151 | tags: Optional[list[Tag]] = None 152 | externalDocs: Optional[ExternalDocumentation] = None 153 | bindings: Union[MessageBindings, Reference, None] = None 154 | examples: Optional[dict[str, Any]] = None 155 | traits: Union[MessageTrait, Reference, None] = None 156 | 157 | 158 | class OperationTrait(BaseModel): 159 | operationId: Optional[str] = None 160 | summary: Optional[str] = None 161 | description: Optional[str] = None 162 | tags: Optional[list[Tag]] = None 163 | externalDocs: Optional[ExternalDocumentation] = None 164 | bindings: Union[OperationBindings, Reference, None] = None 165 | 166 | 167 | class Operation(BaseModel): 168 | operationId: Optional[str] = None 169 | summary: Optional[str] = None 170 | description: Optional[str] = None 171 | tags: Optional[list[Tag]] = None 172 | externalDocs: Optional[ExternalDocumentation] = None 173 | bindings: Union[OperationBindings, Reference, None] = None 174 | traits: Optional[list[Union[OperationTrait, Reference]]] = None 175 | message: Union[Message, Reference, None] = None 176 | 177 | 178 | class SubscribeOperation(Operation): 179 | queue: Optional[str] = Field(None, alias="x-queue") 180 | 181 | 182 | class RequestOperation(Operation): 183 | replies: Optional[list[Union[Message, Reference]]] = None 184 | suggestedTimeout: Optional[float] = Field(None, alias="x-suggested-timeout") 185 | 186 | 187 | class Parameter(BaseModel): 188 | description: Optional[str] = None 189 | schema_: Union[dict[str, Any], Reference] = Field(..., alias="schema") 190 | location: Optional[str] = None 191 | 192 | 193 | class ChannelItem(BaseModel): 194 | description: Optional[str] = None 195 | subscribe: Optional[SubscribeOperation] = None 196 | publish: Optional[Operation] = None 197 | request: Optional[RequestOperation] = None 198 | parameters: Optional[dict[str, Union[Parameter, Reference]]] = None 199 | bindings: Union[ChannelBindings, Reference, None] = None 200 | deprecated: Optional[bool] = None 201 | 202 | 203 | class ServerVariable(BaseModel): 204 | enum: Optional[list[str]] = None 205 | default: Optional[str] = None 206 | description: Optional[str] = None 207 | examples: Optional[list[str]] = None 208 | 209 | 210 | class Server(BaseModel): 211 | url: str 212 | protocol: str 213 | protocolVersion: Optional[str] 214 | description: Optional[str] 215 | variables: Optional[dict[str, ServerVariable]] 216 | bindings: Union[ServerBindings, Reference, None] = None 217 | 218 | 219 | class Components(BaseModel): 220 | schemas: Optional[dict[str, Union[Schema, Reference]]] = None 221 | messages: Optional[dict[str, Union[Message, Reference]]] = None 222 | parameters: Optional[dict[str, Union[dict[str, Parameter], Reference]]] = None 223 | correlationIds: Optional[dict[str, Union[CorrelationId, Reference]]] = None 224 | operationTraits: Optional[dict[str, Union[OperationTrait, Reference]]] = None 225 | messageTraits: Optional[dict[str, Union[MessageTrait, Reference]]] = None 226 | serverBindings: Optional[dict[str, Union[ServerBindings, Reference]]] = None 227 | channelBindings: Optional[dict[str, Union[ChannelBindings, Reference]]] = None 228 | operationBindings: Optional[dict[str, Union[OperationBindings, Reference]]] = None 229 | messageBindings: Optional[dict[str, Union[MessageBindings, Reference]]] = None 230 | 231 | 232 | class Range(BaseModel): 233 | upper: int 234 | lower: int 235 | 236 | @validator("lower") 237 | def upper_bigger_than_lower(v, values): 238 | assert v < values["upper"], "upper bound is smaller than lower bound" 239 | return v 240 | 241 | 242 | class Errors(BaseModel): 243 | range: Range = Field(..., alias="range") 244 | items: list[Any] 245 | 246 | 247 | class AsyncAPI(BaseModel): 248 | asyncapi: str 249 | id: Optional[str] = None 250 | info: Info 251 | servers: Optional[dict[str, Server]] = None 252 | defaultContentType: Optional[str] = "application/json" 253 | channels: Optional[dict[str, Union[ChannelItem, Reference]]] = None 254 | components: Optional[Components] = None 255 | tags: Optional[list[Tag]] = None 256 | externalDocs: Optional[ExternalDocumentation] = None 257 | errors: Optional[Errors] = None 258 | -------------------------------------------------------------------------------- /natsapi/asyncapi/utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from collections.abc import Sequence 3 | from enum import Enum 4 | from typing import Any, Optional, Union 5 | 6 | from pydantic import BaseModel 7 | 8 | from natsapi._compat import ( 9 | ModelField, 10 | MyGenerateJsonSchema, 11 | get_cached_model_fields, 12 | get_compat_model_name_map, 13 | get_definitions, 14 | lenient_issubclass, 15 | ) 16 | from natsapi.asyncapi.constants import REF_PREFIX, REF_TEMPLATE 17 | from natsapi.encoders import jsonable_encoder 18 | from natsapi.models import JsonRPCError 19 | from natsapi.routing import Pub, Publish, Request, Sub 20 | 21 | from . import Errors, ExternalDocumentation, Server 22 | from .models import AsyncAPI 23 | 24 | 25 | def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]: 26 | if not fields: 27 | return fields 28 | first_field = fields[0] 29 | 30 | if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): 31 | fields_to_extract = get_cached_model_fields(first_field.type_) 32 | return fields_to_extract 33 | return fields 34 | 35 | 36 | def get_fields_from_routes(routes: Sequence[Request], pubs: Sequence[Pub]) -> Union[set[type[BaseModel], type[Enum]]]: 37 | replies_from_routes: set[ModelField] = set() 38 | requests_from_routes: set[ModelField] = set() 39 | messages_from_pubs: set[ModelField] = set() 40 | for route in routes: 41 | if getattr(route, "include_schema", True) and isinstance(route, Request): 42 | if route.result: 43 | replies_from_routes.add(route.request_field) 44 | if route.params: 45 | replies_from_routes.add(route.reply_field) 46 | elif getattr(route, "include_schema", True) and isinstance(route, Publish): 47 | if route.params: 48 | replies_from_routes.add(route.reply_field) 49 | for pub in pubs: 50 | messages_from_pubs.add(pub.params_field) 51 | 52 | fields = replies_from_routes | requests_from_routes | messages_from_pubs 53 | 54 | return fields 55 | 56 | 57 | def get_flat_response_models(r) -> list[type[BaseModel]]: 58 | """ 59 | Returns flattened collection of response models of a route. 60 | If the response models are of typing.Union, a list of possible response models is returned. 61 | 62 | :r Single or multiple response models 63 | """ 64 | if type(r) is typing._UnionGenericAlias: 65 | return list(r.__args__) 66 | else: 67 | return [r] 68 | 69 | 70 | def get_asyncapi_request_operation_metadata(operation: Request) -> dict[str, Any]: 71 | metadata: dict[str, Any] = {} 72 | metadata["summary"] = operation.summary.replace("_", " ").title() 73 | metadata["description"] = operation.description 74 | 75 | if operation.tags: 76 | metadata["tags"] = operation.tags 77 | metadata["operationId"] = operation.operation_id 78 | if operation.deprecated: 79 | metadata["deprecated"] = operation.deprecated 80 | if timeout := getattr(operation, "suggested_timeout", None): 81 | metadata["x-suggested-timeout"] = timeout 82 | return metadata 83 | 84 | 85 | def generate_asyncapi_request_channel(operation: Request, model_name_map: dict[str, Any]) -> Any: 86 | 87 | operation_results = get_flat_response_models(operation.result) 88 | 89 | request_field_ref: str = REF_PREFIX + operation.params.__name__ 90 | reply_field_refs: str = [REF_PREFIX + o.__name__ for o in operation_results] 91 | failed_reply_ref: str = REF_PREFIX + JsonRPCError.__name__ 92 | 93 | operation_schema = get_asyncapi_request_operation_metadata(operation) 94 | payload = {"payload": {"$ref": request_field_ref}} 95 | operation_schema["message"] = payload 96 | replies = [] 97 | if len(reply_field_refs) > 1: 98 | anyofs = [{"$ref": r} for r in reply_field_refs] 99 | replies.append({"payload": {"anyOf": anyofs}}) 100 | else: 101 | replies.append({"payload": {"$ref": reply_field_refs[0]}}) 102 | replies.append({"payload": {"$ref": failed_reply_ref}}) 103 | operation_schema["tags"] = [{"name": tag} for tag in operation.tags] 104 | 105 | operation_schema["replies"] = replies 106 | return {"request": operation_schema, "deprecated": operation.deprecated} 107 | 108 | 109 | def generate_asyncapi_publish_channel(operation: Publish, model_name_map: dict[str, Any]) -> Any: 110 | request_field_ref: str = REF_PREFIX + operation.params.__name__ 111 | 112 | operation_schema = get_asyncapi_request_operation_metadata(operation) 113 | payload = {"payload": {"$ref": request_field_ref}} 114 | operation_schema["message"] = payload 115 | operation_schema["tags"] = [{"name": tag} for tag in operation.tags] 116 | return {"publish": operation_schema, "deprecated": operation.deprecated} 117 | 118 | 119 | def domain_errors_schema(lower_bound: int, upper_bound: int, exceptions: list[Exception]): 120 | schema = {} 121 | schema["range"] = {"upper": upper_bound, "lower": lower_bound} 122 | errors = [] 123 | for exc in exceptions: 124 | try: 125 | error = exc(data="") 126 | except Exception: 127 | # Quick fix for FormattedException 128 | error = exc(detail="") 129 | 130 | if hasattr(error, "rpc_code"): 131 | code = error.rpc_code 132 | elif hasattr(error, "code"): 133 | code = error.code 134 | else: 135 | raise AttributeError(f"'{exc}' has no 'code' or 'rpc_code' attribute") 136 | 137 | if hasattr(error, "msg"): 138 | message = error.msg 139 | elif hasattr(error, "message"): 140 | message = error.message 141 | else: 142 | raise AttributeError(f"'{exc}' has no 'message' or 'msg' attribute") 143 | 144 | errors.append({"code": code, "message": message}) 145 | schema["items"] = errors 146 | return schema 147 | 148 | 149 | def get_sub_operation_schema(sub: Sub) -> tuple[str, dict[str, Any]]: 150 | _sub = { 151 | "summary": sub.summary, 152 | "description": sub.description, 153 | "tags": [{"name": tag} for tag in sub.tags] if len(sub.tags) > 0 else None, 154 | "externalDocs": sub.externalDocs, 155 | "message": {"summary": sub.summary}, 156 | } 157 | op = {"subscribe": _sub} 158 | return sub.subject, op 159 | 160 | 161 | def get_pub_operation_schema(pub: Pub, model_name_map: dict[str, Any]) -> tuple[str, dict[str, Any]]: 162 | pub_payload: str = REF_PREFIX + pub.params.__name__ 163 | 164 | _pub = { 165 | "summary": pub.summary, 166 | "description": pub.description, 167 | "externalDocs": pub.externalDocs, 168 | "message": {"payload": {"$ref": pub_payload}, "pub": pub.summary}, 169 | "tags": [{"name": tag} for tag in pub.tags] if len(pub.tags) > 0 else None, 170 | } 171 | op = {"publish": _pub} 172 | return pub.subject, op 173 | 174 | 175 | def get_asyncapi( 176 | title: str, 177 | version: str, 178 | asyncapi_version: str, 179 | external_docs: ExternalDocumentation, 180 | errors: Errors, 181 | routes: dict[str, Request], 182 | subs: list[Sub], 183 | pubs: list[Pub], 184 | description: Optional[str] = None, 185 | servers: Optional[dict[str, Server]] = None, 186 | ) -> dict[str, Any]: 187 | subjects: dict[str, dict[str, Any]] = {} 188 | info = {"title": title, "version": version} 189 | info["description"] = description if description else None 190 | components: dict[str, dict[str, Any]] = {} 191 | 192 | output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info} 193 | 194 | all_fields = get_fields_from_routes(routes.values(), pubs) 195 | model_name_map = get_compat_model_name_map(all_fields) 196 | schema_generator = MyGenerateJsonSchema(ref_template=REF_TEMPLATE) 197 | 198 | # TODO: <26-02-25, Sebastiaan Van Hoecke> # Where to use the first paramter (see https://github.com/fastapi/fastapi/blob/master/fastapi/openapi/utils.py#L493) 199 | _, definitions = get_definitions( 200 | fields=all_fields, 201 | schema_generator=schema_generator, 202 | model_name_map=model_name_map, 203 | ) 204 | definitions[JsonRPCError.__name__] = JsonRPCError.schema() 205 | components["schemas"] = definitions 206 | 207 | subjects: dict[str, dict[str, Any]] = {} 208 | for subject, endpoint in routes.items(): 209 | if getattr(endpoint, "include_schema", None) and isinstance(endpoint, Request): 210 | result = generate_asyncapi_request_channel(endpoint, model_name_map) 211 | subjects[subject] = result 212 | elif getattr(endpoint, "include_schema", None) and isinstance(endpoint, Publish): 213 | result = generate_asyncapi_publish_channel(endpoint, model_name_map) 214 | subjects[subject] = result 215 | 216 | for sub in subs: 217 | channel, operation = get_sub_operation_schema(sub) 218 | 219 | subjects[channel] = operation 220 | 221 | for pub in pubs: 222 | channel, operation = get_pub_operation_schema(pub, model_name_map) 223 | subjects[channel] = operation 224 | 225 | info["description"] = description if description else None 226 | output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info} 227 | output["servers"] = servers if {n: s.dict() for n, s in servers.items()} else None 228 | 229 | output["externalDocs"] = external_docs.dict() if external_docs else None 230 | output["errors"] = domain_errors_schema(errors.lower_bound, errors.upper_bound, errors.errors) if errors else None 231 | 232 | output["channels"] = subjects if len(subjects) > 0 else None 233 | output["components"] = components 234 | 235 | return jsonable_encoder(AsyncAPI(**output), by_alias=True, exclude_none=True) 236 | -------------------------------------------------------------------------------- /natsapi/client/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import NatsClient 2 | from .config import Config, ConnectConfig, SubscribeConfig 3 | 4 | __all__ = ["NatsClient", "Config", "SubscribeConfig", "ConnectConfig"] 5 | -------------------------------------------------------------------------------- /natsapi/client/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import inspect 3 | import json 4 | import logging 5 | import secrets 6 | from collections.abc import Callable 7 | from ssl import create_default_context 8 | from typing import Any, Optional 9 | from uuid import uuid4 10 | 11 | from nats.aio.client import Client as NATS 12 | 13 | from natsapi.context import CTX_JSONRPC_ID 14 | from natsapi.exceptions import JsonRPCException, JsonRPCUnknownMethodException 15 | from natsapi.models import JsonRPCError, JsonRPCReply, JsonRPCRequest 16 | from natsapi.routing import Request 17 | 18 | from .config import Config, default_config 19 | 20 | 21 | class NatsClient: 22 | def __init__( 23 | self, 24 | routes: dict[str, Request], 25 | app: Any = None, 26 | config: Optional[Config] = None, 27 | exception_handlers: Optional[dict[type[Exception], Callable[[type[Exception]], JsonRPCException]]] = None, 28 | ) -> None: 29 | self.routes = routes 30 | self.app = app 31 | self.config = config or default_config 32 | self._exception_handlers = exception_handlers 33 | self.nats = NATS() 34 | 35 | async def connect(self) -> None: 36 | cfg = self.config.connect 37 | cfg.error_cb = cfg.error_cb or self._error_cb 38 | cfg.closed_cb = cfg.closed_cb or self._closed_cb 39 | cfg.reconnected_cb = cfg.reconnected_cb or self._reconnected_cb 40 | cfg.tls = cfg.tls or create_default_context() 41 | 42 | await self.nats.connect(**(cfg.dict())) 43 | 44 | async def root_path_subscribe(self, subject: str, cb: Callable, queue: str = ""): 45 | await self.nats.subscribe(subject, cb=cb, **(self.config.subscribe.dict())) 46 | 47 | async def publish(self, subject: str, params: dict[str, Any], method: str = None, reply=None, headers: dict = None): 48 | """ 49 | method: legacy attribute, used for backwards compatibility 50 | """ 51 | json_rpc_payload = JsonRPCRequest(id=uuid4(), params=params, method=method, timeout=-1) 52 | await self.nats.publish(subject, json_rpc_payload.json().encode(), reply=reply, headers=headers) 53 | 54 | async def publish_on_reply(self, subject, payload): 55 | await self.nats.publish(subject, payload) 56 | 57 | async def request( 58 | self, 59 | subject: str, 60 | params: dict[str, Any] = dict(), 61 | timeout=60, 62 | method: str = None, 63 | headers: dict = None, 64 | ) -> JsonRPCReply: 65 | """ 66 | method: legacy attribute, used for backwards compatibility 67 | """ 68 | json_rpc_payload = JsonRPCRequest(params=params, method=method, timeout=timeout) 69 | reply_raw = await self.nats.request(subject, json_rpc_payload.json().encode(), timeout, headers=headers) 70 | reply = JsonRPCReply.parse_raw(reply_raw.data) 71 | return reply 72 | 73 | async def handle_request(self, msg): 74 | if msg.reply and msg.reply != "None": 75 | asyncio.create_task(self._handle_request(msg), name="natsapi_" + secrets.token_hex(16)) 76 | else: 77 | asyncio.create_task(self._handle_publish(msg), name="natsapi_" + secrets.token_hex(16)) 78 | 79 | async def _handle_publish(self, msg): 80 | request = JsonRPCRequest.parse_raw(msg.data) 81 | request.id = request.id or uuid4() 82 | 83 | subject = msg.subject 84 | 85 | if subject not in self.routes and request.method: 86 | subject = ".".join([subject, request.method]) 87 | 88 | try: 89 | logging.debug(f"Handling: {subject}") 90 | route: Request = self.routes[subject] 91 | except KeyError as e: 92 | raise JsonRPCUnknownMethodException(data=f"No such endpoint available for {subject}") from e 93 | 94 | handler = route.endpoint 95 | params_model = route.params 96 | params = request.params if route.skip_validation else vars(params_model.parse_obj(request.params)) 97 | 98 | if inspect.iscoroutinefunction(handler): 99 | await handler(app=self.app, **params) 100 | else: 101 | handler(self.app, **params) 102 | 103 | async def _handle_request(self, msg): 104 | request = result = None 105 | try: 106 | request = JsonRPCRequest.parse_raw(msg.data) 107 | request.id = request.id or uuid4() 108 | 109 | CTX_JSONRPC_ID.set(request.id) 110 | subject = msg.subject 111 | 112 | if subject not in self.routes and request.method: 113 | subject = ".".join([subject, request.method]) 114 | 115 | try: 116 | logging.debug(f"Handling: {subject}") 117 | route: Request = self.routes[subject] 118 | except KeyError as e: 119 | raise JsonRPCUnknownMethodException(data=f"No such endpoint available. Checked for {subject}") from e 120 | 121 | handler = route.endpoint 122 | params_model = route.params 123 | params = request.params if route.skip_validation else vars(params_model.parse_obj(request.params)) 124 | 125 | if inspect.iscoroutinefunction(handler): 126 | result = await handler(app=self.app, **params) 127 | else: 128 | result = handler(self.app, **params) 129 | 130 | if not isinstance(result, dict): 131 | if hasattr(result, "dict"): 132 | result = result.dict() 133 | elif hasattr(result, "json"): 134 | result = json.loads(result.json()) 135 | 136 | reply = JsonRPCReply(id=request.id, result=result) 137 | except Exception as exc: 138 | if not request: 139 | request = JsonRPCRequest(params={}, timeout=60) 140 | exception_handler = self._lookup_exception_handler(exc) 141 | if inspect.iscoroutinefunction(exception_handler): 142 | error: JsonRPCError = await exception_handler(exc, request, msg.subject) 143 | else: 144 | error: JsonRPCError = exception_handler(exc, request, msg.subject) 145 | reply = JsonRPCReply(id=request.id, error=error) 146 | finally: 147 | await self.publish_on_reply(msg.reply, reply.json().encode()) 148 | 149 | def _lookup_exception_handler(self, exc: Exception) -> Optional[Callable]: 150 | """ 151 | Gets list of all the types the exception instance inherits from and checks if 152 | exception type is in the 'exception_handlers' dict. 153 | 154 | e.g. your handler throws a BrokerNotFoundException, the generated list will be: 155 | [BrokerNotFoundException, FormattedException, Exception, BaseException] 156 | 157 | If you have written an handler for FormattedException, this method will return that handler. 158 | Worst-case is getting the default handler for Exception 159 | 160 | The method will only return 'None' if you have a custom exception inheriting from BaseException. 161 | But inheriting from BaseException is bad practice, and your application should crash if you do this. 162 | """ 163 | for cls in type(exc).__mro__: 164 | if cls in self._exception_handlers: 165 | return self._exception_handlers[cls] 166 | return None 167 | 168 | async def _error_cb(self, e): 169 | logging.exception(e) 170 | 171 | async def _closed_cb(self): 172 | logging.warning("NATS CLOSED") 173 | 174 | async def _reconnected_cb(self): 175 | logging.warning(f"Got reconnected to {self.nats.connected_url.netloc}") 176 | 177 | async def shutdown(self, signal=None): 178 | await self.nats.drain() 179 | logging.info("All NATS connections put in drain state.") 180 | await self.nats.close() 181 | logging.info("All NATS connections closed.") 182 | -------------------------------------------------------------------------------- /natsapi/client/config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | 3 | from nats.aio.client import ( 4 | DEFAULT_CONNECT_TIMEOUT, 5 | DEFAULT_DRAIN_TIMEOUT, 6 | DEFAULT_MAX_FLUSHER_QUEUE_SIZE, 7 | DEFAULT_MAX_OUTSTANDING_PINGS, 8 | DEFAULT_MAX_RECONNECT_ATTEMPTS, 9 | DEFAULT_PENDING_SIZE, 10 | DEFAULT_PING_INTERVAL, 11 | DEFAULT_RECONNECT_TIME_WAIT, 12 | DEFAULT_SUB_PENDING_BYTES_LIMIT, 13 | DEFAULT_SUB_PENDING_MSGS_LIMIT, 14 | ) 15 | 16 | from natsapi._compat import BaseSettings 17 | 18 | 19 | class ConnectConfig(BaseSettings): 20 | servers: Union[str, list[str]] = ["nats://127.0.0.1:4222"] 21 | error_cb: Any = None 22 | closed_cb: Any = None 23 | reconnected_cb: Any = None 24 | disconnected_cb: Any = None 25 | discovered_server_cb: Any = None 26 | name: Any = None 27 | pedantic: Any = False 28 | verbose: Any = False 29 | allow_reconnect: Any = True 30 | connect_timeout: Any = DEFAULT_CONNECT_TIMEOUT 31 | reconnect_time_wait: Any = DEFAULT_RECONNECT_TIME_WAIT 32 | max_reconnect_attempts: Any = DEFAULT_MAX_RECONNECT_ATTEMPTS 33 | ping_interval: Any = DEFAULT_PING_INTERVAL 34 | max_outstanding_pings: Any = DEFAULT_MAX_OUTSTANDING_PINGS 35 | dont_randomize: Any = False 36 | flusher_queue_size: Any = DEFAULT_MAX_FLUSHER_QUEUE_SIZE 37 | no_echo: Any = False 38 | tls: Any = None 39 | tls_hostname: Any = None 40 | user: Any = None 41 | password: Any = None 42 | token: Any = None 43 | drain_timeout: Any = DEFAULT_DRAIN_TIMEOUT 44 | signature_cb: Any = None 45 | user_jwt_cb: Any = None 46 | user_credentials: Any = None 47 | nkeys_seed: Optional[str] = None 48 | flush_timeout: Optional[float] = None 49 | pending_size: int = DEFAULT_PENDING_SIZE 50 | 51 | 52 | class SubscribeConfig(BaseSettings): 53 | queue: Any = "" 54 | future: Any = None 55 | max_msgs: Any = 0 56 | pending_msgs_limit: Any = DEFAULT_SUB_PENDING_MSGS_LIMIT 57 | pending_bytes_limit: Any = DEFAULT_SUB_PENDING_BYTES_LIMIT 58 | 59 | 60 | class Config(BaseSettings): 61 | connect: ConnectConfig = ConnectConfig() 62 | subscribe: SubscribeConfig = SubscribeConfig() 63 | 64 | 65 | default_config = Config() 66 | -------------------------------------------------------------------------------- /natsapi/context.py: -------------------------------------------------------------------------------- 1 | import contextvars 2 | 3 | CTX_JSONRPC_ID = contextvars.ContextVar("jsonrpc_id") 4 | 5 | __all__ = ["CTX_JSONRPC_ID"] 6 | -------------------------------------------------------------------------------- /natsapi/encoders.py: -------------------------------------------------------------------------------- 1 | """Yanked from fastapi.encoders""" 2 | 3 | from collections import defaultdict 4 | from collections.abc import Callable 5 | from enum import Enum 6 | from pathlib import PurePath 7 | from types import GeneratorType 8 | from typing import Any, Union 9 | 10 | from pydantic import BaseModel 11 | 12 | from natsapi._compat import ENCODERS_BY_TYPE, PYDANTIC_V2 13 | 14 | SetIntStr = set[Union[int, str]] 15 | DictIntStrAny = dict[Union[int, str], Any] 16 | 17 | 18 | def generate_encoders_by_class_tuples( 19 | type_encoder_map: dict[Any, Callable[[Any], Any]], 20 | ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: 21 | encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) 22 | for type_, encoder in type_encoder_map.items(): 23 | encoders_by_class_tuples[encoder] += (type_,) 24 | return encoders_by_class_tuples 25 | 26 | 27 | encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) 28 | 29 | 30 | def jsonable_encoder( 31 | obj: Any, 32 | include: Union[SetIntStr, DictIntStrAny, None] = None, 33 | exclude: Union[SetIntStr, DictIntStrAny, None] = None, 34 | by_alias: bool = True, 35 | exclude_unset: bool = False, 36 | exclude_defaults: bool = False, 37 | exclude_none: bool = False, 38 | custom_encoder: dict[Any, Callable[[Any], Any]] = dict(), 39 | sqlalchemy_safe: bool = True, 40 | ) -> Any: 41 | if include is not None and not isinstance(include, set): 42 | include = set(include) 43 | if exclude is not None and not isinstance(exclude, set): 44 | exclude = set(exclude) 45 | if isinstance(obj, BaseModel): 46 | 47 | if PYDANTIC_V2: 48 | encoder = obj.model_config.get("json_encoders", {}) 49 | else: 50 | encoder = getattr(obj.__config__, "json_encoders", {}) 51 | 52 | if custom_encoder: 53 | encoder.update(custom_encoder) 54 | obj_dict = obj.dict( 55 | include=include, # type: ignore # in Pydantic 56 | exclude=exclude, # type: ignore # in Pydantic 57 | by_alias=by_alias, 58 | exclude_unset=exclude_unset, 59 | exclude_none=exclude_none, 60 | exclude_defaults=exclude_defaults, 61 | ) 62 | if "__root__" in obj_dict: 63 | obj_dict = obj_dict["__root__"] 64 | return jsonable_encoder( 65 | obj_dict, 66 | exclude_none=exclude_none, 67 | exclude_defaults=exclude_defaults, 68 | custom_encoder=encoder, 69 | sqlalchemy_safe=sqlalchemy_safe, 70 | ) 71 | if isinstance(obj, Enum): 72 | return obj.value 73 | if isinstance(obj, PurePath): 74 | return str(obj) 75 | if isinstance(obj, (str, int, float, type(None))): 76 | return obj 77 | if isinstance(obj, dict): 78 | encoded_dict = {} 79 | for key, value in obj.items(): 80 | if ( 81 | (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) 82 | and (value is not None or not exclude_none) 83 | and ((include and key in include) or not exclude or key not in exclude) 84 | ): 85 | encoded_key = jsonable_encoder( 86 | key, 87 | by_alias=by_alias, 88 | exclude_unset=exclude_unset, 89 | exclude_none=exclude_none, 90 | custom_encoder=custom_encoder, 91 | sqlalchemy_safe=sqlalchemy_safe, 92 | ) 93 | encoded_value = jsonable_encoder( 94 | value, 95 | by_alias=by_alias, 96 | exclude_unset=exclude_unset, 97 | exclude_none=exclude_none, 98 | custom_encoder=custom_encoder, 99 | sqlalchemy_safe=sqlalchemy_safe, 100 | ) 101 | encoded_dict[encoded_key] = encoded_value 102 | return encoded_dict 103 | if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): 104 | encoded_list = [] 105 | for item in obj: 106 | encoded_list.append( 107 | jsonable_encoder( 108 | item, 109 | include=include, 110 | exclude=exclude, 111 | by_alias=by_alias, 112 | exclude_unset=exclude_unset, 113 | exclude_defaults=exclude_defaults, 114 | exclude_none=exclude_none, 115 | custom_encoder=custom_encoder, 116 | sqlalchemy_safe=sqlalchemy_safe, 117 | ), 118 | ) 119 | return encoded_list 120 | 121 | if custom_encoder: 122 | if type(obj) in custom_encoder: 123 | return custom_encoder[type(obj)](obj) 124 | else: 125 | for encoder_type, encoder in custom_encoder.items(): 126 | if isinstance(obj, encoder_type): 127 | return encoder(obj) 128 | 129 | if type(obj) in ENCODERS_BY_TYPE: 130 | return ENCODERS_BY_TYPE[type(obj)](obj) 131 | for encoder, classes_tuple in encoders_by_class_tuples.items(): 132 | if isinstance(obj, classes_tuple): 133 | return encoder(obj) 134 | 135 | errors: list[Exception] = [] 136 | try: 137 | data = dict(obj) 138 | except Exception as e: 139 | errors.append(e) 140 | try: 141 | data = vars(obj) 142 | except Exception as e: 143 | errors.append(e) 144 | raise ValueError(errors) from e 145 | return jsonable_encoder( 146 | data, 147 | by_alias=by_alias, 148 | exclude_unset=exclude_unset, 149 | exclude_defaults=exclude_defaults, 150 | exclude_none=exclude_none, 151 | custom_encoder=custom_encoder, 152 | sqlalchemy_safe=sqlalchemy_safe, 153 | ) 154 | -------------------------------------------------------------------------------- /natsapi/enums.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | JSON_RPC_VERSION = Literal["2.0"] 4 | -------------------------------------------------------------------------------- /natsapi/exception_handlers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pydantic import ValidationError 4 | 5 | from natsapi.exceptions import JsonRPCException 6 | from natsapi.models import ErrorData, JsonRPCError, JsonRPCRequest 7 | 8 | 9 | def get_validation_target(e): 10 | err_loc_tree = [str(loc_part) for loc_part in e["loc"]] 11 | target = ".".join(err_loc_tree) 12 | return target 13 | 14 | 15 | def handle_jsonrpc_exception(exc: JsonRPCException, request: JsonRPCRequest, subject: str) -> JsonRPCError: 16 | try: 17 | data = ErrorData.parse_obj(exc.data) 18 | except Exception: 19 | data = ErrorData(type=type(exc).__name__, errors=[]) 20 | logging.error( 21 | exc, 22 | exc_info=True, 23 | stack_info=True, 24 | extra={ 25 | "json_rpc_id": request.id, 26 | "auth": request.params.get("auth"), 27 | "json": request.dict(), 28 | "subject": subject, 29 | "NATS": True, 30 | "code": exc.code, 31 | }, 32 | ) 33 | return JsonRPCError(code=exc.code, message=exc.message, data=data) 34 | 35 | 36 | def handle_validation_error(exc: ValidationError, request: JsonRPCRequest, subject: str) -> JsonRPCError: 37 | errors = [] 38 | for ve in exc.errors(): 39 | detail_target = get_validation_target(ve) 40 | detail_msg = ve["msg"] 41 | errors.append({"type": type(exc).__name__, "target": detail_target, "message": detail_msg}) 42 | 43 | logging.error( 44 | exc, 45 | exc_info=True, 46 | stack_info=True, 47 | extra={ 48 | "json_rpc_id": request.id, 49 | "auth": request.params.get("auth"), 50 | "json": request.dict(), 51 | "subject": subject, 52 | "NATS": True, 53 | "code": -40001, 54 | }, 55 | ) 56 | data = ErrorData(type=type(exc).__name__, errors=errors) 57 | msg = "Invalid data was provided or some data is missing." 58 | return JsonRPCError(code=-40001, message=msg, data=data) 59 | 60 | 61 | def handle_internal_error(exc: Exception, request: JsonRPCRequest, subject: str) -> JsonRPCError: 62 | code = -40000 63 | try: 64 | # Try parsing FormattedException variants 65 | code = exc.rpc_code 66 | message = f"{exc.msg}: {exc.detail}" 67 | except Exception: 68 | message = str(exc) 69 | code = -40000 70 | logging.error( 71 | exc, 72 | exc_info=True, 73 | stack_info=True, 74 | extra={ 75 | "json_rpc_id": request.id, 76 | "auth": request.params.get("auth"), 77 | "json": request.dict(), 78 | "subject": subject, 79 | "NATS": True, 80 | "code": code, 81 | }, 82 | ) 83 | data = ErrorData(type=type(exc).__name__, errors=[]) 84 | return JsonRPCError(code=code, message=message, data=data) 85 | -------------------------------------------------------------------------------- /natsapi/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class NatsAPIError(RuntimeError): ... 5 | 6 | 7 | class DuplicateRouteException(NatsAPIError): 8 | """ 9 | Raised when a route with the same nats subject is added to app. 10 | """ 11 | 12 | def __init__(self, msg: str): 13 | self.msg = msg 14 | 15 | def __str__(self): 16 | return f"{self.__class__.__name__} {self.msg}" 17 | 18 | 19 | class JsonRPCException(Exception): 20 | def __init__(self, code: int, message: str, data: Any = None): 21 | super().__init__(message) 22 | self.message = message 23 | self.code = code 24 | self.data = data 25 | 26 | 27 | class JsonRPCRequestException(JsonRPCException): 28 | def __init__(self, data: Any = None): 29 | self.code = -32600 30 | self.message = "INVALID_REQUEST_FORMAT" 31 | self.data = data 32 | 33 | 34 | class JsonRPCUnknownMethodException(JsonRPCException): 35 | def __init__(self, data: Any = None): 36 | self.code = -32601 37 | self.message = "NO_SUCH_ENDPOINT" 38 | self.data = data 39 | 40 | 41 | class JsonRPCInvalidMethodParamsException(JsonRPCException): 42 | def __init__(self, message, data: Any = None): 43 | self.code = -32602 44 | self.message = message 45 | self.data = data 46 | 47 | 48 | class JsonRPCInternalErrorException(JsonRPCException): 49 | def __init__(self, data: Any = None): 50 | self.code = -32603 51 | self.message = "INTERNAL_ERROR" 52 | self.data = data 53 | 54 | 55 | class JsonRPCInvalidParamsException(JsonRPCException): 56 | def __init__(self, data: Any = None): 57 | self.code = -32602 58 | self.message = "INVALID_PARAMETERS_RECEIVED" 59 | self.data = data 60 | -------------------------------------------------------------------------------- /natsapi/logger.py: -------------------------------------------------------------------------------- 1 | # Make logging prettier with this 2 | # https://github.com/encode/uvicorn/blob/master/uvicorn/logging.py 3 | import logging 4 | 5 | logger = logging.getLogger("natsapi") 6 | logging.basicConfig(level=logging.INFO) 7 | -------------------------------------------------------------------------------- /natsapi/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Any, Optional 3 | from uuid import UUID, uuid4 4 | 5 | from pydantic import BaseModel, Field, create_model, root_validator, validator 6 | 7 | from natsapi.enums import JSON_RPC_VERSION 8 | 9 | 10 | class ErrorDetail(BaseModel): 11 | type: str 12 | target: Optional[str] = None 13 | message: str 14 | 15 | 16 | class ErrorData(BaseModel): 17 | type: Optional[str] = None 18 | errors: list[ErrorDetail] = [] 19 | 20 | 21 | class JsonRPCError(BaseModel): 22 | code: int = Field(..., description="Error code that falls in the predefined error range for this type of exception") 23 | message: str = Field( 24 | ..., 25 | description="A message providing a short description of the error. SHOULD be limited to a concise single sentence", 26 | ) 27 | timestamp: datetime = Field(default_factory=datetime.now, description="Timestamp of when the error occured") 28 | data: Any = Field(None, description="Additional information about the error") 29 | 30 | 31 | class JsonRPCReply(BaseModel): 32 | jsonrpc: JSON_RPC_VERSION = Field("2.0") 33 | id: UUID = Field(...) 34 | result: Optional[dict[str, Any]] = Field(None) 35 | error: Optional[JsonRPCError] = Field(None) 36 | 37 | @root_validator(pre=True) 38 | def check_result_and_error(cls, values): 39 | result, error = values.get("result"), values.get("error") 40 | assert result or error, "A result or error should be required" 41 | if result and error: 42 | raise AttributeError( 43 | "An RPC reply MUST NOT have an error and a result. Based on the result, you should provide only one.", 44 | ) 45 | return values 46 | 47 | 48 | class JsonRPCRequest(BaseModel): 49 | jsonrpc: Optional[JSON_RPC_VERSION] = Field("2.0") 50 | timeout: Optional[float] = Field( 51 | None, 52 | description="Timeout set by client, should be equal to the timeout set when doing nc.request, if publish use '-1'", 53 | ) 54 | method: Optional[str] = Field(None, description="Request method used") 55 | params: dict[str, Any] = Field(...) 56 | id: Optional[UUID] = Field(None, alias="id", description="UUID created at the creation of the request") 57 | 58 | @validator("id", pre=True, always=True) 59 | def set_id(cls, id): 60 | return id or uuid4() 61 | 62 | @classmethod 63 | def with_params(self, params: BaseModel): 64 | return create_model("JsonRPC" + params.__name__, __base__=self, params=(params, ...)) 65 | -------------------------------------------------------------------------------- /natsapi/plugin.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | from collections import defaultdict 5 | from typing import Any 6 | from uuid import uuid4 7 | 8 | import pytest 9 | from nats.aio.client import Client as NATS 10 | from nats.aio.client import Msg 11 | 12 | from natsapi.exceptions import JsonRPCUnknownMethodException 13 | from natsapi.models import JsonRPCReply 14 | 15 | 16 | class NatsapiMock: 17 | """ 18 | A simple, yet powerful utility to mock nats requests. 19 | define subject with response and error fakes, and NatsMock 20 | will give you the possibility to assert against payloads provided 21 | and the fakes that were given. 22 | 23 | A mock that can be used as follows: 24 | 25 | ``` 26 | async def test(natsapi_mock): 27 | await natsapi_mock.request("", response={"foo": "bar"}) 28 | 29 | reply = await nats.request("", {"id": 1}, timeout=1) 30 | 31 | # use the fakes in your app 32 | assert reply.result == {"foo": "bar"} 33 | 34 | # test against the payload provided in the nats request 35 | assert nats_service_mock.payloads[""][0]["params"]["id"] == 1 36 | ``` 37 | """ 38 | 39 | def __init__(self, host: str) -> None: 40 | self.host = host 41 | self.nats: NATS = NATS() 42 | self.subjects: dict[str, tuple[Any, Any]] = {} 43 | self.responses: dict[str, tuple[Any, Any]] = {} 44 | self.payloads: dict[str, Any] = defaultdict(list) 45 | 46 | async def lifespan(self) -> None: 47 | await self.nats.connect(self.host, verbose=True, ping_interval=5) 48 | 49 | async def wait_startup(self): 50 | counter = 0 51 | while self.nats.is_connected is False: 52 | await asyncio.sleep(0.1) 53 | counter += 1 54 | if counter == 10: 55 | raise Exception("Waited to long for nats to connect!") 56 | 57 | async def handle(self, message: Msg) -> None: 58 | payload = json.loads(message.data.decode("utf-8")) 59 | 60 | try: 61 | result, error = self.responses[message.subject] 62 | 63 | if not isinstance(result, dict): 64 | if hasattr(result, "dict"): 65 | result = result.dict() 66 | elif hasattr(result, "json"): 67 | result = json.loads(result.json()) 68 | 69 | response = JsonRPCReply(jsonrpc="2.0", id=uuid4(), error=error, result=result) 70 | except KeyError: 71 | exc = JsonRPCUnknownMethodException() 72 | response = JsonRPCReply(jsonrpc="2.0", id=uuid4(), error={"code": -1, "message": str(exc.message)}) 73 | except Exception as e: 74 | response = JsonRPCReply(jsonrpc="2.0", id=uuid4(), error={"code": -1, "message": f"{type(e)}: {str(e)}"}) 75 | finally: 76 | if message.subject not in self.payloads: 77 | self.payloads[message.subject] = [] 78 | self.payloads[message.subject].append(payload) 79 | 80 | if message.reply: 81 | await self.nats.publish(message.reply, response.json().encode()) 82 | 83 | async def request(self, subject: str, *, response: Any = None, error: dict[str, Any] = None) -> None: 84 | assert response or error, "Need a response of an error" 85 | self.responses[subject] = response, error 86 | await self.nats.subscribe(subject, cb=self.handle) 87 | await self.nats.flush(timeout=5) 88 | 89 | async def publish(self, subject: str) -> None: 90 | await self.nats.subscribe(subject, cb=self.handle) 91 | await self.nats.flush(timeout=5) 92 | 93 | async def __aenter__(self): 94 | await self.lifespan() 95 | await self.wait_startup() 96 | return self 97 | 98 | async def __aexit__(self, *args): 99 | await self.nats.drain() 100 | await self.nats.close() 101 | 102 | 103 | @pytest.fixture 104 | async def natsapi_mock() -> NatsapiMock: 105 | host = os.environ.get("HOST_NATS", "nats://127.0.0.1:4222") 106 | async with NatsapiMock(host=host) as mock: 107 | yield mock 108 | -------------------------------------------------------------------------------- /natsapi/routing.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections.abc import Callable 3 | from typing import Any, Optional 4 | 5 | from pydantic import BaseModel 6 | 7 | from natsapi.asyncapi import ExternalDocumentation 8 | from natsapi.types import DecoratedCallable 9 | from natsapi.utils import create_field, generate_operation_id_for_subject, get_request_model, get_summary 10 | 11 | 12 | class Request: 13 | def __init__( 14 | self, 15 | subject: str, 16 | endpoint: Callable[..., Any], 17 | *, 18 | result=type[Any], 19 | skip_validation: Optional[bool] = False, 20 | description: Optional[str] = None, 21 | deprecated: Optional[bool] = None, 22 | tags: Optional[list[str]] = None, 23 | summary: Optional[str] = None, 24 | include_schema: Optional[bool] = True, 25 | suggested_timeout: Optional[float] = None, 26 | ): 27 | self.subject = subject 28 | self.endpoint = endpoint 29 | self.skip_validation = skip_validation 30 | self.summary = summary or get_summary(endpoint) or subject 31 | self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject) 32 | self.result = result 33 | self.params = get_request_model(self.endpoint, subject, self.skip_validation) 34 | reply_name = "Reply_" + self.operation_id 35 | request_name = "Request_" + self.operation_id 36 | self.reply_field = create_field(name=reply_name, type_=self.params) 37 | self.request_field = create_field(name=request_name, type_=self.result) 38 | 39 | self.tags = tags or [] 40 | self.description = description or "" 41 | self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") 42 | self.deprecated = deprecated 43 | self.include_schema = include_schema 44 | self.suggested_timeout = suggested_timeout 45 | 46 | assert callable(endpoint), "An endpoint must be callable" 47 | 48 | 49 | class Publish: 50 | def __init__( 51 | self, 52 | subject: str, 53 | endpoint: Callable[..., Any], 54 | *, 55 | skip_validation: Optional[bool] = False, 56 | description: Optional[str] = None, 57 | deprecated: Optional[bool] = None, 58 | tags: Optional[list[str]] = None, 59 | summary: Optional[str] = None, 60 | include_schema: Optional[bool] = True, 61 | ): 62 | self.subject = subject 63 | self.endpoint = endpoint 64 | self.skip_validation = skip_validation 65 | self.summary = summary or get_summary(endpoint) or subject 66 | self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject) 67 | self.params = get_request_model(self.endpoint, subject, self.skip_validation) 68 | reply_name = "Reply_" + self.operation_id 69 | self.reply_field = create_field(name=reply_name, type_=self.params) 70 | 71 | self.tags = tags or [] 72 | self.description = description or "" 73 | self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") 74 | self.deprecated = deprecated 75 | self.include_schema = include_schema 76 | 77 | assert callable(endpoint), "An endpoint must be callable" 78 | 79 | 80 | class Sub: 81 | def __init__( 82 | self, 83 | subject: str, 84 | *, 85 | queue: Optional[str] = None, 86 | summary: Optional[str] = None, 87 | description: Optional[str] = None, 88 | tags: Optional[list[str]] = None, 89 | externalDocs: Optional[ExternalDocumentation] = None, 90 | ): 91 | 92 | self.subject = subject 93 | self.queue = queue 94 | self.summary = summary 95 | self.description = description 96 | self.tags = tags or [] 97 | self.externalDocs = externalDocs 98 | 99 | 100 | class Pub: 101 | def __init__( 102 | self, 103 | subject: str, 104 | params: BaseModel, 105 | *, 106 | summary: Optional[str] = None, 107 | description: Optional[str] = None, 108 | tags: Optional[list[str]] = None, 109 | externalDocs: Optional[ExternalDocumentation] = None, 110 | ): 111 | self.subject = subject 112 | self.summary = summary 113 | self.description = description 114 | self.tags = tags or [] 115 | self.externalDocs = externalDocs 116 | self.params = params 117 | self.params_field = create_field(name="Publish_" + subject, type_=self.params) 118 | 119 | 120 | class SubjectRouter: 121 | def __init__( 122 | self, 123 | *, 124 | prefix: str = None, 125 | tags: Optional[list[str]] = None, 126 | routes: Optional[list[Request]] = None, 127 | subs: Optional[set[Sub]] = None, 128 | pubs: Optional[set[Pub]] = None, 129 | deprecated: Optional[bool] = None, 130 | include_in_schema: bool = True, 131 | ) -> None: 132 | self.prefix = prefix 133 | self.routes = routes or [] 134 | self.pubs = pubs or set() 135 | self.subs = subs or set() 136 | self.tags: list[str] = tags or [] 137 | self.deprecated = deprecated 138 | self.include_in_schema = include_in_schema 139 | 140 | def add_request( 141 | self, 142 | subject: str, 143 | endpoint: Callable[..., Any], 144 | *, 145 | result=type[Any], 146 | skip_validation: Optional[bool] = False, 147 | description: Optional[str] = None, 148 | deprecated: Optional[bool] = None, 149 | tags: Optional[list[str]] = None, 150 | summary: Optional[str] = None, 151 | suggested_timeout: Optional[float] = None, 152 | include_schema: Optional[bool] = True, 153 | ) -> None: 154 | current_tags = self.tags.copy() 155 | if tags: 156 | current_tags.extend(tags) 157 | current_subject = ".".join([self.prefix, subject]) if self.prefix is not None else subject 158 | subject = Request( 159 | subject=current_subject, 160 | endpoint=endpoint, 161 | result=result, 162 | skip_validation=skip_validation, 163 | description=description, 164 | deprecated=deprecated if deprecated is not None else self.deprecated, 165 | tags=current_tags, 166 | summary=summary, 167 | suggested_timeout=suggested_timeout, 168 | include_schema=include_schema, 169 | ) 170 | self.routes.append(subject) 171 | 172 | def add_publish( 173 | self, 174 | subject: str, 175 | endpoint: Callable[..., Any], 176 | *, 177 | skip_validation: Optional[bool] = False, 178 | description: Optional[str] = None, 179 | deprecated: Optional[bool] = None, 180 | tags: Optional[list[str]] = None, 181 | summary: Optional[str] = None, 182 | include_schema: Optional[bool] = True, 183 | ) -> None: 184 | current_tags = self.tags.copy() 185 | if tags: 186 | current_tags.extend(tags) 187 | current_subject = ".".join([self.prefix, subject]) if self.prefix is not None else subject 188 | subject = Publish( 189 | subject=current_subject, 190 | endpoint=endpoint, 191 | skip_validation=skip_validation, 192 | description=description, 193 | deprecated=deprecated if deprecated is not None else self.deprecated, 194 | tags=current_tags, 195 | summary=summary, 196 | include_schema=include_schema, 197 | ) 198 | self.routes.append(subject) 199 | 200 | def request( 201 | self, 202 | subject: str, 203 | *, 204 | result=type[Any], 205 | skip_validation: Optional[bool] = False, 206 | description: Optional[str] = None, 207 | deprecated: Optional[bool] = None, 208 | tags: Optional[list[str]] = None, 209 | summary: Optional[str] = None, 210 | suggested_timeout: Optional[float] = None, 211 | include_schema: Optional[bool] = True, 212 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 213 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 214 | self.add_request( 215 | subject=subject, 216 | endpoint=func, 217 | result=result, 218 | skip_validation=skip_validation, 219 | description=description, 220 | deprecated=deprecated, 221 | tags=tags, 222 | summary=summary, 223 | suggested_timeout=suggested_timeout, 224 | include_schema=include_schema, 225 | ) 226 | return func 227 | 228 | return decorator 229 | 230 | def publish( 231 | self, 232 | subject: str, 233 | *, 234 | result=type[Any], 235 | skip_validation: Optional[bool] = False, 236 | description: Optional[str] = None, 237 | deprecated: Optional[bool] = None, 238 | tags: Optional[list[str]] = None, 239 | summary: Optional[str] = None, 240 | include_schema: Optional[bool] = True, 241 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 242 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 243 | self.add_publish( 244 | subject=subject, 245 | endpoint=func, 246 | skip_validation=skip_validation, 247 | description=description, 248 | deprecated=deprecated, 249 | tags=tags, 250 | summary=summary, 251 | include_schema=include_schema, 252 | ) 253 | return func 254 | 255 | return decorator 256 | 257 | def add_pub( 258 | self, 259 | subject: str, 260 | params: BaseModel, 261 | *, 262 | summary: Optional[str] = None, 263 | description: Optional[str] = None, 264 | tags: Optional[list[str]] = None, 265 | externalDocs: Optional[ExternalDocumentation] = None, 266 | ) -> None: 267 | """ 268 | Include pub in asyncapi schema 269 | """ 270 | pub = Pub( 271 | subject, 272 | params, 273 | summary=summary, 274 | description=description, 275 | tags=tags or None, 276 | externalDocs=externalDocs, 277 | ) 278 | self.pubs.add(pub) 279 | 280 | def pub( 281 | self, 282 | subject: str, 283 | *, 284 | params=type[Any], 285 | description: Optional[str] = None, 286 | tags: Optional[list[str]] = None, 287 | summary: Optional[str] = None, 288 | externalDocs: Optional[ExternalDocumentation] = None, 289 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 290 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 291 | self.add_pub( 292 | subject=subject, 293 | params=params, 294 | summary=summary, 295 | description=description, 296 | tags=tags, 297 | externalDocs=externalDocs, 298 | ) 299 | return func 300 | 301 | return decorator 302 | 303 | def add_sub( 304 | self, 305 | subject: str, 306 | *, 307 | queue: Optional[str] = None, 308 | summary: Optional[str] = None, 309 | description: Optional[str] = None, 310 | tags: Optional[list[str]] = None, 311 | externalDocs: Optional[ExternalDocumentation] = None, 312 | ) -> None: 313 | """ 314 | Include sub in asyncapi schema 315 | """ 316 | sub = Sub( 317 | subject, 318 | queue=queue, 319 | summary=summary, 320 | description=description, 321 | tags=tags or None, 322 | externalDocs=externalDocs, 323 | ) 324 | self.subs.add(sub) 325 | 326 | def sub( 327 | self, 328 | subject: str, 329 | *, 330 | queue: Optional[str] = None, 331 | description: Optional[str] = None, 332 | tags: Optional[list[str]] = None, 333 | summary: Optional[str] = None, 334 | externalDocs: Optional[ExternalDocumentation] = None, 335 | ) -> Callable[[DecoratedCallable], DecoratedCallable]: 336 | def decorator(func: DecoratedCallable) -> DecoratedCallable: 337 | self.add_sub( 338 | subject=subject, 339 | queue=queue, 340 | summary=summary, 341 | description=description, 342 | tags=tags, 343 | externalDocs=externalDocs, 344 | ) 345 | return func 346 | 347 | return decorator 348 | -------------------------------------------------------------------------------- /natsapi/state.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | 4 | class State: 5 | """ 6 | Yanked from starlette.datascructures 7 | 8 | An object that can be used to store arbitrary state. 9 | 10 | Used for `app.state`. 11 | """ 12 | 13 | def __init__(self, state: dict = None): 14 | if state is None: 15 | state = {} 16 | super(State, self).__setattr__("_state", state) 17 | 18 | def __setattr__(self, key: typing.Any, value: typing.Any) -> None: 19 | self._state[key] = value 20 | 21 | def __getattr__(self, key: typing.Any) -> typing.Any: 22 | try: 23 | return self._state[key] 24 | except KeyError as e: 25 | message = "'{}' object has no attribute '{}'" 26 | raise AttributeError(message.format(self.__class__.__name__, key)) from e 27 | 28 | def __delattr__(self, key: typing.Any) -> None: 29 | del self._state[key] 30 | -------------------------------------------------------------------------------- /natsapi/types.py: -------------------------------------------------------------------------------- 1 | """Yanked from FastApi.typing""" 2 | 3 | from collections.abc import Callable 4 | from typing import Any, TypeVar 5 | 6 | DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) 7 | -------------------------------------------------------------------------------- /natsapi/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import inspect 3 | import re 4 | from collections.abc import Callable 5 | from enum import Enum 6 | from typing import Any, Optional, Union 7 | 8 | from pydantic import BaseConfig, BaseModel, create_model 9 | from pydantic.fields import FieldInfo 10 | from pydantic.v1.schema import model_process_schema 11 | 12 | from natsapi._compat import PYDANTIC_V2, ModelField 13 | from natsapi.asyncapi.constants import REF_PREFIX 14 | from natsapi.exceptions import NatsAPIError 15 | 16 | 17 | def get_summary(endpoint: Callable) -> str: 18 | """Yanked from `Starlette.routing`""" 19 | if inspect.isfunction(endpoint) or inspect.isclass(endpoint): 20 | return endpoint.__name__ if endpoint.__name__ != "_" else None 21 | return endpoint.__class__.__name__ 22 | 23 | 24 | def generate_operation_id_for_subject(*, summary: str, subject: str) -> str: 25 | operation_id = summary + "_" + subject 26 | operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) 27 | return operation_id 28 | 29 | 30 | def create_field( 31 | name: str, 32 | type_: type[Any], 33 | class_validators: Optional[dict[str, Any]] = None, 34 | model_config: type[BaseConfig] = BaseConfig, 35 | field_info: Optional[FieldInfo] = None, 36 | ) -> ModelField: 37 | """ 38 | Yanked from fastapi.utils 39 | Create a new reply field. Raises if type_ is invalid. 40 | """ 41 | class_validators = class_validators or {} 42 | 43 | field_info = (field_info or FieldInfo(annotation=type_)) if PYDANTIC_V2 else (field_info or FieldInfo()) 44 | 45 | kwargs = {"name": name, "field_info": field_info} 46 | 47 | if not PYDANTIC_V2: 48 | kwargs.update( 49 | { 50 | "type_": type_, 51 | "class_validators": class_validators, 52 | "model_config": model_config, 53 | "required": True, 54 | }, 55 | ) 56 | 57 | try: 58 | return ModelField(**kwargs) 59 | except RuntimeError as e: 60 | raise NatsAPIError( 61 | f"Invalid args for reply field! Hint: check that {type_} is a valid pydantic field type", 62 | ) from e 63 | 64 | 65 | def get_model_definitions( 66 | *, 67 | flat_models: Union[set[type[BaseModel], type[Enum]]], 68 | model_name_map: Union[dict[type[BaseModel], type[Enum], str]], 69 | ) -> dict[str, Any]: 70 | definitions: dict[str, dict[str, Any]] = {} 71 | for model in flat_models: 72 | m_schema, m_definitions, m_nested_models = model_process_schema( 73 | model, 74 | model_name_map=model_name_map, 75 | ref_prefix=REF_PREFIX, 76 | ) 77 | definitions.update(m_definitions) 78 | try: 79 | model_name = model_name_map[model] 80 | except KeyError as exc: 81 | method_name = str(exc.args[0]).replace("", "") 82 | raise NatsAPIError( 83 | f"Could not generate schema. Two or more functions share the name '{method_name}'. Make sure methods don't share the same name", 84 | ) from exc 85 | definitions[model_name] = m_schema 86 | return definitions 87 | 88 | 89 | def get_request_model(func: Callable, subject: str, skip_validation: bool): 90 | parameters = collections.OrderedDict(inspect.signature(func).parameters) 91 | name_prefix = func.__name__ if func.__name__ != "_" else subject 92 | 93 | if skip_validation: 94 | assert ( 95 | "kwargs" in parameters 96 | ), f"Add '**kwargs' to the '{name_prefix}' method as extra arguments can be passed in payload and won't be filtered out." 97 | 98 | param_fields = {} 99 | valid_app_types = ("FastAPI", "NatsAPI", "_empty") 100 | for i, parameter in enumerate(parameters.values()): 101 | if i == 0: 102 | assert parameter.name == "app", "First parameter should be named 'app'" 103 | if parameter.annotation == Any: 104 | continue 105 | else: 106 | assert ( 107 | parameter.annotation.__name__ in valid_app_types 108 | ), f"Valid types for app are: NatsAPI, FastAPI, or Any. Got {parameter.annotation.__name__}" 109 | continue 110 | 111 | if parameter.name in ["args", "kwargs"] and skip_validation: 112 | continue 113 | else: 114 | assert parameter.annotation is not inspect._empty, f"{parameter.name} has no type" 115 | default = ... if parameter.default is inspect._empty else parameter.default 116 | param_fields[parameter.name] = (parameter.annotation, default) 117 | 118 | model = create_model(f"{name_prefix}_params", **param_fields) 119 | return model 120 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "natsapi" 3 | version = "0.1.1" 4 | description = "A Python microservice framework that speaks nats.io with asyncapi spec generation capability" 5 | authors = ["WeGroup NV "] 6 | readme = "README.md" 7 | license = "MIT" 8 | homepage = "https://github.com/wegroupwolves/natsapi" 9 | repository = "https://github.com/wegroupwolves/natsapi" 10 | 11 | [tool.poetry.urls] 12 | Pypi = "https://pypi.org/project/natsapi/" 13 | 14 | [tool.poetry.build] 15 | generate-setup-file = false 16 | 17 | [tool.poetry.dependencies] 18 | python = "^3.9" 19 | pydantic = ">=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0" 20 | nats-py = {extras = ["nkeys"], version = "^2.2.0"} 21 | pydantic-settings = "^2.8.1" 22 | 23 | [tool.poetry.group.dev.dependencies] 24 | pre-commit = "^2.12.1" 25 | 26 | [tool.poetry.group.test.dependencies] 27 | pytest-cov = "^2.11.1" 28 | black = "^24.10.0" 29 | bandit = "^1.7.0" 30 | piprot = "^0.9.11" 31 | safety = "^1.10.3" 32 | pytest-asyncio = "^0.17.2" 33 | vulture = "^2.13" 34 | ruff = "^0.8.6" 35 | 36 | 37 | [tool.poetry.plugins."pytest11"] 38 | natsapi = "natsapi.plugin" 39 | 40 | [tool.pytest.ini_options] 41 | asyncio_mode = "auto" 42 | testpaths = [ 43 | "tests" 44 | ] 45 | 46 | [build-system] 47 | requires = ["poetry-core>=1.1.0"] 48 | build-backend = "poetry.core.masonry.api" 49 | 50 | [tool.ruff] 51 | line-length = 120 52 | exclude = ["scripts"] 53 | target-version = "py39" 54 | 55 | [tool.ruff.lint] 56 | extend-select = ["C90", "I", "B", "Q", "UP", "S", "COM", "C4", "T10", "SIM", "TID", "PTH", "ERA"] 57 | ignore = ["S101", "UP038", "B017", "UP007", "UP008", "B006", "C408"] 58 | 59 | [tool.ruff.lint.mccabe] 60 | max-complexity = 25 61 | 62 | [tool.black] 63 | line-length = 120 64 | 65 | [tool.coverage.report] 66 | omit=["**/site-packages/**"] 67 | -------------------------------------------------------------------------------- /render-asyncapi/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/wegroupwolves/natsapi/render-asyncapi 2 | 3 | go 1.16 4 | 5 | require github.com/nats-io/nats.go v1.11.0 // indirect 6 | -------------------------------------------------------------------------------- /render-asyncapi/go.sum: -------------------------------------------------------------------------------- 1 | github.com/nats-io/nats.go v1.11.0 h1:L263PZkrmkRJRJT2YHU8GwWWvEvmr9/LUKuJTXsF32k= 2 | github.com/nats-io/nats.go v1.11.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= 3 | github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= 4 | github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= 5 | github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= 6 | github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= 7 | golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b h1:wSOdpTq0/eI46Ez/LkDwIsAKA71YP2SRKBODiRWM0as= 8 | golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= 9 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 10 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 11 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 12 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 13 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 14 | -------------------------------------------------------------------------------- /render-asyncapi/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "os" 9 | "time" 10 | 11 | "github.com/nats-io/nats.go" 12 | ) 13 | 14 | type Handler struct { 15 | nc *nats.Conn 16 | rootPath string 17 | } 18 | 19 | func (h *Handler) asyncAPISchemaHandler(w http.ResponseWriter, req *http.Request) { 20 | data := map[string]interface{}{"jsonrpc": "2.0", "id": "ad1f2612-4f11-4667-bf7d-d20c2b5c8285", "params": map[string]string{}} 21 | payload, _ := json.Marshal(data) 22 | path := h.rootPath + ".schema.RETRIEVE" 23 | reply_raw, err := h.nc.Request(path, payload, 500*time.Millisecond) 24 | 25 | if err != nil { 26 | // If not head is not 200, html generator will not work and display the error 27 | w.WriteHeader(http.StatusOK) 28 | fmt.Fprintf(w, ` 29 | { 30 | "asyncapi": "2.0", 31 | "info": { 32 | "title": "ERROR", 33 | "description": "

%s

Full subject: '%s'

", 34 | }, 35 | "channels": {} 36 | } 37 | `, err.Error(), path) 38 | return 39 | } 40 | 41 | reply := make(map[string]interface{}) 42 | json_err := json.Unmarshal(reply_raw.Data, &reply) 43 | 44 | schema, _ := json.Marshal(reply["result"]) 45 | 46 | if reply["result"] == nil || json_err != nil { 47 | // If not head is not 200, html generator will not work and display the error 48 | w.WriteHeader(http.StatusOK) 49 | fmt.Fprintf(w, ` 50 | { 51 | "asyncapi": "2.0", 52 | "info": { 53 | "title": "ERROR", 54 | "description": "

Was able to get reply from NATS service, but not able to retrieve schema. Make sure you pass the correct root path as an argument.

Your full path is: '%s'

", 55 | }, 56 | "channels": {} 57 | } 58 | `, path) 59 | return 60 | } 61 | w.WriteHeader(http.StatusOK) 62 | fmt.Fprintf(w, "%s", schema) 63 | } 64 | 65 | func asyncAPIRedocHandler(w http.ResponseWriter, req *http.Request) { 66 | w.WriteHeader(http.StatusOK) 67 | 68 | fmt.Fprintf(w, ` 69 | 70 | 71 | 72 | 73 | `) 74 | } 75 | 76 | func main() { 77 | if len(os.Args) < 2 { 78 | log.Fatal("No natsport and root path given, should be `./app 4222 master.service-staging`") 79 | } 80 | natsPort := os.Args[1] 81 | rootPath := os.Args[2] 82 | 83 | nc, err := nats.Connect("nats://127.0.0.1:" + natsPort) 84 | if err != nil { 85 | log.Fatal(err) 86 | } 87 | h := &Handler{nc: nc, rootPath: rootPath} 88 | 89 | http.HandleFunc("/asyncapi.json", h.asyncAPISchemaHandler) 90 | http.HandleFunc("/", asyncAPIRedocHandler) 91 | 92 | fmt.Println("Server running") 93 | fmt.Println("Docs can be found on localhost:8090") 94 | fmt.Println("Connected to nats on port " + natsPort) 95 | http.ListenAndServe(":8090", nil) 96 | } 97 | -------------------------------------------------------------------------------- /render-asyncapi/render-asyncapi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wegroupwolves/natsapi/cb31ef79c5cea79f815ed4d11f1cc8f03deaf21c/render-asyncapi/render-asyncapi -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | let 2 | pkgs = import {}; 3 | in pkgs.mkShell { 4 | shellHook = '' 5 | export PIP_NO_BINARY="ruff" 6 | set -a; source .env; set +a 7 | echo "SHELLHOOK LOG: .env loaded to ENV variables" 8 | ''; 9 | 10 | packages = with pkgs; [ 11 | python311 12 | 13 | ruff 14 | rustc 15 | cargo 16 | 17 | (poetry.override { python3 = python311; }) 18 | 19 | (python311.withPackages (p: with p; [ 20 | pip 21 | python-lsp-server 22 | pynvim 23 | pyls-isort 24 | python-lsp-black 25 | ])) 26 | 27 | ]; 28 | 29 | env.LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [ 30 | pkgs.stdenv.cc.cc.lib 31 | ]; 32 | } 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wegroupwolves/natsapi/cb31ef79c5cea79f815ed4d11f1cc8f03deaf21c/tests/__init__.py -------------------------------------------------------------------------------- /tests/asyncapi/test_generation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | from uuid import uuid4 3 | 4 | import pytest 5 | from pydantic import BaseModel, Field 6 | 7 | from natsapi import NatsAPI, Pub, SubjectRouter 8 | from natsapi.asyncapi import Errors 9 | from natsapi.asyncapi.models import ExternalDocumentation, Server 10 | from natsapi.exceptions import JsonRPCException 11 | 12 | pytestmark = pytest.mark.asyncio 13 | 14 | production_server = Server( 15 | **{ 16 | "url": "nats://127.0.0.1:{port}", 17 | "description": "Production NATS", 18 | "protocol": "nats", 19 | "protocolVersion": "2.0", 20 | "variables": {"port": {"default": "422"}}, 21 | }, 22 | ) 23 | 24 | staging_server = Server( 25 | **{ 26 | "url": "nats://127.0.0.1:{port}", 27 | "description": "Staging NATS", 28 | "protocol": "nats", 29 | "protocolVersion": "2.0", 30 | "variables": {"port": {"default": "422"}}, 31 | }, 32 | ) 33 | 34 | servers_schema = {"production": production_server, "staging": staging_server} 35 | 36 | 37 | class Pagination(BaseModel): 38 | page: int 39 | pagelen: int 40 | 41 | 42 | class CreateResult(BaseModel): 43 | id: Any 44 | pagination: Pagination 45 | 46 | 47 | class BrokerAlreadyExists(JsonRPCException): 48 | def __init__(self, data=None): 49 | self.code = -27001 50 | self.message = "BROKER_EXISTS" 51 | self.data = data 52 | 53 | 54 | class SignatureException(JsonRPCException): 55 | def __init__(self, data=None) -> None: 56 | self.code = -27002 57 | self.message = "SIGNATURE_ERROR" 58 | self.data = data 59 | 60 | 61 | domain_errors = Errors(lower_bound=-27000, upper_bound=-3000, errors=[BrokerAlreadyExists, SignatureException]) 62 | 63 | external_docs = ExternalDocumentation( 64 | description="This client uses the JsonRPC standard for the payloads. Requests follow the company guidelines.", 65 | url="https://github.com/wegroupwolves", 66 | ) 67 | 68 | 69 | def test_generate_minimal_asyncapi_schema_should_generate(): 70 | client = NatsAPI("natsapi.development") 71 | client.generate_asyncapi() 72 | schema = client.asyncapi_schema 73 | assert schema["asyncapi"] == "2.0.0" 74 | assert schema["defaultContentType"] == "application/json" 75 | assert schema["info"]["title"] == "NatsAPI" 76 | assert schema["info"]["version"] == "0.1.0" 77 | 78 | 79 | def test_asyncapi_schema_generation_should_be_cached(monkeypatch): 80 | client = NatsAPI("natsapi.development") 81 | assert client.asyncapi_schema is None 82 | 83 | client.generate_asyncapi() 84 | s1 = client.asyncapi_schema 85 | assert s1 is not None 86 | 87 | # WHEN generating schema again 88 | def patched(): 89 | return "Invalid schema" 90 | 91 | monkeypatch.setattr("natsapi.asyncapi.utils.get_asyncapi.__code__", patched.__code__) 92 | 93 | client.generate_asyncapi() 94 | s2 = client.asyncapi_schema 95 | assert id(s2) == id(s1) 96 | 97 | 98 | def test_asyncapi_schema_w_personal_title_should_generate(): 99 | client = NatsAPI( 100 | "natsapi.development", 101 | title="My Nats Client", 102 | description="This is my nats client", 103 | version="2.4.3", 104 | ) 105 | client.generate_asyncapi() 106 | schema = client.asyncapi_schema 107 | assert schema["info"]["title"] == "My Nats Client" 108 | assert schema["info"]["description"] == "This is my nats client" 109 | assert schema["info"]["version"] == "2.4.3" 110 | 111 | 112 | def test_generate_schema_w_servers_should_generate(): 113 | client = NatsAPI("natsapi.development", servers=servers_schema) 114 | client.generate_asyncapi() 115 | schema = client.asyncapi_schema 116 | 117 | assert len(schema["servers"]) == len(servers_schema) 118 | assert schema["servers"]["production"] == servers_schema["production"].dict(exclude_none=True) 119 | assert schema["servers"]["staging"] == servers_schema["staging"].dict(exclude_none=True) 120 | 121 | 122 | def test_generate_schema_w_external_docs_should_generate(): 123 | client = NatsAPI("natsapi.development", external_docs=external_docs) 124 | client.generate_asyncapi() 125 | schema = client.asyncapi_schema 126 | assert schema["externalDocs"] == external_docs.dict() 127 | 128 | 129 | async def test_generate_shema_w_requests_should_generate(app: NatsAPI): 130 | class BaseUser(BaseModel): 131 | email: str = Field(..., description="Unique email of user", example="foo@bar.com") 132 | password: str = Field(..., description="Password of user", example="Supers3cret") 133 | 134 | user_router = SubjectRouter(prefix="v1", tags=["users"], deprecated=True) 135 | 136 | @user_router.request( 137 | "users.CREATE", 138 | result=CreateResult, 139 | description="Creates user that can be used throughout the app", 140 | tags=["auth"], 141 | suggested_timeout=0.5, 142 | ) 143 | def create_base_user(app): 144 | return {"id": uuid4()} 145 | 146 | app.include_router(user_router) 147 | app.generate_asyncapi() 148 | schema = app.asyncapi_schema 149 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["x-suggested-timeout"] 150 | assert ( 151 | schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["operationId"] 152 | == "create_base_user_v1_users_CREATE" 153 | ) 154 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["summary"] 155 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["description"] 156 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["tags"] == [ 157 | {"name": "users"}, 158 | {"name": "auth"}, 159 | ] 160 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["message"]["payload"] 161 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["request"]["replies"] 162 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["deprecated"] 163 | 164 | schema_from_request = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 165 | assert schema_from_request == schema 166 | 167 | 168 | async def test_generate_shema_w_publishes_should_generate(app: NatsAPI): 169 | class BaseUser(BaseModel): 170 | email: str = Field(..., description="Unique email of user", example="foo@bar.com") 171 | password: str = Field(..., description="Password of user", example="Supers3cret") 172 | 173 | user_router = SubjectRouter(prefix="v1", tags=["users"], deprecated=True) 174 | 175 | @user_router.publish( 176 | "users.CREATE", 177 | description="Creates user that can be used throughout the app", 178 | tags=["auth"], 179 | ) 180 | def create_base_user(app): 181 | return {"id": uuid4()} 182 | 183 | app.include_router(user_router) 184 | app.generate_asyncapi() 185 | schema = app.asyncapi_schema 186 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"] 187 | assert ( 188 | schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"]["operationId"] 189 | == "create_base_user_v1_users_CREATE" 190 | ) 191 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"]["summary"] 192 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"]["description"] 193 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"]["tags"] == [ 194 | {"name": "users"}, 195 | {"name": "auth"}, 196 | ] 197 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"]["message"]["payload"] 198 | assert schema["channels"]["natsapi.development.v1.users.CREATE"]["publish"] 199 | 200 | schema_from_request = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 201 | assert schema_from_request == schema 202 | 203 | 204 | def test_dont_include_in_schema_should_generate(): 205 | router = SubjectRouter() 206 | 207 | @router.request("not_included_router", result=BaseModel, include_schema=False) 208 | def test_not_included_subject(): 209 | pass 210 | 211 | client = NatsAPI("natsapi.development") 212 | client.include_router(router) 213 | client.generate_asyncapi() 214 | schema = client.asyncapi_schema 215 | assert "channels" not in schema 216 | 217 | 218 | def test_generate_domain_errors_schema_should_generate(): 219 | app = NatsAPI("natsapi", domain_errors=domain_errors) 220 | 221 | app.generate_asyncapi() 222 | assert app.asyncapi_schema["errors"]["range"]["upper"] == domain_errors.dict()["upper_bound"] 223 | assert app.asyncapi_schema["errors"]["range"]["lower"] == domain_errors.dict()["lower_bound"] 224 | assert len(app.asyncapi_schema["errors"]["items"]) == len(domain_errors.dict()["errors"]) 225 | 226 | 227 | def test_root_subs_in_schema_should_be_in_schema(app: NatsAPI): 228 | app.generate_asyncapi() 229 | root_path_sub = app.asyncapi_schema["channels"]["natsapi.development.>"] 230 | assert root_path_sub["subscribe"]["summary"] 231 | 232 | 233 | async def test_included_pub_in_schema_should_be_in_schema(app: NatsAPI): 234 | app.include_pubs([Pub("some.subject", Server)]) 235 | 236 | assert len(app.pubs) == 1 237 | 238 | schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 239 | assert schema["channels"]["some.subject"] 240 | assert schema["components"]["schemas"]["Server"] 241 | 242 | 243 | def test_routes_use_identically_named_class_in_different_modules_should_reference_correctly(): 244 | app = NatsAPI("natsapi.development") 245 | 246 | class SomeClassA(BaseModel): 247 | foo: str 248 | bar: int 249 | 250 | class SomeClassB(BaseModel): 251 | foo: str 252 | 253 | @app.request("subject_a", result=SomeClassA) 254 | def get_subject_a(app): 255 | return {"foo": "str", "bar": 2} 256 | 257 | @app.request("subject_b", result=SomeClassB) 258 | def get_subject_b(app): 259 | return {"foo": "str"} 260 | 261 | app.generate_asyncapi() 262 | schema = app.asyncapi_schema 263 | subject_a_reply_ref = schema["channels"]["natsapi.development.subject_a"]["request"]["replies"][0]["payload"][ 264 | "$ref" 265 | ].split("/")[-1] 266 | subject_b_reply_ref = schema["channels"]["natsapi.development.subject_b"]["request"]["replies"][0]["payload"][ 267 | "$ref" 268 | ].split("/")[-1] 269 | 270 | assert subject_a_reply_ref in schema["components"]["schemas"] 271 | assert subject_b_reply_ref in schema["components"]["schemas"] 272 | 273 | 274 | async def test_generate_shema_w_docstring_should_generate_proper_description(app: NatsAPI): 275 | user_router = SubjectRouter(prefix="v1", tags=["users"]) 276 | 277 | @user_router.request( 278 | "users.CREATE", 279 | result=CreateResult, 280 | ) 281 | def create_base_user(app): 282 | """ 283 | should be generated 284 | """ 285 | return {} 286 | 287 | app.include_router(user_router) 288 | 289 | schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 290 | assert ( 291 | schema["channels"]["natsapi.development.v1.users.CREATE"]["request"].get("description") == "should be generated" 292 | ) 293 | 294 | 295 | async def test_generate_shema_w_description_in_route_should_overwrite_description(app: NatsAPI): 296 | user_router = SubjectRouter(prefix="v1", tags=["users"]) 297 | 298 | @user_router.request( 299 | "users.CREATE", 300 | result=CreateResult, 301 | description="Creates user that can be used throughout the app", 302 | ) 303 | def create_base_user(app): 304 | """ 305 | should be generated 306 | """ 307 | return {} 308 | 309 | app.include_router(user_router) 310 | 311 | schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 312 | assert ( 313 | schema["channels"]["natsapi.development.v1.users.CREATE"]["request"].get("description") 314 | == "Creates user that can be used throughout the app" 315 | ) 316 | 317 | 318 | async def test_generate_shema_w_routers_that_have_union_typing_should_generate(app: NatsAPI): 319 | class A(BaseModel): 320 | a: int 321 | 322 | class B(BaseModel): 323 | b: int 324 | 325 | r = SubjectRouter(prefix="v1") 326 | 327 | @r.request("union", result=Union[A, B]) 328 | def union_result(app): 329 | return A(a=4) 330 | 331 | app.include_router(r) 332 | 333 | app.generate_asyncapi() 334 | schema = app.asyncapi_schema 335 | 336 | union_route = schema["channels"]["natsapi.development.v1.union"] 337 | union_route_replies = union_route["request"]["replies"] 338 | assert len(union_route_replies) == 2 339 | union_opts = union_route_replies[0]["payload"]["anyOf"] 340 | assert len(union_opts) == 2 341 | response_models = [m["$ref"].split("/")[-1] for m in union_opts] 342 | assert response_models == ["A", "B"] 343 | 344 | schema_from_request = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 345 | assert schema_from_request == schema 346 | 347 | 348 | async def test_given_no_func_name_should_use_subject_as_summary(app: NatsAPI): 349 | @app.request("foo", result=CreateResult) 350 | def _(app): 351 | return {} 352 | 353 | schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 354 | assert schema["channels"]["natsapi.development.foo"]["request"]["summary"] == "Foo" 355 | 356 | 357 | async def test_pub_and_req_should_render_properly(app: NatsAPI): 358 | @app.request("req", result=CreateResult) 359 | def _(app): 360 | return {} 361 | 362 | @app.publish("pub") 363 | def _(app): 364 | return {} 365 | 366 | schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 367 | assert schema["channels"]["natsapi.development.req"]["request"]["summary"] == "Req" 368 | assert schema["channels"]["natsapi.development.pub"]["publish"]["summary"] == "Pub" 369 | -------------------------------------------------------------------------------- /tests/asyncapi/test_models.py: -------------------------------------------------------------------------------- 1 | from natsapi.asyncapi.models import RequestOperation 2 | 3 | 4 | def test_suggested_timeout_should_generate_timeout(): 5 | expected = {"x-suggested-timeout": 0.5} 6 | actual = RequestOperation(**expected) 7 | assert actual.suggestedTimeout == expected["x-suggested-timeout"] 8 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | pytest_plugins = [ 2 | "tests.fixtures", 3 | ] 4 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import pytest 5 | 6 | from natsapi import NatsAPI 7 | from natsapi.client import Config 8 | from natsapi.client.config import ConnectConfig 9 | 10 | 11 | @pytest.fixture(scope="session") 12 | def client_config(): 13 | connect = ConnectConfig( 14 | servers=os.environ.get("HOST_NATS", "nats://127.0.0.1:4222"), 15 | nkeys_seed=os.environ.get("NATS_CREDENTIALS_FILE", None), 16 | ) 17 | return Config(connect=connect) 18 | 19 | 20 | @pytest.fixture(scope="function") 21 | async def app(client_config, event_loop): 22 | """ 23 | Clean NatsAPI instance with rootpath 24 | """ 25 | app = NatsAPI("natsapi.development", client_config=client_config) 26 | await app.startup(loop=event_loop) 27 | yield app 28 | await app.shutdown(app) 29 | 30 | 31 | @pytest.fixture(scope="session") 32 | def event_loop(): 33 | yield asyncio.get_event_loop() 34 | -------------------------------------------------------------------------------- /tests/plugins/test_mock.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from natsapi.plugin import NatsapiMock 6 | 7 | pytestmark = pytest.mark.asyncio 8 | ch = "service" 9 | 10 | 11 | async def test_nats_mock_wrong_host_should_raise_error(): 12 | with pytest.raises(Exception): 13 | mock = NatsapiMock(host="foobar", channel="foobar") 14 | await mock.wait_startup() 15 | 16 | 17 | async def test_nats_mock_should_respond_with_mocked_response(app, natsapi_mock): 18 | # given: 19 | await natsapi_mock.request(f"{ch}.items.retrieve", response={"items": [{"id": 1}]}) 20 | 21 | # when: 22 | reply = await app.nc.request(f"{ch}.items.retrieve", timeout=1) 23 | 24 | # then: 25 | assert not reply.error 26 | assert reply.result == {"items": [{"id": 1}]}, reply.result 27 | 28 | 29 | async def test_nats_mock_should_respond_with_mocked_response_given_a_model(app, natsapi_mock): 30 | from pydantic import BaseModel 31 | 32 | class Foo(BaseModel): 33 | items: list[str] 34 | 35 | # given: 36 | await natsapi_mock.request(f"{ch}.items.retrieve", response=Foo(items=["a", "b"])) 37 | 38 | # when: 39 | reply = await app.nc.request(f"{ch}.items.retrieve", timeout=1) 40 | 41 | # then: 42 | assert not reply.error 43 | assert reply.result == {"items": ["a", "b"]}, reply.result 44 | 45 | 46 | async def test_be_able_to_intercept_nats_request_payload(app, natsapi_mock): 47 | # given: 48 | await natsapi_mock.request(f"{ch}.items.id.retrieve", response={"id": 1}) 49 | 50 | # when: 51 | payload = {"id": 1} 52 | await app.nc.request(f"{ch}.items.id.retrieve", payload, timeout=1) 53 | 54 | # then: 55 | assert len(natsapi_mock.payloads[f"{ch}.items.id.retrieve"]) == 1, natsapi_mock.payloads 56 | payload = natsapi_mock.payloads[f"{ch}.items.id.retrieve"][0] 57 | assert payload["params"]["id"] == 1 58 | 59 | 60 | async def test_nats_mock_should_respond_with_mocked_error_when_error(app, natsapi_mock): 61 | # given: 62 | await natsapi_mock.request(f"{ch}.items.retrieve", error={"code": -1, "message": "FOOBAR"}) 63 | 64 | # when: 65 | reply = await app.nc.request(f"{ch}.items.retrieve", timeout=1) 66 | 67 | # then: 68 | assert reply.error.message == "FOOBAR" 69 | 70 | 71 | async def test_nats_mock_should_raise_error_when_invalid_error_response(app, natsapi_mock): 72 | # given: an invalid JsonRPCError response that is not a dict 73 | await natsapi_mock.request(f"{ch}.items.retrieve", error="ERROR") 74 | 75 | # when: 76 | reply = await app.nc.request(f"{ch}.items.retrieve", timeout=1) 77 | 78 | # then: 79 | assert "valid dict" in reply.error.message 80 | 81 | 82 | async def test_be_able_to_intercept_nats_publish_event_payload(app, natsapi_mock): 83 | # given: 84 | subject = "a.publish.event" 85 | await natsapi_mock.publish(subject) 86 | 87 | # when: 88 | payload = {"id": 1} 89 | await app.nc.publish(subject, payload) 90 | 91 | # wait 1 second to be sure publish event is picked up 92 | await asyncio.sleep(1) 93 | 94 | # then: 95 | assert len(natsapi_mock.payloads) == 1 96 | assert natsapi_mock.payloads[subject][0]["params"] == payload 97 | -------------------------------------------------------------------------------- /tests/test_applications.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from natsapi import NatsAPI, Pub, Sub, SubjectRouter 5 | 6 | 7 | class NotificationParams(BaseModel): 8 | notification: str 9 | service: str 10 | 11 | 12 | class StatusResult(BaseModel): 13 | status: str 14 | 15 | 16 | class ThemeCreateCmd(BaseModel): 17 | name: str 18 | 19 | 20 | async def test_minimal_app_should_have_route_with_minimal_schema(app): 21 | themes_router = SubjectRouter() 22 | 23 | @themes_router.request("themes.CREATE", result=StatusResult) 24 | def create_theme(app, data: ThemeCreateCmd): 25 | return {"status": "OK"} 26 | 27 | app.include_router(themes_router) 28 | 29 | reply = await app.nc.request("natsapi.development.themes.CREATE", {"data": {"name": "orange"}}) 30 | 31 | assert reply.result["status"] == "OK" 32 | 33 | 34 | async def test_async_startup_should_run_with_setup_method(client_config, event_loop): 35 | app = NatsAPI("natsapi.development", client_config=client_config) 36 | 37 | @app.on_startup 38 | async def setup(): 39 | db = "connected" 40 | assert db == "connected" 41 | 42 | await app.startup(loop=event_loop) 43 | await app.shutdown() 44 | 45 | 46 | async def test_async_shutdown(client_config, event_loop): 47 | app = NatsAPI("natsapi.development", client_config=client_config) 48 | 49 | @app.on_shutdown 50 | async def teardown(): 51 | db = "disconnected" 52 | assert db == "disconnected" 53 | 54 | await app.startup(loop=event_loop) 55 | await app.shutdown() 56 | 57 | 58 | def test_pass_unrecognized_type_for_app_attr_should_fail(app): 59 | with pytest.raises(AssertionError): 60 | _ = NatsAPI("natsapi.development", app=42) 61 | 62 | 63 | def test_app_request_decorator_should_add_request_to_routes(): 64 | app = NatsAPI("natsapi.development") 65 | 66 | @app.request("request_without_router", result=StatusResult) 67 | def new_route(app): 68 | return {"status": "OK"} 69 | 70 | assert len(app.routes) == 1 71 | assert app.routes["natsapi.development.request_without_router"] 72 | assert app.routes["natsapi.development.request_without_router"].subject == "request_without_router" 73 | 74 | 75 | def test_subs_to_root_paths_should_be_documented(app: NatsAPI): 76 | app.include_subs([Sub("*.subject.>", queue="natsapi")]) 77 | assert len(app.subs) == 2 78 | 79 | 80 | def test_include_pub_should_add_pub_to_app(app: NatsAPI): 81 | app.include_pubs([Pub("some.subject", NotificationParams)]) 82 | assert len(app.pubs) == 1 83 | 84 | 85 | def test_add_pub_should_add_pub_to_app(app: NatsAPI): 86 | @app.pub("notifications.CREATE", params=NotificationParams) 87 | @app.request("subject.RETRIEVE", result=StatusResult) 88 | async def _(app: NatsAPI): 89 | await app.nc.publish("notifications.CREATE", {"notification": "Hi", "service": "SMT"}) 90 | return {"status": "OK"} 91 | 92 | assert len(app.pubs) == 1 93 | 94 | 95 | def test_add_sub_should_add_sub_to_app(app: NatsAPI): 96 | @app.sub("*.subject.>", queue="natsapi") 97 | @app.request("subject.RETRIEVE", result=StatusResult) 98 | async def handle_request_with_sub(app: NatsAPI): 99 | await app.nc.subscribe("*.subject.>", queue="natsapi", cb=lambda _: print("some callback")) 100 | return {"status": "OK"} 101 | 102 | assert len(app.subs) == 2 103 | 104 | 105 | def test_add_subject_that_doesnt_end_in_rpc_method_should_fail(): 106 | JSON_RPC_METHODS = ["CREATE", "RETRIEVE", "UPDATE", "DELETE", "CONVERT", "EXPORT", "CALCULATE", "VERIFY"] 107 | app = NatsAPI("natsapi.development", rpc_methods=JSON_RPC_METHODS) 108 | 109 | with pytest.raises(AssertionError) as exc: 110 | 111 | @app.request("subject.UNKNOWN_RPC_METHOD", result=BaseModel) 112 | def handle_request(app): 113 | return {"Status": "OK"} 114 | 115 | assert str(JSON_RPC_METHODS) in str(exc) 116 | assert "invalid request method" in str(exc) 117 | 118 | 119 | async def test_with_method_should_run_commands_within_with_block_and_shutdown_afterwards(client_config): 120 | async with NatsAPI("with_block", client_config=client_config) as app: 121 | reply = await app.nc.request("with_block.schema.RETRIEVE") 122 | assert reply.result 123 | 124 | 125 | async def test_send_a_pub_should_send_without_error(app): 126 | await app.nc.publish("natsapi.development.test.CREATE", {}) 127 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from natsapi import NatsAPI 2 | from natsapi.client import Config, SubscribeConfig 3 | 4 | 5 | def test_customize_config_should_use_customized_config(): 6 | subscribe = SubscribeConfig(queue="smt") 7 | client_config = Config(subscribe=subscribe) 8 | 9 | app = NatsAPI("natsapi.development", client_config=client_config) 10 | 11 | assert app.client_config.subscribe.queue == client_config.subscribe.queue 12 | -------------------------------------------------------------------------------- /tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | from pydantic.main import BaseModel 2 | 3 | from natsapi import NatsAPI 4 | from natsapi.exceptions import JsonRPCException 5 | from natsapi.models import JsonRPCError, JsonRPCRequest 6 | 7 | 8 | async def test_overwrite_default_exception_should_use_custom_method(app): 9 | @app.exception_handler(JsonRPCException) 10 | def handle_custom_jsonrpc(exc: JsonRPCException, request: JsonRPCRequest, subject: str) -> JsonRPCError: 11 | return JsonRPCError(code=exc.code, message=exc.message) 12 | 13 | assert app._exception_handlers[JsonRPCException] == handle_custom_jsonrpc 14 | 15 | @app.request("test.RETRIEVE", result=BaseModel) 16 | def _(app: NatsAPI): 17 | raise JsonRPCException(code=42, message="hello world") 18 | 19 | reply = await app.nc.request("natsapi.development.test.RETRIEVE", {}) 20 | assert reply.error.code == 42 21 | assert reply.error.message == "hello world" 22 | 23 | 24 | async def test_add_custom_exception_should_use_handler_when_exception_is_thrown(app): 25 | class CustomException(Exception): 26 | def __init__(self, msg: str, rpc_code: int): 27 | self.rpc_code = rpc_code 28 | self.msg = msg 29 | 30 | @app.exception_handler(CustomException) 31 | async def handle_custom_exception(exc: CustomException, request: JsonRPCRequest, subject: str) -> JsonRPCError: 32 | return JsonRPCError(code=exc.rpc_code, message=exc.msg) 33 | 34 | assert app._exception_handlers[CustomException] == handle_custom_exception 35 | 36 | @app.request("test.RETRIEVE", result=BaseModel) 37 | async def _(app: NatsAPI): 38 | raise CustomException(rpc_code=500, msg="custom_message") 39 | 40 | reply = (await app.nc.request("natsapi.development.test.RETRIEVE", {})).error 41 | assert reply.code == 500 42 | assert reply.message == "custom_message" 43 | 44 | 45 | async def test_throw_derived_custom_exception_should_use_base_exception_handler(app): 46 | class CustomException(Exception): 47 | def __init__(self, msg: str, rpc_code: int): 48 | self.rpc_code = rpc_code 49 | self.msg = msg 50 | 51 | class DerivedException(CustomException): 52 | def __init__(self, msg: str, rpc_code: int): 53 | super().__init__(msg, rpc_code) 54 | 55 | @app.exception_handler(CustomException) 56 | async def handle_custom_exception(exc: CustomException, request: JsonRPCRequest, subject: str) -> JsonRPCError: 57 | return JsonRPCError(code=exc.rpc_code, message=exc.msg) 58 | 59 | assert app._exception_handlers[CustomException] == handle_custom_exception 60 | 61 | @app.request("test.RETRIEVE", result=BaseModel) 62 | async def _(app: NatsAPI): 63 | raise DerivedException(rpc_code=500, msg="custom_message") 64 | 65 | reply = (await app.nc.request("natsapi.development.test.RETRIEVE", {})).error 66 | assert reply.code == 500 67 | assert reply.message == "custom_message" 68 | 69 | 70 | async def test_default_jsonrpc_exception_handler_should_handle_exception_and_return_default_error_reply(app): 71 | @app.request("test.RETRIEVE", result=BaseModel) 72 | async def _(app: NatsAPI): 73 | raise JsonRPCException(code=500, message="custom_message") 74 | 75 | reply = (await app.nc.request("natsapi.development.test.RETRIEVE", {})).error 76 | 77 | assert reply.code == 500 78 | assert reply.message == "custom_message" 79 | assert reply.data 80 | assert reply.data["type"] == "JsonRPCException" 81 | assert reply.data["errors"] == [] 82 | 83 | 84 | async def test_default_validation_error_handler_should_handle_exception_and_return_default_error_reply(app): 85 | class TestClass(BaseModel): 86 | a: int 87 | b: str 88 | 89 | @app.request("test.RETRIEVE", result=BaseModel) 90 | async def _(app: NatsAPI): 91 | TestClass(a="abc") 92 | 93 | reply = (await app.nc.request("natsapi.development.test.RETRIEVE", {})).error 94 | 95 | assert reply.code == -40001 96 | assert reply.message 97 | assert reply.data 98 | assert reply.data["type"] == "ValidationError" 99 | assert reply.data["errors"] != [] 100 | assert len(reply.data["errors"]) == 2 101 | 102 | for e in reply.data["errors"]: 103 | assert e["type"] == "ValidationError" 104 | assert e["target"] 105 | assert e["message"] 106 | 107 | 108 | async def test_default_exception_handler_should_handle_exception_and_return_default_error_reply(app): 109 | @app.request("test.RETRIEVE", result=BaseModel) 110 | async def _(app: NatsAPI): 111 | raise Exception("Hello world") 112 | 113 | reply = (await app.nc.request("natsapi.development.test.RETRIEVE", {})).error 114 | 115 | assert reply.code == -40000 116 | assert reply.message == "Hello world" 117 | assert reply.data 118 | assert reply.data["type"] == "Exception" 119 | assert reply.data["errors"] == [] 120 | 121 | 122 | async def test_default_exception_handler_should_write_error_log(app, caplog): 123 | @app.request("test.RETRIEVE", result=BaseModel) 124 | async def _(app: NatsAPI): 125 | raise Exception("Hello world") 126 | 127 | await app.nc.request("natsapi.development.test.RETRIEVE", {}) 128 | 129 | errors = [r for r in caplog.records if r.levelname in ["ERROR"]] 130 | assert len(errors) == 1 131 | assert "Hello world" in str(errors[0]) 132 | 133 | 134 | async def test_default_jsonrpc_exception_handler_should_write_warning_log(app, caplog): 135 | @app.request("test.RETRIEVE", result=BaseModel) 136 | async def _(app: NatsAPI): 137 | raise JsonRPCException(code=42, message="Hello world") 138 | 139 | await app.nc.request("natsapi.development.test.RETRIEVE", {}) 140 | 141 | errors = [r for r in caplog.records if r.levelname in ["ERROR"]] 142 | assert len(errors) == 1 143 | assert "Hello world" in str(errors[0]) 144 | 145 | 146 | async def test_default_validation_error_handler_should_write_warning_log(app, caplog): 147 | class TestClass(BaseModel): 148 | a: int 149 | b: str 150 | 151 | @app.request("test.RETRIEVE", result=BaseModel) 152 | async def _(app: NatsAPI): 153 | TestClass(a="abc") 154 | 155 | await app.nc.request("natsapi.development.test.RETRIEVE", {}) 156 | 157 | errors = [r for r in caplog.records if r.levelname in ["ERROR"]] 158 | assert len(errors) == 1 159 | assert "validation errors" in str(errors[0]), f"Expected validation error message, got: {str(errors[0])}" 160 | 161 | 162 | async def test_default_exception_handler_should_handle_formatted_exception_and_return_default_error_reply(app): 163 | class FormattedException(Exception): 164 | def __init__( 165 | self, 166 | msg, 167 | domain=None, 168 | detail=None, 169 | code=None, 170 | rpc_code=None, 171 | ): 172 | self.msg = msg 173 | self.domain = (domain,) 174 | self.detail = detail 175 | self.code = code 176 | self.rpc_code = rpc_code 177 | 178 | @app.request("test.RETRIEVE", result=BaseModel) 179 | async def _(app: NatsAPI): 180 | raise FormattedException( 181 | msg="NATS_ERROR", 182 | detail="Something went wrong while working with NATS.", 183 | domain="Some service", 184 | code=500, 185 | rpc_code=-27000, 186 | ) 187 | 188 | reply = (await app.nc.request("natsapi.development.test.RETRIEVE", {})).error 189 | 190 | assert reply.code == -27000 191 | assert reply.message == "NATS_ERROR: Something went wrong while working with NATS." 192 | assert reply.data 193 | assert reply.data["type"] == "FormattedException" 194 | assert reply.data["errors"] == [] 195 | -------------------------------------------------------------------------------- /tests/test_fastapi.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from natsapi import NatsAPI 7 | 8 | 9 | class FastAPI(BaseModel): 10 | pass 11 | 12 | 13 | @pytest.mark.skip(reason="Removed fastapi dependency.") 14 | async def test_run_natsapi_on_same_loop_should_run_simultaneously(): 15 | fastapi = FastAPI() 16 | natsapi = NatsAPI("natsapi.dev") 17 | app = NatsAPI("test") 18 | await app.startup() 19 | 20 | @fastapi.on_event("startup") 21 | async def setup(): 22 | loop = asyncio.get_running_loop() 23 | await natsapi.startup(loop=loop) 24 | 25 | @fastapi.on_event("shutdown") 26 | async def teardown(): 27 | await natsapi.shutdown() 28 | 29 | 30 | @pytest.mark.skip(reason="Removed fastapi dependency.") 31 | def test_pass_fastapi_instance_as_app_should_work(): 32 | fastapi = FastAPI() 33 | 34 | fastapi.state.db = "postgresql" 35 | fastapi.controllers = "dirty_object" 36 | 37 | natsapi = NatsAPI("natsapi.development", app=fastapi) 38 | 39 | assert type(natsapi.app) is type(fastapi) 40 | assert natsapi.app.controllers == "dirty_object" 41 | assert natsapi.app.state.db == "postgresql" 42 | -------------------------------------------------------------------------------- /tests/test_jsonable_encoder.py: -------------------------------------------------------------------------------- 1 | """yanked from fastapi""" 2 | 3 | from datetime import datetime, timezone 4 | from enum import Enum 5 | from pathlib import PurePath, PurePosixPath, PureWindowsPath 6 | from typing import Optional 7 | 8 | import pytest 9 | from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model 10 | 11 | from natsapi._compat import RootModel 12 | from natsapi.encoders import jsonable_encoder 13 | 14 | 15 | class Person: 16 | def __init__(self, name: str): 17 | self.name = name 18 | 19 | 20 | class Pet: 21 | def __init__(self, owner: Person, name: str): 22 | self.owner = owner 23 | self.name = name 24 | 25 | 26 | class DictablePerson(Person): 27 | def __iter__(self): 28 | return ((k, v) for k, v in self.__dict__.items()) 29 | 30 | 31 | class DictablePet(Pet): 32 | def __iter__(self): 33 | return ((k, v) for k, v in self.__dict__.items()) 34 | 35 | 36 | class Unserializable: 37 | def __iter__(self): 38 | raise NotImplementedError() 39 | 40 | @property 41 | def __dict__(self): 42 | raise NotImplementedError() 43 | 44 | 45 | class ModelWithCustomEncoder(BaseModel): 46 | dt_field: datetime 47 | 48 | class Config: 49 | json_encoders = {datetime: lambda dt: dt.replace(microsecond=0, tzinfo=timezone.utc).isoformat()} 50 | 51 | 52 | class ModelWithCustomEncoderSubclass(ModelWithCustomEncoder): 53 | class Config: 54 | pass 55 | 56 | 57 | class RoleEnum(Enum): 58 | admin = "admin" 59 | normal = "normal" 60 | 61 | 62 | class ModelWithConfig(BaseModel): 63 | role: Optional[RoleEnum] = None 64 | 65 | class Config: 66 | use_enum_values = True 67 | 68 | 69 | class ModelWithAlias(BaseModel): 70 | foo: str = Field(..., alias="Foo") 71 | 72 | 73 | class ModelWithDefault(BaseModel): 74 | foo: str = ... # type: ignore 75 | bar: str = "bar" 76 | bla: str = "bla" 77 | 78 | 79 | class ModelWithRoot(RootModel): 80 | pass 81 | 82 | 83 | @pytest.fixture(name="model_with_path", params=[PurePath, PurePosixPath, PureWindowsPath]) 84 | def fixture_model_with_path(request): 85 | class Config: 86 | arbitrary_types_allowed = True 87 | 88 | ModelWithPath = create_model( 89 | # type: ignore 90 | "ModelWithPath", 91 | path=(request.param, ...), 92 | __config__=Config, 93 | ) 94 | return ModelWithPath(path=request.param("/foo", "bar")) 95 | 96 | 97 | def test_encode_class(): 98 | person = Person(name="Foo") 99 | pet = Pet(owner=person, name="Firulais") 100 | assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}} 101 | 102 | 103 | def test_encode_dictable(): 104 | person = DictablePerson(name="Foo") 105 | pet = DictablePet(owner=person, name="Firulais") 106 | assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}} 107 | 108 | 109 | def test_encode_unsupported(): 110 | unserializable = Unserializable() 111 | with pytest.raises(ValueError): 112 | jsonable_encoder(unserializable) 113 | 114 | 115 | def test_encode_custom_json_encoders_model(): 116 | model = ModelWithCustomEncoder(dt_field=datetime(2019, 1, 1, 8)) 117 | assert jsonable_encoder(model) == {"dt_field": "2019-01-01T08:00:00+00:00"} 118 | 119 | 120 | def test_encode_custom_json_encoders_model_subclass(): 121 | model = ModelWithCustomEncoderSubclass(dt_field=datetime(2019, 1, 1, 8)) 122 | assert jsonable_encoder(model) == {"dt_field": "2019-01-01T08:00:00+00:00"} 123 | 124 | 125 | def test_encode_model_with_config(): 126 | model = ModelWithConfig(role=RoleEnum.admin) 127 | assert jsonable_encoder(model) == {"role": "admin"} 128 | 129 | 130 | def test_encode_model_with_alias_raises(): 131 | with pytest.raises(ValidationError): 132 | ModelWithAlias(foo="Bar") 133 | 134 | 135 | def test_encode_model_with_alias(): 136 | model = ModelWithAlias(Foo="Bar") 137 | assert jsonable_encoder(model) == {"Foo": "Bar"} 138 | 139 | 140 | def test_encode_model_with_default(): 141 | model = ModelWithDefault(foo="foo", bar="bar") 142 | assert jsonable_encoder(model) == {"foo": "foo", "bar": "bar", "bla": "bla"} 143 | assert jsonable_encoder(model, exclude_unset=True) == {"foo": "foo", "bar": "bar"} 144 | assert jsonable_encoder(model, exclude_defaults=True) == {"foo": "foo"} 145 | assert jsonable_encoder(model, exclude_unset=True, exclude_defaults=True) == {"foo": "foo"} 146 | 147 | 148 | def test_custom_encoders(): 149 | class safe_datetime(datetime): 150 | pass 151 | 152 | class MyModel(BaseModel): 153 | model_config = ConfigDict(arbitrary_types_allowed=True) 154 | dt_field: safe_datetime 155 | 156 | instance = MyModel(dt_field=safe_datetime.now()) 157 | 158 | encoded_instance = jsonable_encoder(instance, custom_encoder={safe_datetime: lambda o: o.isoformat()}) 159 | assert encoded_instance["dt_field"] == instance.dt_field.isoformat() 160 | 161 | 162 | def test_encode_model_with_path(model_with_path): 163 | expected = "\\foo\\bar" if isinstance(model_with_path.path, PureWindowsPath) else "/foo/bar" 164 | assert jsonable_encoder(model_with_path) == {"path": expected} 165 | 166 | 167 | def test_encode_root(): 168 | model = ModelWithRoot(__root__="Foo") 169 | assert jsonable_encoder(model) == "Foo" 170 | -------------------------------------------------------------------------------- /tests/test_method_type_conversion.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from natsapi import NatsAPI, SubjectRouter 7 | 8 | 9 | class StatusResult(BaseModel): 10 | status: typing.Any 11 | 12 | 13 | class FastAPI(BaseModel): 14 | pass 15 | 16 | 17 | async def test_method_parameters_should_get_parsed_to_correct_typing(app): 18 | class ThemesCreateCmd(BaseModel): 19 | primary: str 20 | color: typing.Union[str, None] = None 21 | 22 | @app.request("themes.CREATE", result=StatusResult) 23 | async def create_theme(app, data: ThemesCreateCmd): 24 | return {"status": data.primary} 25 | 26 | reply = await app.nc.request("natsapi.development.themes.CREATE", {"data": {"primary": "blue"}}) 27 | 28 | assert reply.result["status"] == "blue" 29 | 30 | 31 | def test_app_parameter_typing_should_validate_type(): 32 | router = SubjectRouter() 33 | 34 | @router.request("themes.CONVERT", result=StatusResult) 35 | async def convert_theme(app): 36 | return {"status": "OK"} 37 | 38 | @router.request("themes.CREATE", result=StatusResult) 39 | async def create_theme(app: NatsAPI): 40 | return {"status": "OK"} 41 | 42 | @router.request("themes.DELETE", result=StatusResult) 43 | async def delete_theme(app: FastAPI): 44 | return {"status": "OK"} 45 | 46 | @router.request("themes.UPDATE", result=StatusResult) 47 | async def update_theme(app: typing.Any): 48 | return {"status": "OK"} 49 | 50 | with pytest.raises(AssertionError) as exc: 51 | 52 | @router.request("themes.CALCULATE", result=StatusResult) 53 | async def calculate_theme(app: int): 54 | return {"status": "OK"} 55 | 56 | assert "Got int" in str(exc) 57 | 58 | 59 | class TypeResult(BaseModel): 60 | typing: str 61 | 62 | 63 | async def test_exotic_typing_should_convert_to_correct_type(app): 64 | @app.request("themes.CONVERT", result=TypeResult) 65 | async def convert_theme(app, param: typing.Union[list[str], int]): 66 | return {"typing": type(param).__name__} 67 | 68 | reply = await app.nc.request("natsapi.development.themes.CONVERT", {"param": ["foo", "bar", "baz"]}) 69 | assert reply.result["typing"] == "list" 70 | 71 | reply = await app.nc.request("natsapi.development.themes.CONVERT", {"param": 42}) 72 | assert reply.result["typing"] == "int" 73 | 74 | reply = await app.nc.request("natsapi.development.themes.CONVERT", {"param": "NOT A LIST OR INT"}) 75 | assert reply.error.code == -40001 # Validation Error 76 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from uuid import uuid4 3 | 4 | import pytest 5 | from pydantic import BaseModel, ValidationError 6 | from pydantic.fields import Field 7 | 8 | from natsapi.models import JsonRPCError, JsonRPCReply, JsonRPCRequest 9 | 10 | 11 | def test_change_param_type_of_model_should_change(): 12 | class Params(BaseModel): 13 | foo: int = Field(...) 14 | bar: Optional[str] = Field(None) 15 | 16 | new_model = JsonRPCRequest.with_params(Params) 17 | 18 | payload = {"params": {"foo": 22}, "timeout": 60} 19 | created_request = JsonRPCRequest.parse_obj(payload) 20 | d = new_model.parse_raw(created_request.json()) 21 | actual = new_model.parse_obj(d).params 22 | expected = Params 23 | 24 | assert type(actual) is expected 25 | 26 | 27 | def test_result_or_error_should_be_provided_in_jsonrpcreply(): 28 | with pytest.raises(ValidationError) as e: 29 | JsonRPCReply(id=uuid4()) 30 | assert "A result or error should be required" in str(e.value), e.value 31 | 32 | 33 | def test_result_and_error_should_not_be_provided_at_same_time_in_jsonrpcreply(): 34 | with pytest.raises(AttributeError) as e: 35 | JsonRPCReply(error={"code": 1, "message": "foobar"}, result={"status": "OK"}) 36 | assert "An RPC reply MUST NOT have an error and a result" in str(e) 37 | 38 | 39 | def test_jsponrpcerror_timestamp_is_generated_on_creation(): 40 | error_1 = JsonRPCError(code=1, message="", data=None) 41 | error_2 = JsonRPCError(code=1, message="", data=None) 42 | assert error_1.timestamp != error_2.timestamp 43 | -------------------------------------------------------------------------------- /tests/test_publish.py: -------------------------------------------------------------------------------- 1 | from natsapi import NatsAPI 2 | 3 | 4 | async def test_send_publish_should_be_called(client_config, event_loop): 5 | app = NatsAPI("natsapi.development", client_config=client_config) 6 | app.count = 0 7 | 8 | @app.publish(subject="foo") 9 | async def _(app): 10 | app.count += 1 11 | 12 | await app.startup(loop=event_loop) 13 | await app.nc.publish("natsapi.development.foo", {}) 14 | await app.shutdown(app) 15 | assert app.count == 1 16 | -------------------------------------------------------------------------------- /tests/test_requests.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from natsapi import NatsAPI, SubjectRouter 4 | from natsapi.context import CTX_JSONRPC_ID 5 | from natsapi.exceptions import JsonRPCException 6 | from natsapi.models import JsonRPCReply, JsonRPCRequest 7 | 8 | 9 | class StatusResult(BaseModel): 10 | status: str 11 | 12 | 13 | class BrokerAlreadyExists(JsonRPCException): 14 | def __init__(self, data=None): 15 | self.code = -27001 16 | self.message = "BROKER_EXISTS" 17 | self.data = data 18 | 19 | 20 | async def test_send_request_should_get_successful_reply(app): 21 | @app.request(subject="foo") 22 | async def _(app): 23 | return {"status": "OK"} 24 | 25 | reply = await app.nc.request("natsapi.development.foo", {"foo": 1}) 26 | assert reply.result["status"] == "OK" 27 | 28 | 29 | async def test_nonexistant_subject_should_get_failed_reply(app): 30 | reply = await app.nc.request("natsapi.development.nonexistant.CREATE", {}) 31 | assert reply.error.message == "NO_SUCH_ENDPOINT" 32 | assert reply.error.code == -32601 33 | 34 | 35 | async def test_incorrect_payload_should_get_failed_reply(app): 36 | @app.request(subject="foo") 37 | async def test(app, foo: int): 38 | return {"status": "OK", "foo": foo} 39 | 40 | reply = await app.nc.request("natsapi.development.foo", {"foo": "str"}) 41 | assert reply.error.code == -40001 42 | 43 | 44 | async def test_incorrect_request_format_should_fail(app): 45 | @app.request(subject="foo") 46 | async def test(app, foo: int, bar: str = None): 47 | return {"status": "OK", "foo": foo, "bar": bar} 48 | 49 | payload = {"timeout": 60, "foo": "str"} 50 | reply = await app.nc.request("natsapi.development.foo", payload) 51 | assert reply.error.code == -40001 52 | 53 | 54 | async def test_payload_with_request_method_in_payload_should_find_endpoint(app): 55 | @app.request(subject="foo") 56 | async def test(app, foo: int): 57 | return {"status": "OK", "foo": foo} 58 | 59 | reply = await app.nc.request("natsapi.development.foo", {"foo": 1}) 60 | assert reply.result["status"] == "OK" 61 | 62 | 63 | async def test_payload_with_request_method_in_subject_and_payload_should_prioritize_subject(app): 64 | @app.request(subject="foo") 65 | async def _(app, foo: int): 66 | return {"status": "OK", "foo": foo} 67 | 68 | reply = await app.nc.request("natsapi.development.foo", {"foo": 1}) 69 | assert reply.result["status"] == "OK" 70 | 71 | 72 | async def test_no_such_endpoint(app): 73 | reply = await app.nc.request("natsapi.development.foo", {"foo": 1}) 74 | assert reply.error.message == "NO_SUCH_ENDPOINT" 75 | assert reply.error.code == -32601 76 | 77 | 78 | async def test_payload_with_empty_request_method_and_method__in_subject_get_successful_reply(app): 79 | @app.request(subject="foo") 80 | async def _(app, foo: int): 81 | return {"status": "OK", "foo": foo} 82 | 83 | reply = await app.nc.request("natsapi.development.foo", {"foo": 1}) 84 | assert reply.result["status"] == "OK" 85 | 86 | 87 | async def test_payload_with_empty_request_method_and_method__in_subject_get_successful_reply_with_return_model(app): 88 | @app.request(subject="foo") 89 | async def _(app, foo: int): 90 | return StatusResult(status="OK") 91 | 92 | reply = await app.nc.request("natsapi.development.foo", {"foo": 1}) 93 | assert reply.result["status"] == "OK" 94 | 95 | 96 | async def test_unhandled_application_error_should_get_failed_reply(app): 97 | expected = EOFError("Unhandled exception, e.g. UniqueViolationError") 98 | router = SubjectRouter() 99 | 100 | @router.request("error.CONVERT", result=BaseModel) 101 | def raise_exception(app): 102 | raise expected 103 | 104 | app.include_router(router) 105 | 106 | reply = await app.nc.request("natsapi.development.error.CONVERT", {}) 107 | actual = reply.error.data 108 | 109 | assert type(expected).__name__ == actual["type"] 110 | assert reply.error.code == -40000 111 | 112 | 113 | async def test_method_raised_domain_specific_error_should_get_failed_reply_in_domain_error_range(app): 114 | router = SubjectRouter() 115 | 116 | @router.request("error.CONVERT", result=BaseModel) 117 | def raise_exception(app): 118 | raise BrokerAlreadyExists("Broker with UUID already exists in DB") 119 | 120 | app.include_router(router) 121 | 122 | reply = await app.nc.request("natsapi.development.error.CONVERT", {}) 123 | assert reply.error.message == "BROKER_EXISTS" 124 | assert reply.error.code == -27001 125 | 126 | 127 | async def test_handle_request_meant_for_multiple_services_should_get_successful_reply(client_config, event_loop): 128 | # given 129 | app = NatsAPI("natsapi.development", client_config=client_config) 130 | 131 | router = SubjectRouter() 132 | 133 | @router.request(subject="bar") 134 | async def _(app): 135 | return {"status": "OK"} 136 | 137 | app.include_router(router, root_path="bar") 138 | 139 | await app.startup(loop=event_loop) 140 | 141 | # when 142 | reply = await app.nc.request("bar.bar", {}) 143 | 144 | # then 145 | assert reply.result["status"] == "OK" 146 | 147 | # cleanup 148 | await app.shutdown(app) 149 | 150 | 151 | async def test_skip_validation_should_pass_original_dict_in_validator_and_have_model_in_schema(app): 152 | class SomeParams(BaseModel): 153 | foo: str 154 | 155 | @app.request("skip_validation.CREATE", result=StatusResult, skip_validation=True) 156 | def _(app, data: SomeParams, **kwargs): 157 | kwargs = kwargs 158 | _ = data.get("foo") # Should throw error if this is a pydantic model 159 | return {"status": str(type(data))} 160 | 161 | reply = await app.nc.request( 162 | "natsapi.development.skip_validation.CREATE", 163 | {"data": {"foo": "string", "undocumented_param": "bar"}, "extra_param": "string"}, 164 | ) 165 | assert "dict" in reply.result["status"] 166 | 167 | schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result 168 | assert "SomeParams" in schema["components"]["schemas"] 169 | 170 | 171 | async def test_each_nats_request_should_have_different_id(app, natsapi_mock): 172 | # given: 173 | await natsapi_mock.request("foobar", response={"status": "OK"}) 174 | 175 | # when: 2 reqeusts 176 | await app.nc.request("foobar", timeout=1) 177 | await app.nc.request("foobar", timeout=1) 178 | 179 | # then: 180 | assert natsapi_mock.payloads["foobar"][0]["id"] != natsapi_mock.payloads["foobar"][1]["id"] 181 | 182 | 183 | async def test_send_request_should_store_jsonrpc_id_in_contextvars(app): 184 | @app.request(subject="foo") 185 | async def _(app): 186 | return {"jsonrpc_id": CTX_JSONRPC_ID.get()} 187 | 188 | json_rpc_payload = JsonRPCRequest(params={"foo": 1}, method=None, timeout=60) 189 | reply_raw = await app.nc.nats.request("natsapi.development.foo", json_rpc_payload.json().encode(), 60, headers=None) 190 | reply = JsonRPCReply.parse_raw(reply_raw.data) 191 | 192 | assert reply.result["jsonrpc_id"] == str(json_rpc_payload.id) 193 | -------------------------------------------------------------------------------- /tests/test_routing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from natsapi import NatsAPI, SubjectRouter 5 | 6 | 7 | class NotificationParams(BaseModel): 8 | notification: str 9 | service: str 10 | 11 | 12 | class StatusResult(BaseModel): 13 | status: str 14 | 15 | 16 | def test_use_request_method_decorated_should_add_subject_to_router(): 17 | app = NatsAPI("natsapi.development") 18 | 19 | router = SubjectRouter(prefix="v1") 20 | 21 | @router.request("test.CREATE", result=StatusResult) 22 | def _(app) -> StatusResult: 23 | return {"status": "OK"} 24 | 25 | app.include_router(router) 26 | 27 | assert len(app.routes) == len(router.routes) 28 | 29 | 30 | def test_add_pub_should_add_pub_to_router(app: NatsAPI): 31 | router = SubjectRouter() 32 | 33 | @router.pub("*.subject.>", params=NotificationParams) 34 | @router.request("subject.RETRIEVE", result=StatusResult) 35 | async def _(app: NatsAPI): 36 | await app.nc.publish("notifications.CREATE", {"notification": "Hi", "service": "SMT"}) 37 | return {"status": "OK"} 38 | 39 | app.include_router(router) 40 | assert len(app.pubs) == 1 41 | 42 | 43 | def test_add_sub_should_add_sub_to_router(app: NatsAPI): 44 | router = SubjectRouter() 45 | 46 | @router.sub("*.subject.>", queue="natsapi") 47 | @router.request("subject.RETRIEVE", result=StatusResult) 48 | async def _(app: NatsAPI): 49 | await app.nc.subscribe("*.subject.>", queue="natsapi", cb=lambda _: print("some callback")) 50 | return {"status": "OK"} 51 | 52 | app.include_router(router) 53 | assert len(app.subs) == 2 54 | 55 | 56 | def test_add_route_w_skip_validation_but_no_args_kwargs_should_throw_error(app): 57 | with pytest.raises(AssertionError): 58 | 59 | @app.request("skip_validation.CREATE", result=StatusResult, skip_validation=True) 60 | def skip_validation(app): 61 | return {"status": "OK"} 62 | 63 | 64 | def test_two_routes_with_same_subject_should_throw_clear_exception_w_subject_router(app: NatsAPI): 65 | router = SubjectRouter() 66 | 67 | @router.request("foo", result=StatusResult) 68 | async def _(app: NatsAPI): # noqa: F811 69 | return {"status": "OK"} 70 | 71 | @router.request("foo", result=StatusResult) 72 | async def _(app: NatsAPI): # noqa: F811 73 | return {"status": "OK"} 74 | 75 | with pytest.raises(Exception) as e: 76 | app.include_router(router, root_path="foo") 77 | assert "defined twice" in str(e.value) 78 | 79 | 80 | def test_pub_and_sub_with_same_subject_should_throw_clear_exception(app: NatsAPI): 81 | @app.publish("foo") 82 | async def _(app: NatsAPI): # noqa: F811 83 | return {"status": "OK"} 84 | 85 | with pytest.raises(Exception) as e: 86 | 87 | @app.request("foo", result=StatusResult) 88 | async def _(app: NatsAPI): # noqa: F811 89 | return {"status": "OK"} 90 | 91 | assert "defined twice" in str(e.value) 92 | 93 | 94 | def test_two_pubs_with_same_subject_should_throw_clear_exception(app: NatsAPI): 95 | @app.publish("foo") 96 | async def _(app: NatsAPI): # noqa: F811 97 | return {"status": "OK"} 98 | 99 | with pytest.raises(Exception) as e: 100 | 101 | @app.publish("foo") 102 | async def _(app: NatsAPI): # noqa: F811 103 | return {"status": "OK"} 104 | 105 | assert "defined twice" in str(e.value) 106 | 107 | 108 | def test_two_reqs_with_same_subject_should_throw_clear_exception(app: NatsAPI): 109 | @app.request("foo") 110 | async def _(app: NatsAPI): # noqa: F811 111 | return {"status": "OK"} 112 | 113 | with pytest.raises(Exception) as e: 114 | 115 | @app.request("foo") 116 | async def _(app: NatsAPI): # noqa: F811 117 | return {"status": "OK"} 118 | 119 | assert "defined twice" in str(e.value) 120 | --------------------------------------------------------------------------------