├── rolo ├── py.typed ├── testing │ ├── __init__.py │ └── pytest.py ├── resource.py ├── serving │ └── __init__.py ├── dispatcher.py ├── __init__.py ├── websocket │ ├── __init__.py │ ├── websocket.py │ ├── errors.py │ ├── adapter.py │ └── request.py ├── routing │ ├── __init__.py │ ├── converter.py │ ├── pydantic.py │ ├── handler.py │ ├── resource.py │ └── rules.py ├── router.py ├── gateway │ ├── __init__.py │ ├── wsgi.py │ ├── gateway.py │ ├── asgi.py │ └── handlers.py ├── response.py ├── client.py ├── proxy.py └── request.py ├── docs ├── _static │ ├── .gitkeep │ └── rolo.png ├── _templates │ └── .gitkeep ├── requirements.txt ├── Makefile ├── make.bat ├── gateway.md ├── index.md ├── conf.py ├── websockets.md ├── serving.md ├── getting_started.md ├── router.md └── handler_chain.md ├── tests ├── gateway │ ├── __init__.py │ ├── conftest.py │ ├── test_context.py │ ├── test_websocket.py │ ├── test_headers.py │ ├── test_handlers.py │ └── test_chain.py ├── serving │ ├── __init__.py │ └── test_twisted.py ├── static │ ├── __init__.py │ ├── test.txt │ └── index.html ├── websocket │ ├── __init__.py │ └── test_websockets.py ├── conftest.py ├── __init__.py ├── test_client.py ├── test_response.py ├── test_dispatcher.py ├── test_resource.py ├── test_pydantic.py └── test_request.py ├── CODEOWNERS ├── .github ├── workflows │ ├── docs-publish.yml │ └── build.yml └── PULL_REQUEST_TEMPLATE.md ├── Makefile ├── .gitignore ├── pyproject.toml ├── examples └── json-rpc │ └── server.py ├── README.md └── LICENSE /rolo/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rolo/testing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/gateway/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/serving/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/static/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/websocket/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/static/test.txt: -------------------------------------------------------------------------------- 1 | hello world 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | furo 3 | myst_parser 4 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | pytest_plugins = [ 2 | "rolo.testing.pytest", 3 | ] 4 | -------------------------------------------------------------------------------- /tests/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | hello 3 | 4 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # by default, add all repo maintainers as reviewers 2 | * @thrau 3 | -------------------------------------------------------------------------------- /docs/_static/rolo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/localstack/rolo/HEAD/docs/_static/rolo.png -------------------------------------------------------------------------------- /rolo/resource.py: -------------------------------------------------------------------------------- 1 | """ 2 | DEPRECATED: use ``from rolo.routing import Resource, resource`` instead 3 | """ 4 | from .routing.resource import Resource, resource 5 | 6 | __all__ = [ 7 | "Resource", 8 | "resource", 9 | ] 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # This module cannot be named http, since pycharm cannot run in tests in the module above anymore in debug mode 2 | # - https://youtrack.jetbrains.com/issue/PY-54265 3 | # - https://youtrack.jetbrains.com/issue/PY-35220 4 | -------------------------------------------------------------------------------- /rolo/serving/__init__.py: -------------------------------------------------------------------------------- 1 | """This module holds glue code between web servers and rolo. Any ASGI-compliant server can server rolo with 2 | a few patches via ``rolo.serving.asgi``. Moreover, rolo can be served through twisted via 3 | ``rolo.serving.twisted``.""" 4 | -------------------------------------------------------------------------------- /rolo/dispatcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | DEPRECATED: use ``from rolo.routing import handler_dispatcher`` instead 3 | """ 4 | from .routing.handler import ResultValue, handler_dispatcher 5 | 6 | __all__ = [ 7 | "handler_dispatcher", 8 | "ResultValue", 9 | ] 10 | -------------------------------------------------------------------------------- /rolo/__init__.py: -------------------------------------------------------------------------------- 1 | from .request import Request 2 | from .response import Response 3 | from .routing.resource import Resource, resource 4 | from .routing.router import Router, route 5 | 6 | __all__ = [ 7 | "route", 8 | "resource", 9 | "Resource", 10 | "Router", 11 | "Response", 12 | "Request", 13 | ] 14 | -------------------------------------------------------------------------------- /.github/workflows/docs-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] # branch to trigger deployment 6 | 7 | jobs: 8 | pages: 9 | runs-on: ubuntu-latest 10 | environment: 11 | name: github-pages 12 | url: ${{ steps.deployment.outputs.page_url }} 13 | permissions: 14 | pages: write 15 | id-token: write 16 | steps: 17 | - id: deployment 18 | uses: sphinx-notes/pages@v3 19 | -------------------------------------------------------------------------------- /rolo/websocket/__init__.py: -------------------------------------------------------------------------------- 1 | from .adapter import WebSocketEnvironment, WebSocketListener 2 | from .errors import WebSocketDisconnectedError, WebSocketError, WebSocketProtocolError 3 | from .request import WebSocket, WebSocketRequest 4 | 5 | __all__ = [ 6 | "WebSocket", 7 | "WebSocketDisconnectedError", 8 | "WebSocketEnvironment", 9 | "WebSocketError", 10 | "WebSocketListener", 11 | "WebSocketProtocolError", 12 | "WebSocketRequest", 13 | ] 14 | -------------------------------------------------------------------------------- /rolo/routing/__init__.py: -------------------------------------------------------------------------------- 1 | from .converter import PortConverter, RegexConverter 2 | from .handler import handler_dispatcher 3 | from .router import RequestArguments, Router, route 4 | from .rules import RuleAdapter, RuleGroup, WithHost 5 | 6 | __all__ = [ 7 | "PortConverter", 8 | "RegexConverter", 9 | "RequestArguments", 10 | "Router", 11 | "RuleAdapter", 12 | "RuleGroup", 13 | "WithHost", 14 | "handler_dispatcher", 15 | "route", 16 | ] 17 | -------------------------------------------------------------------------------- /rolo/router.py: -------------------------------------------------------------------------------- 1 | """ 2 | DEPRECATED: use ``rolo.routing`` instead. 3 | """ 4 | 5 | from .routing.converter import PortConverter, RegexConverter 6 | from .routing.router import RequestArguments, Router, route 7 | from .routing.rules import RuleAdapter, RuleGroup, WithHost 8 | 9 | __all__ = [ 10 | "Router", 11 | "route", 12 | "RegexConverter", 13 | "PortConverter", 14 | "RuleAdapter", 15 | "RuleGroup", 16 | "WithHost", 17 | "RequestArguments", 18 | ] 19 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | ## Motivation 3 | 4 | 5 | 6 | ## Changes 7 | 8 | 9 | 13 | 14 | 15 | 23 | -------------------------------------------------------------------------------- /rolo/websocket/websocket.py: -------------------------------------------------------------------------------- 1 | """TODO: remove with release (just keeping this so localstack doesn't blow up)""" 2 | from .adapter import WebSocketEnvironment, WebSocketListener 3 | from .errors import WebSocketDisconnectedError, WebSocketError, WebSocketProtocolError 4 | from .request import WebSocket, WebSocketRequest 5 | 6 | __all__ = [ 7 | "WebSocket", 8 | "WebSocketDisconnectedError", 9 | "WebSocketEnvironment", 10 | "WebSocketError", 11 | "WebSocketListener", 12 | "WebSocketProtocolError", 13 | "WebSocketRequest", 14 | ] 15 | -------------------------------------------------------------------------------- /rolo/gateway/__init__.py: -------------------------------------------------------------------------------- 1 | from .chain import ( 2 | CompositeExceptionHandler, 3 | CompositeFinalizer, 4 | CompositeHandler, 5 | CompositeResponseHandler, 6 | ExceptionHandler, 7 | Handler, 8 | HandlerChain, 9 | RequestContext, 10 | ) 11 | from .gateway import Gateway 12 | 13 | __all__ = [ 14 | "Gateway", 15 | "HandlerChain", 16 | "RequestContext", 17 | "Handler", 18 | "ExceptionHandler", 19 | "CompositeHandler", 20 | "CompositeExceptionHandler", 21 | "CompositeResponseHandler", 22 | "CompositeFinalizer", 23 | ] 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/gateway/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rolo.gateway import Gateway 4 | from rolo.testing.pytest import Server 5 | 6 | 7 | @pytest.fixture 8 | def serve_gateway(request): 9 | def _serve(gateway: Gateway) -> Server: 10 | try: 11 | gw_type = request.param 12 | except AttributeError: 13 | gw_type = "wsgi" 14 | 15 | if gw_type == "asgi": 16 | fixture = request.getfixturevalue("serve_asgi_gateway") 17 | elif gw_type == "twisted": 18 | fixture = request.getfixturevalue("serve_twisted_gateway") 19 | else: 20 | fixture = request.getfixturevalue("serve_wsgi_gateway") 21 | 22 | return fixture(gateway) 23 | 24 | yield _serve 25 | -------------------------------------------------------------------------------- /rolo/websocket/errors.py: -------------------------------------------------------------------------------- 1 | class WebSocketError(IOError): 2 | """Base class for websocket errors""" 3 | 4 | pass 5 | 6 | 7 | class WebSocketDisconnectedError(WebSocketError): 8 | """Raised when the client has disconnected while the server is still trying to receive data.""" 9 | 10 | default_code = 1005 11 | """https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event-ws""" 12 | 13 | def __init__(self, code: int = None): 14 | self.code = code if code is not None else self.default_code 15 | super().__init__(f"Websocket disconnected code={self.code}") 16 | 17 | 18 | class WebSocketProtocolError(WebSocketError): 19 | """Raised if there is a problem in the interaction between app and the websocket server.""" 20 | 21 | pass 22 | -------------------------------------------------------------------------------- /tests/gateway/test_context.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rolo.gateway import RequestContext 4 | 5 | 6 | def test_set_and_access_data(): 7 | context = RequestContext() 8 | 9 | context.some_data = "foo" 10 | assert context.some_data == "foo" 11 | 12 | 13 | def test_access_non_existing_data(): 14 | context = RequestContext() 15 | 16 | with pytest.raises(AttributeError): 17 | assert context.some_data 18 | 19 | 20 | def test_get_non_existing_data(): 21 | context = RequestContext() 22 | 23 | assert context.get("some_data") is None 24 | 25 | 26 | def test_set_and_get_data(): 27 | context = RequestContext() 28 | 29 | context.some_data = "foo" 30 | assert context.get("some_data") == "foo" 31 | context.some_data = "bar" 32 | assert context.get("some_data") == "bar" 33 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/gateway/test_websocket.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import websocket 3 | 4 | from rolo import Router, route 5 | from rolo.gateway import Gateway 6 | from rolo.gateway.handlers import RouterHandler 7 | from rolo.websocket.request import WebSocketRequest 8 | 9 | 10 | @pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True) 11 | def test_gateway_router_websocket_integration(serve_gateway): 12 | @route("/ws", methods=["WEBSOCKET"]) 13 | def _handler(request: WebSocketRequest, args): 14 | with request.accept() as ws: 15 | ws.send("hello") 16 | ws.send(ws.receive()) 17 | 18 | router = Router() 19 | router.add(_handler) 20 | 21 | server = serve_gateway(Gateway(request_handlers=[RouterHandler(router)])) 22 | 23 | client = websocket.WebSocket() 24 | client.connect(server.url.replace("http://", "ws://") + "/ws") 25 | 26 | assert client.recv() == "hello" 27 | client.send("foobar") 28 | assert client.recv() == "foobar" 29 | client.close() 30 | -------------------------------------------------------------------------------- /tests/gateway/test_headers.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import requests 5 | 6 | from rolo.gateway import Gateway, HandlerChain, RequestContext 7 | 8 | 9 | @pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True) 10 | def test_raw_header_handling(serve_gateway): 11 | def handler(chain: HandlerChain, context: RequestContext, response): 12 | response.data = json.dumps({"headers": dict(context.request.headers)}) 13 | response.mimetype = "application/json" 14 | response.headers["X-fOO_bar"] = "FooBar" 15 | return response 16 | 17 | gateway = Gateway(request_handlers=[handler]) 18 | 19 | srv = serve_gateway(gateway) 20 | 21 | response = requests.get( 22 | srv.url, 23 | headers={"x-mIxEd-CaSe": "myheader", "X-UPPER__CASE": "uppercase"}, 24 | ) 25 | returned_headers = response.json()["headers"] 26 | assert "X-UPPER__CASE" in returned_headers 27 | assert "x-mIxEd-CaSe" in returned_headers 28 | assert "X-fOO_bar" in dict(response.headers) 29 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | paths-ignore: 6 | - 'README.md' 7 | - 'docs/**' 8 | branches: 9 | - main 10 | pull_request: 11 | branches: 12 | - main 13 | 14 | jobs: 15 | test: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | python-version: [ "3.10", "3.11", "3.12" ] 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v3 26 | 27 | - name: Set up Python 28 | id: setup-python 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | 33 | - name: Run install 34 | run: | 35 | make install 36 | 37 | - name: Run linting 38 | run: | 39 | make lint 40 | 41 | - name: Run tests 42 | run: | 43 | make test-coverage 44 | 45 | - name: Report coverage 46 | run: | 47 | make coveralls 48 | -------------------------------------------------------------------------------- /rolo/routing/converter.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from werkzeug.routing import BaseConverter, Map 4 | 5 | 6 | class RegexConverter(BaseConverter): 7 | """ 8 | A converter that can be used to inject a regex as parameter, e.g., ``path=/``. 9 | When using groups in regex, make sure they are non-capturing ``(?:[a-z]+)`` 10 | """ 11 | 12 | def __init__(self, map: Map, *args: t.Any, **kwargs: t.Any) -> None: 13 | super().__init__(map, *args, **kwargs) 14 | self.regex = args[0] 15 | 16 | 17 | class PortConverter(BaseConverter): 18 | """ 19 | Useful to optionally match ports for host patterns, like ``localstack.localhost.cloud``. Notice how you 20 | don't need to specify the colon. The regex matches it if the port is there, and will remove the colon if matched. 21 | The converter converts the port to an int, or returns None if there's no port in the input string. 22 | """ 23 | 24 | regex = r"(?::[0-9]{1,5})?" 25 | 26 | def to_python(self, value: str) -> t.Any: 27 | if value: 28 | return int(value[1:]) 29 | return None 30 | -------------------------------------------------------------------------------- /docs/gateway.md: -------------------------------------------------------------------------------- 1 | Gateway 2 | ======= 3 | 4 | The `Gateway` serves as the factory for `HandlerChain` instances. 5 | It also serves as the interface to HTTP servers. 6 | 7 | ## Creating a Gateway 8 | 9 | ```python 10 | from rolo.gateway import Gateway 11 | 12 | gateway = Gateway( 13 | request_handlers=[ 14 | ... 15 | ], 16 | response_handlers=[ 17 | ... 18 | ], 19 | exception_handlers=[ 20 | ... 21 | ], 22 | finalizers=[ 23 | ... 24 | ] 25 | ) 26 | ``` 27 | 28 | ## Protocol adapters 29 | 30 | You can use `rolo.gateway.wsgi` or `rolo.gateway.asgi` to expose a `Gateway` as either a WSGI or ASGI app. 31 | 32 | Read more in the [serving](serving.md) section. 33 | 34 | ## Custom `RequestContext` 35 | 36 | You can add a custom request context with type hints or your own methods by setting the `context_class` parameter in the constructor. 37 | First, define a request context subclass: 38 | 39 | ```python 40 | from rolo.gateway import RequestContext 41 | 42 | class MyContext(RequestContext): 43 | myattr: str 44 | ``` 45 | 46 | Then, when you instantiate the Gateway: 47 | 48 | ```python 49 | gateway = Gateway( 50 | request_handlers=[ 51 | ... 52 | ], 53 | context_class=MyContext 54 | ) 55 | ``` 56 | 57 | In your handlers, you can now reference your context: 58 | 59 | ```python 60 | def handler(chain: HandlerChain, context: MyContext, response: Response): 61 | ... 62 | ``` 63 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VENV_BIN = python3 -m venv 2 | VENV_DIR ?= .venv 3 | VENV_ACTIVATE = $(VENV_DIR)/bin/activate 4 | VENV_RUN = . $(VENV_ACTIVATE) 5 | ROOT_MODULE = rolo 6 | 7 | venv: $(VENV_ACTIVATE) 8 | 9 | $(VENV_ACTIVATE): pyproject.toml 10 | test -d .venv || $(VENV_BIN) .venv 11 | $(VENV_RUN); pip install --upgrade pip setuptools wheel 12 | $(VENV_RUN); pip install -e .[dev] 13 | touch $(VENV_DIR)/bin/activate 14 | 15 | install: venv 16 | 17 | clean: 18 | rm -rf .venv 19 | rm -rf build/ 20 | rm -rf .eggs/ 21 | rm -rf *.egg-info/ 22 | 23 | format: 24 | $(VENV_RUN); python -m ruff check --show-source --fix .; python -m black . 25 | 26 | lint: 27 | $(VENV_RUN); python -m ruff check --show-source . && python -m black --check . 28 | 29 | test: venv 30 | $(VENV_RUN); python -m pytest 31 | 32 | test-coverage: venv 33 | $(VENV_RUN); coverage run --source=$(ROOT_MODULE) -m pytest tests/ 34 | 35 | coveralls: venv 36 | $(VENV_RUN); coveralls 37 | 38 | $(VENV_DIR)/.docs-install: pyproject.toml $(VENV_ACTIVATE) 39 | $(VENV_RUN); pip install -e .[docs] 40 | touch $(VENV_DIR)/.docs-install 41 | 42 | install-docs: $(VENV_DIR)/.docs-install 43 | 44 | docs: install-docs 45 | $(VENV_RUN); cd docs && make html 46 | 47 | dist: venv 48 | $(VENV_RUN); pip install --upgrade build; python -m build 49 | 50 | publish: clean-dist venv test dist 51 | $(VENV_RUN); pip install --upgrade twine; twine upload dist/* 52 | 53 | clean-dist: clean 54 | rm -rf dist/ 55 | 56 | .PHONY: clean clean-dist 57 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | Rolo documentation 2 | ================== 3 | 4 |

5 | Rolo HTTP 6 |

7 |

8 | Rolo HTTP: A Python framework for building HTTP-based server applications. 9 |

10 | 11 | ## Introduction 12 | 13 | Rolo is a flexible framework and library to build HTTP-based server applications beyond microservices and REST APIs. 14 | You can build HTTP-based RPC servers, websocket proxies, or other server types that typical web frameworks are not designed for. 15 | Rolo was originally designed to build the AWS RPC protocol server in [LocalStack](https://github.com/localstack/localstack). 16 | 17 | Rolo extends [Werkzeug](https://github.com/pallets/werkzeug/), a flexible Python HTTP server library, for you to use concepts you are familiar with like ``@route``, ``Request``, or ``Response``. 18 | It introduces the concept of a ``Gateway`` and ``HandlerChain``, an implementation variant of the [chain-of-responsibility pattern](https://en.wikipedia.org/wiki/Chain-of-responsibility_pattern). 19 | 20 | Rolo is designed for environments that do not use asyncio, but still require asynchronous HTTP features like HTTP2 SSE or Websockets. 21 | To allow asynchronous communication, Rolo introduces an ASGI/WSGI bridge, that allows you to serve Rolo applications through ASGI servers like Hypercorn. 22 | 23 | ## Table of Content 24 | 25 | ```{toctree} 26 | :caption: Quickstart 27 | :maxdepth: 2 28 | 29 | getting_started 30 | ``` 31 | 32 | ```{toctree} 33 | :caption: User Guide 34 | :maxdepth: 2 35 | 36 | router 37 | handler_chain 38 | gateway 39 | websockets 40 | serving 41 | ``` 42 | 43 | ```{toctree} 44 | :caption: Tutorials 45 | :maxdepth: 1 46 | 47 | tutorials/jsonrpc-server 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /rolo/gateway/wsgi.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from werkzeug.datastructures import Headers, MultiDict 4 | from werkzeug.wrappers import Request 5 | 6 | from rolo.response import Response 7 | 8 | if t.TYPE_CHECKING: 9 | from _typeshed.wsgi import StartResponse, WSGIEnvironment 10 | 11 | import logging 12 | 13 | from .gateway import Gateway 14 | 15 | LOG = logging.getLogger(__name__) 16 | 17 | 18 | class WsgiGateway: 19 | """ 20 | Exposes a ``Gateway`` as a WSGI application. 21 | """ 22 | 23 | gateway: Gateway 24 | 25 | def __init__(self, gateway: Gateway) -> None: 26 | super().__init__() 27 | self.gateway = gateway 28 | 29 | def __call__( 30 | self, environ: "WSGIEnvironment", start_response: "StartResponse" 31 | ) -> t.Iterable[bytes]: 32 | # create request from environment 33 | LOG.debug( 34 | "%s %s%s", 35 | environ["REQUEST_METHOD"], 36 | environ.get("HTTP_HOST"), 37 | environ.get("RAW_URI"), 38 | ) 39 | request = Request(environ) 40 | 41 | raw_headers = environ.get("rolo.headers") or environ.get("asgi.headers") 42 | if raw_headers: 43 | # restores raw headers from ASGI scope, which allows dashes in header keys 44 | # see https://github.com/pallets/werkzeug/issues/940 45 | 46 | request.headers = Headers( 47 | MultiDict([(k.decode("latin-1"), v.decode("latin-1")) for (k, v) in raw_headers]) 48 | ) 49 | else: 50 | # by default, werkzeug requests from environ are immutable 51 | request.headers = Headers(request.headers) 52 | 53 | # prepare response 54 | response = Response() 55 | 56 | self.gateway.process(request, response) 57 | 58 | return response(environ, start_response) 59 | -------------------------------------------------------------------------------- /tests/serving/test_twisted.py: -------------------------------------------------------------------------------- 1 | import http.client 2 | import io 3 | import json 4 | 5 | import requests 6 | 7 | from rolo import Request, Router, route 8 | from rolo.dispatcher import handler_dispatcher 9 | from rolo.gateway import Gateway 10 | from rolo.gateway.handlers import RouterHandler 11 | 12 | 13 | def test_large_file_upload(serve_twisted_gateway): 14 | router = Router(handler_dispatcher()) 15 | 16 | @route("/hello", methods=["POST"]) 17 | def hello(request: Request): 18 | return "ok" 19 | 20 | router.add(hello) 21 | 22 | gateway = Gateway(request_handlers=[RouterHandler(router, True)]) 23 | server = serve_twisted_gateway(gateway) 24 | 25 | response = requests.post(server.url + "/hello", io.BytesIO(b"0" * 100001)) 26 | 27 | assert response.status_code == 200 28 | 29 | 30 | def test_full_absolute_form_uri(serve_twisted_gateway): 31 | router = Router(handler_dispatcher()) 32 | 33 | @route("/hello", methods=["GET"]) 34 | def hello(request: Request): 35 | return {"path": request.path, "raw_uri": request.environ["RAW_URI"]} 36 | 37 | router.add(hello) 38 | 39 | gateway = Gateway(request_handlers=[RouterHandler(router, True)]) 40 | server = serve_twisted_gateway(gateway) 41 | host = server.url 42 | 43 | conn = http.client.HTTPConnection(host="127.0.0.1", port=server.port) 44 | 45 | # This is what is sent: 46 | # send: b'GET http://localhost:/hello HTTP/1.1\r\nHost: localhost:\r\nAccept-Encoding: identity\r\n\r\n' 47 | # note the full URI in the HTTP request 48 | conn.request("GET", url=f"{host}/hello") 49 | response = conn.getresponse() 50 | 51 | assert response.status == 200 52 | response_data = json.loads(response.read()) 53 | assert response_data["path"] == "/hello" 54 | assert response_data["raw_uri"].startswith("http") 55 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | from pytest_httpserver import HTTPServer 2 | from werkzeug import Request as WerkzeugRequest 3 | from werkzeug.datastructures import Headers 4 | 5 | from rolo import Response 6 | from rolo.client import SimpleRequestsClient 7 | from rolo.request import Request 8 | 9 | 10 | def echo_request_metadata_handler(request: WerkzeugRequest) -> Response: 11 | """ 12 | Simple request handler that returns the incoming request metadata (method, path, url, headers). 13 | 14 | :param request: the incoming HTTP request 15 | :return: an HTTP response 16 | """ 17 | response = Response() 18 | response.set_json( 19 | { 20 | "method": request.method, 21 | "path": request.path, 22 | "url": request.url, 23 | "headers": dict(Headers(request.headers)), 24 | } 25 | ) 26 | return response 27 | 28 | 29 | class TestSimpleRequestClient: 30 | def test_empty_accept_encoding_header(self, httpserver: HTTPServer): 31 | httpserver.expect_request("/").respond_with_handler(echo_request_metadata_handler) 32 | 33 | url = httpserver.url_for("/") 34 | 35 | request = Request(path="/", method="GET") 36 | 37 | with SimpleRequestsClient() as client: 38 | response = client.request(request, url) 39 | 40 | assert "Accept-Encoding" not in response.json["headers"] 41 | assert "accept-encoding" not in response.json["headers"] 42 | 43 | def test_multi_values_headers(self, httpserver: HTTPServer): 44 | def multi_values_handler(_request: WerkzeugRequest) -> Response: 45 | multi_headers = Headers() 46 | multi_headers.add("Set-Cookie", "value1") 47 | multi_headers.add("Set-Cookie", "value2") 48 | assert multi_headers.getlist("Set-Cookie") == ["value1", "value2"] 49 | 50 | return Response(headers=multi_headers) 51 | 52 | httpserver.expect_request("/").respond_with_handler(multi_values_handler) 53 | 54 | url = httpserver.url_for("/") 55 | 56 | request = Request(path="/", method="GET") 57 | 58 | with SimpleRequestsClient() as client: 59 | response = client.request(request, url) 60 | assert response.headers.getlist("Set-Cookie") == ["value1", "value2"] 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.iml 3 | *~ 4 | 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | # Thumbnails 14 | ._* 15 | 16 | # Files that might appear in the root of a volume 17 | .DocumentRevisions-V100 18 | .fseventsd 19 | .Spotlight-V100 20 | .TemporaryItems 21 | .Trashes 22 | .VolumeIcon.icns 23 | .com.apple.timemachine.donotpresent 24 | 25 | # Directories potentially created on remote AFP share 26 | .AppleDB 27 | .AppleDesktop 28 | Network Trash Folder 29 | Temporary Items 30 | .apdisk 31 | 32 | # Byte-compiled / optimized / DLL files 33 | __pycache__/ 34 | *.py[cod] 35 | *$py.class 36 | 37 | # C extensions 38 | *.so 39 | 40 | # Distribution / packaging 41 | .Python 42 | build/ 43 | develop-eggs/ 44 | dist/ 45 | downloads/ 46 | eggs/ 47 | .eggs/ 48 | lib/ 49 | lib64/ 50 | parts/ 51 | sdist/ 52 | var/ 53 | wheels/ 54 | *.egg-info/ 55 | .installed.cfg 56 | *.egg 57 | MANIFEST 58 | 59 | # PyInstaller 60 | # Usually these files are written by a python script from a template 61 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 62 | *.manifest 63 | *.spec 64 | 65 | # Installer logs 66 | pip-log.txt 67 | pip-delete-this-directory.txt 68 | 69 | # Unit test / coverage reports 70 | htmlcov/ 71 | .tox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | .hypothesis/ 79 | 80 | # Translations 81 | *.mo 82 | *.pot 83 | 84 | # Django stuff: 85 | *.log 86 | .static_storage/ 87 | .media/ 88 | local_settings.py 89 | 90 | # Flask stuff: 91 | instance/ 92 | .webassets-cache 93 | 94 | # Scrapy stuff: 95 | .scrapy 96 | 97 | # Sphinx documentation 98 | docs/_build/ 99 | 100 | # PyBuilder 101 | target/ 102 | 103 | # Jupyter Notebook 104 | .ipynb_checkpoints 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # celery beat schedule file 110 | celerybeat-schedule 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | -------------------------------------------------------------------------------- /rolo/gateway/gateway.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing as t 3 | 4 | from ..request import Request 5 | from ..response import Response 6 | from ..websocket.request import WebSocketRequest 7 | from .chain import ExceptionHandler, Handler, HandlerChain, RequestContext 8 | 9 | LOG = logging.getLogger(__name__) 10 | 11 | 12 | class Gateway: 13 | """ 14 | A gateway creates new HandlerChain instances for each request and processes requests through them. 15 | """ 16 | 17 | request_handlers: list[Handler] 18 | response_handlers: list[Handler] 19 | finalizers: list[Handler] 20 | exception_handlers: list[ExceptionHandler] 21 | 22 | def __init__( 23 | self, 24 | request_handlers: list[Handler] = None, 25 | response_handlers: list[Handler] = None, 26 | finalizers: list[Handler] = None, 27 | exception_handlers: list[ExceptionHandler] = None, 28 | context_class: t.Type[RequestContext] = None, 29 | ) -> None: 30 | super().__init__() 31 | self.request_handlers = request_handlers if request_handlers is not None else [] 32 | self.response_handlers = response_handlers if response_handlers is not None else [] 33 | self.exception_handlers = exception_handlers if exception_handlers is not None else [] 34 | self.finalizers = finalizers if finalizers is not None else [] 35 | self.context_class = context_class or RequestContext 36 | 37 | def new_chain(self) -> HandlerChain: 38 | return HandlerChain( 39 | self.request_handlers, 40 | self.response_handlers, 41 | self.finalizers, 42 | self.exception_handlers, 43 | ) 44 | 45 | def process(self, request: Request, response: Response): 46 | chain = self.new_chain() 47 | 48 | context = self.context_class(request) 49 | 50 | chain.handle(context, response) 51 | 52 | def accept(self, request: WebSocketRequest): 53 | response = Response(status=101) 54 | self.process(request, response) 55 | 56 | # only send the populated response if the websocket hasn't already done so before 57 | if response.status_code != 101: 58 | if request.is_upgraded(): 59 | return 60 | if request.is_rejected(): 61 | return 62 | request.reject(response) 63 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'Rolo' 10 | copyright = '2024, LocalStack' 11 | author = 'Thomas Rausch' 12 | release = '0.6.x' 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | 'myst_parser' 19 | ] 20 | 21 | templates_path = ['_templates'] 22 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 23 | 24 | 25 | # -- Options for HTML output ------------------------------------------------- 26 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 27 | 28 | html_theme = "furo" 29 | html_static_path = ["_static"] 30 | html_title = "Rolo documentation" 31 | 32 | html_logo = "_static/rolo.png" 33 | html_theme_options = { 34 | "top_of_page_buttons": ["view", "edit"], 35 | "source_repository": "https://github.com/localstack/rolo/", 36 | "source_branch": "main", 37 | "source_directory": "docs/", 38 | "footer_icons": [ 39 | { 40 | "name": "GitHub", 41 | "url": "https://github.com/localstack/rolo", 42 | "html": """ 43 | 44 | 45 | 46 | """, 47 | "class": "", 48 | }, 49 | ], 50 | } 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rolo" 7 | authors = [ 8 | { name = "LocalStack Contributors", email = "info@localstack.cloud" } 9 | ] 10 | version = "0.7.6" 11 | description = "A Python framework for building HTTP-based server applications" 12 | dependencies = [ 13 | "requests>=2.20", 14 | "werkzeug>=3.0" 15 | ] 16 | requires-python = ">=3.10" 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Development Status :: 5 - Production/Stable", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | "Topic :: Software Development :: Libraries" 27 | ] 28 | dynamic = ["readme"] 29 | 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | "black==23.10.0", 34 | "pytest>=7.0", 35 | "hypercorn", 36 | "pydantic", 37 | "pytest_httpserver", 38 | "websocket-client>=1.7.0", 39 | "coverage[toml]>=5.0.0", 40 | "coveralls>=3.3", 41 | "localstack-twisted", 42 | "ruff==0.1.0" 43 | ] 44 | docs = [ 45 | "sphinx", 46 | "furo", 47 | "myst_parser", 48 | ] 49 | 50 | [tool.setuptools] 51 | include-package-data = false 52 | 53 | [tool.setuptools.dynamic] 54 | readme = {file = ["README.md"], content-type = "text/markdown"} 55 | 56 | [tool.setuptools.packages.find] 57 | include = ["rolo*"] 58 | exclude = ["tests*"] 59 | 60 | [tool.setuptools.package-data] 61 | "*" = ["*.md"] 62 | 63 | [tool.black] 64 | line_length = 100 65 | include = '((rolo)/.*\.py$|tests/.*\.py$)' 66 | #extend_exclude = '()' 67 | 68 | [tool.ruff] 69 | # Always generate Python 3.10-compatible code. 70 | target-version = "py310" 71 | line-length = 110 72 | select = ["B", "C", "E", "F", "I", "W", "T", "B9"] 73 | ignore = [ 74 | "E501", # E501 Line too long - handled by black, see https://docs.astral.sh/ruff/faq/#is-ruff-compatible-with-black 75 | ] 76 | exclude = [ 77 | ".venv*", 78 | "venv*", 79 | "dist", 80 | "build", 81 | "target", 82 | "*.egg-info", 83 | ".git", 84 | ] 85 | 86 | [tool.coverage.report] 87 | exclude_lines = [ 88 | "if __name__ == .__main__.:", 89 | "raise NotImplemented.", 90 | "return NotImplemented", 91 | "def __repr__", 92 | "__all__", 93 | ] 94 | -------------------------------------------------------------------------------- /docs/websockets.md: -------------------------------------------------------------------------------- 1 | Websockets 2 | ========== 3 | 4 | Rolo supports Websockets through ASGI and Twisted (see [serving](serving.md)). 5 | 6 | ## Websocket requests 7 | 8 | Rolo introduces an HTTP method called `WEBSOCKET`, which can be used to register routes that deal with websocket 9 | requests. 10 | 11 | ```python 12 | from rolo import route 13 | from rolo.websocket import WebSocketRequest 14 | 15 | 16 | @route("/stream", methods=["WEBSOCKET"]) 17 | def handler(request: WebSocketRequest, name: str): 18 | ... 19 | ``` 20 | 21 | You can add such a route `Router`, but the Router needs to be handled through a `Gateway` using the `RouterHandler`, and 22 | served through an ASGI webserver or twisted. 23 | 24 | With a tool like [websocat](https://github.com/vi/websocat), you could now connect to the websocket. 25 | 26 | ### Accepting or rejecting the connection 27 | 28 | The websocket connection needs to be either accepted or rejected via `WebSocketRequest`. 29 | When calling ``WebSocketRequest.accept``, an upgrade response will be sent to the client, and the protocol will be 30 | switched to the bidirectional WebSocket protocol. 31 | If ``WebSocketRequest.reject`` is called, the server immediately returns an HTTP response and closes the connection. 32 | 33 | You may want to do this when doing authorization for example: 34 | 35 | ```python 36 | def app(request: WebsocketRequest): 37 | # example: do authorization first 38 | auth = request.headers.get("Authorization") 39 | if not is_authorized(auth): 40 | request.reject(Response("no dice", 403)) 41 | return 42 | 43 | # then continue working with the websocket 44 | with request.accept() as websocket: 45 | websocket.send("hello world!") 46 | data = websocket.receive() 47 | # ... 48 | ``` 49 | 50 | ## Websocket object 51 | 52 | `WebSocketRequest.accept` also returns a `WebSocket` object, that can then be used to send and receive data 53 | 54 | You can explicitly call `WebSocket.receive`, or you can simply iterate over the `WebSocket` object. 55 | Here is an example: 56 | 57 | ```python 58 | 59 | from rolo import route 60 | from rolo.websocket import WebSocketRequest 61 | 62 | 63 | @route("/echo/", methods=["WEBSOCKET"]) 64 | def handler(request: WebSocketRequest, name: str): 65 | with request.accept() as websocket: 66 | websocket.send(f"thanks for connecting {name}") 67 | for line in websocket: 68 | websocket.send(f"echo: {line}") 69 | if line == "exit": 70 | websocket.send("ok bye!") 71 | return 72 | ``` 73 | -------------------------------------------------------------------------------- /docs/serving.md: -------------------------------------------------------------------------------- 1 | Serving 2 | ======= 3 | 4 | This guide shows you how to serve Rolo components through different Python web server technologies. 5 | 6 | WSGI 7 | ---- 8 | 9 | ### Serving a Router as WSGI app 10 | 11 | If you only need a `Router` instance to serve your application, you can convert to a WSGI app using the `Router.wsgi()` method. 12 | 13 | ```python 14 | from rolo import Router, route 15 | from rolo.routing import handler_dispatcher 16 | 17 | @route("/") 18 | def index(request): 19 | return "hello world" 20 | 21 | router = Router(dispatcher=handler_dispatcher()) 22 | router.add(index) 23 | 24 | app = router.wsgi() 25 | ``` 26 | 27 | Now you can use any old WSGI compliant server to serve the application. 28 | For example, if this file is stored in `myapp.py`, using gunicorn, you can: 29 | 30 | ```sh 31 | pip install gunicorn 32 | gunicorn -w 4 myapp:app 33 | ``` 34 | 35 | ### Serving a Gateway as WSGI app 36 | 37 | Unless you need Websockets, the Rolo Request object is fully WSGI compliant, so you can also use any WSGI server to serve a `Gateway`. 38 | Simply use the `WSGIGateway` adapter. 39 | 40 | ```python 41 | from rolo.gateway import Gateway 42 | from rolo.gateway.wsgi import WsgiGateway 43 | 44 | gateway: Gateway = ... 45 | 46 | app = WsgiGateway(gateway) 47 | ``` 48 | 49 | Similar to the previous example, you can serve the `app` object through any WSGI compliant server. 50 | 51 | ASGI 52 | ---- 53 | 54 | ASGI servers like Hypercorn allow asynchronous server communication, which is needed for HTTP/2 streaming or Websockets. 55 | Gateways can be served through the `AsgiGateway` adapter, which exposes a `Gateway` as an ASGI3 application. 56 | Under the hood, it uses our own ASGI/WSGI bridge (`AsgiAdapter`), and converts ASGI calls to WSGI calls for regular HTTP requests, and uses ASGI websockets for serving rolo websockets. 57 | File `myapp.py`: 58 | 59 | ```python 60 | from rolo.gateway import Gateway 61 | from rolo.gateway.asgi import AsgiGateway 62 | 63 | gateway: Gateway = ... 64 | 65 | app = AsgiGateway(gateway) 66 | ``` 67 | 68 | Now you can use Hypercorn or other ASGI servers to serve the `app` object. 69 | 70 | ```sh 71 | pip install hypercorn 72 | hypercorn myapp:app 73 | ``` 74 | 75 | Twisted 76 | ------- 77 | 78 | Rolo can be served through [Twisted](https://twisted.org/), which supports both WSGI and Websockets. 79 | You will need twisted, and wsproto installed `pip install twisted wsproto`. 80 | 81 | ```python 82 | from rolo.gateway import Gateway 83 | from rolo.serving.twisted import TwistedGateway 84 | from twisted.internet import endpoints, reactor 85 | 86 | gateway: Gateway = ... 87 | 88 | # Rolo/Twisted adapter, that exposes a Rolo Gateway as a twisted.web.server.Site object 89 | site = TwistedGateway(gateway) 90 | 91 | endpoint = endpoints.TCP4ServerEndpoint(reactor, 8000) 92 | endpoint.listen(site) 93 | 94 | reactor.run() 95 | ``` 96 | -------------------------------------------------------------------------------- /tests/test_response.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import pytest 4 | from werkzeug.exceptions import NotFound 5 | 6 | from rolo import Response 7 | from tests import static 8 | 9 | 10 | def test_for_resource_html(): 11 | response = Response.for_resource(static, "index.html") 12 | assert response.content_type == "text/html; charset=utf-8" 13 | assert response.get_data() == b'\nhello\n\n' 14 | assert response.status == "200 OK" 15 | 16 | 17 | def test_for_resource_txt(): 18 | response = Response.for_resource(static, "test.txt") 19 | assert response.content_type == "text/plain; charset=utf-8" 20 | assert response.get_data() == b"hello world\n" 21 | assert response.status == "200 OK" 22 | 23 | 24 | def test_for_resource_with_custom_response_status_and_headers(): 25 | response = Response.for_resource(static, "test.txt", status=201, headers={"X-Foo": "Bar"}) 26 | assert response.content_type == "text/plain; charset=utf-8" 27 | assert response.get_data() == b"hello world\n" 28 | assert response.status == "201 CREATED" 29 | assert response.headers.get("X-Foo") == "Bar" 30 | 31 | 32 | def test_for_resource_not_found(): 33 | with pytest.raises(NotFound): 34 | Response.for_resource(static, "doesntexist.txt") 35 | 36 | 37 | def test_for_json(): 38 | response = Response.for_json( 39 | {"foo": "bar", "420": 69, "isTrue": True}, 40 | ) 41 | assert response.content_type == "application/json" 42 | assert response.get_data() == b'{"foo": "bar", "420": 69, "isTrue": true}' 43 | assert response.status == "200 OK" 44 | 45 | 46 | def test_for_json_with_custom_response_status_and_headers(): 47 | response = Response.for_json( 48 | {"foo": "bar", "420": 69, "isTrue": True}, 49 | status=201, 50 | headers={"X-Foo": "Bar"}, 51 | ) 52 | assert response.content_type == "application/json" 53 | assert response.get_data() == b'{"foo": "bar", "420": 69, "isTrue": true}' 54 | assert response.status == "201 CREATED" 55 | assert response.headers.get("X-Foo") == "Bar" 56 | 57 | 58 | @pytest.mark.parametrize( 59 | argnames="data", 60 | argvalues=[ 61 | b"foobar", 62 | "foobar", 63 | io.BytesIO(b"foobar"), 64 | [b"foo", b"bar"], 65 | ], 66 | ) 67 | def test_set_response(data): 68 | response = Response() 69 | response.set_response(data) 70 | assert response.get_data() == b"foobar" 71 | 72 | 73 | def test_update_from(): 74 | original = Response( 75 | [b"foo", b"bar"], 202, headers={"X-Foo": "Bar"}, mimetype="application/octet-stream" 76 | ) 77 | 78 | response = Response() 79 | response.update_from(original) 80 | 81 | assert response.get_data() == b"foobar" 82 | assert response.status_code == 202 83 | assert response.headers.get("X-Foo") == "Bar" 84 | assert response.content_type == "application/octet-stream" 85 | -------------------------------------------------------------------------------- /rolo/gateway/asgi.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures.thread 3 | from asyncio import AbstractEventLoop 4 | from typing import Optional 5 | 6 | from rolo.asgi import ASGIAdapter, ASGILifespanListener 7 | from rolo.websocket.request import WebSocketRequest 8 | 9 | from .gateway import Gateway 10 | from .wsgi import WsgiGateway 11 | 12 | 13 | class _ThreadPool(concurrent.futures.thread.ThreadPoolExecutor): 14 | """ 15 | This thread pool executor removes the threads it creates from the global ``_thread_queues`` of 16 | ``concurrent.futures.thread``, which joins all created threads at python exit and will block interpreter shutdown if 17 | any threads are still running, even if they are daemon threads. 18 | """ 19 | 20 | def _adjust_thread_count(self) -> None: 21 | super()._adjust_thread_count() 22 | 23 | for t in self._threads: 24 | if not t.daemon: 25 | continue 26 | try: 27 | del concurrent.futures.thread._threads_queues[t] 28 | except KeyError: 29 | pass 30 | 31 | 32 | class AsgiGateway: 33 | """ 34 | Exposes a Gateway as an ASGI3 application. Under the hood, it uses a WsgiGateway with a threading async/sync bridge. 35 | """ 36 | 37 | default_thread_count = 1000 38 | 39 | gateway: Gateway 40 | 41 | def __init__( 42 | self, 43 | gateway: Gateway, 44 | event_loop: Optional[AbstractEventLoop] = None, 45 | threads: int = None, 46 | lifespan_listener: Optional[ASGILifespanListener] = None, 47 | websocket_listener=None, 48 | ) -> None: 49 | self.gateway = gateway 50 | 51 | self.event_loop = event_loop or asyncio.get_event_loop() 52 | self.executor = _ThreadPool( 53 | threads or self.default_thread_count, thread_name_prefix="asgi_gw" 54 | ) 55 | self.adapter = ASGIAdapter( 56 | WsgiGateway(gateway), 57 | event_loop=event_loop, 58 | executor=self.executor, 59 | lifespan_listener=lifespan_listener, 60 | websocket_listener=websocket_listener or WebSocketRequest.listener(gateway.accept), 61 | ) 62 | self._closed = False 63 | 64 | async def __call__(self, scope, receive, send) -> None: 65 | """ 66 | ASGI3 application interface. 67 | 68 | :param scope: the ASGI request scope 69 | :param receive: the receive callable 70 | :param send: the send callable 71 | """ 72 | if self._closed: 73 | raise RuntimeError("Cannot except new request on closed ASGIGateway") 74 | 75 | return await self.adapter(scope, receive, send) 76 | 77 | def close(self): 78 | """ 79 | Close the ASGIGateway by shutting down the underlying executor. 80 | """ 81 | self._closed = True 82 | self.executor.shutdown(wait=False, cancel_futures=True) 83 | -------------------------------------------------------------------------------- /rolo/gateway/handlers.py: -------------------------------------------------------------------------------- 1 | """Several gateway handlers""" 2 | import typing as t 3 | 4 | from werkzeug.datastructures import Headers 5 | from werkzeug.exceptions import HTTPException, NotFound 6 | 7 | from rolo.response import Response 8 | from rolo.routing import Router 9 | 10 | from .chain import HandlerChain, RequestContext 11 | 12 | 13 | class RouterHandler: 14 | """ 15 | Adapter to serve a ``Router`` as a ``Handler``. 16 | """ 17 | 18 | router: Router 19 | respond_not_found: bool 20 | 21 | def __init__(self, router: Router, respond_not_found: bool = False) -> None: 22 | self.router = router 23 | self.respond_not_found = respond_not_found 24 | 25 | def __call__(self, chain: HandlerChain, context: RequestContext, response: Response): 26 | try: 27 | router_response = self.router.dispatch(context.request) 28 | response.update_from(router_response) 29 | chain.stop() 30 | except NotFound: 31 | if self.respond_not_found: 32 | chain.respond(404, "not found") 33 | 34 | 35 | class EmptyResponseHandler: 36 | """ 37 | Handler that creates a default response if the response in the context is empty. 38 | """ 39 | 40 | status_code: int 41 | body: bytes 42 | headers: dict 43 | 44 | def __init__(self, status_code: int = 404, body: bytes = None, headers: Headers = None): 45 | self.status_code = status_code 46 | self.body = body or b"" 47 | self.headers = headers or Headers() 48 | 49 | def __call__(self, chain: HandlerChain, context: RequestContext, response: Response): 50 | if self.is_empty_response(response): 51 | self.populate_default_response(response) 52 | 53 | def is_empty_response(self, response: Response): 54 | return response.status_code in [0, None] and not response.response 55 | 56 | def populate_default_response(self, response: Response): 57 | response.status_code = self.status_code 58 | response.data = self.body 59 | response.headers.update(self.headers) 60 | 61 | 62 | class WerkzeugExceptionHandler: 63 | def __init__(self, output_format: t.Literal["json", "html"] = None) -> None: 64 | self.format = output_format or "json" 65 | 66 | def __call__( 67 | self, chain: HandlerChain, exception: Exception, context: RequestContext, response: Response 68 | ): 69 | if not isinstance(exception, HTTPException): 70 | return 71 | 72 | headers = Headers(exception.get_headers()) # FIXME 73 | headers.pop() 74 | 75 | if self.format == "html": 76 | chain.respond(status_code=exception.code, headers=headers, payload=exception.get_body()) 77 | elif self.format == "json": 78 | chain.respond( 79 | status_code=exception.code, 80 | headers=headers, 81 | payload={"code": exception.code, "description": exception.description}, 82 | ) 83 | else: 84 | raise ValueError(f"unknown rendering format {self.format}") 85 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | Getting started 2 | =============== 3 | 4 | ## Installation 5 | 6 | Rolo is hosted on [pypi](https://pypi.org/project/rolo/) and can be installed via pip. 7 | 8 | ```sh 9 | pip install rolo 10 | ``` 11 | 12 | ## Hello World 13 | 14 | Rolo provides different ways of building a web application. 15 | It provides familiar concepts such as Router and `@route`, but also more flexible concepts like a Handler Chain. 16 | 17 | ### Router 18 | 19 | Here is a simple [`Router`](router.md) that can be served as WSGI application using the Werkzeug dev server. 20 | If you are familiar with Flask, `@route` works in a similar way. 21 | 22 | ```python 23 | from werkzeug import Request 24 | from werkzeug.serving import run_simple 25 | 26 | from rolo import Router, route 27 | from rolo.routing import handler_dispatcher 28 | 29 | @route("/") 30 | def hello(request: Request): 31 | return {"message": "Hello World"} 32 | 33 | router = Router(dispatcher=handler_dispatcher()) 34 | router.add(hello) 35 | 36 | run_simple("localhost", 8000, router.wsgi()) 37 | ``` 38 | 39 | And to test: 40 | ```console 41 | curl localhost:8000/ 42 | ``` 43 | Should yield 44 | ```json 45 | {"message": "Hello World"} 46 | ``` 47 | 48 | `rolo.Request` and `rolo.Response` objects work in the same way as Werkzeug's [Request / Response](https://werkzeug.palletsprojects.com/en/latest/wrappers/) wrappers. 49 | 50 | ### Gateway 51 | 52 | A Gateway holds a set of handlers that are combined into a handler chain. 53 | Here is a simple example with a single request handler that dynamically creates a response object similar to httpbin. 54 | 55 | ```python 56 | from werkzeug.serving import run_simple 57 | 58 | from rolo import Response 59 | from rolo.gateway import Gateway, RequestContext, HandlerChain 60 | from rolo.gateway.wsgi import WsgiGateway 61 | 62 | 63 | def echo_handler(chain: HandlerChain, context: RequestContext, response: Response): 64 | response.status_code = 200 65 | response.set_json( 66 | { 67 | "method": context.request.method, 68 | "path": context.request.path, 69 | "query": context.request.args, 70 | "headers": dict(context.request.headers), 71 | } 72 | ) 73 | chain.stop() 74 | 75 | 76 | gateway = Gateway( 77 | request_handlers=[echo_handler], 78 | ) 79 | 80 | run_simple("localhost", 8000, WsgiGateway(gateway)) 81 | ``` 82 | 83 | And to test: 84 | ```console 85 | curl -s -X POST "localhost:8000/foo/bar?a=1&b=2" | jq . 86 | ``` 87 | Should give you: 88 | ```json 89 | { 90 | "method": "POST", 91 | "path": "/foo/bar", 92 | "query": { 93 | "a": "1", 94 | "b": "2" 95 | }, 96 | "headers": { 97 | "Host": "localhost:8000", 98 | "User-Agent": "curl/7.81.0", 99 | "Accept": "*/*" 100 | } 101 | } 102 | ``` 103 | 104 | ## Next Steps 105 | 106 | Learn how to 107 | * Use the [Router](router.md) 108 | * Use the [Handler Chain](handler_chain.md) 109 | * [Serve](serving.md) rolo through your favorite web server 110 | -------------------------------------------------------------------------------- /tests/test_dispatcher.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | from werkzeug.exceptions import NotFound 5 | 6 | from rolo import Request, Response, Router 7 | from rolo.routing import handler_dispatcher 8 | 9 | 10 | class TestHandlerDispatcher: 11 | def test_handler_dispatcher(self): 12 | router = Router(dispatcher=handler_dispatcher()) 13 | 14 | def handler_foo(_request: Request) -> Response: 15 | return Response("ok") 16 | 17 | def handler_bar(_request: Request, bar, baz) -> Response: 18 | response = Response() 19 | response.set_json({"bar": bar, "baz": baz}) 20 | return response 21 | 22 | router.add("/foo", handler_foo) 23 | router.add("/bar//", handler_bar) 24 | 25 | assert router.dispatch(Request("GET", "/foo")).data == b"ok" 26 | assert router.dispatch(Request("GET", "/bar/420/ed")).json == {"bar": 420, "baz": "ed"} 27 | 28 | with pytest.raises(NotFound): 29 | assert router.dispatch(Request("GET", "/bar/asfg/ed")) 30 | 31 | def test_handler_dispatcher_invalid_signature(self): 32 | router = Router(dispatcher=handler_dispatcher()) 33 | 34 | def handler(_request: Request, arg1) -> Response: # invalid signature 35 | return Response("ok") 36 | 37 | router.add("/foo//", handler) 38 | 39 | with pytest.raises(TypeError): 40 | router.dispatch(Request("GET", "/foo/a/b")) 41 | 42 | def test_handler_dispatcher_with_dict_return(self): 43 | router = Router(dispatcher=handler_dispatcher()) 44 | 45 | def handler(_request: Request, arg1) -> Dict[str, Any]: 46 | return {"arg1": arg1, "hello": "there"} 47 | 48 | router.add("/foo/", handler) 49 | assert router.dispatch(Request("GET", "/foo/a")).json == {"arg1": "a", "hello": "there"} 50 | 51 | def test_handler_dispatcher_with_list_return(self): 52 | router = Router(dispatcher=handler_dispatcher()) 53 | 54 | def handler(_request: Request, arg1) -> Dict[str, Any]: 55 | return [{"arg1": arg1, "hello": "there"}, 1, 2, "3", [4, 5]] 56 | 57 | router.add("/foo/", handler) 58 | assert router.dispatch(Request("GET", "/foo/a")).json == [ 59 | {"arg1": "a", "hello": "there"}, 60 | 1, 61 | 2, 62 | "3", 63 | [4, 5], 64 | ] 65 | 66 | def test_handler_dispatcher_with_text_return(self): 67 | router = Router(dispatcher=handler_dispatcher()) 68 | 69 | def handler(_request: Request, arg1) -> str: 70 | return f"hello: {arg1}" 71 | 72 | router.add("/", handler) 73 | assert router.dispatch(Request("GET", "/world")).data == b"hello: world" 74 | 75 | def test_handler_dispatcher_with_none_return(self): 76 | router = Router(dispatcher=handler_dispatcher()) 77 | 78 | def handler(_request: Request): 79 | return None 80 | 81 | router.add("/", handler) 82 | assert router.dispatch(Request("GET", "/")).status_code == 200 83 | -------------------------------------------------------------------------------- /tests/gateway/test_handlers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | from werkzeug import Request 4 | from werkzeug.exceptions import BadRequest 5 | 6 | from rolo import Response, Router 7 | from rolo.gateway import Gateway, HandlerChain, RequestContext 8 | from rolo.gateway.handlers import EmptyResponseHandler, RouterHandler, WerkzeugExceptionHandler 9 | 10 | 11 | def _echo_handler(request: Request, args): 12 | return Response.for_json( 13 | { 14 | "path": request.path, 15 | "method": request.method, 16 | "headers": dict(request.headers), 17 | } 18 | ) 19 | 20 | 21 | @pytest.mark.parametrize("serve_gateway", ["wsgi", "asgi", "twisted"], indirect=True) 22 | class TestWerkzeugExceptionHandler: 23 | def test_json_output_format(self, serve_gateway): 24 | def handler(chain: HandlerChain, context: RequestContext, response: Response): 25 | if context.request.method != "GET": 26 | raise BadRequest("oh noes") 27 | 28 | chain.respond(payload="ok") 29 | 30 | server = serve_gateway( 31 | Gateway( 32 | request_handlers=[ 33 | handler, 34 | ], 35 | exception_handlers=[ 36 | WerkzeugExceptionHandler(), 37 | ], 38 | ) 39 | ) 40 | 41 | resp = requests.get(server.url) 42 | assert resp.status_code == 200 43 | assert resp.text == "ok" 44 | 45 | resp = requests.post(server.url) 46 | assert resp.status_code == 400 47 | assert resp.json() == {"code": 400, "description": "oh noes"} 48 | 49 | 50 | @pytest.mark.parametrize("serve_gateway", ["wsgi", "asgi", "twisted"], indirect=True) 51 | class TestRouterHandler: 52 | def test_router_handler_with_respond_not_found(self, serve_gateway): 53 | router = Router() 54 | router.add("/foo", _echo_handler) 55 | 56 | server = serve_gateway( 57 | Gateway( 58 | request_handlers=[ 59 | RouterHandler(router, True), 60 | ], 61 | ) 62 | ) 63 | 64 | doc = requests.get(server.url + "/foo", headers={"Foo-Bar": "foobar"}).json() 65 | assert doc["path"] == "/foo" 66 | assert doc["method"] == "GET" 67 | assert doc["headers"]["Foo-Bar"] == "foobar" 68 | 69 | response = requests.get(server.url + "/bar") 70 | assert response.status_code == 404 71 | assert response.text == "not found" 72 | 73 | 74 | @pytest.mark.parametrize("serve_gateway", ["wsgi", "asgi", "twisted"], indirect=True) 75 | class TestEmptyResponseHandler: 76 | def test_empty_response_handler(self, serve_gateway): 77 | def _handler(chain, context, response): 78 | if context.request.method == "GET": 79 | chain.respond(202, "ok") 80 | else: 81 | response.status_code = 0 82 | 83 | server = serve_gateway( 84 | Gateway( 85 | request_handlers=[_handler], 86 | response_handlers=[EmptyResponseHandler(status_code=412, body=b"teapot?")], 87 | ) 88 | ) 89 | 90 | response = requests.get(server.url) 91 | assert response.text == "ok" 92 | assert response.status_code == 202 93 | 94 | response = requests.post(server.url) 95 | assert response.text == "teapot?" 96 | assert response.status_code == 412 97 | -------------------------------------------------------------------------------- /rolo/routing/pydantic.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing as t 3 | 4 | import pydantic 5 | 6 | from rolo.request import Request 7 | from rolo.response import Response 8 | 9 | from .handler import Handler, HandlerDispatcher, ResultValue 10 | from .router import RequestArguments 11 | 12 | 13 | def _get_model_argument(endpoint: Handler) -> t.Optional[tuple[str, t.Type[pydantic.BaseModel]]]: 14 | """ 15 | Inspects the endpoint function using Python reflection to find in its signature a ``pydantic.BaseModel`` attribute. 16 | 17 | :param endpoint: the endpoint to inspect 18 | :return: a tuple containing the name and class, or None 19 | """ 20 | if not inspect.isfunction(endpoint) and not inspect.ismethod(endpoint): 21 | # cannot yet dispatch to other callables (e.g. an object with a `__call__` method) 22 | return None 23 | 24 | # finds the first pydantic.BaseModel in the list of annotations. 25 | # ``def foo(request: Request, id: int, item: MyItem)`` would yield ``('my_item', MyItem)`` 26 | for arg_name, arg_type in endpoint.__annotations__.items(): 27 | if arg_name in ("self", "return"): 28 | continue 29 | if not inspect.isclass(arg_type): 30 | continue 31 | try: 32 | if issubclass(arg_type, pydantic.BaseModel): 33 | return arg_name, arg_type 34 | except TypeError: 35 | # FIXME: this is needed for Python 3.10 support 36 | continue 37 | 38 | return None 39 | 40 | 41 | def _try_parse_pydantic_request_body( 42 | request: Request, endpoint: Handler 43 | ) -> t.Optional[tuple[str, pydantic.BaseModel]]: 44 | arg = _get_model_argument(endpoint) 45 | 46 | if not arg: 47 | return 48 | 49 | arg_name, arg_type = arg 50 | 51 | if not request.content_length: 52 | # forces a ValidationError "Invalid JSON: EOF while parsing a value at line 1 column 0" 53 | arg_type.model_validate_json(b"") 54 | 55 | # will raise a werkzeug.BadRequest error if the JSON is invalid 56 | obj = request.get_json(force=True) 57 | 58 | return arg_name, arg_type.model_validate(obj) 59 | 60 | 61 | class PydanticHandlerDispatcher(HandlerDispatcher): 62 | """ 63 | Special HandlerDispatcher that knows how to serialize and deserialize pydantic models. 64 | """ 65 | 66 | def invoke_endpoint( 67 | self, 68 | request: Request, 69 | endpoint: t.Callable, 70 | request_args: RequestArguments, 71 | ) -> t.Any: 72 | # prepare request args 73 | try: 74 | arg = _try_parse_pydantic_request_body(request, endpoint) 75 | except pydantic.ValidationError as e: 76 | return Response(e.json(), mimetype="application/json", status=400) 77 | 78 | if arg: 79 | arg_name, model = arg 80 | request_args = {**request_args, arg_name: model} 81 | 82 | return super().invoke_endpoint(request, endpoint, request_args) 83 | 84 | def populate_response(self, response: Response, value: ResultValue): 85 | # try to convert any pydantic types to dicts before handing them to the parent implementation 86 | if isinstance(value, pydantic.BaseModel): 87 | value = value.model_dump() 88 | elif isinstance(value, (list, tuple)): 89 | converted = [] 90 | for element in value: 91 | if isinstance(element, pydantic.BaseModel): 92 | converted.append(element.model_dump()) 93 | else: 94 | converted.append(element) 95 | value = converted 96 | 97 | super().populate_response(response, value) 98 | -------------------------------------------------------------------------------- /rolo/routing/handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import typing as t 4 | 5 | from werkzeug import Response as WerkzeugResponse 6 | 7 | try: 8 | import pydantic # noqa 9 | 10 | ENABLE_PYDANTIC = True 11 | except ImportError: 12 | ENABLE_PYDANTIC = False 13 | 14 | from rolo.request import Request 15 | from rolo.response import Response 16 | 17 | from .router import Dispatcher, RequestArguments 18 | 19 | LOG = logging.getLogger(__name__) 20 | 21 | ResultValue = t.Union[ 22 | WerkzeugResponse, 23 | str, 24 | bytes, 25 | dict[str, t.Any], # a JSON dict 26 | list[t.Any], 27 | ] 28 | 29 | 30 | class Handler(t.Protocol): 31 | """ 32 | A protocol used by a ``Router`` together with the dispatcher created with ``handler_dispatcher``. Endpoints added 33 | with this protocol take as first argument the HTTP request object, and then as keyword arguments the request 34 | parameters added in the rule. This makes it work very similar to flask routes. 35 | 36 | Example code could look like this:: 37 | 38 | def my_route(request: Request, organization: str, repo: str): 39 | return {"something": "returned as json response"} 40 | 41 | router = Router(dispatcher=handler_dispatcher) 42 | router.add("//", endpoint=my_route) 43 | 44 | """ 45 | 46 | def __call__(self, request: Request, **kwargs) -> ResultValue: 47 | """ 48 | Handle the given request. 49 | 50 | :param request: the HTTP request object 51 | :param kwargs: the url request parameters 52 | :return: a string or bytes value, a dict to create a json response, or a raw werkzeug Response object. 53 | """ 54 | raise NotImplementedError 55 | 56 | 57 | class HandlerDispatcher: 58 | def __init__(self, json_encoder: t.Type[json.JSONEncoder] = None): 59 | self.json_encoder = json_encoder 60 | 61 | def __call__( 62 | self, request: Request, endpoint: t.Callable, request_args: RequestArguments 63 | ) -> Response: 64 | result = self.invoke_endpoint(request, endpoint, request_args) 65 | return self.to_response(result) 66 | 67 | def invoke_endpoint( 68 | self, 69 | request: Request, 70 | endpoint: t.Callable, 71 | request_args: RequestArguments, 72 | ) -> t.Any: 73 | return endpoint(request, **request_args) 74 | 75 | def to_response(self, value: ResultValue) -> Response: 76 | if isinstance(value, WerkzeugResponse): 77 | return value 78 | 79 | response = Response() 80 | if value is None: 81 | return response 82 | 83 | self.populate_response(response, value) 84 | return response 85 | 86 | def populate_response(self, response: Response, value: ResultValue): 87 | if isinstance(value, (str, bytes, bytearray)): 88 | response.data = value 89 | elif isinstance(value, (dict, list)): 90 | response.data = json.dumps(value, cls=self.json_encoder) 91 | response.mimetype = "application/json" 92 | else: 93 | raise ValueError("unhandled result type %s", type(value)) 94 | 95 | 96 | def handler_dispatcher(json_encoder: t.Type[json.JSONEncoder] = None) -> Dispatcher[Handler]: 97 | """ 98 | Creates a Dispatcher that treats endpoints like callables of the ``Handler`` Protocol. 99 | 100 | :param json_encoder: optionally the json encoder class to use for translating responses 101 | :return: a new dispatcher 102 | """ 103 | if ENABLE_PYDANTIC: 104 | from rolo.routing.pydantic import PydanticHandlerDispatcher 105 | 106 | return PydanticHandlerDispatcher(json_encoder) 107 | 108 | return HandlerDispatcher(json_encoder) 109 | -------------------------------------------------------------------------------- /rolo/routing/resource.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module enables the resource class pattern, where each respective ``on_`` method of a class is 3 | treated like an endpoint for the respective HTTP method. The following shows an example of how the pattern is used:: 4 | 5 | class Foo: 6 | def on_get(self, request: Request): 7 | return {"ok": "GET it"} 8 | 9 | def on_post(self, request: Request): 10 | return {"ok": "it was POSTed"} 11 | 12 | 13 | router = Router(dispatcher=resource_dispatcher()) 14 | router.add(Resource("/foo", Foo()) 15 | """ 16 | from typing import Any, Iterable, Optional, Type 17 | 18 | from werkzeug.routing import Map, Rule, RuleFactory 19 | 20 | from .router import route 21 | 22 | _resource_methods = [ 23 | "on_head", # it's important that HEAD rules are added first (werkzeug matching order) 24 | "on_get", 25 | "on_post", 26 | "on_put", 27 | "on_patch", 28 | "on_delete", 29 | "on_options", 30 | "on_trace", 31 | ] 32 | 33 | 34 | def resource(path: str, host: Optional[str] = None, **kwargs): 35 | """ 36 | Class decorator that turns every method that follows the pattern ``on_`` into a route, 37 | where the allowed method for that route is automatically set to the method specified in the function name. Example 38 | when using a Router with the ``handler_dispatcher``:: 39 | 40 | @resource("/myresource/") 41 | class MyResource: 42 | def on_get(request: Request, resource_id: str) -> Response: 43 | return Response(f"GET called on {resource_id}") 44 | 45 | def on_post(request: Request, resource_id: str) -> Response: 46 | return Response(f"POST called on {resource_id}") 47 | 48 | This class can then be added to a router via ``router.add_route_endpoints(MyResource())``. 49 | 50 | Note that, if an on_get method is present in the resource but on_head is not, then HEAD requests are automatically 51 | routed to ``on_get``. This replicates Werkzeug's behavior https://werkzeug.palletsprojects.com/en/2.2.x/routing/. 52 | 53 | :param path: the path pattern to match 54 | :param host: an optional host matching pattern. if not pattern is given, the rule matches any host 55 | :param kwargs: any other argument that can be passed to ``werkzeug.routing.Rule`` 56 | :return: a class where each matching function is wrapped as a ``_RouteEndpoint`` 57 | """ 58 | kwargs.pop("methods", None) 59 | 60 | def _wrapper(cls: Type): 61 | for name in _resource_methods: 62 | member = getattr(cls, name, None) 63 | if member is None: 64 | continue 65 | 66 | http_method = name[3:].upper() 67 | setattr(cls, name, route(path, host, methods=[http_method], **kwargs)(member)) 68 | 69 | return cls 70 | 71 | return _wrapper 72 | 73 | 74 | class Resource(RuleFactory): 75 | """ 76 | Exposes a given object that follows the "Resource" class pattern as a ``RuleFactory` that can then be added to a 77 | Router. Example use when using a Router with the ``handler_dispatcher``:: 78 | 79 | class MyResource: 80 | def on_get(request: Request, resource_id: str) -> Response: 81 | return Response(f"GET called on {resource_id}") 82 | 83 | def on_post(request: Request, resource_id: str) -> Response: 84 | return Response(f"POST called on {resource_id}") 85 | 86 | router.add(Resource("/myresource/", MyResource())) 87 | 88 | Note that, if an on_get method is present in the resource but on_head is not, then HEAD requests are automatically 89 | routed to ``on_get``. This replicates Werkzeug's behavior https://werkzeug.palletsprojects.com/en/2.2.x/routing/. 90 | """ 91 | 92 | def __init__(self, path: str, obj: Any, host: Optional[str] = None, **kwargs): 93 | self.path = path 94 | self.obj = obj 95 | self.host = host 96 | self.kwargs = kwargs 97 | 98 | def get_rules(self, map: Map) -> Iterable[Rule]: 99 | rules = [] 100 | for name in _resource_methods: 101 | member = getattr(self.obj, name, None) 102 | if member is None: 103 | continue 104 | 105 | http_method = name[3:].upper() 106 | rules.append( 107 | Rule( 108 | self.path, endpoint=member, methods=[http_method], host=self.host, **self.kwargs 109 | ) 110 | ) 111 | return rules 112 | -------------------------------------------------------------------------------- /examples/json-rpc/server.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import logging 4 | from typing import Callable 5 | 6 | from werkzeug.exceptions import BadRequest 7 | from werkzeug.serving import run_simple 8 | 9 | from rolo import Response 10 | from rolo.gateway import Gateway, HandlerChain, RequestContext 11 | from rolo.gateway.wsgi import WsgiGateway 12 | 13 | LOG = logging.getLogger(__name__) 14 | 15 | 16 | @dataclasses.dataclass 17 | class RpcRequest: 18 | jsonrpc: str 19 | method: str 20 | id: str | int | None 21 | params: dict | list | None = None 22 | 23 | 24 | class RpcError(Exception): 25 | code: int 26 | message: str 27 | 28 | 29 | class ParseError(RpcError): 30 | code = -32700 31 | message = "Parse error" 32 | 33 | 34 | class InvalidRequest(RpcError): 35 | code = -32600 36 | message = "Invalid params" 37 | 38 | 39 | class MethodNotFoundError(RpcError): 40 | code = -32601 41 | message = "Method not found" 42 | 43 | 44 | class InternalError(RpcError): 45 | code = -32603 46 | message = "Internal error" 47 | 48 | 49 | def parse_request(chain: HandlerChain, context: RequestContext, response: Response): 50 | context.rpc_request_id = None 51 | 52 | try: 53 | doc = context.request.get_json() 54 | except BadRequest as e: 55 | raise ParseError() from e 56 | 57 | try: 58 | context.rpc_request_id = doc["id"] 59 | context.rpc_request = RpcRequest( 60 | doc["jsonrpc"], 61 | doc["method"], 62 | doc["id"], 63 | doc.get("params"), 64 | ) 65 | except KeyError as e: 66 | raise ParseError() from e 67 | 68 | 69 | def log_request(chain: HandlerChain, context: RequestContext, response: Response): 70 | if context.rpc_request: 71 | LOG.info("RPC request object: %s", context.rpc_request) 72 | 73 | 74 | def serialize_rpc_error( 75 | chain: HandlerChain, 76 | exception: Exception, 77 | context: RequestContext, 78 | response: Response, 79 | ): 80 | if not isinstance(exception, RpcError): 81 | return 82 | 83 | response.set_json( 84 | { 85 | "jsonrpc": "2.0", 86 | "error": {"code": exception.code, "message": exception.message}, 87 | "id": context.rpc_request_id, 88 | } 89 | ) 90 | 91 | 92 | def log_exception( 93 | chain: HandlerChain, 94 | exception: Exception, 95 | context: RequestContext, 96 | response: Response, 97 | ): 98 | LOG.error("Exception in handler chain", exc_info=exception) 99 | 100 | 101 | class Registry: 102 | methods: dict[str, Callable] 103 | 104 | def __init__(self, methods: dict[str, Callable]): 105 | self.methods = methods 106 | 107 | def __call__( 108 | self, chain: HandlerChain, context: RequestContext, response: Response 109 | ): 110 | try: 111 | context.method = self.methods[context.rpc_request.method] 112 | except KeyError as e: 113 | raise MethodNotFoundError() from e 114 | 115 | 116 | def dispatch(chain: HandlerChain, context: RequestContext, response: Response): 117 | request: RpcRequest = context.rpc_request 118 | 119 | if isinstance(request.params, list): 120 | args = request.params 121 | kwargs = {} 122 | elif isinstance(request.params, dict): 123 | args = [] 124 | kwargs = request.params 125 | else: 126 | raise InvalidRequest() 127 | 128 | try: 129 | context.result = context.method(*args, **kwargs) 130 | except RpcError: 131 | # if the method raises an RpcError, just re-raise it since it will be handled later 132 | raise 133 | except Exception as e: 134 | # all other exceptions are considered unhandled and therefore "Internal" 135 | raise InternalError() from e 136 | 137 | 138 | def serialize_result(chain: HandlerChain, context: RequestContext, response: Response): 139 | if not context.rpc_request_id: 140 | # this is a notification, so we don't want to respond 141 | return 142 | 143 | response.set_json( 144 | { 145 | "jsonrpc": "2.0", 146 | "result": json.dumps(context.result), 147 | "id": context.rpc_request_id, 148 | } 149 | ) 150 | 151 | 152 | def main(): 153 | logging.basicConfig(level=logging.DEBUG) 154 | 155 | def subtract(subtrahend: int, minuend: int): 156 | return subtrahend - minuend 157 | 158 | locate_method = Registry( 159 | { 160 | "subtract": subtract, 161 | } 162 | ) 163 | 164 | gateway = Gateway( 165 | request_handlers=[ 166 | parse_request, 167 | log_request, 168 | locate_method, 169 | dispatch, 170 | ], 171 | exception_handlers=[ 172 | log_exception, 173 | serialize_rpc_error, 174 | ], 175 | ) 176 | 177 | run_simple("localhost", 8000, WsgiGateway(gateway)) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /tests/test_resource.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from werkzeug.exceptions import MethodNotAllowed 3 | 4 | from rolo import Request, Resource, Response, Router, resource 5 | from rolo.routing import handler_dispatcher 6 | 7 | 8 | class TestResource: 9 | def test_resource_decorator_dispatches_correctly(self): 10 | router = Router(dispatcher=handler_dispatcher()) 11 | 12 | requests = [] 13 | 14 | @resource("/_localstack/health") 15 | class TestResource: 16 | def on_get(self, req): 17 | requests.append(req) 18 | return "GET/OK" 19 | 20 | def on_post(self, req): 21 | requests.append(req) 22 | return {"ok": "POST"} 23 | 24 | def on_head(self, req): 25 | # this is ignored 26 | requests.append(req) 27 | return "HEAD/OK" 28 | 29 | router.add(TestResource()) 30 | 31 | request1 = Request("GET", "/_localstack/health") 32 | request2 = Request("POST", "/_localstack/health") 33 | request3 = Request("HEAD", "/_localstack/health") 34 | assert router.dispatch(request1).get_data(True) == "GET/OK" 35 | assert router.dispatch(request1).get_data(True) == "GET/OK" 36 | assert router.dispatch(request2).json == {"ok": "POST"} 37 | assert router.dispatch(request3).get_data(True) == "HEAD/OK" 38 | assert len(requests) == 4 39 | assert requests[0] is request1 40 | assert requests[1] is request1 41 | assert requests[2] is request2 42 | assert requests[3] is request3 43 | 44 | def test_resource_dispatches_correctly(self): 45 | router = Router(dispatcher=handler_dispatcher()) 46 | 47 | class TestResource: 48 | def on_get(self, req): 49 | return "GET/OK" 50 | 51 | def on_post(self, req): 52 | return "POST/OK" 53 | 54 | def on_head(self, req): 55 | return "HEAD/OK" 56 | 57 | router.add(Resource("/_localstack/health", TestResource())) 58 | 59 | request1 = Request("GET", "/_localstack/health") 60 | request2 = Request("POST", "/_localstack/health") 61 | request3 = Request("HEAD", "/_localstack/health") 62 | assert router.dispatch(request1).get_data(True) == "GET/OK" 63 | assert router.dispatch(request2).get_data(True) == "POST/OK" 64 | assert router.dispatch(request3).get_data(True) == "HEAD/OK" 65 | 66 | def test_dispatch_to_non_existing_method_raises_exception(self): 67 | router = Router(dispatcher=handler_dispatcher()) 68 | 69 | @resource("/_localstack/health") 70 | class TestResource: 71 | def on_post(self, request): 72 | return "POST/OK" 73 | 74 | router.add(TestResource()) 75 | 76 | with pytest.raises(MethodNotAllowed): 77 | assert router.dispatch(Request("GET", "/_localstack/health")) 78 | assert router.dispatch(Request("POST", "/_localstack/health")).get_data(True) == "POST/OK" 79 | 80 | def test_resource_with_default_dispatcher(self): 81 | router = Router() 82 | 83 | @resource("/_localstack/") 84 | class TestResource: 85 | def on_get(self, req, args): 86 | return Response.for_json({"message": "GET/OK", "path": args["path"]}) 87 | 88 | def on_post(self, req, args): 89 | return Response.for_json({"message": "POST/OK", "path": args["path"]}) 90 | 91 | router.add(TestResource()) 92 | assert router.dispatch(Request("GET", "/_localstack/health")).json == { 93 | "message": "GET/OK", 94 | "path": "health", 95 | } 96 | assert router.dispatch(Request("POST", "/_localstack/foobar")).json == { 97 | "message": "POST/OK", 98 | "path": "foobar", 99 | } 100 | 101 | def test_resource_overwrite_with_resource_wrapper(self): 102 | router = Router(dispatcher=handler_dispatcher()) 103 | 104 | @resource("/_localstack/health") 105 | class TestResourceHealth: 106 | def on_get(self, req): 107 | return Response.for_json({"message": "GET/OK", "path": req.path}) 108 | 109 | def on_post(self, req): 110 | return Response.for_json({"message": "POST/OK", "path": req.path}) 111 | 112 | endpoints = TestResourceHealth() 113 | router.add(endpoints) 114 | router.add(Resource("/health", endpoints)) 115 | 116 | assert router.dispatch(Request("GET", "/_localstack/health")).json == { 117 | "message": "GET/OK", 118 | "path": "/_localstack/health", 119 | } 120 | assert router.dispatch(Request("POST", "/_localstack/health")).json == { 121 | "message": "POST/OK", 122 | "path": "/_localstack/health", 123 | } 124 | 125 | assert router.dispatch(Request("GET", "/health")).json == { 126 | "message": "GET/OK", 127 | "path": "/health", 128 | } 129 | assert router.dispatch(Request("POST", "/health")).json == { 130 | "message": "POST/OK", 131 | "path": "/health", 132 | } 133 | -------------------------------------------------------------------------------- /docs/router.md: -------------------------------------------------------------------------------- 1 | Router 2 | ====== 3 | 4 | Routers are based on Werkzeug's [URL Map](https://werkzeug.palletsprojects.com/en/2.3.x/routing/), but dispatch to handler functions directly. 5 | All features from Werkzeug's URL routing are inherited, including the [rule string format](https://werkzeug.palletsprojects.com/en/latest/routing/#rule-format) and [type converters](https://werkzeug.palletsprojects.com/en/latest/routing/#built-in-converters). 6 | 7 | `@route` 8 | -------- 9 | 10 | The `@route` decorator works similar to Flask or FastAPI, but they are not tied to an Application object. 11 | Instead, you can define routes on functions or methods, and then add them directly to the router. 12 | 13 | ```python 14 | from rolo import Router, route, Response 15 | from werkzeug import Request 16 | from werkzeug.serving import run_simple 17 | 18 | @route("/users") 19 | def list_users(_request: Request, args): 20 | assert not args 21 | return Response("user") 22 | 23 | @route("/users/") 24 | def get_user_by_id(_request: Request, args): 25 | assert args 26 | return Response(f"{args['user_id']}") 27 | 28 | router = Router() 29 | router.add(list_users) 30 | router.add(get_user_by_id) 31 | 32 | # convert Router to a WSGI app and serve it through werkzeug 33 | run_simple('localhost', 8080, router.wsgi(), use_reloader=True) 34 | ``` 35 | 36 | Depending on the _dispatcher_ your Router uses, the signature of your endpoints will look differently. 37 | 38 | Handler dispatcher 39 | ------------------ 40 | 41 | Routers use dispatchers to dispatch the request to functions. 42 | In the previous example, the default dispatcher calls the function with the `Request` object and the request arguments as dictionary. 43 | The "handler dispatcher" can transform functions into more Flask or FastAPI-like functions, that also allow you to return values that are automatically transformed. 44 | 45 | ```python 46 | from rolo import Router, route 47 | from rolo.routing import handler_dispatcher 48 | 49 | from werkzeug import Request 50 | from werkzeug.serving import run_simple 51 | 52 | @route("/users") 53 | def list_users(request: Request): 54 | # query from db using the ?q= query string 55 | query = request.args["q"] 56 | # ... 57 | return [{"user_id": ...}, ...] 58 | 59 | @route("/users/") 60 | def get_user_by_id(_request: Request, user_id: int): 61 | return {"user_id": user_id, "name": ...} 62 | 63 | router = Router(dispatcher=handler_dispatcher()) 64 | router.add(list_users) 65 | router.add(get_user_by_id) 66 | 67 | # convert Router to a WSGI app and serve it through werkzeug 68 | run_simple('localhost', 8080, router.wsgi(), use_reloader=True) 69 | ``` 70 | 71 | Using classes 72 | ------------- 73 | 74 | Unlike Flask or FastAPI, Rolo allows you to use classes to organize your routes. 75 | The above example can also be written as follows 76 | 77 | ```python 78 | from rolo import Router, route, Request 79 | from rolo.routing import handler_dispatcher 80 | 81 | class UserResource: 82 | 83 | @route("/users/") 84 | def list_users(self, _request: Request): 85 | return "user" 86 | 87 | @route("/users/") 88 | def get_user_by_id(self, _request: Request, user_id: int): 89 | return f"{user_id}" 90 | 91 | router = Router(dispatcher=handler_dispatcher()) 92 | router.add(UserResource()) 93 | ``` 94 | The router will scan the instantiated `UserResource` for `@route` decorators, and add them automatically. 95 | 96 | Resource classes 97 | ---------------- 98 | 99 | If you prefer the RESTful style that `Falcon `_ implements, you can use the `@resource` decorator on a class. 100 | This will automatically create routes for all `on_` methods. 101 | Here is an example 102 | 103 | 104 | ```python 105 | from rolo import Router, resource, Request 106 | 107 | @resource("/users/") 108 | class UserResource: 109 | 110 | def on_get(self, request: Request, user_id: int): 111 | return {"user_id": user_id, "user": ...} 112 | 113 | def on_post(self, request: Request, user_id: int): 114 | data = request.json 115 | # ... do something 116 | 117 | router = Router() 118 | router.add(UserResource()) 119 | ``` 120 | 121 | Pydantic integration 122 | -------------------- 123 | 124 | Here's how the default example from the FastAPI documentation would look like with rolo: 125 | 126 | ```python 127 | import pydantic 128 | 129 | from rolo import Request, Router, route 130 | 131 | 132 | class Item(pydantic.BaseModel): 133 | name: str 134 | price: float 135 | is_offer: bool | None = None 136 | 137 | 138 | @route("/", methods=["GET"]) 139 | def read_root(request: Request): 140 | return {"Hello": "World"} 141 | 142 | 143 | @route("/items/", methods=["GET"]) 144 | def read_item(request: Request, item_id: int): 145 | return {"item_id": item_id, "q": request.query_string} 146 | 147 | 148 | @route("/items/", methods=["PUT"]) 149 | def update_item(request: Request, item_id: int, item: Item): 150 | return {"item_name": item.name, "item_id": item_id} 151 | 152 | 153 | router = Router() 154 | router.add(read_root) 155 | router.add(read_item) 156 | router.add(update_item) 157 | ``` 158 | -------------------------------------------------------------------------------- /docs/handler_chain.md: -------------------------------------------------------------------------------- 1 | Handler Chain 2 | ============= 3 | 4 | The rolo handler chain implements a variant of the chain-of-responsibility pattern to process an incoming HTTP request. 5 | It is meant to be used together with a [`Gateway`](gateway.md), which is responsible for creating `HandlerChain` instances. 6 | 7 | Handler chains are a powerful abstraction to create complex HTTP server behavior, while keeping code cleanly encapsulated and the high-level logic easy to understand. 8 | You can find a simple example how to create a handler chain in the [Getting Started](getting_started.md) guide. 9 | 10 | ## Behavior 11 | 12 | A handler chain consists of: 13 | * request handlers: process the request and attempt to create an initial response 14 | * response handlers: process the response 15 | * finalizers: handlers that are always executed at the end of running a handler chain 16 | * exception handlers: run when an exception occurs during the execution of a handler 17 | 18 | Each HTTP request coming into the server has its own `HandlerChain` instance, since the handler chain holds state for the handling of a request. 19 | A handler chain can be in three states that can be controlled by the handlers. 20 | 21 | * Running - the implicit state in which _all_ handlers are executed sequentially 22 | * Stopped - a handler has called `chain.stop()`. This stops the execution of all request handlers, and 23 | proceeds immediately to executing the response handlers. Response handlers and finalizers will be run, 24 | even if the chain has been stopped. 25 | * Terminated - a handler has called `chain.terminate()`. This stops the execution of all request 26 | handlers, and all response handlers, but runs the finalizers at the end. 27 | 28 | If an exception occurs during the execution of request handlers, the chain by default stops the chain, 29 | then runs each exception handler, and finally runs the response handlers. 30 | Exceptions that happen during the execution of response or exception handlers are logged but do not modify the control flow of the chain. 31 | 32 | ## Request Context 33 | 34 | The `RequestContext` object holds the HTTP `Request` object in `context.request`, as well as any arbitrary data you would like to pass down to other handlers. 35 | It's a universal attribute store, so you can simply call: `context.myattr = "foo"` to set a value. 36 | You can add type hints for your request context, see [gateway](gateway.md). 37 | 38 | ## Handlers 39 | 40 | Request handlers, response handlers, and finalizers need to satisfy the `Handler` protocol: 41 | 42 | ```python 43 | from rolo import Response 44 | from rolo.gateway import HandlerChain, RequestContext 45 | 46 | def handle(chain: HandlerChain, context: RequestContext, response: Response): 47 | ... 48 | ``` 49 | 50 | * `chain`: the HandlerChain instance currently being executed. The handler implementation can call for example `chain.stop()` to indicate that it should skip all other request handlers. 51 | * `context`: the RequestContext contains the rolo `Request` object, as well as a universal property store. You can simply call `context.myattr = ...` to pass a value down to the next handler 52 | * `response`: Handlers of a handler chain don't return a response, instead the response being populated is handed down from handler to handler, and can thus be enriched 53 | 54 | ### Exception Handlers 55 | 56 | Exception handlers are similar, only they are also passed the `Exception` that was raised in the handler chain. 57 | 58 | ```python 59 | from rolo import Response 60 | from rolo.gateway import HandlerChain, RequestContext 61 | 62 | def handle(chain: HandlerChain, exception: Exception, context: RequestContext, response: Response): 63 | ... 64 | ``` 65 | 66 | ## Builtin Handlers 67 | 68 | ### Router handler 69 | 70 | Sometimes you have a `Gateway` but also want to use the [`Router`](router.md). 71 | You can use the `RouterHandler` adapter to make a `Router` look like a handler chain `Handler`, and then pass it as handler to a Gateway. 72 | 73 | ```python 74 | from rolo import Router 75 | from rolo.gateway import Gateway 76 | from rolo.gateway.handlers import RouterHandler 77 | 78 | router: Router = ... 79 | gateway: Gateway = ... 80 | 81 | gateway.request_handlers.append(RouterHandler(router)) 82 | ``` 83 | 84 | ### Empty response handler 85 | 86 | With the `EmptyResponseHandler` response handler automatically creates a default response if the response in the chain is empty. 87 | By default, it creates an empty 404 response, but it can be customized: 88 | 89 | ```python 90 | from rolo.gateway.handlers import EmptyResponseHandler 91 | 92 | gateway.response_handlers.append(EmptyResponseHandler(status_code=404, body=b'404 Not Found')) 93 | ``` 94 | 95 | ### Werkzeug exception handler 96 | 97 | Werkzeug has a very useful [HTTP exception hierarchy](https://werkzeug.palletsprojects.com/en/latest/exceptions/) that can be used to programmatically trigger HTTP errors. 98 | For instance, a request handler may raise a `NotFound` error. 99 | To get the Gateway to automatically handle those exceptions and render them into JSON objects or HTML, you can use the `WerkzeugExceptionHandler`. 100 | 101 | ```python 102 | from rolo.gateway.handlers import WerkzeugExceptionHandler 103 | 104 | gateway.exception_handlers.append(WerkzeugExceptionHandler(output_format="json")) 105 | ``` 106 | 107 | In your request handler you can now raise any exception from `werkzeug.exceptions` and it will be rendered accordingly. 108 | -------------------------------------------------------------------------------- /rolo/websocket/adapter.py: -------------------------------------------------------------------------------- 1 | """Adapter API between high-level rolo.websocket.request and an underlying IO frameworks like ASGI or 2 | twisted.""" 3 | import dataclasses 4 | import typing as t 5 | 6 | from werkzeug.datastructures import Headers 7 | 8 | WebSocketEnvironment: t.TypeAlias = t.Dict[str, t.Any] 9 | """Special WSGIEnvironment that has a ``rolo.websocket`` key that stores a `Websocket` instance.""" 10 | 11 | 12 | class Event: 13 | """A websocket event (subset of ``wsproto.events``).""" 14 | 15 | pass 16 | 17 | 18 | @dataclasses.dataclass 19 | class Message(Event): 20 | data: bytes | str 21 | 22 | 23 | @dataclasses.dataclass 24 | class TextMessage(Message): 25 | data: str 26 | 27 | 28 | @dataclasses.dataclass 29 | class BytesMessage(Message): 30 | data: bytes 31 | 32 | 33 | @dataclasses.dataclass 34 | class CreateConnection(Event): 35 | """ 36 | This indicates the first event of the websocket after a connection upgrade. For example, in wsproto 37 | this corresponds to a ``Request`` event, or ``websocket.connect`` event in ASGI. 38 | """ 39 | 40 | pass 41 | 42 | 43 | @dataclasses.dataclass 44 | class AcceptConnection(Event): 45 | subprotocol: t.Optional[str] = None 46 | extensions: list[str] = dataclasses.field(default_factory=list) 47 | extra_headers: list[tuple[bytes, bytes]] = dataclasses.field(default_factory=list) 48 | 49 | 50 | class WebSocketAdapter: 51 | """ 52 | Adapter to plug the high-level interfaces ``WebSocket`` and ``WebSocketRequest`` into an IO framework. 53 | It doesn't cover the full websocket protocol API (for instance there are no Ping/Pong events), 54 | under the assumption that the lower-level IO framework will abstract them away. 55 | """ 56 | 57 | def accept( 58 | self, 59 | subprotocol: str = None, 60 | extensions: list[str] = None, 61 | extra_headers: Headers = None, 62 | timeout: float = None, 63 | ): 64 | """ 65 | Accept the websocket upgrade request and send an accept message back to the client. This or 66 | ``reject`` must be the first things to be called. 67 | 68 | :param subprotocol: the accepted subprotocol 69 | :param extensions: any accepted extensions to use 70 | :param extra_headers: headers to pass to the accept response 71 | :param timeout: optional timeout 72 | """ 73 | raise NotImplementedError 74 | 75 | def reject( 76 | self, 77 | status_code: int, 78 | headers: Headers = None, 79 | body: t.Iterable[bytes] = None, 80 | timeout: float = None, 81 | ): 82 | """ 83 | Reject the websocket request. This means sending an actual HTTP response back to the client, i.e., 84 | not upgrading the connection. This only makes sense before any call to ``receive`` was made. 85 | 86 | :param status_code: the HTTP response status code 87 | :param headers: the HTTP response headers 88 | :param body: the body 89 | :param timeout: optional timeout 90 | """ 91 | raise NotImplementedError 92 | 93 | def receive( 94 | self, 95 | timeout: float = None, 96 | ) -> CreateConnection | Message: 97 | """Blocking IO method to wait for the next ``Message`` or, if not initialized yet, the first 98 | ``CreateConnection`` event.""" 99 | raise NotImplementedError 100 | 101 | def send( 102 | self, 103 | event: Message, 104 | timeout: float = None, 105 | ): 106 | """ 107 | Send the given message to the websocket. 108 | 109 | :param event: the message to send 110 | :param timeout: optional timeout 111 | :return: 112 | """ 113 | raise NotImplementedError 114 | 115 | def close(self, code: int = 1001, reason: str = None, timeout: float = None): 116 | """ 117 | If the underlying websocket connection has already been closed, this call is ignore, so it's safe 118 | to always call. 119 | """ 120 | raise NotImplementedError 121 | 122 | 123 | class WebSocketListener(t.Protocol): 124 | """ 125 | Similar protocol to a WSGIApplication, only it expects a Websocket instead of a WSGIEnvironment. 126 | """ 127 | 128 | def __call__(self, environ: WebSocketEnvironment): 129 | """ 130 | Called when a new Websocket connection is established. To initiate the connection, you need to perform 131 | the connect handshake yourself. First, receive the ``websocket.connect`` event, and then send the 132 | ``websocket.accept`` event. Here's a minimal example:: 133 | 134 | def accept(self, environ: WebsocketEnvironment): 135 | websocket: WebSocketAdapter = environ['rolo.websocket'] 136 | event = websocket.receive() 137 | if isinstance(event, CreateConnection): 138 | websocket.accept() 139 | else: 140 | websocket.close(1002) # protocol error 141 | return 142 | 143 | while True: 144 | event = websocket.receive() 145 | if isinstance(event, CloseConnection): 146 | return 147 | print(event) 148 | 149 | In reality, you wouldn't be using the websocket adapter directly, the server would probably create a 150 | ``rolo.websocket.WebSocketRequest`` and serve it accordingly through a ``Gateway``. 151 | 152 | :param environ: The new Websocket environment 153 | """ 154 | raise NotImplementedError 155 | -------------------------------------------------------------------------------- /rolo/response.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mimetypes 3 | import typing as t 4 | from importlib import resources 5 | 6 | from werkzeug.exceptions import NotFound 7 | from werkzeug.wrappers import Response as WerkzeugResponse 8 | 9 | if t.TYPE_CHECKING: 10 | from types import ModuleType 11 | 12 | 13 | class _StreamIterableWrapper(t.Iterable[bytes]): 14 | """ 15 | This can wrap an IO[bytes] stream to return an Iterable with a default chunk size of 65536 bytes 16 | """ 17 | 18 | def __init__(self, stream: t.IO[bytes], chunk_size: int = 65536): 19 | self.stream = stream 20 | self._chunk_size = chunk_size 21 | 22 | def __iter__(self) -> t.Iterator[bytes]: 23 | """ 24 | When passing a stream back to the WSGI server, it will often iterate only 1 byte at a time. Using this chunking 25 | mechanism allows us to bypass this issue. 26 | The caller needs to call `close()` to properly close the file descriptor 27 | :return: 28 | """ 29 | while data := self.stream.read(self._chunk_size): 30 | if not data: 31 | return b"" 32 | 33 | yield data 34 | 35 | def close(self): 36 | if hasattr(self.stream, "close"): 37 | self.stream.close() 38 | 39 | 40 | class Response(WerkzeugResponse): 41 | """ 42 | An HTTP Response object, which simply extends werkzeug's Response object with a few convenience methods. 43 | """ 44 | 45 | def update_from(self, other: WerkzeugResponse): 46 | """ 47 | Updates this response object with the data from the given response object. It reads the status code, 48 | the response data, and updates its own headers (overwrites existing headers, but does not remove ones 49 | not present in the given object). Also updates ``call_on_close`` callbacks in the same way. 50 | 51 | :param other: the response object to read from 52 | """ 53 | self.status_code = other.status_code 54 | self.response = other.response 55 | self._on_close.extend(other._on_close) 56 | self.headers.update(other.headers) 57 | 58 | def set_json(self, doc: t.Any, cls: t.Type[json.JSONEncoder] = None): 59 | """ 60 | Serializes the given dictionary using localstack's ``CustomEncoder`` into a json response, and sets the 61 | mimetype automatically to ``application/json``. 62 | 63 | :param doc: the response dictionary to be serialized as JSON 64 | :param cls: the JSON encoder class to use for serializing the passed document 65 | """ 66 | self.data = json.dumps(doc, cls=cls) 67 | self.mimetype = "application/json" 68 | 69 | def set_response(self, response: t.Union[str, bytes, bytearray, t.Iterable[bytes]]): 70 | """ 71 | Function to set the low-level ``response`` object. This is copied from the werkzeug Response constructor. The 72 | response attribute always holds an iterable of bytes. Passing a str, bytes or bytearray is equivalent to 73 | calling ``response.data = ``. If None is passed, then it will create an empty list. If anything 74 | else is passed, the value is set directly. This value can be a list of bytes, and iterator that returns bytes 75 | (e.g., a generator), which can be used by the underlying server to stream responses to the client. Anything else 76 | (like passing dicts) will result in errors at lower levels of the server. 77 | 78 | :param response: the response value 79 | """ 80 | if response is None: 81 | self.response = [] 82 | elif isinstance(response, (str, bytes, bytearray)): 83 | self.data = response 84 | else: 85 | self.response = response 86 | 87 | return self 88 | 89 | def to_readonly_response_dict(self) -> t.Dict: 90 | """ 91 | Returns a read-only version of a response dictionary as it is often expected by other libraries like boto. 92 | """ 93 | return { 94 | "body": self.stream if self.is_streamed else self.data, 95 | "status_code": self.status_code, 96 | "headers": dict(self.headers), 97 | } 98 | 99 | @classmethod 100 | def for_json(cls, doc: t.Any, *args, **kwargs) -> "Response": 101 | """ 102 | Creates a new JSON response from the given document. It automatically sets the mimetype to ``application/json``. 103 | 104 | :param doc: the document to serialize into JSON 105 | :param args: arguments passed to the ``Response`` constructor 106 | :param kwargs: keyword arguments passed to the ``Response`` constructor 107 | :return: a new Response object 108 | """ 109 | response = cls(*args, **kwargs) 110 | response.set_json(doc) 111 | return response 112 | 113 | @classmethod 114 | def for_resource(cls, module: "ModuleType", path: str, *args, **kwargs) -> "Response": 115 | """ 116 | Looks up the given file in the given module, and creates a new Response object with the contents of that 117 | file. It guesses the mimetype of the file and sets it in the response accordingly. If the file does not exist 118 | ,it raises a ``NotFound`` error. 119 | 120 | :param module: the module to look up the file in 121 | :param path: the path/file name 122 | :return: a new Response object 123 | """ 124 | resource = resources.files(module).joinpath(path) 125 | if not resource.is_file(): 126 | raise NotFound() 127 | mimetype = mimetypes.guess_type(resource.name) 128 | mimetype = mimetype[0] if mimetype and mimetype[0] else "application/octet-stream" 129 | 130 | return cls(_StreamIterableWrapper(resource.open("rb")), *args, mimetype=mimetype, **kwargs) 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Rolo HTTP 3 |

4 |

5 | Rolo HTTP: A Python framework for building HTTP-based server applications. 6 |

7 | 8 | # Rolo HTTP 9 | 10 |

11 | CI badge 12 | PyPI Version 13 | 14 | PyPI License 15 | Code style: black 16 |

17 | 18 | Rolo is a flexible framework and library to build HTTP-based server applications beyond microservices and REST APIs. 19 | You can build HTTP-based RPC servers, websocket proxies, or other server types that typical web frameworks are not designed for. 20 | 21 | Rolo extends [Werkzeug](https://github.com/pallets/werkzeug/), a flexible Python HTTP server library, for you to use concepts you are familiar with like `Router`, `Request`, `Response`, or `@route`. 22 | It introduces the concept of a `Gateway` and `HandlerChain`, an implementation variant of the [chain-of-responsibility pattern](https://en.wikipedia.org/wiki/Chain-of-responsibility_pattern). 23 | 24 | Rolo is designed for environments that do not use asyncio, but still require asynchronous HTTP features like HTTP2 SSE or Websockets. 25 | To allow asynchronous communication, Rolo introduces an ASGI/WSGI bridge, that allows you to serve Rolo applications through ASGI servers like Hypercorn. 26 | 27 | ## Usage 28 | 29 | ### Default router example 30 | 31 | Routers are based on Werkzeug's [URL Map](https://werkzeug.palletsprojects.com/en/2.3.x/routing/), but dispatch to handler functions directly. 32 | The `@route` decorator works similar to Flask or FastAPI, but they are not tied to an Application object. 33 | Instead, you can define routes on functions or methods, and then add them directly to the router. 34 | 35 | ```python 36 | from rolo import Router, route, Response 37 | from werkzeug import Request 38 | from werkzeug.serving import run_simple 39 | 40 | @route("/users") 41 | def user(_request: Request, args): 42 | assert not args 43 | return Response("user") 44 | 45 | @route("/users/") 46 | def user_id(_request: Request, args): 47 | assert args 48 | return Response(f"{args['user_id']}") 49 | 50 | router = Router() 51 | router.add(user) 52 | router.add(user_id) 53 | 54 | # convert Router to a WSGI app and serve it through werkzeug 55 | run_simple('localhost', 8080, router.wsgi(), use_reloader=True) 56 | ``` 57 | 58 | ### Pydantic integration 59 | 60 | Routers use dispatchers to dispatch the request to functions. 61 | In the previous example, the default dispatcher calls the function with the `Request` object and the request arguments as dictionary. 62 | The "handler dispatcher" can transform functions into more Flask or FastAPI-like functions, that also allow you to integrate with Pydantic. 63 | Here's how the default example from the FastAPI documentation would look like with rolo: 64 | 65 | ```python 66 | import pydantic 67 | from werkzeug import Request 68 | from werkzeug.serving import run_simple 69 | 70 | from rolo import Router, route 71 | 72 | 73 | class Item(pydantic.BaseModel): 74 | name: str 75 | price: float 76 | is_offer: bool | None = None 77 | 78 | 79 | @route("/", methods=["GET"]) 80 | def read_root(request: Request): 81 | return {"Hello": "World"} 82 | 83 | 84 | @route("/items/", methods=["GET"]) 85 | def read_item(request: Request, item_id: int): 86 | return {"item_id": item_id, "q": request.query_string} 87 | 88 | 89 | @route("/items/", methods=["PUT"]) 90 | def update_item(request: Request, item_id: int, item: Item): 91 | return {"item_name": item.name, "item_id": item_id} 92 | 93 | 94 | router = Router() 95 | router.add(read_root) 96 | router.add(read_item) 97 | router.add(update_item) 98 | 99 | # convert Router to a WSGI app and serve it through werkzeug 100 | run_simple("localhost", 8080, router.wsgi(), use_reloader=True) 101 | ``` 102 | 103 | ### Gateway & Handler Chain 104 | 105 | A rolo `Gateway` holds a set of request, response, and exception handlers, as well as request finalizers. 106 | Gateway instances can then be served as WSGI or ASGI applications by using the appropriate serving adapter. 107 | Here is a simple example of a Gateway with just one handler that returns the URL and method that was invoked. 108 | 109 | ```python 110 | from werkzeug import run_simple 111 | 112 | from rolo import Response 113 | from rolo.gateway import Gateway, HandlerChain, RequestContext 114 | from rolo.gateway.wsgi import WsgiGateway 115 | 116 | 117 | def echo_handler(chain: HandlerChain, context: RequestContext, response: Response): 118 | response.status_code = 200 119 | response.set_json({"url": context.request.url, "method": context.request.method}) 120 | 121 | 122 | gateway = Gateway(request_handlers=[echo_handler]) 123 | 124 | app = WsgiGateway(gateway) 125 | run_simple("localhost", 8080, app, use_reloader=True) 126 | ``` 127 | 128 | Serving this will yield: 129 | 130 | ```console 131 | curl http://localhost:8080/hello-world 132 | {"url": "http://localhost:8080/hello-world", "method": "GET"} 133 | ``` 134 | 135 | 136 | ## Develop 137 | 138 | ### Quickstart 139 | 140 | to install the python and other developer requirements into a venv run: 141 | 142 | make install 143 | 144 | ### Format code 145 | 146 | We use black and isort as code style tools. 147 | To execute them, run: 148 | 149 | make format 150 | 151 | ### Build distribution 152 | 153 | To build a wheel and source distribution, simply run 154 | 155 | make dist 156 | -------------------------------------------------------------------------------- /tests/websocket/test_websockets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import threading 3 | from queue import Queue 4 | 5 | import pytest 6 | import websocket 7 | from _pytest.fixtures import SubRequest 8 | from werkzeug.datastructures import Headers 9 | 10 | from rolo import Response, Router 11 | from rolo.websocket.request import ( 12 | WebSocketDisconnectedError, 13 | WebSocketProtocolError, 14 | WebSocketRequest, 15 | ) 16 | 17 | 18 | @pytest.fixture(params=["asgi", "twisted"]) 19 | def serve_websocket_listener(request: SubRequest): 20 | def _serve(listener): 21 | if request.param == "asgi": 22 | srv = request.getfixturevalue("serve_asgi_adapter") 23 | return srv(wsgi_app=None, websocket_listener=listener) 24 | else: 25 | srv = request.getfixturevalue("serve_twisted_websocket_listener") 26 | return srv(listener) 27 | 28 | yield _serve 29 | 30 | 31 | def test_websocket_basic_interaction(serve_websocket_listener): 32 | raised = threading.Event() 33 | 34 | @WebSocketRequest.listener 35 | def app(request: WebSocketRequest): 36 | with request.accept() as ws: 37 | ws.send("hello") 38 | assert ws.receive() == "foobar" 39 | ws.send("world") 40 | 41 | with pytest.raises(WebSocketDisconnectedError): 42 | ws.receive() 43 | 44 | raised.set() 45 | 46 | server = serve_websocket_listener(app) 47 | 48 | client = websocket.WebSocket() 49 | client.connect(server.url.replace("http://", "ws://")) 50 | assert client.recv() == "hello" 51 | client.send("foobar") 52 | assert client.recv() == "world" 53 | client.close() 54 | 55 | assert raised.wait(timeout=3) 56 | 57 | 58 | def test_websocket_disconnect_while_iter(serve_websocket_listener): 59 | """Makes sure that the ``for line in iter(ws)`` pattern works smoothly when the client disconnects.""" 60 | returned = threading.Event() 61 | received = [] 62 | 63 | @WebSocketRequest.listener 64 | def app(request: WebSocketRequest): 65 | with request.accept() as ws: 66 | for line in iter(ws): 67 | received.append(line) 68 | 69 | returned.set() 70 | 71 | server = serve_websocket_listener(app) 72 | 73 | client = websocket.WebSocket() 74 | client.connect(server.url.replace("http://", "ws://")) 75 | 76 | client.send("foo") 77 | client.send("bar") 78 | client.close() 79 | 80 | assert returned.wait(timeout=3) 81 | assert received[0] == "foo" 82 | assert received[1] == "bar" 83 | 84 | 85 | def test_websocket_headers(serve_websocket_listener): 86 | @WebSocketRequest.listener 87 | def echo_headers(request: WebSocketRequest): 88 | with request.accept(headers=Headers({"x-foo-bar": "foobar"})) as ws: 89 | ws.send(json.dumps(dict(request.headers))) 90 | 91 | server = serve_websocket_listener(echo_headers) 92 | 93 | client = websocket.WebSocket() 94 | client.connect( 95 | server.url.replace("http://", "ws://"), 96 | header=["Authorization: Basic let-me-in", "CasedHeader: hello"], 97 | ) 98 | 99 | assert client.handshake_response.status == 101 100 | assert client.getheaders()["x-foo-bar"] == "foobar" 101 | doc = client.recv() 102 | headers = json.loads(doc) 103 | assert headers["Connection"] == "Upgrade" 104 | assert headers["Authorization"] == "Basic let-me-in" 105 | assert headers["CasedHeader"] == "hello" 106 | 107 | 108 | def test_websocket_reject(serve_websocket_listener): 109 | @WebSocketRequest.listener 110 | def respond(request: WebSocketRequest): 111 | request.reject(Response("nope", 403)) 112 | 113 | server = serve_websocket_listener(respond) 114 | 115 | socket = websocket.WebSocket() 116 | with pytest.raises(websocket.WebSocketBadStatusException) as e: 117 | socket.connect(server.url.replace("http://", "ws://")) 118 | 119 | assert e.value.status_code == 403 120 | assert e.value.resp_body == b"nope" 121 | 122 | 123 | def test_binary_and_text_mode(serve_websocket_listener): 124 | received = Queue() 125 | 126 | @WebSocketRequest.listener 127 | def echo_headers(request: WebSocketRequest): 128 | with request.accept() as ws: 129 | ws.send(b"foo") 130 | ws.send("textfoo") 131 | received.put(ws.receive()) 132 | received.put(ws.receive()) 133 | 134 | server = serve_websocket_listener(echo_headers) 135 | 136 | client = websocket.WebSocket() 137 | client.connect(server.url.replace("http://", "ws://")) 138 | 139 | assert client.handshake_response.status == 101 140 | data = client.recv() 141 | assert data == b"foo" 142 | 143 | data = client.recv() 144 | assert data == "textfoo" 145 | 146 | client.send("textbar") 147 | client.send_binary(b"bar") 148 | 149 | assert received.get(timeout=5) == "textbar" 150 | assert received.get(timeout=5) == b"bar" 151 | 152 | 153 | def test_send_non_confirming_data(serve_websocket_listener): 154 | match = Queue() 155 | 156 | @WebSocketRequest.listener 157 | def echo_headers(request: WebSocketRequest): 158 | with request.accept() as ws: 159 | with pytest.raises(WebSocketProtocolError) as e: 160 | ws.send({"foo": "bar"}) 161 | match.put(e) 162 | 163 | server = serve_websocket_listener(echo_headers) 164 | 165 | client = websocket.WebSocket() 166 | client.connect(server.url.replace("http://", "ws://")) 167 | 168 | e = match.get(timeout=5) 169 | assert e.match("Cannot send data type over websocket") 170 | 171 | 172 | def test_router_integration(serve_websocket_listener): 173 | router = Router() 174 | 175 | def _handler(request: WebSocketRequest, request_args: dict): 176 | with request.accept() as ws: 177 | ws.send("foo") 178 | ws.send(f"id={request_args['id']}") 179 | ws.send(json.dumps(dict(request.headers))) 180 | 181 | router.add("/foo/", _handler) 182 | 183 | server = serve_websocket_listener(WebSocketRequest.listener(router.dispatch)) 184 | client = websocket.WebSocket() 185 | client.connect( 186 | server.url.replace("http://", "ws://") + "/foo/bar", header=["CasedHeader: hello"] 187 | ) 188 | assert client.recv() == "foo" 189 | assert client.recv() == "id=bar" 190 | assert "CasedHeader" in json.loads(client.recv()) 191 | -------------------------------------------------------------------------------- /rolo/client.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from urllib.parse import urlparse 3 | 4 | import requests 5 | import urllib3.util 6 | from werkzeug import Request, Response 7 | from werkzeug.datastructures import Headers 8 | 9 | from .request import get_raw_base_url, get_raw_current_url, get_raw_path, restore_payload 10 | 11 | 12 | class HttpClient(abc.ABC): 13 | """ 14 | An HTTP client that can make http requests using werkzeug's request object. 15 | """ 16 | 17 | @abc.abstractmethod 18 | def request(self, request: Request, server: str | None = None) -> Response: 19 | """ 20 | Make the given HTTP as a client. 21 | 22 | :param request: the request to make 23 | :param server: the URL to send the request to, which defaults to the host component of the original Request. 24 | :return: the response. 25 | """ 26 | raise NotImplementedError 27 | 28 | @abc.abstractmethod 29 | def close(self): 30 | """ 31 | Close any underlying resources the client may need. 32 | """ 33 | pass 34 | 35 | def __enter__(self): 36 | return self 37 | 38 | def __exit__(self, *args): 39 | self.close() 40 | 41 | 42 | class _VerifyRespectingSession(requests.Session): 43 | """ 44 | A class which wraps requests.Session to circumvent https://github.com/psf/requests/issues/3829. 45 | This ensures that if `REQUESTS_CA_BUNDLE` or `CURL_CA_BUNDLE` are set, the request does not perform the TLS 46 | verification if `session.verify` is set to `False. 47 | """ 48 | 49 | def merge_environment_settings(self, url, proxies, stream, verify, *args, **kwargs): 50 | if self.verify is False: 51 | verify = False 52 | 53 | return super(_VerifyRespectingSession, self).merge_environment_settings( 54 | url, proxies, stream, verify, *args, **kwargs 55 | ) 56 | 57 | 58 | class SimpleRequestsClient(HttpClient): 59 | session: requests.Session 60 | follow_redirects: bool 61 | 62 | def __init__(self, session: requests.Session = None, follow_redirects: bool = True): 63 | self.session = session or _VerifyRespectingSession() 64 | self.follow_redirects = follow_redirects 65 | 66 | @staticmethod 67 | def _get_destination_url(request: Request, server: str | None = None) -> str: 68 | if server: 69 | # accepts "http://localhost:5000" or "localhost:5000" 70 | if "://" in server: 71 | parts = urlparse(server) 72 | scheme, server = parts.scheme, parts.netloc 73 | else: 74 | scheme = request.scheme 75 | return get_raw_current_url(scheme, server, request.root_path, get_raw_path(request)) 76 | 77 | return get_raw_base_url(request) 78 | 79 | @staticmethod 80 | def _transform_response_headers(response: requests.Response) -> Headers: 81 | """ 82 | `requests` by default concatenate headers in response under a single header separated by a comma 83 | This behavior is generally the same as having the same header multiple times with different values in a 84 | response. 85 | However, specific headers like `Set-Cookie` needs to be defined multiple times. By directly using the raw 86 | `urllib3` response that still contains non-concatenate values, we can follow more closely the response. 87 | """ 88 | headers = Headers() 89 | for k, v in response.raw.headers.iteritems(): 90 | headers.add(k, v) 91 | return headers 92 | 93 | def request(self, request: Request, server: str | None = None) -> Response: 94 | """ 95 | Very naive implementation to make the given HTTP request using the requests library, i.e., process the request 96 | as a client. 97 | 98 | :param request: the request to perform 99 | :param server: the URL to send the request to, which defaults to the host component of the original Request. 100 | :param allow_redirects: allow the request to follow redirects 101 | :return: the response. 102 | """ 103 | 104 | url = self._get_destination_url(request, server) 105 | 106 | headers = dict(request.headers.items()) 107 | 108 | # urllib3 (used by requests) will set an Accept-Encoding header ("gzip,deflate") 109 | # - See urllib3.util.request.ACCEPT_ENCODING 110 | # - The solution to this, provided by urllib3, is to use `urllib3.util.SKIP_HEADER` 111 | # to prevent the header from being added. 112 | if not request.headers.get("accept-encoding"): 113 | headers["accept-encoding"] = urllib3.util.SKIP_HEADER 114 | 115 | response = self.session.request( 116 | method=request.method, 117 | # use raw base url to preserve path url encoding 118 | url=url, 119 | # request.args are only the url parameters 120 | params=list(request.args.items(multi=True)), 121 | headers=headers, 122 | data=restore_payload(request), 123 | stream=True, 124 | allow_redirects=self.follow_redirects, 125 | ) 126 | 127 | if request.method == "HEAD": 128 | # for HEAD requests we have to keep the original content-length, but it will be re-calculated when creating 129 | # the final_response object 130 | final_response = Response( 131 | response=response.content, 132 | status=response.status_code, 133 | headers=self._transform_response_headers(response), 134 | ) 135 | final_response.content_length = response.headers.get("Content-Length", 0) 136 | return final_response 137 | 138 | response_headers = self._transform_response_headers(response) 139 | 140 | if "chunked" in (transfer_encoding := response_headers.get("Transfer-Encoding", "")): 141 | response_headers.pop("Content-Length", None) 142 | # We should not set `Transfer-Encoding` in a Response, because it is the responsibility of the webserver 143 | # to do so, if there are no Content-Length. However, gzip behavior is more related to the actual content of 144 | # the response, so we keep that one. 145 | transfer_encoding_values = [v.strip() for v in transfer_encoding.split(",")] 146 | transfer_encoding_no_chunked = [ 147 | v for v in transfer_encoding_values if v.lower() != "chunked" 148 | ] 149 | response_headers.setlist("Transfer-Encoding", transfer_encoding_no_chunked) 150 | 151 | final_response = Response( 152 | response=(chunk for chunk in response.raw.stream(1024, decode_content=False)), 153 | status=response.status_code, 154 | headers=response_headers, 155 | ) 156 | 157 | return final_response 158 | 159 | def close(self): 160 | self.session.close() 161 | 162 | 163 | def make_request(request: Request) -> Response: 164 | """ 165 | Convenience method to make the given HTTP as a client. 166 | 167 | :param request: the request to make 168 | :return: the response. 169 | """ 170 | with SimpleRequestsClient() as client: 171 | return client.request(request) 172 | -------------------------------------------------------------------------------- /rolo/proxy.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from typing import Mapping, Union 3 | from urllib.parse import urlparse 4 | 5 | from werkzeug import Request, Response 6 | from werkzeug.datastructures import Headers 7 | from werkzeug.test import EnvironBuilder 8 | 9 | from .client import HttpClient, SimpleRequestsClient 10 | from .request import get_raw_path, restore_payload, set_environment_headers 11 | 12 | 13 | def forward( 14 | request: Request, 15 | forward_base_url: str, 16 | forward_path: str = None, 17 | headers: Union[Headers, Mapping[str, str]] = None, 18 | ) -> Response: 19 | """ 20 | Convenience method that creates a new Proxy and immediately calls proxy.forward(...). See ``Proxy`` for more 21 | information. 22 | """ 23 | with Proxy(forward_base_url=forward_base_url) as proxy: 24 | return proxy.forward(request, forward_path=forward_path, headers=headers) 25 | 26 | 27 | class Proxy(HttpClient): 28 | preserve_host: bool 29 | 30 | def __init__( 31 | self, forward_base_url: str, client: HttpClient = None, preserve_host: bool = True 32 | ): 33 | """ 34 | Creates a new HTTP Proxy which can be used to forward incoming requests according to the configuration. 35 | 36 | :param forward_base_url: the base url (backend) to forward the requests to. 37 | :param client: the HTTP Client used to make the requests 38 | :param preserve_host: True to ensure that the Host header of the incoming request is preserved. 39 | If False, then the Host header will be set to the Host from the perspective of the Proxy. 40 | """ 41 | self.forward_base_url = forward_base_url 42 | self.client = client or SimpleRequestsClient() 43 | self.preserve_host = preserve_host 44 | 45 | def request(self, request: Request, server: str | None = None) -> Response: 46 | """ 47 | Compatibility with HttpClient interface. A call is equivalent to ``Proxy.forward(request, None, None)``. 48 | 49 | :param request: the request to proxy 50 | :param server: ignored for a proxy, since the server is already set by `forward_base_url`. 51 | :return: the proxied response 52 | """ 53 | return self.forward(request) 54 | 55 | def forward( 56 | self, 57 | request: Request, 58 | forward_path: str = None, 59 | headers: Union[Headers, Mapping[str, str]] = None, 60 | ) -> Response: 61 | """ 62 | Uses the client to forward the given request according to the proxy's configuration. 63 | 64 | :param request: the base request to forward (with the original URL and path data) 65 | :param forward_path: the path to forward the request to. if set, the original path will be replaced completely, 66 | otherwise the original path will be used 67 | :param headers: additional custom headers to send as part of the proxy request 68 | :return: the proxied response 69 | """ 70 | headers = Headers(headers) if headers else Headers() 71 | 72 | if client_ip := request.remote_addr: 73 | if xff := request.headers.get("X-Forwarded-For"): 74 | headers["X-Forwarded-For"] = f"{xff}, {client_ip}" 75 | else: 76 | headers["X-Forwarded-For"] = f"{client_ip}" 77 | 78 | if forward_path is None: 79 | forward_path = get_raw_path(request) 80 | if forward_path: 81 | forward_path = "/" + forward_path.lstrip("/") 82 | 83 | proxy_request = _copy_request(request, self.forward_base_url, forward_path, headers) 84 | 85 | if self.preserve_host and "Host" in request.headers: 86 | proxy_request.headers["Host"] = request.headers["Host"] 87 | 88 | target = urlparse(self.forward_base_url) 89 | return self.client.request(proxy_request, server=f"{target.scheme}://{target.netloc}") 90 | 91 | def close(self): 92 | self.client.close() 93 | 94 | 95 | class ProxyHandler: 96 | """ 97 | A dispatcher Handler which can be used in a ``Router[Handler]`` that proxies incoming requests according to the 98 | configuration. 99 | 100 | The Handler is expected to be used together with a route that uses a ``path`` parameter named ``path`` in the URL. 101 | Fir example: if you want to forward all requests from ``/foobar/`` to ``http://localhost:8080/v1/``, 102 | you would do the following:: 103 | 104 | router = Router(dispatcher=handler_dispatcher()) 105 | router.add("/foobar/", ProxyHandler("http://localhost:8080/v1") 106 | 107 | This is similar to the common nginx configuration where proxy_pass is a URI:: 108 | 109 | location /foobar { 110 | proxy_pass http://localhost:8080/v1/; 111 | } 112 | """ 113 | 114 | def __init__(self, forward_base_url: str, client: HttpClient = None): 115 | """ 116 | Creates a new Proxy with the given ``forward_base_url`` (see ``Proxy``). 117 | 118 | :param forward_base_url: the base url (backend) to forward the requests to. 119 | :param client: the HTTP Client used to make the requests 120 | """ 121 | self.proxy = Proxy(forward_base_url=forward_base_url, client=client) 122 | 123 | def __call__(self, request: Request, **kwargs) -> Response: 124 | return self.proxy.forward(request, forward_path=kwargs.get("path", "")) 125 | 126 | def close(self): 127 | self.proxy.close() 128 | 129 | 130 | def _copy_request( 131 | request: Request, 132 | base_url: str = None, 133 | path: str = None, 134 | headers: Union[Headers, Mapping[str, str]] = None, 135 | ) -> Request: 136 | """ 137 | Creates a new request from the given one that can be used to perform a proxy call. 138 | 139 | :param request: the original request 140 | :param base_url: the url to forward the request to (e.g., http://localhost:8080) 141 | :param path: the path to forward the request to (e.g., /foobar), if set to None, the original path will be used 142 | :param headers: optional headers to overwrite 143 | :return: a new request with slightly modified underlying environment but the same data stream 144 | """ 145 | # ensure that the headers in the env are set on the environment 146 | # FIXME: we should preserve header casing like we do with the `asgi.headers` property in the asgi/wsgi bridge to 147 | # pass through the raw headers. 148 | set_environment_headers(request.environ, request.headers) 149 | builder = EnvironBuilder.from_environ(request.environ) 150 | 151 | if base_url: 152 | builder.base_url = base_url 153 | builder.headers["Host"] = builder.host 154 | 155 | if path is not None: 156 | builder.path = path 157 | 158 | if headers: 159 | builder.headers.update(headers) 160 | 161 | # FIXME: unfortunately, EnvironBuilder expects the input stream to be seekable, but we don't have that when using 162 | # the asgi/wsgi bridge. we need a better way of dealing with IO! 163 | data = restore_payload(request) 164 | builder.input_stream = BytesIO(data) 165 | builder.content_length = len(data) 166 | # Since the payload is completely restored, the proxy forwarding is not streamed. 167 | # Therefore, we need to remove a potential "chunked" Transfer-Encoding 168 | if builder.headers.get("Transfer-Encoding", None) == "chunked": 169 | builder.headers.pop("Transfer-Encoding") 170 | 171 | new_request = builder.get_request() 172 | 173 | # explicitly set the path in the environment and in the newly created request 174 | if path is not None: 175 | new_request.environ["RAW_URI"] = path or "/" 176 | 177 | # copy headers s.t. they are no longer immutable (by default, EnvironHeaders are used) 178 | new_request.headers = Headers(new_request.headers) 179 | 180 | return new_request 181 | -------------------------------------------------------------------------------- /tests/test_pydantic.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | 3 | import pydantic 4 | import pytest 5 | from werkzeug.exceptions import BadRequest 6 | 7 | from rolo import Request, Router, resource 8 | from rolo.routing import handler as routing_handler 9 | from rolo.routing import handler_dispatcher 10 | 11 | pydantic_version = pydantic.version.version_short() 12 | 13 | 14 | class MyItem(pydantic.BaseModel): 15 | name: str 16 | price: float 17 | is_offer: bool = None 18 | 19 | 20 | class TestPydanticHandlerDispatcher: 21 | def test_request_arg(self): 22 | router = Router(dispatcher=handler_dispatcher()) 23 | 24 | def handler(_request: Request, item: MyItem) -> dict: 25 | return {"item": item.model_dump()} 26 | 27 | router.add("/items", handler) 28 | 29 | request = Request("POST", "/items", body=b'{"name":"rolo","price":420.69}') 30 | assert router.dispatch(request).get_json(force=True) == { 31 | "item": { 32 | "name": "rolo", 33 | "price": 420.69, 34 | "is_offer": None, 35 | }, 36 | } 37 | 38 | def test_request_args(self): 39 | router = Router(dispatcher=handler_dispatcher()) 40 | 41 | def handler(_request: Request, item_id: int, item: MyItem) -> dict: 42 | return {"item_id": item_id, "item": item.model_dump()} 43 | 44 | router.add("/items/", handler) 45 | 46 | request = Request("POST", "/items/123", body=b'{"name":"rolo","price":420.69}') 47 | assert router.dispatch(request).get_json(force=True) == { 48 | "item_id": 123, 49 | "item": { 50 | "name": "rolo", 51 | "price": 420.69, 52 | "is_offer": None, 53 | }, 54 | } 55 | 56 | def test_request_args_empty_body(self): 57 | router = Router(dispatcher=handler_dispatcher()) 58 | 59 | def handler(_request: Request, item_id: int, item: MyItem) -> dict: 60 | return {"item_id": item_id, "item": item.model_dump()} 61 | 62 | router.add("/items/", handler) 63 | 64 | request = Request("POST", "/items/123", body=b"") 65 | assert router.dispatch(request).get_json(force=True) == [ 66 | { 67 | "type": "json_invalid", 68 | "loc": [], 69 | "msg": "Invalid JSON: EOF while parsing a value at line 1 column 0", 70 | "ctx": {"error": "EOF while parsing a value at line 1 column 0"}, 71 | "input": "", 72 | "url": f"https://errors.pydantic.dev/{pydantic_version}/v/json_invalid", 73 | } 74 | ] 75 | 76 | def test_response(self): 77 | router = Router(dispatcher=handler_dispatcher()) 78 | 79 | def handler(_request: Request, item_id: int) -> MyItem: 80 | return MyItem(name="rolo", price=420.69) 81 | 82 | router.add("/items/", handler) 83 | 84 | request = Request("GET", "/items/123") 85 | assert router.dispatch(request).get_json() == { 86 | "name": "rolo", 87 | "price": 420.69, 88 | "is_offer": None, 89 | } 90 | 91 | def test_response_list(self): 92 | router = Router(dispatcher=handler_dispatcher()) 93 | 94 | def handler(_request: Request) -> list[MyItem]: 95 | return [ 96 | MyItem(name="rolo", price=420.69), 97 | MyItem(name="twiks", price=1.23, is_offer=True), 98 | ] 99 | 100 | router.add("/items", handler) 101 | 102 | request = Request("GET", "/items") 103 | assert router.dispatch(request).get_json() == [ 104 | { 105 | "name": "rolo", 106 | "price": 420.69, 107 | "is_offer": None, 108 | }, 109 | { 110 | "name": "twiks", 111 | "price": 1.23, 112 | "is_offer": True, 113 | }, 114 | ] 115 | 116 | def test_request_arg_validation_error(self): 117 | router = Router(dispatcher=handler_dispatcher()) 118 | 119 | def handler(_request: Request, item_id: int, item: MyItem) -> str: 120 | return item.model_dump_json() 121 | 122 | router.add("/items/", handler) 123 | 124 | request = Request("POST", "/items/123", body=b'{"name":"rolo"}') 125 | assert router.dispatch(request).get_json() == [ 126 | { 127 | "type": "missing", 128 | "loc": ["price"], 129 | "msg": "Field required", 130 | "input": {"name": "rolo"}, 131 | "url": f"https://errors.pydantic.dev/{pydantic_version}/v/missing", 132 | } 133 | ] 134 | 135 | def test_request_arg_invalid_json(self): 136 | router = Router(dispatcher=handler_dispatcher()) 137 | 138 | def handler(_request: Request, item_id: int, item: MyItem) -> str: 139 | return item.model_dump_json() 140 | 141 | router.add("/items/", handler) 142 | 143 | request = Request("POST", "/items/123", body=b'{"}') 144 | with pytest.raises(BadRequest): 145 | assert router.dispatch(request) 146 | 147 | def test_missing_annotation(self): 148 | router = Router(dispatcher=handler_dispatcher()) 149 | 150 | # without an annotation, we cannot be sure what type to serialize into, so the dispatcher doesn't pass 151 | # anything into ``item``. 152 | def handler(_request: Request, item=None) -> dict: 153 | return {"item": item} 154 | 155 | router.add("/items", handler) 156 | 157 | request = Request("POST", "/items", body=b'{"name":"rolo","price":420.69}') 158 | assert router.dispatch(request).get_json(force=True) == {"item": None} 159 | 160 | def test_with_pydantic_disabled(self, monkeypatch): 161 | monkeypatch.setattr(routing_handler, "ENABLE_PYDANTIC", False) 162 | router = Router(dispatcher=handler_dispatcher()) 163 | 164 | def handler(_request: Request, item: MyItem) -> dict: 165 | return {"item": item.model_dump()} 166 | 167 | router.add("/items", handler) 168 | 169 | request = Request("POST", "/items", body=b'{"name":"rolo","price":420.69}') 170 | with pytest.raises(TypeError): 171 | # "missing 1 required positional argument: 'item'" 172 | assert router.dispatch(request) 173 | 174 | def test_with_resource(self): 175 | router = Router(dispatcher=handler_dispatcher()) 176 | 177 | @resource("/items/") 178 | class MyResource: 179 | def on_get(self, request: Request, item_id: int): 180 | return MyItem(name="rolo", price=420.69) 181 | 182 | def on_post(self, request: Request, item_id: int, item: MyItem): 183 | return {"item_id": item_id, "item": item.model_dump()} 184 | 185 | router.add(MyResource()) 186 | 187 | response = router.dispatch(Request("GET", "/items/123")) 188 | assert response.get_json() == { 189 | "name": "rolo", 190 | "price": 420.69, 191 | "is_offer": None, 192 | } 193 | 194 | response = router.dispatch( 195 | Request("POST", "/items/123", body=b'{"name":"rolo","price":420.69}') 196 | ) 197 | assert response.get_json() == { 198 | "item": {"is_offer": None, "name": "rolo", "price": 420.69}, 199 | "item_id": 123, 200 | } 201 | 202 | def test_with_generic_type_alias(self): 203 | router = Router(dispatcher=handler_dispatcher()) 204 | 205 | def handler(request: Request, matrix: dict[str, str] = None): 206 | return "ok" 207 | 208 | router.add("/", endpoint=handler) 209 | 210 | request = Request("GET", "/") 211 | assert router.dispatch(request).data == b"ok" 212 | 213 | def test_with_typed_dict(self): 214 | try: 215 | from typing import Unpack 216 | except ImportError: 217 | pytest.skip("This test only works with Python >=3.11") 218 | 219 | router = Router(dispatcher=handler_dispatcher()) 220 | 221 | class Test(TypedDict, total=False): 222 | path: str 223 | random_value: str 224 | 225 | def func(request: Request, **kwargs: Unpack[Test]): 226 | return f"path={kwargs.get('path')},random_value={kwargs.get('random_value')}" 227 | 228 | router.add( 229 | "/", 230 | endpoint=func, 231 | defaults={"path": "", "random_value": "dev"}, 232 | ) 233 | 234 | request = Request("GET", "/") 235 | assert router.dispatch(request).data == b"path=,random_value=dev" 236 | -------------------------------------------------------------------------------- /tests/gateway/test_chain.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from werkzeug.datastructures import Headers 4 | 5 | from rolo.gateway import CompositeFinalizer, CompositeHandler, HandlerChain, RequestContext 6 | from rolo.response import Response 7 | 8 | 9 | def test_response_handler_exception(): 10 | def _raise(*args, **kwargs): 11 | raise ValueError("oh noes") 12 | 13 | response1 = mock.MagicMock() 14 | response2 = _raise 15 | response3 = mock.MagicMock() 16 | exception = mock.MagicMock() 17 | finalizer = mock.MagicMock() 18 | 19 | chain = HandlerChain( 20 | response_handlers=[response1, response2, response3], 21 | exception_handlers=[exception], 22 | finalizers=[finalizer], 23 | ) 24 | chain.handle(RequestContext(), Response()) 25 | 26 | response1.assert_called_once() 27 | response3.assert_called_once() # all response handlers should be called 28 | exception.assert_not_called() # response handlers don't trigger exception handlers 29 | finalizer.assert_called_once() 30 | 31 | assert chain.error is None 32 | 33 | 34 | def test_finalizer_handler_exception(): 35 | def _raise(*args, **kwargs): 36 | raise ValueError("oh noes") 37 | 38 | response = mock.MagicMock() 39 | exception = mock.MagicMock() 40 | finalizer1 = mock.MagicMock() 41 | finalizer2 = _raise 42 | finalizer3 = mock.MagicMock() 43 | 44 | chain = HandlerChain( 45 | response_handlers=[response], 46 | exception_handlers=[exception], 47 | finalizers=[finalizer1, finalizer2, finalizer3], 48 | ) 49 | chain.handle(RequestContext(), Response()) 50 | 51 | response.assert_called_once() 52 | exception.assert_not_called() # response handlers don't trigger exception handlers 53 | finalizer1.assert_called_once() 54 | finalizer3.assert_called_once() 55 | 56 | assert chain.error is None 57 | 58 | 59 | def test_composite_finalizer_handler_exception(): 60 | def _raise(*args, **kwargs): 61 | raise ValueError("oh noes") 62 | 63 | response = mock.MagicMock() 64 | exception = mock.MagicMock() 65 | finalizer1 = mock.MagicMock() 66 | finalizer2 = _raise 67 | finalizer3 = mock.MagicMock() 68 | 69 | finalizer = CompositeFinalizer() 70 | finalizer.append(finalizer1) 71 | finalizer.append(finalizer2) 72 | finalizer.append(finalizer3) 73 | 74 | chain = HandlerChain( 75 | response_handlers=[response], 76 | exception_handlers=[exception], 77 | finalizers=[finalizer], 78 | ) 79 | chain.handle(RequestContext(), Response()) 80 | 81 | response.assert_called_once() 82 | exception.assert_not_called() # response handlers don't trigger exception handlers 83 | finalizer1.assert_called_once() 84 | finalizer3.assert_called_once() 85 | 86 | assert chain.error is None 87 | 88 | 89 | def test_respond_with_json_response(): 90 | def handle(chain_: HandlerChain, _context, _response): 91 | chain_.respond(202, {"foo": "bar"}, headers={"X-Foo": "Bar"}) 92 | 93 | chain = HandlerChain(request_handlers=[handle]) 94 | chain.handle(RequestContext(), Response()) 95 | 96 | assert chain.response.status_code == 202 97 | assert chain.response.json == {"foo": "bar"} 98 | assert chain.response.headers.get("x-foo") == "Bar" 99 | assert chain.response.mimetype == "application/json" 100 | 101 | 102 | def test_respond_with_string_response(): 103 | def handle(chain_: HandlerChain, _context, _response): 104 | chain_.respond(200, "foobar", Headers({"X-Foo": "Bar"})) 105 | 106 | chain = HandlerChain(request_handlers=[handle]) 107 | chain.handle(RequestContext(), Response()) 108 | 109 | assert chain.response.status_code == 200 110 | assert chain.response.data == b"foobar" 111 | assert chain.response.headers.get("x-foo") == "Bar" 112 | assert chain.response.mimetype == "text/plain" 113 | 114 | 115 | class TestCompositeHandler: 116 | def test_composite_handler_stops_handler_chain(self): 117 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response): 118 | _chain.stop() 119 | 120 | inner2 = mock.MagicMock() 121 | outer1 = mock.MagicMock() 122 | outer2 = mock.MagicMock() 123 | response1 = mock.MagicMock() 124 | finalizer = mock.MagicMock() 125 | 126 | chain = HandlerChain() 127 | 128 | composite = CompositeHandler() 129 | composite.handlers.append(inner1) 130 | composite.handlers.append(inner2) 131 | 132 | chain.request_handlers.append(outer1) 133 | chain.request_handlers.append(composite) 134 | chain.request_handlers.append(outer2) 135 | chain.response_handlers.append(response1) 136 | chain.finalizers.append(finalizer) 137 | 138 | chain.handle(RequestContext(), Response()) 139 | outer1.assert_called_once() 140 | outer2.assert_not_called() 141 | inner2.assert_not_called() 142 | response1.assert_called_once() 143 | finalizer.assert_called_once() 144 | 145 | def test_composite_handler_terminates_handler_chain(self): 146 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response): 147 | _chain.terminate() 148 | 149 | inner2 = mock.MagicMock() 150 | outer1 = mock.MagicMock() 151 | outer2 = mock.MagicMock() 152 | response1 = mock.MagicMock() 153 | finalizer = mock.MagicMock() 154 | 155 | chain = HandlerChain() 156 | 157 | composite = CompositeHandler() 158 | composite.handlers.append(inner1) 159 | composite.handlers.append(inner2) 160 | 161 | chain.request_handlers.append(outer1) 162 | chain.request_handlers.append(composite) 163 | chain.request_handlers.append(outer2) 164 | chain.response_handlers.append(response1) 165 | chain.finalizers.append(finalizer) 166 | 167 | chain.handle(RequestContext(), Response()) 168 | outer1.assert_called_once() 169 | outer2.assert_not_called() 170 | inner2.assert_not_called() 171 | response1.assert_not_called() 172 | finalizer.assert_called_once() 173 | 174 | def test_composite_handler_with_not_return_on_stop(self): 175 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response): 176 | _chain.stop() 177 | 178 | inner2 = mock.MagicMock() 179 | outer1 = mock.MagicMock() 180 | outer2 = mock.MagicMock() 181 | response1 = mock.MagicMock() 182 | finalizer = mock.MagicMock() 183 | 184 | chain = HandlerChain() 185 | 186 | composite = CompositeHandler(return_on_stop=False) 187 | composite.handlers.append(inner1) 188 | composite.handlers.append(inner2) 189 | 190 | chain.request_handlers.append(outer1) 191 | chain.request_handlers.append(composite) 192 | chain.request_handlers.append(outer2) 193 | chain.response_handlers.append(response1) 194 | chain.finalizers.append(finalizer) 195 | 196 | chain.handle(RequestContext(), Response()) 197 | outer1.assert_called_once() 198 | outer2.assert_not_called() 199 | inner2.assert_called_once() 200 | response1.assert_called_once() 201 | finalizer.assert_called_once() 202 | 203 | def test_composite_handler_continues_handler_chain(self): 204 | inner1 = mock.MagicMock() 205 | inner2 = mock.MagicMock() 206 | outer1 = mock.MagicMock() 207 | outer2 = mock.MagicMock() 208 | response1 = mock.MagicMock() 209 | finalizer = mock.MagicMock() 210 | 211 | chain = HandlerChain() 212 | 213 | composite = CompositeHandler() 214 | composite.handlers.append(inner1) 215 | composite.handlers.append(inner2) 216 | 217 | chain.request_handlers.append(outer1) 218 | chain.request_handlers.append(composite) 219 | chain.request_handlers.append(outer2) 220 | chain.response_handlers.append(response1) 221 | chain.finalizers.append(finalizer) 222 | 223 | chain.handle(RequestContext(), Response()) 224 | outer1.assert_called_once() 225 | outer2.assert_called_once() 226 | inner1.assert_called_once() 227 | inner2.assert_called_once() 228 | response1.assert_called_once() 229 | finalizer.assert_called_once() 230 | 231 | def test_composite_handler_exception_calls_outer_exception_handlers(self): 232 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response): 233 | raise ValueError() 234 | 235 | inner2 = mock.MagicMock() 236 | outer1 = mock.MagicMock() 237 | outer2 = mock.MagicMock() 238 | exception_handler = mock.MagicMock() 239 | response1 = mock.MagicMock() 240 | finalizer = mock.MagicMock() 241 | 242 | chain = HandlerChain() 243 | 244 | composite = CompositeHandler() 245 | composite.handlers.append(inner1) 246 | composite.handlers.append(inner2) 247 | 248 | chain.request_handlers.append(outer1) 249 | chain.request_handlers.append(composite) 250 | chain.request_handlers.append(outer2) 251 | chain.exception_handlers.append(exception_handler) 252 | chain.response_handlers.append(response1) 253 | chain.finalizers.append(finalizer) 254 | 255 | chain.handle(RequestContext(), Response()) 256 | outer1.assert_called_once() 257 | outer2.assert_not_called() 258 | inner2.assert_not_called() 259 | exception_handler.assert_called_once() 260 | response1.assert_called_once() 261 | finalizer.assert_called_once() 262 | -------------------------------------------------------------------------------- /rolo/routing/rules.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing as t 3 | 4 | from werkzeug.routing import Map, Rule, RuleFactory 5 | 6 | 7 | class _RuleAttributes(t.NamedTuple): 8 | path: str 9 | host: t.Optional[str] = (None,) 10 | methods: t.Optional[t.Iterable[str]] = None 11 | kwargs: t.Optional[dict[str, t.Any]] = {} 12 | 13 | 14 | class _RouteEndpoint(t.Protocol): 15 | """ 16 | An endpoint that encapsulates ``_RuleAttributes`` for the creation of a ``Rule`` inside a ``Router``. 17 | """ 18 | 19 | rule_attributes: list[_RuleAttributes] 20 | 21 | def __call__(self, *args, **kwargs): 22 | raise NotImplementedError 23 | 24 | 25 | class WithHost(RuleFactory): 26 | def __init__(self, host: str, rules: t.Iterable[RuleFactory]) -> None: 27 | self.host = host 28 | self.rules = rules 29 | 30 | def get_rules(self, map: Map) -> t.Iterator[Rule]: 31 | for rulefactory in self.rules: 32 | for rule in rulefactory.get_rules(map): 33 | rule = rule.empty() 34 | rule.host = self.host 35 | yield rule 36 | 37 | 38 | class RuleGroup(RuleFactory): 39 | def __init__(self, rules: t.Iterable[RuleFactory]): 40 | self.rules = rules 41 | 42 | def get_rules(self, map: Map) -> t.Iterable[Rule]: 43 | for rule in self.rules: 44 | yield from rule.get_rules(map) 45 | 46 | 47 | class RuleAdapter(RuleFactory): 48 | """ 49 | Takes something that can also be passed to ``Router.add``, and exposes it as a ``RuleFactory`` that generates the 50 | appropriate Werkzeug rules. This can be used in combination with other rule factories like ``Submount``, 51 | and creates general compatibility with werkzeug rules. Here's an example:: 52 | 53 | @route("/my_api", methods=["GET"]) 54 | def do_get(request: Request, _args): 55 | # should be inherited 56 | return Response(f"{request.path}/do-get") 57 | 58 | def hello(request: Request, _args): 59 | return Response(f"hello world") 60 | 61 | router = Router() 62 | 63 | # base endpoints 64 | endpoints = RuleAdapter([ 65 | do_get, 66 | RuleAdapter("/hello", hello) 67 | ]) 68 | 69 | router.add([ 70 | endpoints, 71 | Submount("/foo", [endpoints]) 72 | ]) 73 | 74 | """ 75 | 76 | factory: RuleFactory 77 | """The underlying real rule factory.""" 78 | 79 | @t.overload 80 | def __init__( 81 | self, 82 | path: str, 83 | endpoint: t.Callable, 84 | host: t.Optional[str] = None, 85 | methods: t.Optional[t.Iterable[str]] = None, 86 | **kwargs, 87 | ): 88 | """ 89 | Basically a ``Rule``. 90 | 91 | :param path: the path pattern to match. This path rule, in contrast to the default behavior of Werkzeug, will be 92 | matched against the raw / original (potentially URL-encoded) path. 93 | :param endpoint: the endpoint to invoke 94 | :param host: an optional host matching pattern. if not pattern is given, the rule matches any host 95 | :param methods: the allowed HTTP verbs for this rule 96 | :param kwargs: any other argument that can be passed to ``werkzeug.routing.Rule`` 97 | """ 98 | ... 99 | 100 | @t.overload 101 | def __init__(self, fn: _RouteEndpoint): 102 | """ 103 | Takes a route endpoint (typically a function decorated with ``@route``) and adds it as ``EndpointRule``. 104 | 105 | :param fn: the RouteEndpoint function 106 | """ 107 | ... 108 | 109 | @t.overload 110 | def __init__(self, rule_factory: RuleFactory): 111 | """ 112 | Adds a ``Rule`` or the rules created by a ``RuleFactory`` to the given router. It passes the rules down to 113 | the underlying Werkzeug ``Map``, but also returns the created Rules. 114 | 115 | :param rule_factory: a `Rule` or ``RuleFactory` 116 | """ 117 | ... 118 | 119 | @t.overload 120 | def __init__(self, obj: t.Any): 121 | """ 122 | Scans the given object for members that can be used as a `RouteEndpoint` and adds them to the router. 123 | 124 | :param obj: the object to scan 125 | """ 126 | ... 127 | 128 | @t.overload 129 | def __init__(self, rules: list[t.Union[_RouteEndpoint, RuleFactory, t.Any]]): 130 | """Add multiple rules at once""" 131 | ... 132 | 133 | def __init__(self, *args, **kwargs): 134 | """ 135 | Dispatcher for overloaded ``__init__`` methods. 136 | """ 137 | if "path" in kwargs or isinstance(args[0], str): 138 | self.factory = _EndpointRule(*args, **kwargs) 139 | elif "fn" in kwargs or callable(args[0]): 140 | self.factory = _EndpointFunction(*args, **kwargs) 141 | elif "rule_factory" in kwargs: 142 | self.factory = kwargs["rule_factory"] 143 | elif isinstance(args[0], RuleFactory): 144 | self.factory = args[0] 145 | elif isinstance(args[0], list): 146 | self.factory = RuleGroup([RuleAdapter(rule) for rule in args[0]]) 147 | else: 148 | self.factory = _EndpointsObject(*args, **kwargs) 149 | 150 | def get_rules(self, map: Map) -> t.Iterable[Rule]: 151 | yield from self.factory.get_rules(map) 152 | 153 | 154 | class _EndpointRule(RuleFactory): 155 | """ 156 | Generates default werkzeug ``Rule`` object with the given attributes. Additionally, it makes sure that 157 | the generated rule always has a default host value, if the map has host matching enabled. Specifically, 158 | it adds the well-known placeholder ``<__host__>``, which is later stripped out of the request arguments 159 | when dispatching to the endpoint. This ensures compatibility of rule definitions across routers that 160 | have host matching enabled or not. 161 | """ 162 | 163 | def __init__( 164 | self, 165 | path: str, 166 | endpoint: t.Callable, 167 | host: t.Optional[str] = None, 168 | methods: t.Optional[t.Iterable[str]] = None, 169 | **kwargs, 170 | ): 171 | self.path = path 172 | self.endpoint = endpoint 173 | self.host = host 174 | self.methods = methods 175 | self.kwargs = kwargs 176 | 177 | def get_rules(self, map: Map) -> t.Iterable[Rule]: 178 | host = self.host 179 | 180 | if host is None and map.host_matching: 181 | # this creates a "match any" rule, and will put the value of the host 182 | # into the variable "__host__" 183 | host = "<__host__>" 184 | 185 | # the typing for endpoint is a str, but the doc states it can be any value, 186 | # however then the redirection URL building will not work 187 | rule = Rule( 188 | self.path, endpoint=self.endpoint, methods=self.methods, host=host, **self.kwargs 189 | ) 190 | yield rule 191 | 192 | 193 | class _EndpointFunction(RuleFactory): 194 | """ 195 | Internal rule factory that generates router Rules from ``@route`` annotated functions, or anything else 196 | that can be interpreted as a ``_RouteEndpoint``. It extracts the rule attributes from the 197 | ``_RuleAttributes`` attribute defined by ``_RouteEndpoint``. Example:: 198 | 199 | @route("/my_api", methods=["GET"]) 200 | def do_get(request: Request, _args): 201 | # should be inherited 202 | return Response(f"{request.path}/do-get") 203 | 204 | router.add(do_get) # <- will use an _EndpointFunction RuleFactory. 205 | """ 206 | 207 | def __init__(self, fn: _RouteEndpoint): 208 | self.fn = fn 209 | 210 | def get_rules(self, map: Map) -> t.Iterable[Rule]: 211 | attrs: list[_RuleAttributes] = self.fn.rule_attributes 212 | for attr in attrs: 213 | yield from _EndpointRule( 214 | path=attr.path, 215 | endpoint=self.fn, 216 | host=attr.host, 217 | methods=attr.methods, 218 | **attr.kwargs, 219 | ).get_rules(map) 220 | 221 | 222 | class _EndpointsObject(RuleFactory): 223 | """ 224 | Scans the given object for members that can be used as a `RouteEndpoint` and yields them as rules. 225 | """ 226 | 227 | def __init__(self, obj: object): 228 | self.obj = obj 229 | 230 | def get_rules(self, map: Map) -> t.Iterable[Rule]: 231 | endpoints: list[_RouteEndpoint] = [] 232 | 233 | members = inspect.getmembers(self.obj) 234 | for _, member in members: 235 | if hasattr(member, "rule_attributes"): 236 | endpoints.append(member) 237 | 238 | # make sure rules with "HEAD" are added first, otherwise werkzeug would let any "GET" rule would overwrite them. 239 | for endpoint in endpoints: 240 | for attr in endpoint.rule_attributes: 241 | if attr.methods and "HEAD" in attr.methods: 242 | yield from _EndpointRule( 243 | path=attr.path, 244 | endpoint=endpoint, 245 | host=attr.host, 246 | methods=attr.methods, 247 | **attr.kwargs, 248 | ).get_rules(map) 249 | 250 | for endpoint in endpoints: 251 | for attr in endpoint.rule_attributes: 252 | if not attr.methods or "HEAD" not in attr.methods: 253 | yield from _EndpointRule( 254 | path=attr.path, 255 | endpoint=endpoint, 256 | host=attr.host, 257 | methods=attr.methods, 258 | **attr.kwargs, 259 | ).get_rules(map) 260 | -------------------------------------------------------------------------------- /rolo/testing/pytest.py: -------------------------------------------------------------------------------- 1 | """rolo pytest plugin used both for internal testing and as testing library.""" 2 | import asyncio 3 | import dataclasses 4 | import socket 5 | import threading 6 | import time 7 | import typing 8 | from typing import Protocol 9 | 10 | import pytest 11 | from werkzeug import Request as WerkzeugRequest 12 | from werkzeug import serving 13 | 14 | from rolo import Router 15 | from rolo.asgi import ASGIAdapter, ASGILifespanListener 16 | from rolo.gateway import Gateway 17 | from rolo.gateway.asgi import AsgiGateway 18 | from rolo.gateway.wsgi import WsgiGateway 19 | from rolo.routing import handler_dispatcher 20 | from rolo.serving.twisted import HeaderPreservingHTTPChannel, TwistedGateway 21 | from rolo.websocket.adapter import WebSocketListener 22 | 23 | if typing.TYPE_CHECKING: 24 | from hypercorn.typing import ASGIFramework 25 | 26 | 27 | class ServerInfo(Protocol): 28 | url: str 29 | host: str 30 | port: int 31 | 32 | 33 | class Server(ServerInfo): 34 | def shutdown(self): 35 | ... 36 | 37 | 38 | @dataclasses.dataclass 39 | class _ServerInfo: 40 | host: str 41 | port: int 42 | url: str 43 | 44 | 45 | @pytest.fixture 46 | def serve_wsgi_app(): 47 | servers: list[serving.BaseWSGIServer] = [] 48 | 49 | def _serve(app, host: str = "localhost", port: int = None) -> serving.BaseWSGIServer | Server: 50 | srv = serving.make_server(host, port or 0, app, threaded=True) 51 | name = threading._newname("test-server-%d") 52 | threading.Thread(target=srv.serve_forever, name=name, daemon=True).start() 53 | servers.append(srv) 54 | srv.url = f"http://{srv.host}:{srv.port}" 55 | return srv 56 | 57 | yield _serve 58 | 59 | for server in servers: 60 | server.shutdown() 61 | 62 | 63 | @pytest.fixture 64 | def wsgi_router_server(serve_wsgi_app) -> tuple[Router, serving.BaseWSGIServer | Server]: 65 | """Creates a new Router with a handler dispatcher, serves it through a newly created ASGI server, and returns 66 | both the router and the server. 67 | """ 68 | router = Router(dispatcher=handler_dispatcher()) 69 | app = WerkzeugRequest.application(router.dispatch) 70 | return router, serve_wsgi_app(app) 71 | 72 | 73 | @pytest.fixture() 74 | def serve_asgi_app(): 75 | import hypercorn 76 | import hypercorn.asyncio 77 | 78 | _server_shutdown = [] 79 | 80 | def _create( 81 | app: "ASGIFramework", 82 | config: hypercorn.Config = None, 83 | event_loop: asyncio.AbstractEventLoop = None, 84 | ) -> Server: 85 | host = "localhost" 86 | port = get_random_tcp_port() 87 | bind = f"localhost:{port}" 88 | 89 | if not config: 90 | config = hypercorn.Config() 91 | config.h11_pass_raw_headers = True 92 | config.bind = [bind] 93 | 94 | event_loop = event_loop or asyncio.new_event_loop() 95 | close = asyncio.Event() 96 | closed = threading.Event() 97 | 98 | async def _set_close(): 99 | close.set() 100 | 101 | def _run(): 102 | event_loop.run_until_complete( 103 | hypercorn.asyncio.serve(app, config, shutdown_trigger=close.wait) 104 | ) 105 | closed.set() 106 | 107 | def _shutdown(): 108 | if close.is_set(): 109 | return 110 | asyncio.run_coroutine_threadsafe(_set_close(), event_loop) 111 | closed.wait(timeout=10) 112 | try: 113 | app.close() 114 | except AttributeError: 115 | pass 116 | asyncio.run_coroutine_threadsafe(event_loop.shutdown_asyncgens(), event_loop) 117 | event_loop.shutdown_default_executor() 118 | event_loop.stop() 119 | event_loop.close() 120 | 121 | _server_shutdown.append(_shutdown) 122 | threading.Thread( 123 | target=_run, name=threading._newname("asgi-server-%d"), daemon=True 124 | ).start() 125 | 126 | srv = _ServerInfo(host, port, f"http://{host}:{port}") 127 | srv.shutdown = _shutdown 128 | 129 | assert wait_server_is_up(srv), f"gave up waiting for server {srv}" 130 | 131 | return srv 132 | 133 | yield _create 134 | 135 | for server_shutdown in _server_shutdown: 136 | server_shutdown() 137 | 138 | 139 | @pytest.fixture() 140 | def serve_asgi_adapter(serve_asgi_app): 141 | def _create( 142 | wsgi_app, 143 | lifespan_listener: ASGILifespanListener = None, 144 | websocket_listener: WebSocketListener = None, 145 | ): 146 | loop = asyncio.new_event_loop() 147 | return serve_asgi_app( 148 | ASGIAdapter( 149 | wsgi_app, 150 | event_loop=loop, 151 | lifespan_listener=lifespan_listener, 152 | websocket_listener=websocket_listener, 153 | ), 154 | event_loop=loop, 155 | ) 156 | 157 | yield _create 158 | 159 | 160 | @pytest.fixture 161 | def serve_wsgi_gateway(serve_wsgi_app): 162 | def _serve(gateway: Gateway) -> Server: 163 | return serve_wsgi_app(WsgiGateway(gateway)) 164 | 165 | return _serve 166 | 167 | 168 | @pytest.fixture 169 | def serve_asgi_gateway(serve_asgi_app): 170 | def _serve(gateway: Gateway) -> Server: 171 | loop = asyncio.new_event_loop() 172 | return serve_asgi_app(AsgiGateway(gateway, event_loop=loop), event_loop=loop) 173 | 174 | return _serve 175 | 176 | 177 | @pytest.fixture(scope="session") 178 | def twisted_reactor(): 179 | """Session fixture that controls the lifecycle of the main twisted reactor.""" 180 | from twisted.internet import reactor 181 | from twisted.internet.error import ReactorAlreadyRunning 182 | from twisted.web.http import HTTPFactory 183 | 184 | def _run(): 185 | if reactor.running: 186 | return 187 | 188 | try: 189 | # for some reason, when using a `SelectReactor` (like you do by default on MacOS), whatever 190 | # protocols are added to the reactor via `listenTCP` _after_ `run` has been called, 191 | # are not served properly. We see this because the request calls in tests block forever. If we 192 | # add any listener here before calling `run`, then for some reason it works. 🤷 193 | reactor.listenTCP(get_random_tcp_port(), HTTPFactory()) 194 | reactor.run(installSignalHandlers=False) 195 | except ReactorAlreadyRunning: 196 | pass 197 | 198 | threading.Thread(target=_run, daemon=True).start() 199 | 200 | assert poll_condition( 201 | lambda: reactor.running, timeout=5 202 | ), f"gave up waiting for {reactor} to start" 203 | 204 | yield reactor 205 | 206 | reactor.stop() 207 | 208 | 209 | @pytest.fixture 210 | def serve_twisted_tcp_server(twisted_reactor): 211 | """Factory ficture for serving a twisted protocol factory (like ``Site``) through the twisted reactor.""" 212 | from twisted.internet.tcp import Port 213 | 214 | ports: list[Port] = [] 215 | 216 | def _create(protocol_factory): 217 | port = get_random_tcp_port() 218 | host = "localhost" 219 | ports.append(twisted_reactor.listenTCP(port, protocol_factory)) 220 | srv = _ServerInfo(host, port, f"http://{host}:{port}") 221 | assert wait_server_is_up(srv), f"gave up waiting for {srv}" 222 | return srv 223 | 224 | yield _create 225 | 226 | for _port in ports: 227 | _port.stopListening() 228 | 229 | 230 | @pytest.fixture 231 | def serve_twisted_gateway(serve_twisted_tcp_server): 232 | def _create(gateway): 233 | return serve_twisted_tcp_server(TwistedGateway(gateway)) 234 | 235 | yield _create 236 | 237 | 238 | @pytest.fixture 239 | def serve_twisted_websocket_listener(twisted_reactor, serve_twisted_tcp_server): 240 | """ 241 | This fixture creates a Twisted Site, without the need to serve a fully-fledged rolo Gateway. 242 | This is inspired by `rolo.serving.twisted.TwistedGateway`, but directly uses `WebsocketResourceDecorator` to 243 | pass the `WebSocketListener` instead of `gateway.accept` to the `websocketListener` parameter. 244 | It allows us to test the low-level behavior of WebSockets without being dependent on the Gateway implementation. 245 | """ 246 | from twisted.web.server import Site 247 | 248 | from rolo.serving.twisted import HeaderPreservingWSGIResource, WebsocketResourceDecorator 249 | 250 | def _create(websocket_listener: WebSocketListener): 251 | site = Site( 252 | WebsocketResourceDecorator( 253 | original=HeaderPreservingWSGIResource( 254 | twisted_reactor, twisted_reactor.getThreadPool(), None 255 | ), 256 | websocketListener=websocket_listener, 257 | ) 258 | ) 259 | site.protocol = HeaderPreservingHTTPChannel.protocol_factory 260 | return serve_twisted_tcp_server(site) 261 | 262 | return _create 263 | 264 | 265 | def is_server_up(srv: ServerInfo): 266 | args = socket.getaddrinfo(srv.host, srv.port, socket.AF_INET, socket.SOCK_STREAM) 267 | for family, socktype, proto, _canonname, sockaddr in args: 268 | s = socket.socket(family, socktype, proto) 269 | try: 270 | s.connect(sockaddr) 271 | except socket.error: 272 | return False 273 | else: 274 | s.close() 275 | return True 276 | 277 | 278 | def wait_server_is_up(srv: ServerInfo, timeout: float = 10, interval: float = 0.1) -> bool: 279 | return poll_condition(lambda: is_server_up(srv), timeout=timeout, interval=interval) 280 | 281 | 282 | def get_random_tcp_port() -> int: 283 | import socket 284 | 285 | sock = socket.socket() 286 | sock.bind(("", 0)) 287 | return sock.getsockname()[1] 288 | 289 | 290 | def poll_condition( 291 | condition: typing.Callable[[], bool], 292 | timeout: float = None, 293 | interval: float = 0.5, 294 | ) -> bool: 295 | """ 296 | Poll evaluates the given condition until a truthy value is returned. It does this every `interval` seconds 297 | (0.5 by default), until the timeout (in seconds, if any) is reached. 298 | 299 | Poll returns True once `condition()` returns a truthy value, or False if the timeout is reached. 300 | """ 301 | remaining = 0 302 | if timeout is not None: 303 | remaining = timeout 304 | 305 | while not condition(): 306 | if timeout is not None: 307 | remaining -= interval 308 | 309 | if remaining <= 0: 310 | return False 311 | 312 | time.sleep(interval) 313 | 314 | return True 315 | -------------------------------------------------------------------------------- /tests/test_request.py: -------------------------------------------------------------------------------- 1 | import wsgiref.validate 2 | 3 | import pytest 4 | from werkzeug.exceptions import BadRequest 5 | 6 | from rolo.request import Request, dummy_wsgi_environment, get_raw_path, restore_payload 7 | 8 | 9 | def test_get_json(): 10 | r = Request( 11 | "POST", 12 | "/", 13 | headers={"Content-Type": "application/json"}, 14 | body=b'{"foo": "bar", "baz": 420}', 15 | ) 16 | assert r.json == {"foo": "bar", "baz": 420} 17 | assert r.content_type == "application/json" 18 | 19 | 20 | def test_get_json_force(): 21 | r = Request("POST", "/", body=b'{"foo": "bar", "baz": 420}') 22 | assert r.get_json(force=True) == {"foo": "bar", "baz": 420} 23 | 24 | 25 | def test_get_json_invalid(): 26 | r = Request("POST", "/", body=b'{"foo": "') 27 | 28 | with pytest.raises(BadRequest): 29 | assert r.get_json(force=True) 30 | 31 | assert r.get_json(force=True, silent=True) is None 32 | 33 | 34 | def test_get_data(): 35 | r = Request("GET", "/", body="foobar") 36 | assert r.data == b"foobar" 37 | 38 | 39 | def test_get_data_as_text(): 40 | r = Request("GET", "/", body="foobar") 41 | assert r.get_data(as_text=True) == "foobar" 42 | 43 | 44 | def test_get_stream(): 45 | r = Request("GET", "/", body=b"foobar") 46 | assert r.stream.read(3) == b"foo" 47 | assert r.stream.read(3) == b"bar" 48 | 49 | 50 | def test_args(): 51 | r = Request("GET", "/", query_string="foo=420&bar=69") 52 | assert len(r.args) == 2 53 | assert r.args["foo"] == "420" 54 | assert r.args["bar"] == "69" 55 | 56 | 57 | def test_values(): 58 | r = Request("GET", "/", query_string="foo=420&bar=69") 59 | assert len(r.values) == 2 60 | assert r.values["foo"] == "420" 61 | assert r.values["bar"] == "69" 62 | 63 | 64 | def test_form_empty(): 65 | r = Request("POST", "/") 66 | assert len(r.form) == 0 67 | 68 | 69 | def test_post_form_urlencoded_and_query(): 70 | # see https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/POST#example 71 | r = Request( 72 | "POST", 73 | "/form", 74 | query_string="query1=foo&query2=bar", 75 | body=b"field1=value1&field2=value2", 76 | headers={"Content-Type": "application/x-www-form-urlencoded"}, 77 | ) 78 | 79 | assert len(r.form) == 2 80 | assert r.form["field1"] == "value1" 81 | assert r.form["field2"] == "value2" 82 | 83 | assert len(r.args) == 2 84 | assert r.args["query1"] == "foo" 85 | assert r.args["query2"] == "bar" 86 | 87 | assert len(r.values) == 4 88 | assert r.values["field1"] == "value1" 89 | assert r.values["field2"] == "value2" 90 | assert r.args["query1"] == "foo" 91 | assert r.args["query2"] == "bar" 92 | 93 | 94 | def test_validate_dummy_environment(): 95 | def validate(*args, **kwargs): 96 | assert wsgiref.validate.check_environ(dummy_wsgi_environment(*args, **kwargs)) is None 97 | 98 | validate(path="/foo/bar", body="foo") 99 | validate(path="/foo/bar", query_string="foo=420&bar=69") 100 | validate(server=("localstack.cloud", 4566)) 101 | validate(server=("localstack.cloud", None)) 102 | validate(remote_addr="127.0.0.1") 103 | validate(headers={"Content-Type": "text/xml"}, body=b"") 104 | validate(headers={"Content-Type": "text/xml", "x-amz-target": "foobar"}, body=b"") 105 | 106 | 107 | def test_content_length_is_set_automatically(): 108 | # checking that the value is calculated automatically 109 | request = Request("GET", "/", body="foobar") 110 | assert request.content_length == 6 111 | 112 | 113 | def test_content_length_is_overwritten(): 114 | # checking that the value passed from headers take precedence 115 | request = Request("GET", "/", body="foobar", headers={"Content-Length": "7"}) 116 | assert request.content_length == 7 117 | 118 | 119 | def test_get_custom_headers(): 120 | request = Request("GET", "/", body="foobar", headers={"x-amz-target": "foobar"}) 121 | assert request.headers["x-amz-target"] == "foobar" 122 | 123 | 124 | def test_get_raw_path(): 125 | request = Request("GET", "/foo/bar/ed", raw_path="/foo%2Fbar/ed") 126 | 127 | assert request.path == "/foo/bar/ed" 128 | assert request.environ["RAW_URI"] == "/foo%2Fbar/ed" 129 | assert get_raw_path(request) == "/foo%2Fbar/ed" 130 | 131 | 132 | def test_get_raw_path_with_query(): 133 | request = Request("GET", "/foo/bar/ed", raw_path="/foo%2Fbar/ed?fizz=buzz") 134 | 135 | assert request.path == "/foo/bar/ed" 136 | assert request.environ["RAW_URI"] == "/foo%2Fbar/ed?fizz=buzz" 137 | assert get_raw_path(request) == "/foo%2Fbar/ed" 138 | 139 | 140 | def test_get_raw_path_with_prefix_slashes(): 141 | request = Request("GET", "/foo/bar/ed", raw_path="//foo%2Fbar/ed?fizz=buzz") 142 | 143 | assert request.path == "/foo/bar/ed" 144 | assert request.environ["RAW_URI"] == "//foo%2Fbar/ed?fizz=buzz" 145 | assert get_raw_path(request) == "//foo%2Fbar/ed" 146 | 147 | 148 | def test_get_raw_path_with_full_uri(): 149 | # raw_path is actually raw_uri in the WSGI environment 150 | # it can be a full URL 151 | request = Request("GET", "/foo/bar/ed", raw_path="http://localhost:4566/foo%2Fbar/ed") 152 | 153 | assert request.path == "/foo/bar/ed" 154 | assert request.environ["RAW_URI"] == "http://localhost:4566/foo%2Fbar/ed" 155 | assert get_raw_path(request) == "/foo%2Fbar/ed" 156 | 157 | 158 | def test_headers_retain_dashes(): 159 | request = Request("GET", "/foo/bar/ed", {"X-Amz-Meta--foo_bar-ed": "foobar"}) 160 | assert "x-amz-meta--foo_bar-ed" in request.headers 161 | assert request.headers["x-amz-meta--foo_bar-ed"] == "foobar" 162 | 163 | 164 | def test_headers_retain_case(): 165 | request = Request("GET", "/foo/bar/ed", {"X-Amz-Meta--FOO_BaR-ed": "foobar"}) 166 | keys = list(request.headers.keys()) 167 | for k in keys: 168 | if k.lower().startswith("x-amz-meta"): 169 | assert k == "X-Amz-Meta--FOO_BaR-ed" 170 | return 171 | pytest.fail(f"key not in header keys {keys}") 172 | 173 | 174 | def test_multipart_parsing(): 175 | body = ( 176 | b"--4efd159eae0c4f4e125a5a509e073d85" 177 | b"\r\n" 178 | b'Content-Disposition: form-data; name="foo"; filename="foo"' 179 | b"\r\n\r\n" 180 | b"bar" 181 | b"\r\n" 182 | b"--4efd159eae0c4f4e125a5a509e073d85" 183 | b"\r\n" 184 | b'Content-Disposition: form-data; name="baz"; filename="baz"' 185 | b"\r\n\r\n" 186 | b"ed" 187 | b"\r\n--4efd159eae0c4f4e125a5a509e073d85--" 188 | b"\r\n" 189 | ) 190 | 191 | request = Request( 192 | "POST", 193 | path="/", 194 | body=body, 195 | headers={"Content-Type": "multipart/form-data; boundary=4efd159eae0c4f4e125a5a509e073d85"}, 196 | ) 197 | result = {} 198 | for k, file_storage in request.files.items(): 199 | result[k] = file_storage.stream.read().decode("utf-8") 200 | 201 | assert result == {"foo": "bar", "baz": "ed"} 202 | 203 | 204 | def test_utf8_path(): 205 | r = Request("GET", "/foo/Ā0Ä") 206 | 207 | assert r.path == "/foo/Ā0Ä" 208 | assert r.environ["PATH_INFO"] == "/foo/Ä\x800Ã\x84" # quoted and latin-1 encoded 209 | 210 | 211 | def test_restore_payload_multipart_parsing(): 212 | body = ( 213 | b"\r\n" 214 | b"--4efd159eae0c4f4e125a5a509e073d85" 215 | b"\r\n" 216 | b'Content-Disposition: form-data; name="formfield"' 217 | b"\r\n\r\n" 218 | b"not a file, just a field" 219 | b"\r\n" 220 | b"--4efd159eae0c4f4e125a5a509e073d85" 221 | b"\r\n" 222 | b'Content-Disposition: form-data; name="foo"; filename="foo"' 223 | b"\r\n" 224 | b"Content-Type: text/plain;" 225 | b"\r\n\r\n" 226 | b"bar" 227 | b"\r\n" 228 | b"--4efd159eae0c4f4e125a5a509e073d85" 229 | b"\r\n" 230 | b'Content-Disposition: form-data; name="baz"; filename="baz"' 231 | b"\r\n" 232 | b"Content-Type: text/plain;" 233 | b"\r\n\r\n" 234 | b"ed" 235 | b"\r\n" 236 | b"\r\n--4efd159eae0c4f4e125a5a509e073d85--" 237 | b"\r\n" 238 | ) 239 | 240 | request = Request( 241 | "POST", 242 | path="/", 243 | body=body, 244 | headers={"Content-Type": "multipart/form-data; boundary=4efd159eae0c4f4e125a5a509e073d85"}, 245 | ) 246 | 247 | form = {} 248 | for k, field in request.form.items(): 249 | form[k] = field 250 | 251 | assert form == {"formfield": "not a file, just a field"} 252 | 253 | files = [] 254 | for k, file_storage in request.files.items(): 255 | assert file_storage.stream 256 | # we do not want to consume the file storage stream, because we can't restore the payload then 257 | files.append(k) 258 | 259 | assert files == ["foo", "baz"] 260 | restored_data = restore_payload(request) 261 | 262 | assert restored_data == body 263 | 264 | 265 | def test_request_mixed_multipart(): 266 | # this is almost how we previously restored a form that would have both `form` fields and `files` 267 | # we would URL encode the form first then add multipart, which does not work, the first part should be ignored 268 | # and make certain strict multipart parser fail (Starlette), because it finds data before the first boundary 269 | 270 | # this test does something a bit different to prove it is ignored (add an URL encoded part in the beginning) 271 | body = ( 272 | b"formfield=not+a+file%2C+just+a+field\r\n" 273 | b"--4efd159eae0c4f4e125a5a509e073d85" 274 | b"\r\n" 275 | b'Content-Disposition: form-data; name="foo"; filename="foo"' 276 | b"\r\n" 277 | b"Content-Type: text/plain;" 278 | b"\r\n\r\n" 279 | b"bar" 280 | b"\r\n" 281 | b"--4efd159eae0c4f4e125a5a509e073d85" 282 | b"\r\n" 283 | b'Content-Disposition: form-data; name="baz"; filename="baz"' 284 | b"\r\n" 285 | b"Content-Type: text/plain;" 286 | b"\r\n\r\n" 287 | b"ed" 288 | b"\r\n" 289 | b"\r\n--4efd159eae0c4f4e125a5a509e073d85--" 290 | b"\r\n" 291 | ) 292 | 293 | request = Request( 294 | "POST", 295 | path="/", 296 | body=body, 297 | headers={"Content-Type": "multipart/form-data; boundary=4efd159eae0c4f4e125a5a509e073d85"}, 298 | ) 299 | 300 | form = {} 301 | for k, field in request.form.items(): 302 | form[k] = field 303 | 304 | assert form == {} 305 | 306 | files = [] 307 | for k, file_storage in request.files.items(): 308 | assert file_storage.stream 309 | # we do not want to consume the file storage stream, because we can't restore the payload then 310 | files.append(k) 311 | 312 | assert files == ["foo", "baz"] 313 | 314 | restored_data = restore_payload(request) 315 | assert b"formfield" not in restored_data 316 | 317 | 318 | def test_restore_payload_form_urlencoded(): 319 | body = b"formfield=not+a+file%2C+just+a+field" 320 | 321 | request = Request( 322 | "POST", 323 | path="/", 324 | body=body, 325 | headers={"Content-Type": "application/x-www-form-urlencoded"}, 326 | ) 327 | 328 | form = {} 329 | for k, field in request.form.items(): 330 | form[k] = field 331 | 332 | assert form == {"formfield": "not a file, just a field"} 333 | 334 | assert not request.files 335 | 336 | restored_data = restore_payload(request) 337 | 338 | assert restored_data == body 339 | -------------------------------------------------------------------------------- /rolo/websocket/request.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import typing as t 3 | 4 | from werkzeug import Response 5 | from werkzeug._internal import _wsgi_decoding_dance 6 | from werkzeug.datastructures import EnvironHeaders, Headers, MultiDict 7 | from werkzeug.sansio.request import Request as _SansIORequest 8 | from werkzeug.wsgi import _get_server 9 | 10 | from .adapter import ( 11 | BytesMessage, 12 | CreateConnection, 13 | Message, 14 | TextMessage, 15 | WebSocketAdapter, 16 | WebSocketEnvironment, 17 | WebSocketListener, 18 | ) 19 | from .errors import WebSocketDisconnectedError, WebSocketProtocolError 20 | 21 | 22 | class WebSocket: 23 | """ 24 | High-level interface to interact with a websocket after a handshake has been completed with 25 | `WebsocketRequest.accept()`. 26 | """ 27 | 28 | request: "WebSocketRequest" 29 | socket: WebSocketAdapter 30 | 31 | def __init__(self, request: "WebSocketRequest", socket: WebSocketAdapter): 32 | self.request = request 33 | self.socket = socket 34 | 35 | def __enter__(self): 36 | return self 37 | 38 | def __exit__(self, exc_type, exc_val, exc_tb): 39 | self.close() 40 | 41 | def __iter__(self): 42 | while True: 43 | try: 44 | yield self.receive() 45 | except WebSocketDisconnectedError: 46 | break 47 | 48 | def send(self, text_or_bytes: str | bytes, timeout: float = None): 49 | """ 50 | Send data to the websocket connection. 51 | 52 | :param text_or_bytes: the data to send. Use strings for text-mode sockets (default). 53 | :param timeout: the timeout in seconds to wait before raising a timeout error 54 | """ 55 | if text_or_bytes is None: 56 | raise ValueError("text_or_bytes cannot be None") 57 | 58 | if isinstance(text_or_bytes, str): 59 | event = TextMessage(text_or_bytes) 60 | elif isinstance(text_or_bytes, bytes): 61 | event = BytesMessage(text_or_bytes) 62 | else: 63 | raise WebSocketProtocolError( 64 | f"Cannot send data type {type(text_or_bytes)} over websocket" 65 | ) 66 | self.socket.send(event, timeout) 67 | 68 | def receive(self) -> str | bytes: 69 | """ 70 | Receive the next data package from the websocket. Will be string or byte data and set the 71 | underlying binary for the frame automatically. 72 | 73 | :raise WebSocketDisconnectedError: if the websocket was closed in the meantime 74 | :raise WebSocketProtocolError: error in the interaction between the app and the webserver 75 | :return: the next data package from the websocket 76 | """ 77 | event = self.socket.receive() 78 | if isinstance(event, Message): 79 | data = event.data 80 | if data is None: 81 | raise WebSocketProtocolError("No data returned by the websocket.") 82 | return data 83 | else: 84 | raise WebSocketProtocolError( 85 | f"Unexpected websocket event type {event.__class__.__name__}." 86 | ) 87 | 88 | def close(self, code: int = 1000, reason: t.Optional[str] = None, timeout: float = None): 89 | """ 90 | Closes the websocket connection with specific code. 91 | 92 | :param code: the websocket close code. 93 | :param reason: optional reason 94 | :param timeout: connection timeout 95 | """ 96 | self.socket.close(code, reason, timeout) 97 | 98 | 99 | class WebSocketRequest(_SansIORequest): 100 | """ 101 | A websocket request represents the incoming HTTP request to upgrade the connection to a WebSocket 102 | connection. The request method is an artificial ``WEBSOCKET`` method that can also be used in the Router: 103 | ``@route("/path", method=["WEBSOCKET"])``. 104 | 105 | The websocket connection needs to be either accepted or rejected. When calling 106 | ``WebSocketRequest.accept``, an upgrade response will be sent to the client, and the protocol will be 107 | switched to the bidirectional WebSocket protocol. If ``WebSocketRequest.reject`` is called, the server 108 | immediately returns an HTTP response and closes the connection. 109 | """ 110 | 111 | def __init__(self, environ: WebSocketEnvironment): 112 | """ 113 | Creates a new request from the given WebSocketEnvironment. This is like a sans-IO WSGI Environment, 114 | with an additional field ``rolo.websocket`` that contains a ``WebSocketAdapter`` interface. 115 | 116 | :param environ: the WebSocketEnvironment 117 | """ 118 | raw_headers = environ.get("rolo.headers") 119 | if raw_headers: 120 | # restores raw headers from the server scope, to have proper casing or dashes. This can depend on server 121 | # behavior, but we want a unified way to keep header casing/formatting. 122 | # This is similar to what we do in wsgi.py 123 | headers = Headers( 124 | MultiDict([(k.decode("latin-1"), v.decode("latin-1")) for (k, v) in raw_headers]) 125 | ) 126 | else: 127 | headers = Headers(EnvironHeaders(environ)) 128 | 129 | # copied from werkzeug.wrappers.request 130 | super().__init__( 131 | method=environ.get("REQUEST_METHOD", "WEBSOCKET"), 132 | scheme=environ.get("wsgi.url_scheme", "ws"), 133 | server=_get_server(environ), 134 | root_path=_wsgi_decoding_dance(environ.get("SCRIPT_NAME") or ""), 135 | path=_wsgi_decoding_dance(environ.get("PATH_INFO") or ""), 136 | query_string=environ.get("QUERY_STRING", "").encode("latin1"), 137 | headers=headers, 138 | remote_addr=environ.get("REMOTE_ADDR"), 139 | ) 140 | self.environ = environ 141 | 142 | self.shallow = True # compatibility with werkzeug.Request 143 | 144 | self._upgraded = False 145 | self._rejected = False 146 | 147 | @property 148 | def socket(self) -> WebSocketAdapter: 149 | """ 150 | Returns the underlying WebSocketAdapter from the environment. This is analogous to ``Request.stream`` 151 | in the default werkzeug HTTP request object. 152 | 153 | :return: the WebSocketAdapter from the environment 154 | """ 155 | return self.environ.get("rolo.websocket") or self.environ.get("asgi.websocket") 156 | 157 | def is_upgraded(self) -> bool: 158 | """Returns true if ``accept`` was called.""" 159 | return self._upgraded 160 | 161 | def is_rejected(self) -> bool: 162 | """Returns true if ``reject`` was called.""" 163 | return self._rejected 164 | 165 | def reject(self, response: Response): 166 | """ 167 | Reject the websocket upgrade and return the given response. Will raise a ``ValueError`` if the 168 | request has already been accepted or rejected before. 169 | 170 | :param response: the HTTP response to return to the client. 171 | """ 172 | if self._upgraded: 173 | raise ValueError("Websocket connection already upgraded") 174 | if self._rejected: 175 | raise ValueError("Websocket connection already rejected") 176 | 177 | self.socket.reject( 178 | response.status_code, 179 | response.headers, 180 | response.iter_encoded(), 181 | ) 182 | self._rejected = True 183 | 184 | def accept( 185 | self, subprotocol: str = None, headers: Headers = None, timeout: float = None 186 | ) -> WebSocket: 187 | """ 188 | Performs the websocket connection upgrade handshake. After calling ``accept``, a new ``Websocket`` 189 | instance is returned that represents the bidirectional communication channel, which you should 190 | continue operating on. Example:: 191 | 192 | def app(request: WebsocketRequest): 193 | # example: do authorization first 194 | auth = request.headers.get("Authorization") 195 | if not is_authorized(auth): 196 | request.reject(Response("no dice", 403)) 197 | return 198 | 199 | # then continue working with the websocket 200 | with request.accept() as websocket: 201 | websocket.send("hello world!") 202 | data = websocket.receive() 203 | # ... 204 | 205 | The handshake using the WebSocketAdapter works as follows: receive the ``CreateConnection`` event 206 | from the websocket and then call the ``socket.accept(...)``. If the handshake failed because 207 | the websocket sent an unexpected exception, the connection is closed and the method raises an error. 208 | 209 | :param subprotocol: The subprotocol the server wishes to accept. Optional 210 | :param headers: Response headers 211 | :param timeout: connection timeout 212 | :return: a websocket 213 | :raises ProtocolError: if unexpected events were received from the websocket server 214 | """ 215 | if self._upgraded: 216 | raise ValueError("Websocket connection already upgraded") 217 | if self._rejected: 218 | raise ValueError("Websocket connection already rejected") 219 | 220 | event = self.socket.receive(timeout) 221 | if isinstance(event, CreateConnection): 222 | self.socket.accept(subprotocol, [], headers, timeout) 223 | self._upgraded = True 224 | return WebSocket(self, self.socket) 225 | else: 226 | reason = f"Unexpected event {event.__class__.__name__}" 227 | self.socket.close(1003, reason) 228 | raise WebSocketProtocolError(reason) 229 | 230 | def close(self): 231 | """ 232 | Explicitly close the websocket. If this is called after ``reject(...)`` or ``accept(...)`` has been 233 | called, this will have no effect. Calling ``reject`` inherently closes the websocket connection 234 | since it immediately returns an HTTP response. After calling ``accept`` you should call 235 | ``WebSocket.close`` instead. 236 | """ 237 | if self._rejected or self._upgraded: 238 | return 239 | self.socket.close(1000) 240 | 241 | @classmethod 242 | def listener(cls, fn: t.Callable[["WebSocketRequest"], None]) -> WebSocketListener: 243 | """ 244 | Convenience function inspired by ``werkzeug.Request.application`` that transforms a function into a 245 | ``WebsocketListener`` for the use in server code that support ``WebsocketListeners``. Example:: 246 | 247 | @WebsocketRequest.listener 248 | def app(request: WebSocketRequest): 249 | with request.accept() as ws: 250 | ws.send("hello world") 251 | 252 | adapter = ASGIAdapter(wsgi_app=..., websocket_listener=app) 253 | # ... serve adapter 254 | 255 | 256 | :param fn: the function to wrap 257 | :return: a WebsocketListener compatible interface 258 | """ 259 | from werkzeug.exceptions import HTTPException 260 | 261 | @functools.wraps(fn) 262 | def application(*args): 263 | request = cls(args[-1]) 264 | try: 265 | fn(*args[:-1] + (request,)) 266 | except HTTPException as e: 267 | resp = e.get_response(args[-1]) 268 | request.reject(resp) 269 | finally: 270 | request.close() 271 | 272 | return application 273 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /rolo/request.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from typing import IO, TYPE_CHECKING, Dict, Mapping, Optional, Tuple, Union 3 | from urllib.parse import quote, unquote, urlencode, urlparse 4 | 5 | if TYPE_CHECKING: 6 | from _typeshed.wsgi import WSGIEnvironment 7 | 8 | from werkzeug.datastructures import Headers, MultiDict 9 | from werkzeug.test import encode_multipart 10 | from werkzeug.wrappers.request import Request as WerkzeugRequest 11 | 12 | 13 | def dummy_wsgi_environment( 14 | method: str = "GET", 15 | path: str = "", 16 | headers: Optional[Union[Dict, Headers]] = None, 17 | body: Optional[Union[bytes, str, IO[bytes]]] = None, 18 | scheme: str = "http", 19 | root_path: str = "/", 20 | query_string: Optional[str] = None, 21 | remote_addr: Optional[str] = None, 22 | server: Optional[Tuple[str, Optional[int]]] = None, 23 | raw_uri: Optional[str] = None, 24 | ) -> "WSGIEnvironment": 25 | """ 26 | Creates a dummy WSGIEnvironment that represents a standalone sans-IO HTTP requests. 27 | 28 | See https://wsgi.readthedocs.io/en/latest/definitions.html#standard-environ-keys 29 | 30 | :param method: The HTTP request method (such as GET or POST) 31 | :param path: The remainder of the request URL's path. This may be an empty string, if the 32 | request URL targets the application root and does not have a trailing slash. 33 | :param headers: optional HTTP headers 34 | :param body: the body of the request 35 | :param scheme: the scheme (http or https) 36 | :param root_path: The initial portion of the request URL's path that corresponds to the 37 | application object. 38 | :param query_string: The portion of the request URL that follows the “?”, if any. May be 39 | empty or absent. 40 | :param remote_addr: The address making the request 41 | :param server: The server (tuple of server name and port) 42 | :param raw_uri: The original path that may contain url encoded path elements. 43 | :return: A WSGIEnvironment dictionary 44 | """ 45 | 46 | # Standard environ keys 47 | environ = { 48 | "REQUEST_METHOD": method, 49 | # prepare the paths for the "WSGI decoding dance" done by werkzeug 50 | "SCRIPT_NAME": unquote(quote(root_path.rstrip("/")), "latin-1"), 51 | "PATH_INFO": unquote(quote(path), "latin-1"), 52 | "SERVER_PROTOCOL": "HTTP/1.1", 53 | "QUERY_STRING": query_string or "", 54 | } 55 | 56 | if raw_uri: 57 | if query_string: 58 | raw_uri += "?" + query_string 59 | environ["RAW_URI"] = raw_uri 60 | environ["REQUEST_URI"] = environ["RAW_URI"] 61 | 62 | if server: 63 | environ["SERVER_NAME"] = server[0] 64 | if server[1]: 65 | environ["SERVER_PORT"] = str(server[1]) 66 | else: 67 | environ["SERVER_PORT"] = "80" 68 | else: 69 | environ["SERVER_NAME"] = "127.0.0.1" 70 | environ["SERVER_PORT"] = "80" 71 | 72 | if remote_addr: 73 | environ["REMOTE_ADDR"] = remote_addr 74 | 75 | if headers: 76 | set_environment_headers(environ, headers) 77 | 78 | if not body: 79 | body = b"" 80 | 81 | if isinstance(body, (str, bytes)): 82 | data = body.encode("utf-8") if isinstance(body, str) else body 83 | 84 | wsgi_input = BytesIO(data) 85 | if "CONTENT_LENGTH" not in environ: 86 | # try to determine content length from body 87 | environ["CONTENT_LENGTH"] = str(len(data)) 88 | else: 89 | wsgi_input = body 90 | 91 | # WSGI environ keys 92 | environ["wsgi.version"] = (1, 0) 93 | environ["wsgi.url_scheme"] = scheme 94 | environ["wsgi.input"] = wsgi_input 95 | environ["wsgi.input_terminated"] = True 96 | environ["wsgi.errors"] = BytesIO() 97 | environ["wsgi.multithread"] = True 98 | environ["wsgi.multiprocess"] = False 99 | environ["wsgi.run_once"] = False 100 | 101 | return environ 102 | 103 | 104 | def set_environment_headers(environ: "WSGIEnvironment", headers: Union[Dict, Headers]): 105 | # Collect all the headers to set 106 | # (this might be is accessing the environment, this needs to be done before removing the items from the env) 107 | new_headers = {} 108 | for k, v in headers.items(): 109 | name = k.upper().replace("-", "_") 110 | 111 | if name not in ("CONTENT_TYPE", "CONTENT_LENGTH"): 112 | name = f"HTTP_{name}" 113 | 114 | val = v 115 | if name in new_headers: 116 | val = new_headers[name] + "," + val 117 | 118 | new_headers[name] = val 119 | 120 | # Clear the HTTP headers in the env 121 | header_keys = [name for name in environ if name.startswith("HTTP_")] 122 | for name in header_keys: 123 | environ.pop(name, None) 124 | 125 | # Set the new headers in the env 126 | for k, v in new_headers.items(): 127 | environ[k] = v 128 | 129 | 130 | class Request(WerkzeugRequest): 131 | """ 132 | An HTTP request object. This is (and should remain) a drop-in replacement for werkzeug's WSGI 133 | compliant Request objects. It allows simple sans-IO requests outside a web server environment. 134 | 135 | DO NOT add methods that are not also part of werkzeug.wrappers.request.Request object. 136 | """ 137 | 138 | def __init__( 139 | self, 140 | method: str = "GET", 141 | path: str = "", 142 | headers: Union[Mapping, Headers] = None, 143 | body: Union[bytes, str] = None, 144 | scheme: str = "http", 145 | root_path: str = "/", 146 | query_string: Union[bytes, str] = b"", 147 | remote_addr: str = None, 148 | server: Optional[Tuple[str, Optional[int]]] = None, 149 | raw_path: str = None, 150 | ): 151 | # decode query string if necessary (latin-1 is what werkzeug would expect) 152 | if isinstance(query_string, bytes): 153 | query_string = query_string.decode("latin-1") 154 | 155 | # create the WSGIEnvironment dictionary that represents this request 156 | environ = dummy_wsgi_environment( 157 | method=method, 158 | path=path, 159 | headers=headers, 160 | body=body, 161 | scheme=scheme, 162 | root_path=root_path, 163 | query_string=query_string, 164 | remote_addr=remote_addr, 165 | server=server, 166 | raw_uri=raw_path, 167 | ) 168 | 169 | super(Request, self).__init__(environ) 170 | 171 | # restore originally passed headers: 172 | # werkzeug normally provides read-only access to headers set in the WSGIEnvironment through the EnvironHeaders 173 | # class, here we make them mutable again. moreover, WSGI header encoding conflicts with RFC2616. see this github 174 | # issue for a discussion: https://github.com/pallets/werkzeug/issues/940 175 | headers = Headers(headers) 176 | # these two headers are treated separately in the WSGI environment, so we extract them if necessary 177 | for h in ["content-length", "content-type"]: 178 | if h not in headers and h in self.headers: 179 | headers[h] = self.headers[h] 180 | self.headers = headers 181 | 182 | @classmethod 183 | def application(cls, *args): 184 | # werkzeug's application decorator assumes its Request constructor signature, which our Request doesn't support. 185 | # using ``application`` from our request therefore creates runtime errors. this makes sure no one runs into 186 | # these problems. if we want to support it, we need to create compatibility with werkzeug's Request constructor 187 | raise NotImplementedError 188 | 189 | 190 | def get_raw_path(request) -> str: 191 | """ 192 | Returns the raw_path inside the request without the query string. The request can either be a Quart Request 193 | object (that encodes the raw path in request.scope['raw_path']) or a Werkzeug WSGI request (that encodes the raw 194 | URI in request.environ['RAW_URI']). 195 | 196 | :param request: the request object 197 | :return: the raw path if any 198 | """ 199 | if hasattr(request, "environ"): 200 | # werkzeug/flask request (already a string, and contains the query part) 201 | # we need to parse it, because the RAW_URI can contain a full URL if it is specified in the HTTP request 202 | raw_uri: str = request.environ.get("RAW_URI", "") 203 | if raw_uri.startswith("//"): 204 | # if the RAW_URI starts with double slashes, `urlparse` will fail to decode it as path only 205 | # it also means that we already only have the path, so we just need to remove the query string 206 | return raw_uri.split("?")[0] 207 | return urlparse(raw_uri or request.path).path 208 | 209 | if hasattr(request, "scope"): 210 | # quart request raw_path comes as bytes, and without the query part 211 | return request.scope.get("raw_path", request.path).decode("utf-8") 212 | 213 | raise ValueError("cannot extract raw path from request object %s" % request) 214 | 215 | 216 | def get_full_raw_path(request: WerkzeugRequest) -> str: 217 | """ 218 | Returns the full raw request path (with original URL encoding), including the query string. 219 | This is _not_ equal to request.url, since there the path section would be url-encoded while the query part will be 220 | (partly) url-decoded. 221 | """ 222 | query_str = f"?{request.query_string.decode('latin1')}" if request.query_string else "" 223 | raw_path = f"{get_raw_path(request)}{query_str}" 224 | return raw_path 225 | 226 | 227 | def get_raw_base_url(request: Request) -> str: 228 | """ 229 | Returns the base URL (with original URL encoding). This does not include the query string. 230 | This is the encoding-preserving equivalent to `request.base_url`. 231 | """ 232 | return get_raw_current_url( 233 | request.scheme, request.host, request.root_path, get_raw_path(request) 234 | ) 235 | 236 | 237 | def get_raw_current_url( 238 | scheme: str, 239 | host: str, 240 | root_path: Optional[str] = None, 241 | path: Optional[str] = None, 242 | query_string: Optional[bytes] = None, 243 | ) -> str: 244 | """ 245 | `werkzeug.sansio.utils.get_current_url` implementation without 246 | any encoding dances. 247 | The given paths and query string are directly used without any encodings 248 | (to avoid any double encodings). 249 | It can be used to recreate the raw URL for a request. 250 | 251 | :param scheme: The protocol the request used, like ``"https"``. 252 | :param host: The host the request was made to. See :func:`get_host`. 253 | :param root_path: Prefix that the application is mounted under. This 254 | is prepended to ``path``. 255 | :param path: The path part of the URL after ``root_path``. 256 | :param query_string: The portion of the URL after the "?". 257 | """ 258 | url = [scheme, "://", host] 259 | 260 | if root_path is None: 261 | url.append("/") 262 | return "".join(url) 263 | 264 | url.append(root_path.rstrip("/")) 265 | url.append("/") 266 | 267 | if path is None: 268 | return "".join(url) 269 | 270 | url.append(path.lstrip("/")) 271 | 272 | if query_string: 273 | url.append("?") 274 | url.append(query_string) 275 | 276 | return "".join(url) 277 | 278 | 279 | def restore_payload(request: Request) -> bytes: 280 | """ 281 | This method takes a request and restores the original payload from it even after it has been consumed. A werkzeug 282 | request consumes form/multipart data from the stream, and here we are serializing it back to a request a client 283 | could make . This is a useful method to have when we are proxying requests, i.e., create outgoing requests from 284 | incoming requests that were subject to previous parsing. 285 | 286 | TODO: this construct is not great and will definitely be a source of trouble. but something like this will 287 | inevitably become the the basis of proxying werkzeug requests. The alternative is to build our own request object 288 | that memoizes the original payload before parsing. 289 | """ 290 | if request.shallow: 291 | return b"" 292 | 293 | data = request.data 294 | 295 | if request.method != "POST": 296 | return data 297 | 298 | if request.mimetype == "multipart/form-data": 299 | boundary = request.content_type.split("=")[1] 300 | 301 | fields = MultiDict() 302 | fields.update(request.form) 303 | fields.update(request.files) 304 | 305 | _, data_files = encode_multipart(fields, boundary) 306 | data += data_files 307 | 308 | elif request.mimetype == "application/x-www-form-urlencoded": 309 | data += urlencode(list(request.form.items(multi=True))).encode("utf-8") 310 | 311 | return data 312 | --------------------------------------------------------------------------------