├── rolo
├── py.typed
├── testing
│ ├── __init__.py
│ └── pytest.py
├── resource.py
├── serving
│ └── __init__.py
├── dispatcher.py
├── __init__.py
├── websocket
│ ├── __init__.py
│ ├── websocket.py
│ ├── errors.py
│ ├── adapter.py
│ └── request.py
├── routing
│ ├── __init__.py
│ ├── converter.py
│ ├── pydantic.py
│ ├── handler.py
│ ├── resource.py
│ └── rules.py
├── router.py
├── gateway
│ ├── __init__.py
│ ├── wsgi.py
│ ├── gateway.py
│ ├── asgi.py
│ └── handlers.py
├── response.py
├── client.py
├── proxy.py
└── request.py
├── docs
├── _static
│ ├── .gitkeep
│ └── rolo.png
├── _templates
│ └── .gitkeep
├── requirements.txt
├── Makefile
├── make.bat
├── gateway.md
├── index.md
├── conf.py
├── websockets.md
├── serving.md
├── getting_started.md
├── router.md
└── handler_chain.md
├── tests
├── gateway
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_context.py
│ ├── test_websocket.py
│ ├── test_headers.py
│ ├── test_handlers.py
│ └── test_chain.py
├── serving
│ ├── __init__.py
│ └── test_twisted.py
├── static
│ ├── __init__.py
│ ├── test.txt
│ └── index.html
├── websocket
│ ├── __init__.py
│ └── test_websockets.py
├── conftest.py
├── __init__.py
├── test_client.py
├── test_response.py
├── test_dispatcher.py
├── test_resource.py
├── test_pydantic.py
└── test_request.py
├── CODEOWNERS
├── .github
├── workflows
│ ├── docs-publish.yml
│ └── build.yml
└── PULL_REQUEST_TEMPLATE.md
├── Makefile
├── .gitignore
├── pyproject.toml
├── examples
└── json-rpc
│ └── server.py
├── README.md
└── LICENSE
/rolo/py.typed:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/_static/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/_templates/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rolo/testing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/gateway/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/serving/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/static/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/websocket/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/static/test.txt:
--------------------------------------------------------------------------------
1 | hello world
2 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | furo
3 | myst_parser
4 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | pytest_plugins = [
2 | "rolo.testing.pytest",
3 | ]
4 |
--------------------------------------------------------------------------------
/tests/static/index.html:
--------------------------------------------------------------------------------
1 |
2 |
hello
3 |
4 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # by default, add all repo maintainers as reviewers
2 | * @thrau
3 |
--------------------------------------------------------------------------------
/docs/_static/rolo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/localstack/rolo/HEAD/docs/_static/rolo.png
--------------------------------------------------------------------------------
/rolo/resource.py:
--------------------------------------------------------------------------------
1 | """
2 | DEPRECATED: use ``from rolo.routing import Resource, resource`` instead
3 | """
4 | from .routing.resource import Resource, resource
5 |
6 | __all__ = [
7 | "Resource",
8 | "resource",
9 | ]
10 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # This module cannot be named http, since pycharm cannot run in tests in the module above anymore in debug mode
2 | # - https://youtrack.jetbrains.com/issue/PY-54265
3 | # - https://youtrack.jetbrains.com/issue/PY-35220
4 |
--------------------------------------------------------------------------------
/rolo/serving/__init__.py:
--------------------------------------------------------------------------------
1 | """This module holds glue code between web servers and rolo. Any ASGI-compliant server can server rolo with
2 | a few patches via ``rolo.serving.asgi``. Moreover, rolo can be served through twisted via
3 | ``rolo.serving.twisted``."""
4 |
--------------------------------------------------------------------------------
/rolo/dispatcher.py:
--------------------------------------------------------------------------------
1 | """
2 | DEPRECATED: use ``from rolo.routing import handler_dispatcher`` instead
3 | """
4 | from .routing.handler import ResultValue, handler_dispatcher
5 |
6 | __all__ = [
7 | "handler_dispatcher",
8 | "ResultValue",
9 | ]
10 |
--------------------------------------------------------------------------------
/rolo/__init__.py:
--------------------------------------------------------------------------------
1 | from .request import Request
2 | from .response import Response
3 | from .routing.resource import Resource, resource
4 | from .routing.router import Router, route
5 |
6 | __all__ = [
7 | "route",
8 | "resource",
9 | "Resource",
10 | "Router",
11 | "Response",
12 | "Request",
13 | ]
14 |
--------------------------------------------------------------------------------
/.github/workflows/docs-publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish docs
2 |
3 | on:
4 | push:
5 | branches: [ main ] # branch to trigger deployment
6 |
7 | jobs:
8 | pages:
9 | runs-on: ubuntu-latest
10 | environment:
11 | name: github-pages
12 | url: ${{ steps.deployment.outputs.page_url }}
13 | permissions:
14 | pages: write
15 | id-token: write
16 | steps:
17 | - id: deployment
18 | uses: sphinx-notes/pages@v3
19 |
--------------------------------------------------------------------------------
/rolo/websocket/__init__.py:
--------------------------------------------------------------------------------
1 | from .adapter import WebSocketEnvironment, WebSocketListener
2 | from .errors import WebSocketDisconnectedError, WebSocketError, WebSocketProtocolError
3 | from .request import WebSocket, WebSocketRequest
4 |
5 | __all__ = [
6 | "WebSocket",
7 | "WebSocketDisconnectedError",
8 | "WebSocketEnvironment",
9 | "WebSocketError",
10 | "WebSocketListener",
11 | "WebSocketProtocolError",
12 | "WebSocketRequest",
13 | ]
14 |
--------------------------------------------------------------------------------
/rolo/routing/__init__.py:
--------------------------------------------------------------------------------
1 | from .converter import PortConverter, RegexConverter
2 | from .handler import handler_dispatcher
3 | from .router import RequestArguments, Router, route
4 | from .rules import RuleAdapter, RuleGroup, WithHost
5 |
6 | __all__ = [
7 | "PortConverter",
8 | "RegexConverter",
9 | "RequestArguments",
10 | "Router",
11 | "RuleAdapter",
12 | "RuleGroup",
13 | "WithHost",
14 | "handler_dispatcher",
15 | "route",
16 | ]
17 |
--------------------------------------------------------------------------------
/rolo/router.py:
--------------------------------------------------------------------------------
1 | """
2 | DEPRECATED: use ``rolo.routing`` instead.
3 | """
4 |
5 | from .routing.converter import PortConverter, RegexConverter
6 | from .routing.router import RequestArguments, Router, route
7 | from .routing.rules import RuleAdapter, RuleGroup, WithHost
8 |
9 | __all__ = [
10 | "Router",
11 | "route",
12 | "RegexConverter",
13 | "PortConverter",
14 | "RuleAdapter",
15 | "RuleGroup",
16 | "WithHost",
17 | "RequestArguments",
18 | ]
19 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 | ## Motivation
3 |
4 |
5 |
6 | ## Changes
7 |
8 |
9 |
13 |
14 |
15 |
23 |
--------------------------------------------------------------------------------
/rolo/websocket/websocket.py:
--------------------------------------------------------------------------------
1 | """TODO: remove with release (just keeping this so localstack doesn't blow up)"""
2 | from .adapter import WebSocketEnvironment, WebSocketListener
3 | from .errors import WebSocketDisconnectedError, WebSocketError, WebSocketProtocolError
4 | from .request import WebSocket, WebSocketRequest
5 |
6 | __all__ = [
7 | "WebSocket",
8 | "WebSocketDisconnectedError",
9 | "WebSocketEnvironment",
10 | "WebSocketError",
11 | "WebSocketListener",
12 | "WebSocketProtocolError",
13 | "WebSocketRequest",
14 | ]
15 |
--------------------------------------------------------------------------------
/rolo/gateway/__init__.py:
--------------------------------------------------------------------------------
1 | from .chain import (
2 | CompositeExceptionHandler,
3 | CompositeFinalizer,
4 | CompositeHandler,
5 | CompositeResponseHandler,
6 | ExceptionHandler,
7 | Handler,
8 | HandlerChain,
9 | RequestContext,
10 | )
11 | from .gateway import Gateway
12 |
13 | __all__ = [
14 | "Gateway",
15 | "HandlerChain",
16 | "RequestContext",
17 | "Handler",
18 | "ExceptionHandler",
19 | "CompositeHandler",
20 | "CompositeExceptionHandler",
21 | "CompositeResponseHandler",
22 | "CompositeFinalizer",
23 | ]
24 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/tests/gateway/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from rolo.gateway import Gateway
4 | from rolo.testing.pytest import Server
5 |
6 |
7 | @pytest.fixture
8 | def serve_gateway(request):
9 | def _serve(gateway: Gateway) -> Server:
10 | try:
11 | gw_type = request.param
12 | except AttributeError:
13 | gw_type = "wsgi"
14 |
15 | if gw_type == "asgi":
16 | fixture = request.getfixturevalue("serve_asgi_gateway")
17 | elif gw_type == "twisted":
18 | fixture = request.getfixturevalue("serve_twisted_gateway")
19 | else:
20 | fixture = request.getfixturevalue("serve_wsgi_gateway")
21 |
22 | return fixture(gateway)
23 |
24 | yield _serve
25 |
--------------------------------------------------------------------------------
/rolo/websocket/errors.py:
--------------------------------------------------------------------------------
1 | class WebSocketError(IOError):
2 | """Base class for websocket errors"""
3 |
4 | pass
5 |
6 |
7 | class WebSocketDisconnectedError(WebSocketError):
8 | """Raised when the client has disconnected while the server is still trying to receive data."""
9 |
10 | default_code = 1005
11 | """https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event-ws"""
12 |
13 | def __init__(self, code: int = None):
14 | self.code = code if code is not None else self.default_code
15 | super().__init__(f"Websocket disconnected code={self.code}")
16 |
17 |
18 | class WebSocketProtocolError(WebSocketError):
19 | """Raised if there is a problem in the interaction between app and the websocket server."""
20 |
21 | pass
22 |
--------------------------------------------------------------------------------
/tests/gateway/test_context.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from rolo.gateway import RequestContext
4 |
5 |
6 | def test_set_and_access_data():
7 | context = RequestContext()
8 |
9 | context.some_data = "foo"
10 | assert context.some_data == "foo"
11 |
12 |
13 | def test_access_non_existing_data():
14 | context = RequestContext()
15 |
16 | with pytest.raises(AttributeError):
17 | assert context.some_data
18 |
19 |
20 | def test_get_non_existing_data():
21 | context = RequestContext()
22 |
23 | assert context.get("some_data") is None
24 |
25 |
26 | def test_set_and_get_data():
27 | context = RequestContext()
28 |
29 | context.some_data = "foo"
30 | assert context.get("some_data") == "foo"
31 | context.some_data = "bar"
32 | assert context.get("some_data") == "bar"
33 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/tests/gateway/test_websocket.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import websocket
3 |
4 | from rolo import Router, route
5 | from rolo.gateway import Gateway
6 | from rolo.gateway.handlers import RouterHandler
7 | from rolo.websocket.request import WebSocketRequest
8 |
9 |
10 | @pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True)
11 | def test_gateway_router_websocket_integration(serve_gateway):
12 | @route("/ws", methods=["WEBSOCKET"])
13 | def _handler(request: WebSocketRequest, args):
14 | with request.accept() as ws:
15 | ws.send("hello")
16 | ws.send(ws.receive())
17 |
18 | router = Router()
19 | router.add(_handler)
20 |
21 | server = serve_gateway(Gateway(request_handlers=[RouterHandler(router)]))
22 |
23 | client = websocket.WebSocket()
24 | client.connect(server.url.replace("http://", "ws://") + "/ws")
25 |
26 | assert client.recv() == "hello"
27 | client.send("foobar")
28 | assert client.recv() == "foobar"
29 | client.close()
30 |
--------------------------------------------------------------------------------
/tests/gateway/test_headers.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pytest
4 | import requests
5 |
6 | from rolo.gateway import Gateway, HandlerChain, RequestContext
7 |
8 |
9 | @pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True)
10 | def test_raw_header_handling(serve_gateway):
11 | def handler(chain: HandlerChain, context: RequestContext, response):
12 | response.data = json.dumps({"headers": dict(context.request.headers)})
13 | response.mimetype = "application/json"
14 | response.headers["X-fOO_bar"] = "FooBar"
15 | return response
16 |
17 | gateway = Gateway(request_handlers=[handler])
18 |
19 | srv = serve_gateway(gateway)
20 |
21 | response = requests.get(
22 | srv.url,
23 | headers={"x-mIxEd-CaSe": "myheader", "X-UPPER__CASE": "uppercase"},
24 | )
25 | returned_headers = response.json()["headers"]
26 | assert "X-UPPER__CASE" in returned_headers
27 | assert "x-mIxEd-CaSe" in returned_headers
28 | assert "X-fOO_bar" in dict(response.headers)
29 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | paths-ignore:
6 | - 'README.md'
7 | - 'docs/**'
8 | branches:
9 | - main
10 | pull_request:
11 | branches:
12 | - main
13 |
14 | jobs:
15 | test:
16 | runs-on: ubuntu-latest
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | python-version: [ "3.10", "3.11", "3.12" ]
21 | env:
22 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
23 | steps:
24 | - name: Checkout
25 | uses: actions/checkout@v3
26 |
27 | - name: Set up Python
28 | id: setup-python
29 | uses: actions/setup-python@v4
30 | with:
31 | python-version: ${{ matrix.python-version }}
32 |
33 | - name: Run install
34 | run: |
35 | make install
36 |
37 | - name: Run linting
38 | run: |
39 | make lint
40 |
41 | - name: Run tests
42 | run: |
43 | make test-coverage
44 |
45 | - name: Report coverage
46 | run: |
47 | make coveralls
48 |
--------------------------------------------------------------------------------
/rolo/routing/converter.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from werkzeug.routing import BaseConverter, Map
4 |
5 |
6 | class RegexConverter(BaseConverter):
7 | """
8 | A converter that can be used to inject a regex as parameter, e.g., ``path=/``.
9 | When using groups in regex, make sure they are non-capturing ``(?:[a-z]+)``
10 | """
11 |
12 | def __init__(self, map: Map, *args: t.Any, **kwargs: t.Any) -> None:
13 | super().__init__(map, *args, **kwargs)
14 | self.regex = args[0]
15 |
16 |
17 | class PortConverter(BaseConverter):
18 | """
19 | Useful to optionally match ports for host patterns, like ``localstack.localhost.cloud``. Notice how you
20 | don't need to specify the colon. The regex matches it if the port is there, and will remove the colon if matched.
21 | The converter converts the port to an int, or returns None if there's no port in the input string.
22 | """
23 |
24 | regex = r"(?::[0-9]{1,5})?"
25 |
26 | def to_python(self, value: str) -> t.Any:
27 | if value:
28 | return int(value[1:])
29 | return None
30 |
--------------------------------------------------------------------------------
/docs/gateway.md:
--------------------------------------------------------------------------------
1 | Gateway
2 | =======
3 |
4 | The `Gateway` serves as the factory for `HandlerChain` instances.
5 | It also serves as the interface to HTTP servers.
6 |
7 | ## Creating a Gateway
8 |
9 | ```python
10 | from rolo.gateway import Gateway
11 |
12 | gateway = Gateway(
13 | request_handlers=[
14 | ...
15 | ],
16 | response_handlers=[
17 | ...
18 | ],
19 | exception_handlers=[
20 | ...
21 | ],
22 | finalizers=[
23 | ...
24 | ]
25 | )
26 | ```
27 |
28 | ## Protocol adapters
29 |
30 | You can use `rolo.gateway.wsgi` or `rolo.gateway.asgi` to expose a `Gateway` as either a WSGI or ASGI app.
31 |
32 | Read more in the [serving](serving.md) section.
33 |
34 | ## Custom `RequestContext`
35 |
36 | You can add a custom request context with type hints or your own methods by setting the `context_class` parameter in the constructor.
37 | First, define a request context subclass:
38 |
39 | ```python
40 | from rolo.gateway import RequestContext
41 |
42 | class MyContext(RequestContext):
43 | myattr: str
44 | ```
45 |
46 | Then, when you instantiate the Gateway:
47 |
48 | ```python
49 | gateway = Gateway(
50 | request_handlers=[
51 | ...
52 | ],
53 | context_class=MyContext
54 | )
55 | ```
56 |
57 | In your handlers, you can now reference your context:
58 |
59 | ```python
60 | def handler(chain: HandlerChain, context: MyContext, response: Response):
61 | ...
62 | ```
63 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | VENV_BIN = python3 -m venv
2 | VENV_DIR ?= .venv
3 | VENV_ACTIVATE = $(VENV_DIR)/bin/activate
4 | VENV_RUN = . $(VENV_ACTIVATE)
5 | ROOT_MODULE = rolo
6 |
7 | venv: $(VENV_ACTIVATE)
8 |
9 | $(VENV_ACTIVATE): pyproject.toml
10 | test -d .venv || $(VENV_BIN) .venv
11 | $(VENV_RUN); pip install --upgrade pip setuptools wheel
12 | $(VENV_RUN); pip install -e .[dev]
13 | touch $(VENV_DIR)/bin/activate
14 |
15 | install: venv
16 |
17 | clean:
18 | rm -rf .venv
19 | rm -rf build/
20 | rm -rf .eggs/
21 | rm -rf *.egg-info/
22 |
23 | format:
24 | $(VENV_RUN); python -m ruff check --show-source --fix .; python -m black .
25 |
26 | lint:
27 | $(VENV_RUN); python -m ruff check --show-source . && python -m black --check .
28 |
29 | test: venv
30 | $(VENV_RUN); python -m pytest
31 |
32 | test-coverage: venv
33 | $(VENV_RUN); coverage run --source=$(ROOT_MODULE) -m pytest tests/
34 |
35 | coveralls: venv
36 | $(VENV_RUN); coveralls
37 |
38 | $(VENV_DIR)/.docs-install: pyproject.toml $(VENV_ACTIVATE)
39 | $(VENV_RUN); pip install -e .[docs]
40 | touch $(VENV_DIR)/.docs-install
41 |
42 | install-docs: $(VENV_DIR)/.docs-install
43 |
44 | docs: install-docs
45 | $(VENV_RUN); cd docs && make html
46 |
47 | dist: venv
48 | $(VENV_RUN); pip install --upgrade build; python -m build
49 |
50 | publish: clean-dist venv test dist
51 | $(VENV_RUN); pip install --upgrade twine; twine upload dist/*
52 |
53 | clean-dist: clean
54 | rm -rf dist/
55 |
56 | .PHONY: clean clean-dist
57 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | Rolo documentation
2 | ==================
3 |
4 |
5 |
6 |
7 |
8 | Rolo HTTP: A Python framework for building HTTP-based server applications.
9 |
10 |
11 | ## Introduction
12 |
13 | Rolo is a flexible framework and library to build HTTP-based server applications beyond microservices and REST APIs.
14 | You can build HTTP-based RPC servers, websocket proxies, or other server types that typical web frameworks are not designed for.
15 | Rolo was originally designed to build the AWS RPC protocol server in [LocalStack](https://github.com/localstack/localstack).
16 |
17 | Rolo extends [Werkzeug](https://github.com/pallets/werkzeug/), a flexible Python HTTP server library, for you to use concepts you are familiar with like ``@route``, ``Request``, or ``Response``.
18 | It introduces the concept of a ``Gateway`` and ``HandlerChain``, an implementation variant of the [chain-of-responsibility pattern](https://en.wikipedia.org/wiki/Chain-of-responsibility_pattern).
19 |
20 | Rolo is designed for environments that do not use asyncio, but still require asynchronous HTTP features like HTTP2 SSE or Websockets.
21 | To allow asynchronous communication, Rolo introduces an ASGI/WSGI bridge, that allows you to serve Rolo applications through ASGI servers like Hypercorn.
22 |
23 | ## Table of Content
24 |
25 | ```{toctree}
26 | :caption: Quickstart
27 | :maxdepth: 2
28 |
29 | getting_started
30 | ```
31 |
32 | ```{toctree}
33 | :caption: User Guide
34 | :maxdepth: 2
35 |
36 | router
37 | handler_chain
38 | gateway
39 | websockets
40 | serving
41 | ```
42 |
43 | ```{toctree}
44 | :caption: Tutorials
45 | :maxdepth: 1
46 |
47 | tutorials/jsonrpc-server
48 | ```
49 |
50 |
--------------------------------------------------------------------------------
/rolo/gateway/wsgi.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from werkzeug.datastructures import Headers, MultiDict
4 | from werkzeug.wrappers import Request
5 |
6 | from rolo.response import Response
7 |
8 | if t.TYPE_CHECKING:
9 | from _typeshed.wsgi import StartResponse, WSGIEnvironment
10 |
11 | import logging
12 |
13 | from .gateway import Gateway
14 |
15 | LOG = logging.getLogger(__name__)
16 |
17 |
18 | class WsgiGateway:
19 | """
20 | Exposes a ``Gateway`` as a WSGI application.
21 | """
22 |
23 | gateway: Gateway
24 |
25 | def __init__(self, gateway: Gateway) -> None:
26 | super().__init__()
27 | self.gateway = gateway
28 |
29 | def __call__(
30 | self, environ: "WSGIEnvironment", start_response: "StartResponse"
31 | ) -> t.Iterable[bytes]:
32 | # create request from environment
33 | LOG.debug(
34 | "%s %s%s",
35 | environ["REQUEST_METHOD"],
36 | environ.get("HTTP_HOST"),
37 | environ.get("RAW_URI"),
38 | )
39 | request = Request(environ)
40 |
41 | raw_headers = environ.get("rolo.headers") or environ.get("asgi.headers")
42 | if raw_headers:
43 | # restores raw headers from ASGI scope, which allows dashes in header keys
44 | # see https://github.com/pallets/werkzeug/issues/940
45 |
46 | request.headers = Headers(
47 | MultiDict([(k.decode("latin-1"), v.decode("latin-1")) for (k, v) in raw_headers])
48 | )
49 | else:
50 | # by default, werkzeug requests from environ are immutable
51 | request.headers = Headers(request.headers)
52 |
53 | # prepare response
54 | response = Response()
55 |
56 | self.gateway.process(request, response)
57 |
58 | return response(environ, start_response)
59 |
--------------------------------------------------------------------------------
/tests/serving/test_twisted.py:
--------------------------------------------------------------------------------
1 | import http.client
2 | import io
3 | import json
4 |
5 | import requests
6 |
7 | from rolo import Request, Router, route
8 | from rolo.dispatcher import handler_dispatcher
9 | from rolo.gateway import Gateway
10 | from rolo.gateway.handlers import RouterHandler
11 |
12 |
13 | def test_large_file_upload(serve_twisted_gateway):
14 | router = Router(handler_dispatcher())
15 |
16 | @route("/hello", methods=["POST"])
17 | def hello(request: Request):
18 | return "ok"
19 |
20 | router.add(hello)
21 |
22 | gateway = Gateway(request_handlers=[RouterHandler(router, True)])
23 | server = serve_twisted_gateway(gateway)
24 |
25 | response = requests.post(server.url + "/hello", io.BytesIO(b"0" * 100001))
26 |
27 | assert response.status_code == 200
28 |
29 |
30 | def test_full_absolute_form_uri(serve_twisted_gateway):
31 | router = Router(handler_dispatcher())
32 |
33 | @route("/hello", methods=["GET"])
34 | def hello(request: Request):
35 | return {"path": request.path, "raw_uri": request.environ["RAW_URI"]}
36 |
37 | router.add(hello)
38 |
39 | gateway = Gateway(request_handlers=[RouterHandler(router, True)])
40 | server = serve_twisted_gateway(gateway)
41 | host = server.url
42 |
43 | conn = http.client.HTTPConnection(host="127.0.0.1", port=server.port)
44 |
45 | # This is what is sent:
46 | # send: b'GET http://localhost:/hello HTTP/1.1\r\nHost: localhost:\r\nAccept-Encoding: identity\r\n\r\n'
47 | # note the full URI in the HTTP request
48 | conn.request("GET", url=f"{host}/hello")
49 | response = conn.getresponse()
50 |
51 | assert response.status == 200
52 | response_data = json.loads(response.read())
53 | assert response_data["path"] == "/hello"
54 | assert response_data["raw_uri"].startswith("http")
55 |
--------------------------------------------------------------------------------
/tests/test_client.py:
--------------------------------------------------------------------------------
1 | from pytest_httpserver import HTTPServer
2 | from werkzeug import Request as WerkzeugRequest
3 | from werkzeug.datastructures import Headers
4 |
5 | from rolo import Response
6 | from rolo.client import SimpleRequestsClient
7 | from rolo.request import Request
8 |
9 |
10 | def echo_request_metadata_handler(request: WerkzeugRequest) -> Response:
11 | """
12 | Simple request handler that returns the incoming request metadata (method, path, url, headers).
13 |
14 | :param request: the incoming HTTP request
15 | :return: an HTTP response
16 | """
17 | response = Response()
18 | response.set_json(
19 | {
20 | "method": request.method,
21 | "path": request.path,
22 | "url": request.url,
23 | "headers": dict(Headers(request.headers)),
24 | }
25 | )
26 | return response
27 |
28 |
29 | class TestSimpleRequestClient:
30 | def test_empty_accept_encoding_header(self, httpserver: HTTPServer):
31 | httpserver.expect_request("/").respond_with_handler(echo_request_metadata_handler)
32 |
33 | url = httpserver.url_for("/")
34 |
35 | request = Request(path="/", method="GET")
36 |
37 | with SimpleRequestsClient() as client:
38 | response = client.request(request, url)
39 |
40 | assert "Accept-Encoding" not in response.json["headers"]
41 | assert "accept-encoding" not in response.json["headers"]
42 |
43 | def test_multi_values_headers(self, httpserver: HTTPServer):
44 | def multi_values_handler(_request: WerkzeugRequest) -> Response:
45 | multi_headers = Headers()
46 | multi_headers.add("Set-Cookie", "value1")
47 | multi_headers.add("Set-Cookie", "value2")
48 | assert multi_headers.getlist("Set-Cookie") == ["value1", "value2"]
49 |
50 | return Response(headers=multi_headers)
51 |
52 | httpserver.expect_request("/").respond_with_handler(multi_values_handler)
53 |
54 | url = httpserver.url_for("/")
55 |
56 | request = Request(path="/", method="GET")
57 |
58 | with SimpleRequestsClient() as client:
59 | response = client.request(request, url)
60 | assert response.headers.getlist("Set-Cookie") == ["value1", "value2"]
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.iml
3 | *~
4 |
5 | # General
6 | .DS_Store
7 | .AppleDouble
8 | .LSOverride
9 |
10 | # Icon must end with two \r
11 | Icon
12 |
13 | # Thumbnails
14 | ._*
15 |
16 | # Files that might appear in the root of a volume
17 | .DocumentRevisions-V100
18 | .fseventsd
19 | .Spotlight-V100
20 | .TemporaryItems
21 | .Trashes
22 | .VolumeIcon.icns
23 | .com.apple.timemachine.donotpresent
24 |
25 | # Directories potentially created on remote AFP share
26 | .AppleDB
27 | .AppleDesktop
28 | Network Trash Folder
29 | Temporary Items
30 | .apdisk
31 |
32 | # Byte-compiled / optimized / DLL files
33 | __pycache__/
34 | *.py[cod]
35 | *$py.class
36 |
37 | # C extensions
38 | *.so
39 |
40 | # Distribution / packaging
41 | .Python
42 | build/
43 | develop-eggs/
44 | dist/
45 | downloads/
46 | eggs/
47 | .eggs/
48 | lib/
49 | lib64/
50 | parts/
51 | sdist/
52 | var/
53 | wheels/
54 | *.egg-info/
55 | .installed.cfg
56 | *.egg
57 | MANIFEST
58 |
59 | # PyInstaller
60 | # Usually these files are written by a python script from a template
61 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
62 | *.manifest
63 | *.spec
64 |
65 | # Installer logs
66 | pip-log.txt
67 | pip-delete-this-directory.txt
68 |
69 | # Unit test / coverage reports
70 | htmlcov/
71 | .tox/
72 | .coverage
73 | .coverage.*
74 | .cache
75 | nosetests.xml
76 | coverage.xml
77 | *.cover
78 | .hypothesis/
79 |
80 | # Translations
81 | *.mo
82 | *.pot
83 |
84 | # Django stuff:
85 | *.log
86 | .static_storage/
87 | .media/
88 | local_settings.py
89 |
90 | # Flask stuff:
91 | instance/
92 | .webassets-cache
93 |
94 | # Scrapy stuff:
95 | .scrapy
96 |
97 | # Sphinx documentation
98 | docs/_build/
99 |
100 | # PyBuilder
101 | target/
102 |
103 | # Jupyter Notebook
104 | .ipynb_checkpoints
105 |
106 | # pyenv
107 | .python-version
108 |
109 | # celery beat schedule file
110 | celerybeat-schedule
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 |
--------------------------------------------------------------------------------
/rolo/gateway/gateway.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing as t
3 |
4 | from ..request import Request
5 | from ..response import Response
6 | from ..websocket.request import WebSocketRequest
7 | from .chain import ExceptionHandler, Handler, HandlerChain, RequestContext
8 |
9 | LOG = logging.getLogger(__name__)
10 |
11 |
12 | class Gateway:
13 | """
14 | A gateway creates new HandlerChain instances for each request and processes requests through them.
15 | """
16 |
17 | request_handlers: list[Handler]
18 | response_handlers: list[Handler]
19 | finalizers: list[Handler]
20 | exception_handlers: list[ExceptionHandler]
21 |
22 | def __init__(
23 | self,
24 | request_handlers: list[Handler] = None,
25 | response_handlers: list[Handler] = None,
26 | finalizers: list[Handler] = None,
27 | exception_handlers: list[ExceptionHandler] = None,
28 | context_class: t.Type[RequestContext] = None,
29 | ) -> None:
30 | super().__init__()
31 | self.request_handlers = request_handlers if request_handlers is not None else []
32 | self.response_handlers = response_handlers if response_handlers is not None else []
33 | self.exception_handlers = exception_handlers if exception_handlers is not None else []
34 | self.finalizers = finalizers if finalizers is not None else []
35 | self.context_class = context_class or RequestContext
36 |
37 | def new_chain(self) -> HandlerChain:
38 | return HandlerChain(
39 | self.request_handlers,
40 | self.response_handlers,
41 | self.finalizers,
42 | self.exception_handlers,
43 | )
44 |
45 | def process(self, request: Request, response: Response):
46 | chain = self.new_chain()
47 |
48 | context = self.context_class(request)
49 |
50 | chain.handle(context, response)
51 |
52 | def accept(self, request: WebSocketRequest):
53 | response = Response(status=101)
54 | self.process(request, response)
55 |
56 | # only send the populated response if the websocket hasn't already done so before
57 | if response.status_code != 101:
58 | if request.is_upgraded():
59 | return
60 | if request.is_rejected():
61 | return
62 | request.reject(response)
63 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = 'Rolo'
10 | copyright = '2024, LocalStack'
11 | author = 'Thomas Rausch'
12 | release = '0.6.x'
13 |
14 | # -- General configuration ---------------------------------------------------
15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
16 |
17 | extensions = [
18 | 'myst_parser'
19 | ]
20 |
21 | templates_path = ['_templates']
22 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
23 |
24 |
25 | # -- Options for HTML output -------------------------------------------------
26 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
27 |
28 | html_theme = "furo"
29 | html_static_path = ["_static"]
30 | html_title = "Rolo documentation"
31 |
32 | html_logo = "_static/rolo.png"
33 | html_theme_options = {
34 | "top_of_page_buttons": ["view", "edit"],
35 | "source_repository": "https://github.com/localstack/rolo/",
36 | "source_branch": "main",
37 | "source_directory": "docs/",
38 | "footer_icons": [
39 | {
40 | "name": "GitHub",
41 | "url": "https://github.com/localstack/rolo",
42 | "html": """
43 |
46 | """,
47 | "class": "",
48 | },
49 | ],
50 | }
51 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "rolo"
7 | authors = [
8 | { name = "LocalStack Contributors", email = "info@localstack.cloud" }
9 | ]
10 | version = "0.7.6"
11 | description = "A Python framework for building HTTP-based server applications"
12 | dependencies = [
13 | "requests>=2.20",
14 | "werkzeug>=3.0"
15 | ]
16 | requires-python = ">=3.10"
17 | license = {file = "LICENSE"}
18 | classifiers = [
19 | "Development Status :: 5 - Production/Stable",
20 | "License :: OSI Approved :: Apache Software License",
21 | "Operating System :: OS Independent",
22 | "Programming Language :: Python :: 3",
23 | "Programming Language :: Python :: 3.10",
24 | "Programming Language :: Python :: 3.11",
25 | "Programming Language :: Python :: 3.12",
26 | "Topic :: Software Development :: Libraries"
27 | ]
28 | dynamic = ["readme"]
29 |
30 |
31 | [project.optional-dependencies]
32 | dev = [
33 | "black==23.10.0",
34 | "pytest>=7.0",
35 | "hypercorn",
36 | "pydantic",
37 | "pytest_httpserver",
38 | "websocket-client>=1.7.0",
39 | "coverage[toml]>=5.0.0",
40 | "coveralls>=3.3",
41 | "localstack-twisted",
42 | "ruff==0.1.0"
43 | ]
44 | docs = [
45 | "sphinx",
46 | "furo",
47 | "myst_parser",
48 | ]
49 |
50 | [tool.setuptools]
51 | include-package-data = false
52 |
53 | [tool.setuptools.dynamic]
54 | readme = {file = ["README.md"], content-type = "text/markdown"}
55 |
56 | [tool.setuptools.packages.find]
57 | include = ["rolo*"]
58 | exclude = ["tests*"]
59 |
60 | [tool.setuptools.package-data]
61 | "*" = ["*.md"]
62 |
63 | [tool.black]
64 | line_length = 100
65 | include = '((rolo)/.*\.py$|tests/.*\.py$)'
66 | #extend_exclude = '()'
67 |
68 | [tool.ruff]
69 | # Always generate Python 3.10-compatible code.
70 | target-version = "py310"
71 | line-length = 110
72 | select = ["B", "C", "E", "F", "I", "W", "T", "B9"]
73 | ignore = [
74 | "E501", # E501 Line too long - handled by black, see https://docs.astral.sh/ruff/faq/#is-ruff-compatible-with-black
75 | ]
76 | exclude = [
77 | ".venv*",
78 | "venv*",
79 | "dist",
80 | "build",
81 | "target",
82 | "*.egg-info",
83 | ".git",
84 | ]
85 |
86 | [tool.coverage.report]
87 | exclude_lines = [
88 | "if __name__ == .__main__.:",
89 | "raise NotImplemented.",
90 | "return NotImplemented",
91 | "def __repr__",
92 | "__all__",
93 | ]
94 |
--------------------------------------------------------------------------------
/docs/websockets.md:
--------------------------------------------------------------------------------
1 | Websockets
2 | ==========
3 |
4 | Rolo supports Websockets through ASGI and Twisted (see [serving](serving.md)).
5 |
6 | ## Websocket requests
7 |
8 | Rolo introduces an HTTP method called `WEBSOCKET`, which can be used to register routes that deal with websocket
9 | requests.
10 |
11 | ```python
12 | from rolo import route
13 | from rolo.websocket import WebSocketRequest
14 |
15 |
16 | @route("/stream", methods=["WEBSOCKET"])
17 | def handler(request: WebSocketRequest, name: str):
18 | ...
19 | ```
20 |
21 | You can add such a route `Router`, but the Router needs to be handled through a `Gateway` using the `RouterHandler`, and
22 | served through an ASGI webserver or twisted.
23 |
24 | With a tool like [websocat](https://github.com/vi/websocat), you could now connect to the websocket.
25 |
26 | ### Accepting or rejecting the connection
27 |
28 | The websocket connection needs to be either accepted or rejected via `WebSocketRequest`.
29 | When calling ``WebSocketRequest.accept``, an upgrade response will be sent to the client, and the protocol will be
30 | switched to the bidirectional WebSocket protocol.
31 | If ``WebSocketRequest.reject`` is called, the server immediately returns an HTTP response and closes the connection.
32 |
33 | You may want to do this when doing authorization for example:
34 |
35 | ```python
36 | def app(request: WebsocketRequest):
37 | # example: do authorization first
38 | auth = request.headers.get("Authorization")
39 | if not is_authorized(auth):
40 | request.reject(Response("no dice", 403))
41 | return
42 |
43 | # then continue working with the websocket
44 | with request.accept() as websocket:
45 | websocket.send("hello world!")
46 | data = websocket.receive()
47 | # ...
48 | ```
49 |
50 | ## Websocket object
51 |
52 | `WebSocketRequest.accept` also returns a `WebSocket` object, that can then be used to send and receive data
53 |
54 | You can explicitly call `WebSocket.receive`, or you can simply iterate over the `WebSocket` object.
55 | Here is an example:
56 |
57 | ```python
58 |
59 | from rolo import route
60 | from rolo.websocket import WebSocketRequest
61 |
62 |
63 | @route("/echo/", methods=["WEBSOCKET"])
64 | def handler(request: WebSocketRequest, name: str):
65 | with request.accept() as websocket:
66 | websocket.send(f"thanks for connecting {name}")
67 | for line in websocket:
68 | websocket.send(f"echo: {line}")
69 | if line == "exit":
70 | websocket.send("ok bye!")
71 | return
72 | ```
73 |
--------------------------------------------------------------------------------
/docs/serving.md:
--------------------------------------------------------------------------------
1 | Serving
2 | =======
3 |
4 | This guide shows you how to serve Rolo components through different Python web server technologies.
5 |
6 | WSGI
7 | ----
8 |
9 | ### Serving a Router as WSGI app
10 |
11 | If you only need a `Router` instance to serve your application, you can convert to a WSGI app using the `Router.wsgi()` method.
12 |
13 | ```python
14 | from rolo import Router, route
15 | from rolo.routing import handler_dispatcher
16 |
17 | @route("/")
18 | def index(request):
19 | return "hello world"
20 |
21 | router = Router(dispatcher=handler_dispatcher())
22 | router.add(index)
23 |
24 | app = router.wsgi()
25 | ```
26 |
27 | Now you can use any old WSGI compliant server to serve the application.
28 | For example, if this file is stored in `myapp.py`, using gunicorn, you can:
29 |
30 | ```sh
31 | pip install gunicorn
32 | gunicorn -w 4 myapp:app
33 | ```
34 |
35 | ### Serving a Gateway as WSGI app
36 |
37 | Unless you need Websockets, the Rolo Request object is fully WSGI compliant, so you can also use any WSGI server to serve a `Gateway`.
38 | Simply use the `WSGIGateway` adapter.
39 |
40 | ```python
41 | from rolo.gateway import Gateway
42 | from rolo.gateway.wsgi import WsgiGateway
43 |
44 | gateway: Gateway = ...
45 |
46 | app = WsgiGateway(gateway)
47 | ```
48 |
49 | Similar to the previous example, you can serve the `app` object through any WSGI compliant server.
50 |
51 | ASGI
52 | ----
53 |
54 | ASGI servers like Hypercorn allow asynchronous server communication, which is needed for HTTP/2 streaming or Websockets.
55 | Gateways can be served through the `AsgiGateway` adapter, which exposes a `Gateway` as an ASGI3 application.
56 | Under the hood, it uses our own ASGI/WSGI bridge (`AsgiAdapter`), and converts ASGI calls to WSGI calls for regular HTTP requests, and uses ASGI websockets for serving rolo websockets.
57 | File `myapp.py`:
58 |
59 | ```python
60 | from rolo.gateway import Gateway
61 | from rolo.gateway.asgi import AsgiGateway
62 |
63 | gateway: Gateway = ...
64 |
65 | app = AsgiGateway(gateway)
66 | ```
67 |
68 | Now you can use Hypercorn or other ASGI servers to serve the `app` object.
69 |
70 | ```sh
71 | pip install hypercorn
72 | hypercorn myapp:app
73 | ```
74 |
75 | Twisted
76 | -------
77 |
78 | Rolo can be served through [Twisted](https://twisted.org/), which supports both WSGI and Websockets.
79 | You will need twisted, and wsproto installed `pip install twisted wsproto`.
80 |
81 | ```python
82 | from rolo.gateway import Gateway
83 | from rolo.serving.twisted import TwistedGateway
84 | from twisted.internet import endpoints, reactor
85 |
86 | gateway: Gateway = ...
87 |
88 | # Rolo/Twisted adapter, that exposes a Rolo Gateway as a twisted.web.server.Site object
89 | site = TwistedGateway(gateway)
90 |
91 | endpoint = endpoints.TCP4ServerEndpoint(reactor, 8000)
92 | endpoint.listen(site)
93 |
94 | reactor.run()
95 | ```
96 |
--------------------------------------------------------------------------------
/tests/test_response.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | import pytest
4 | from werkzeug.exceptions import NotFound
5 |
6 | from rolo import Response
7 | from tests import static
8 |
9 |
10 | def test_for_resource_html():
11 | response = Response.for_resource(static, "index.html")
12 | assert response.content_type == "text/html; charset=utf-8"
13 | assert response.get_data() == b'\nhello\n\n'
14 | assert response.status == "200 OK"
15 |
16 |
17 | def test_for_resource_txt():
18 | response = Response.for_resource(static, "test.txt")
19 | assert response.content_type == "text/plain; charset=utf-8"
20 | assert response.get_data() == b"hello world\n"
21 | assert response.status == "200 OK"
22 |
23 |
24 | def test_for_resource_with_custom_response_status_and_headers():
25 | response = Response.for_resource(static, "test.txt", status=201, headers={"X-Foo": "Bar"})
26 | assert response.content_type == "text/plain; charset=utf-8"
27 | assert response.get_data() == b"hello world\n"
28 | assert response.status == "201 CREATED"
29 | assert response.headers.get("X-Foo") == "Bar"
30 |
31 |
32 | def test_for_resource_not_found():
33 | with pytest.raises(NotFound):
34 | Response.for_resource(static, "doesntexist.txt")
35 |
36 |
37 | def test_for_json():
38 | response = Response.for_json(
39 | {"foo": "bar", "420": 69, "isTrue": True},
40 | )
41 | assert response.content_type == "application/json"
42 | assert response.get_data() == b'{"foo": "bar", "420": 69, "isTrue": true}'
43 | assert response.status == "200 OK"
44 |
45 |
46 | def test_for_json_with_custom_response_status_and_headers():
47 | response = Response.for_json(
48 | {"foo": "bar", "420": 69, "isTrue": True},
49 | status=201,
50 | headers={"X-Foo": "Bar"},
51 | )
52 | assert response.content_type == "application/json"
53 | assert response.get_data() == b'{"foo": "bar", "420": 69, "isTrue": true}'
54 | assert response.status == "201 CREATED"
55 | assert response.headers.get("X-Foo") == "Bar"
56 |
57 |
58 | @pytest.mark.parametrize(
59 | argnames="data",
60 | argvalues=[
61 | b"foobar",
62 | "foobar",
63 | io.BytesIO(b"foobar"),
64 | [b"foo", b"bar"],
65 | ],
66 | )
67 | def test_set_response(data):
68 | response = Response()
69 | response.set_response(data)
70 | assert response.get_data() == b"foobar"
71 |
72 |
73 | def test_update_from():
74 | original = Response(
75 | [b"foo", b"bar"], 202, headers={"X-Foo": "Bar"}, mimetype="application/octet-stream"
76 | )
77 |
78 | response = Response()
79 | response.update_from(original)
80 |
81 | assert response.get_data() == b"foobar"
82 | assert response.status_code == 202
83 | assert response.headers.get("X-Foo") == "Bar"
84 | assert response.content_type == "application/octet-stream"
85 |
--------------------------------------------------------------------------------
/rolo/gateway/asgi.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import concurrent.futures.thread
3 | from asyncio import AbstractEventLoop
4 | from typing import Optional
5 |
6 | from rolo.asgi import ASGIAdapter, ASGILifespanListener
7 | from rolo.websocket.request import WebSocketRequest
8 |
9 | from .gateway import Gateway
10 | from .wsgi import WsgiGateway
11 |
12 |
13 | class _ThreadPool(concurrent.futures.thread.ThreadPoolExecutor):
14 | """
15 | This thread pool executor removes the threads it creates from the global ``_thread_queues`` of
16 | ``concurrent.futures.thread``, which joins all created threads at python exit and will block interpreter shutdown if
17 | any threads are still running, even if they are daemon threads.
18 | """
19 |
20 | def _adjust_thread_count(self) -> None:
21 | super()._adjust_thread_count()
22 |
23 | for t in self._threads:
24 | if not t.daemon:
25 | continue
26 | try:
27 | del concurrent.futures.thread._threads_queues[t]
28 | except KeyError:
29 | pass
30 |
31 |
32 | class AsgiGateway:
33 | """
34 | Exposes a Gateway as an ASGI3 application. Under the hood, it uses a WsgiGateway with a threading async/sync bridge.
35 | """
36 |
37 | default_thread_count = 1000
38 |
39 | gateway: Gateway
40 |
41 | def __init__(
42 | self,
43 | gateway: Gateway,
44 | event_loop: Optional[AbstractEventLoop] = None,
45 | threads: int = None,
46 | lifespan_listener: Optional[ASGILifespanListener] = None,
47 | websocket_listener=None,
48 | ) -> None:
49 | self.gateway = gateway
50 |
51 | self.event_loop = event_loop or asyncio.get_event_loop()
52 | self.executor = _ThreadPool(
53 | threads or self.default_thread_count, thread_name_prefix="asgi_gw"
54 | )
55 | self.adapter = ASGIAdapter(
56 | WsgiGateway(gateway),
57 | event_loop=event_loop,
58 | executor=self.executor,
59 | lifespan_listener=lifespan_listener,
60 | websocket_listener=websocket_listener or WebSocketRequest.listener(gateway.accept),
61 | )
62 | self._closed = False
63 |
64 | async def __call__(self, scope, receive, send) -> None:
65 | """
66 | ASGI3 application interface.
67 |
68 | :param scope: the ASGI request scope
69 | :param receive: the receive callable
70 | :param send: the send callable
71 | """
72 | if self._closed:
73 | raise RuntimeError("Cannot except new request on closed ASGIGateway")
74 |
75 | return await self.adapter(scope, receive, send)
76 |
77 | def close(self):
78 | """
79 | Close the ASGIGateway by shutting down the underlying executor.
80 | """
81 | self._closed = True
82 | self.executor.shutdown(wait=False, cancel_futures=True)
83 |
--------------------------------------------------------------------------------
/rolo/gateway/handlers.py:
--------------------------------------------------------------------------------
1 | """Several gateway handlers"""
2 | import typing as t
3 |
4 | from werkzeug.datastructures import Headers
5 | from werkzeug.exceptions import HTTPException, NotFound
6 |
7 | from rolo.response import Response
8 | from rolo.routing import Router
9 |
10 | from .chain import HandlerChain, RequestContext
11 |
12 |
13 | class RouterHandler:
14 | """
15 | Adapter to serve a ``Router`` as a ``Handler``.
16 | """
17 |
18 | router: Router
19 | respond_not_found: bool
20 |
21 | def __init__(self, router: Router, respond_not_found: bool = False) -> None:
22 | self.router = router
23 | self.respond_not_found = respond_not_found
24 |
25 | def __call__(self, chain: HandlerChain, context: RequestContext, response: Response):
26 | try:
27 | router_response = self.router.dispatch(context.request)
28 | response.update_from(router_response)
29 | chain.stop()
30 | except NotFound:
31 | if self.respond_not_found:
32 | chain.respond(404, "not found")
33 |
34 |
35 | class EmptyResponseHandler:
36 | """
37 | Handler that creates a default response if the response in the context is empty.
38 | """
39 |
40 | status_code: int
41 | body: bytes
42 | headers: dict
43 |
44 | def __init__(self, status_code: int = 404, body: bytes = None, headers: Headers = None):
45 | self.status_code = status_code
46 | self.body = body or b""
47 | self.headers = headers or Headers()
48 |
49 | def __call__(self, chain: HandlerChain, context: RequestContext, response: Response):
50 | if self.is_empty_response(response):
51 | self.populate_default_response(response)
52 |
53 | def is_empty_response(self, response: Response):
54 | return response.status_code in [0, None] and not response.response
55 |
56 | def populate_default_response(self, response: Response):
57 | response.status_code = self.status_code
58 | response.data = self.body
59 | response.headers.update(self.headers)
60 |
61 |
62 | class WerkzeugExceptionHandler:
63 | def __init__(self, output_format: t.Literal["json", "html"] = None) -> None:
64 | self.format = output_format or "json"
65 |
66 | def __call__(
67 | self, chain: HandlerChain, exception: Exception, context: RequestContext, response: Response
68 | ):
69 | if not isinstance(exception, HTTPException):
70 | return
71 |
72 | headers = Headers(exception.get_headers()) # FIXME
73 | headers.pop()
74 |
75 | if self.format == "html":
76 | chain.respond(status_code=exception.code, headers=headers, payload=exception.get_body())
77 | elif self.format == "json":
78 | chain.respond(
79 | status_code=exception.code,
80 | headers=headers,
81 | payload={"code": exception.code, "description": exception.description},
82 | )
83 | else:
84 | raise ValueError(f"unknown rendering format {self.format}")
85 |
--------------------------------------------------------------------------------
/docs/getting_started.md:
--------------------------------------------------------------------------------
1 | Getting started
2 | ===============
3 |
4 | ## Installation
5 |
6 | Rolo is hosted on [pypi](https://pypi.org/project/rolo/) and can be installed via pip.
7 |
8 | ```sh
9 | pip install rolo
10 | ```
11 |
12 | ## Hello World
13 |
14 | Rolo provides different ways of building a web application.
15 | It provides familiar concepts such as Router and `@route`, but also more flexible concepts like a Handler Chain.
16 |
17 | ### Router
18 |
19 | Here is a simple [`Router`](router.md) that can be served as WSGI application using the Werkzeug dev server.
20 | If you are familiar with Flask, `@route` works in a similar way.
21 |
22 | ```python
23 | from werkzeug import Request
24 | from werkzeug.serving import run_simple
25 |
26 | from rolo import Router, route
27 | from rolo.routing import handler_dispatcher
28 |
29 | @route("/")
30 | def hello(request: Request):
31 | return {"message": "Hello World"}
32 |
33 | router = Router(dispatcher=handler_dispatcher())
34 | router.add(hello)
35 |
36 | run_simple("localhost", 8000, router.wsgi())
37 | ```
38 |
39 | And to test:
40 | ```console
41 | curl localhost:8000/
42 | ```
43 | Should yield
44 | ```json
45 | {"message": "Hello World"}
46 | ```
47 |
48 | `rolo.Request` and `rolo.Response` objects work in the same way as Werkzeug's [Request / Response](https://werkzeug.palletsprojects.com/en/latest/wrappers/) wrappers.
49 |
50 | ### Gateway
51 |
52 | A Gateway holds a set of handlers that are combined into a handler chain.
53 | Here is a simple example with a single request handler that dynamically creates a response object similar to httpbin.
54 |
55 | ```python
56 | from werkzeug.serving import run_simple
57 |
58 | from rolo import Response
59 | from rolo.gateway import Gateway, RequestContext, HandlerChain
60 | from rolo.gateway.wsgi import WsgiGateway
61 |
62 |
63 | def echo_handler(chain: HandlerChain, context: RequestContext, response: Response):
64 | response.status_code = 200
65 | response.set_json(
66 | {
67 | "method": context.request.method,
68 | "path": context.request.path,
69 | "query": context.request.args,
70 | "headers": dict(context.request.headers),
71 | }
72 | )
73 | chain.stop()
74 |
75 |
76 | gateway = Gateway(
77 | request_handlers=[echo_handler],
78 | )
79 |
80 | run_simple("localhost", 8000, WsgiGateway(gateway))
81 | ```
82 |
83 | And to test:
84 | ```console
85 | curl -s -X POST "localhost:8000/foo/bar?a=1&b=2" | jq .
86 | ```
87 | Should give you:
88 | ```json
89 | {
90 | "method": "POST",
91 | "path": "/foo/bar",
92 | "query": {
93 | "a": "1",
94 | "b": "2"
95 | },
96 | "headers": {
97 | "Host": "localhost:8000",
98 | "User-Agent": "curl/7.81.0",
99 | "Accept": "*/*"
100 | }
101 | }
102 | ```
103 |
104 | ## Next Steps
105 |
106 | Learn how to
107 | * Use the [Router](router.md)
108 | * Use the [Handler Chain](handler_chain.md)
109 | * [Serve](serving.md) rolo through your favorite web server
110 |
--------------------------------------------------------------------------------
/tests/test_dispatcher.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import pytest
4 | from werkzeug.exceptions import NotFound
5 |
6 | from rolo import Request, Response, Router
7 | from rolo.routing import handler_dispatcher
8 |
9 |
10 | class TestHandlerDispatcher:
11 | def test_handler_dispatcher(self):
12 | router = Router(dispatcher=handler_dispatcher())
13 |
14 | def handler_foo(_request: Request) -> Response:
15 | return Response("ok")
16 |
17 | def handler_bar(_request: Request, bar, baz) -> Response:
18 | response = Response()
19 | response.set_json({"bar": bar, "baz": baz})
20 | return response
21 |
22 | router.add("/foo", handler_foo)
23 | router.add("/bar//", handler_bar)
24 |
25 | assert router.dispatch(Request("GET", "/foo")).data == b"ok"
26 | assert router.dispatch(Request("GET", "/bar/420/ed")).json == {"bar": 420, "baz": "ed"}
27 |
28 | with pytest.raises(NotFound):
29 | assert router.dispatch(Request("GET", "/bar/asfg/ed"))
30 |
31 | def test_handler_dispatcher_invalid_signature(self):
32 | router = Router(dispatcher=handler_dispatcher())
33 |
34 | def handler(_request: Request, arg1) -> Response: # invalid signature
35 | return Response("ok")
36 |
37 | router.add("/foo//", handler)
38 |
39 | with pytest.raises(TypeError):
40 | router.dispatch(Request("GET", "/foo/a/b"))
41 |
42 | def test_handler_dispatcher_with_dict_return(self):
43 | router = Router(dispatcher=handler_dispatcher())
44 |
45 | def handler(_request: Request, arg1) -> Dict[str, Any]:
46 | return {"arg1": arg1, "hello": "there"}
47 |
48 | router.add("/foo/", handler)
49 | assert router.dispatch(Request("GET", "/foo/a")).json == {"arg1": "a", "hello": "there"}
50 |
51 | def test_handler_dispatcher_with_list_return(self):
52 | router = Router(dispatcher=handler_dispatcher())
53 |
54 | def handler(_request: Request, arg1) -> Dict[str, Any]:
55 | return [{"arg1": arg1, "hello": "there"}, 1, 2, "3", [4, 5]]
56 |
57 | router.add("/foo/", handler)
58 | assert router.dispatch(Request("GET", "/foo/a")).json == [
59 | {"arg1": "a", "hello": "there"},
60 | 1,
61 | 2,
62 | "3",
63 | [4, 5],
64 | ]
65 |
66 | def test_handler_dispatcher_with_text_return(self):
67 | router = Router(dispatcher=handler_dispatcher())
68 |
69 | def handler(_request: Request, arg1) -> str:
70 | return f"hello: {arg1}"
71 |
72 | router.add("/", handler)
73 | assert router.dispatch(Request("GET", "/world")).data == b"hello: world"
74 |
75 | def test_handler_dispatcher_with_none_return(self):
76 | router = Router(dispatcher=handler_dispatcher())
77 |
78 | def handler(_request: Request):
79 | return None
80 |
81 | router.add("/", handler)
82 | assert router.dispatch(Request("GET", "/")).status_code == 200
83 |
--------------------------------------------------------------------------------
/tests/gateway/test_handlers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import requests
3 | from werkzeug import Request
4 | from werkzeug.exceptions import BadRequest
5 |
6 | from rolo import Response, Router
7 | from rolo.gateway import Gateway, HandlerChain, RequestContext
8 | from rolo.gateway.handlers import EmptyResponseHandler, RouterHandler, WerkzeugExceptionHandler
9 |
10 |
11 | def _echo_handler(request: Request, args):
12 | return Response.for_json(
13 | {
14 | "path": request.path,
15 | "method": request.method,
16 | "headers": dict(request.headers),
17 | }
18 | )
19 |
20 |
21 | @pytest.mark.parametrize("serve_gateway", ["wsgi", "asgi", "twisted"], indirect=True)
22 | class TestWerkzeugExceptionHandler:
23 | def test_json_output_format(self, serve_gateway):
24 | def handler(chain: HandlerChain, context: RequestContext, response: Response):
25 | if context.request.method != "GET":
26 | raise BadRequest("oh noes")
27 |
28 | chain.respond(payload="ok")
29 |
30 | server = serve_gateway(
31 | Gateway(
32 | request_handlers=[
33 | handler,
34 | ],
35 | exception_handlers=[
36 | WerkzeugExceptionHandler(),
37 | ],
38 | )
39 | )
40 |
41 | resp = requests.get(server.url)
42 | assert resp.status_code == 200
43 | assert resp.text == "ok"
44 |
45 | resp = requests.post(server.url)
46 | assert resp.status_code == 400
47 | assert resp.json() == {"code": 400, "description": "oh noes"}
48 |
49 |
50 | @pytest.mark.parametrize("serve_gateway", ["wsgi", "asgi", "twisted"], indirect=True)
51 | class TestRouterHandler:
52 | def test_router_handler_with_respond_not_found(self, serve_gateway):
53 | router = Router()
54 | router.add("/foo", _echo_handler)
55 |
56 | server = serve_gateway(
57 | Gateway(
58 | request_handlers=[
59 | RouterHandler(router, True),
60 | ],
61 | )
62 | )
63 |
64 | doc = requests.get(server.url + "/foo", headers={"Foo-Bar": "foobar"}).json()
65 | assert doc["path"] == "/foo"
66 | assert doc["method"] == "GET"
67 | assert doc["headers"]["Foo-Bar"] == "foobar"
68 |
69 | response = requests.get(server.url + "/bar")
70 | assert response.status_code == 404
71 | assert response.text == "not found"
72 |
73 |
74 | @pytest.mark.parametrize("serve_gateway", ["wsgi", "asgi", "twisted"], indirect=True)
75 | class TestEmptyResponseHandler:
76 | def test_empty_response_handler(self, serve_gateway):
77 | def _handler(chain, context, response):
78 | if context.request.method == "GET":
79 | chain.respond(202, "ok")
80 | else:
81 | response.status_code = 0
82 |
83 | server = serve_gateway(
84 | Gateway(
85 | request_handlers=[_handler],
86 | response_handlers=[EmptyResponseHandler(status_code=412, body=b"teapot?")],
87 | )
88 | )
89 |
90 | response = requests.get(server.url)
91 | assert response.text == "ok"
92 | assert response.status_code == 202
93 |
94 | response = requests.post(server.url)
95 | assert response.text == "teapot?"
96 | assert response.status_code == 412
97 |
--------------------------------------------------------------------------------
/rolo/routing/pydantic.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import typing as t
3 |
4 | import pydantic
5 |
6 | from rolo.request import Request
7 | from rolo.response import Response
8 |
9 | from .handler import Handler, HandlerDispatcher, ResultValue
10 | from .router import RequestArguments
11 |
12 |
13 | def _get_model_argument(endpoint: Handler) -> t.Optional[tuple[str, t.Type[pydantic.BaseModel]]]:
14 | """
15 | Inspects the endpoint function using Python reflection to find in its signature a ``pydantic.BaseModel`` attribute.
16 |
17 | :param endpoint: the endpoint to inspect
18 | :return: a tuple containing the name and class, or None
19 | """
20 | if not inspect.isfunction(endpoint) and not inspect.ismethod(endpoint):
21 | # cannot yet dispatch to other callables (e.g. an object with a `__call__` method)
22 | return None
23 |
24 | # finds the first pydantic.BaseModel in the list of annotations.
25 | # ``def foo(request: Request, id: int, item: MyItem)`` would yield ``('my_item', MyItem)``
26 | for arg_name, arg_type in endpoint.__annotations__.items():
27 | if arg_name in ("self", "return"):
28 | continue
29 | if not inspect.isclass(arg_type):
30 | continue
31 | try:
32 | if issubclass(arg_type, pydantic.BaseModel):
33 | return arg_name, arg_type
34 | except TypeError:
35 | # FIXME: this is needed for Python 3.10 support
36 | continue
37 |
38 | return None
39 |
40 |
41 | def _try_parse_pydantic_request_body(
42 | request: Request, endpoint: Handler
43 | ) -> t.Optional[tuple[str, pydantic.BaseModel]]:
44 | arg = _get_model_argument(endpoint)
45 |
46 | if not arg:
47 | return
48 |
49 | arg_name, arg_type = arg
50 |
51 | if not request.content_length:
52 | # forces a ValidationError "Invalid JSON: EOF while parsing a value at line 1 column 0"
53 | arg_type.model_validate_json(b"")
54 |
55 | # will raise a werkzeug.BadRequest error if the JSON is invalid
56 | obj = request.get_json(force=True)
57 |
58 | return arg_name, arg_type.model_validate(obj)
59 |
60 |
61 | class PydanticHandlerDispatcher(HandlerDispatcher):
62 | """
63 | Special HandlerDispatcher that knows how to serialize and deserialize pydantic models.
64 | """
65 |
66 | def invoke_endpoint(
67 | self,
68 | request: Request,
69 | endpoint: t.Callable,
70 | request_args: RequestArguments,
71 | ) -> t.Any:
72 | # prepare request args
73 | try:
74 | arg = _try_parse_pydantic_request_body(request, endpoint)
75 | except pydantic.ValidationError as e:
76 | return Response(e.json(), mimetype="application/json", status=400)
77 |
78 | if arg:
79 | arg_name, model = arg
80 | request_args = {**request_args, arg_name: model}
81 |
82 | return super().invoke_endpoint(request, endpoint, request_args)
83 |
84 | def populate_response(self, response: Response, value: ResultValue):
85 | # try to convert any pydantic types to dicts before handing them to the parent implementation
86 | if isinstance(value, pydantic.BaseModel):
87 | value = value.model_dump()
88 | elif isinstance(value, (list, tuple)):
89 | converted = []
90 | for element in value:
91 | if isinstance(element, pydantic.BaseModel):
92 | converted.append(element.model_dump())
93 | else:
94 | converted.append(element)
95 | value = converted
96 |
97 | super().populate_response(response, value)
98 |
--------------------------------------------------------------------------------
/rolo/routing/handler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import typing as t
4 |
5 | from werkzeug import Response as WerkzeugResponse
6 |
7 | try:
8 | import pydantic # noqa
9 |
10 | ENABLE_PYDANTIC = True
11 | except ImportError:
12 | ENABLE_PYDANTIC = False
13 |
14 | from rolo.request import Request
15 | from rolo.response import Response
16 |
17 | from .router import Dispatcher, RequestArguments
18 |
19 | LOG = logging.getLogger(__name__)
20 |
21 | ResultValue = t.Union[
22 | WerkzeugResponse,
23 | str,
24 | bytes,
25 | dict[str, t.Any], # a JSON dict
26 | list[t.Any],
27 | ]
28 |
29 |
30 | class Handler(t.Protocol):
31 | """
32 | A protocol used by a ``Router`` together with the dispatcher created with ``handler_dispatcher``. Endpoints added
33 | with this protocol take as first argument the HTTP request object, and then as keyword arguments the request
34 | parameters added in the rule. This makes it work very similar to flask routes.
35 |
36 | Example code could look like this::
37 |
38 | def my_route(request: Request, organization: str, repo: str):
39 | return {"something": "returned as json response"}
40 |
41 | router = Router(dispatcher=handler_dispatcher)
42 | router.add("//", endpoint=my_route)
43 |
44 | """
45 |
46 | def __call__(self, request: Request, **kwargs) -> ResultValue:
47 | """
48 | Handle the given request.
49 |
50 | :param request: the HTTP request object
51 | :param kwargs: the url request parameters
52 | :return: a string or bytes value, a dict to create a json response, or a raw werkzeug Response object.
53 | """
54 | raise NotImplementedError
55 |
56 |
57 | class HandlerDispatcher:
58 | def __init__(self, json_encoder: t.Type[json.JSONEncoder] = None):
59 | self.json_encoder = json_encoder
60 |
61 | def __call__(
62 | self, request: Request, endpoint: t.Callable, request_args: RequestArguments
63 | ) -> Response:
64 | result = self.invoke_endpoint(request, endpoint, request_args)
65 | return self.to_response(result)
66 |
67 | def invoke_endpoint(
68 | self,
69 | request: Request,
70 | endpoint: t.Callable,
71 | request_args: RequestArguments,
72 | ) -> t.Any:
73 | return endpoint(request, **request_args)
74 |
75 | def to_response(self, value: ResultValue) -> Response:
76 | if isinstance(value, WerkzeugResponse):
77 | return value
78 |
79 | response = Response()
80 | if value is None:
81 | return response
82 |
83 | self.populate_response(response, value)
84 | return response
85 |
86 | def populate_response(self, response: Response, value: ResultValue):
87 | if isinstance(value, (str, bytes, bytearray)):
88 | response.data = value
89 | elif isinstance(value, (dict, list)):
90 | response.data = json.dumps(value, cls=self.json_encoder)
91 | response.mimetype = "application/json"
92 | else:
93 | raise ValueError("unhandled result type %s", type(value))
94 |
95 |
96 | def handler_dispatcher(json_encoder: t.Type[json.JSONEncoder] = None) -> Dispatcher[Handler]:
97 | """
98 | Creates a Dispatcher that treats endpoints like callables of the ``Handler`` Protocol.
99 |
100 | :param json_encoder: optionally the json encoder class to use for translating responses
101 | :return: a new dispatcher
102 | """
103 | if ENABLE_PYDANTIC:
104 | from rolo.routing.pydantic import PydanticHandlerDispatcher
105 |
106 | return PydanticHandlerDispatcher(json_encoder)
107 |
108 | return HandlerDispatcher(json_encoder)
109 |
--------------------------------------------------------------------------------
/rolo/routing/resource.py:
--------------------------------------------------------------------------------
1 | """
2 | This module enables the resource class pattern, where each respective ``on_`` method of a class is
3 | treated like an endpoint for the respective HTTP method. The following shows an example of how the pattern is used::
4 |
5 | class Foo:
6 | def on_get(self, request: Request):
7 | return {"ok": "GET it"}
8 |
9 | def on_post(self, request: Request):
10 | return {"ok": "it was POSTed"}
11 |
12 |
13 | router = Router(dispatcher=resource_dispatcher())
14 | router.add(Resource("/foo", Foo())
15 | """
16 | from typing import Any, Iterable, Optional, Type
17 |
18 | from werkzeug.routing import Map, Rule, RuleFactory
19 |
20 | from .router import route
21 |
22 | _resource_methods = [
23 | "on_head", # it's important that HEAD rules are added first (werkzeug matching order)
24 | "on_get",
25 | "on_post",
26 | "on_put",
27 | "on_patch",
28 | "on_delete",
29 | "on_options",
30 | "on_trace",
31 | ]
32 |
33 |
34 | def resource(path: str, host: Optional[str] = None, **kwargs):
35 | """
36 | Class decorator that turns every method that follows the pattern ``on_`` into a route,
37 | where the allowed method for that route is automatically set to the method specified in the function name. Example
38 | when using a Router with the ``handler_dispatcher``::
39 |
40 | @resource("/myresource/")
41 | class MyResource:
42 | def on_get(request: Request, resource_id: str) -> Response:
43 | return Response(f"GET called on {resource_id}")
44 |
45 | def on_post(request: Request, resource_id: str) -> Response:
46 | return Response(f"POST called on {resource_id}")
47 |
48 | This class can then be added to a router via ``router.add_route_endpoints(MyResource())``.
49 |
50 | Note that, if an on_get method is present in the resource but on_head is not, then HEAD requests are automatically
51 | routed to ``on_get``. This replicates Werkzeug's behavior https://werkzeug.palletsprojects.com/en/2.2.x/routing/.
52 |
53 | :param path: the path pattern to match
54 | :param host: an optional host matching pattern. if not pattern is given, the rule matches any host
55 | :param kwargs: any other argument that can be passed to ``werkzeug.routing.Rule``
56 | :return: a class where each matching function is wrapped as a ``_RouteEndpoint``
57 | """
58 | kwargs.pop("methods", None)
59 |
60 | def _wrapper(cls: Type):
61 | for name in _resource_methods:
62 | member = getattr(cls, name, None)
63 | if member is None:
64 | continue
65 |
66 | http_method = name[3:].upper()
67 | setattr(cls, name, route(path, host, methods=[http_method], **kwargs)(member))
68 |
69 | return cls
70 |
71 | return _wrapper
72 |
73 |
74 | class Resource(RuleFactory):
75 | """
76 | Exposes a given object that follows the "Resource" class pattern as a ``RuleFactory` that can then be added to a
77 | Router. Example use when using a Router with the ``handler_dispatcher``::
78 |
79 | class MyResource:
80 | def on_get(request: Request, resource_id: str) -> Response:
81 | return Response(f"GET called on {resource_id}")
82 |
83 | def on_post(request: Request, resource_id: str) -> Response:
84 | return Response(f"POST called on {resource_id}")
85 |
86 | router.add(Resource("/myresource/", MyResource()))
87 |
88 | Note that, if an on_get method is present in the resource but on_head is not, then HEAD requests are automatically
89 | routed to ``on_get``. This replicates Werkzeug's behavior https://werkzeug.palletsprojects.com/en/2.2.x/routing/.
90 | """
91 |
92 | def __init__(self, path: str, obj: Any, host: Optional[str] = None, **kwargs):
93 | self.path = path
94 | self.obj = obj
95 | self.host = host
96 | self.kwargs = kwargs
97 |
98 | def get_rules(self, map: Map) -> Iterable[Rule]:
99 | rules = []
100 | for name in _resource_methods:
101 | member = getattr(self.obj, name, None)
102 | if member is None:
103 | continue
104 |
105 | http_method = name[3:].upper()
106 | rules.append(
107 | Rule(
108 | self.path, endpoint=member, methods=[http_method], host=self.host, **self.kwargs
109 | )
110 | )
111 | return rules
112 |
--------------------------------------------------------------------------------
/examples/json-rpc/server.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 | import logging
4 | from typing import Callable
5 |
6 | from werkzeug.exceptions import BadRequest
7 | from werkzeug.serving import run_simple
8 |
9 | from rolo import Response
10 | from rolo.gateway import Gateway, HandlerChain, RequestContext
11 | from rolo.gateway.wsgi import WsgiGateway
12 |
13 | LOG = logging.getLogger(__name__)
14 |
15 |
16 | @dataclasses.dataclass
17 | class RpcRequest:
18 | jsonrpc: str
19 | method: str
20 | id: str | int | None
21 | params: dict | list | None = None
22 |
23 |
24 | class RpcError(Exception):
25 | code: int
26 | message: str
27 |
28 |
29 | class ParseError(RpcError):
30 | code = -32700
31 | message = "Parse error"
32 |
33 |
34 | class InvalidRequest(RpcError):
35 | code = -32600
36 | message = "Invalid params"
37 |
38 |
39 | class MethodNotFoundError(RpcError):
40 | code = -32601
41 | message = "Method not found"
42 |
43 |
44 | class InternalError(RpcError):
45 | code = -32603
46 | message = "Internal error"
47 |
48 |
49 | def parse_request(chain: HandlerChain, context: RequestContext, response: Response):
50 | context.rpc_request_id = None
51 |
52 | try:
53 | doc = context.request.get_json()
54 | except BadRequest as e:
55 | raise ParseError() from e
56 |
57 | try:
58 | context.rpc_request_id = doc["id"]
59 | context.rpc_request = RpcRequest(
60 | doc["jsonrpc"],
61 | doc["method"],
62 | doc["id"],
63 | doc.get("params"),
64 | )
65 | except KeyError as e:
66 | raise ParseError() from e
67 |
68 |
69 | def log_request(chain: HandlerChain, context: RequestContext, response: Response):
70 | if context.rpc_request:
71 | LOG.info("RPC request object: %s", context.rpc_request)
72 |
73 |
74 | def serialize_rpc_error(
75 | chain: HandlerChain,
76 | exception: Exception,
77 | context: RequestContext,
78 | response: Response,
79 | ):
80 | if not isinstance(exception, RpcError):
81 | return
82 |
83 | response.set_json(
84 | {
85 | "jsonrpc": "2.0",
86 | "error": {"code": exception.code, "message": exception.message},
87 | "id": context.rpc_request_id,
88 | }
89 | )
90 |
91 |
92 | def log_exception(
93 | chain: HandlerChain,
94 | exception: Exception,
95 | context: RequestContext,
96 | response: Response,
97 | ):
98 | LOG.error("Exception in handler chain", exc_info=exception)
99 |
100 |
101 | class Registry:
102 | methods: dict[str, Callable]
103 |
104 | def __init__(self, methods: dict[str, Callable]):
105 | self.methods = methods
106 |
107 | def __call__(
108 | self, chain: HandlerChain, context: RequestContext, response: Response
109 | ):
110 | try:
111 | context.method = self.methods[context.rpc_request.method]
112 | except KeyError as e:
113 | raise MethodNotFoundError() from e
114 |
115 |
116 | def dispatch(chain: HandlerChain, context: RequestContext, response: Response):
117 | request: RpcRequest = context.rpc_request
118 |
119 | if isinstance(request.params, list):
120 | args = request.params
121 | kwargs = {}
122 | elif isinstance(request.params, dict):
123 | args = []
124 | kwargs = request.params
125 | else:
126 | raise InvalidRequest()
127 |
128 | try:
129 | context.result = context.method(*args, **kwargs)
130 | except RpcError:
131 | # if the method raises an RpcError, just re-raise it since it will be handled later
132 | raise
133 | except Exception as e:
134 | # all other exceptions are considered unhandled and therefore "Internal"
135 | raise InternalError() from e
136 |
137 |
138 | def serialize_result(chain: HandlerChain, context: RequestContext, response: Response):
139 | if not context.rpc_request_id:
140 | # this is a notification, so we don't want to respond
141 | return
142 |
143 | response.set_json(
144 | {
145 | "jsonrpc": "2.0",
146 | "result": json.dumps(context.result),
147 | "id": context.rpc_request_id,
148 | }
149 | )
150 |
151 |
152 | def main():
153 | logging.basicConfig(level=logging.DEBUG)
154 |
155 | def subtract(subtrahend: int, minuend: int):
156 | return subtrahend - minuend
157 |
158 | locate_method = Registry(
159 | {
160 | "subtract": subtract,
161 | }
162 | )
163 |
164 | gateway = Gateway(
165 | request_handlers=[
166 | parse_request,
167 | log_request,
168 | locate_method,
169 | dispatch,
170 | ],
171 | exception_handlers=[
172 | log_exception,
173 | serialize_rpc_error,
174 | ],
175 | )
176 |
177 | run_simple("localhost", 8000, WsgiGateway(gateway))
178 |
179 |
180 | if __name__ == "__main__":
181 | main()
182 |
--------------------------------------------------------------------------------
/tests/test_resource.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from werkzeug.exceptions import MethodNotAllowed
3 |
4 | from rolo import Request, Resource, Response, Router, resource
5 | from rolo.routing import handler_dispatcher
6 |
7 |
8 | class TestResource:
9 | def test_resource_decorator_dispatches_correctly(self):
10 | router = Router(dispatcher=handler_dispatcher())
11 |
12 | requests = []
13 |
14 | @resource("/_localstack/health")
15 | class TestResource:
16 | def on_get(self, req):
17 | requests.append(req)
18 | return "GET/OK"
19 |
20 | def on_post(self, req):
21 | requests.append(req)
22 | return {"ok": "POST"}
23 |
24 | def on_head(self, req):
25 | # this is ignored
26 | requests.append(req)
27 | return "HEAD/OK"
28 |
29 | router.add(TestResource())
30 |
31 | request1 = Request("GET", "/_localstack/health")
32 | request2 = Request("POST", "/_localstack/health")
33 | request3 = Request("HEAD", "/_localstack/health")
34 | assert router.dispatch(request1).get_data(True) == "GET/OK"
35 | assert router.dispatch(request1).get_data(True) == "GET/OK"
36 | assert router.dispatch(request2).json == {"ok": "POST"}
37 | assert router.dispatch(request3).get_data(True) == "HEAD/OK"
38 | assert len(requests) == 4
39 | assert requests[0] is request1
40 | assert requests[1] is request1
41 | assert requests[2] is request2
42 | assert requests[3] is request3
43 |
44 | def test_resource_dispatches_correctly(self):
45 | router = Router(dispatcher=handler_dispatcher())
46 |
47 | class TestResource:
48 | def on_get(self, req):
49 | return "GET/OK"
50 |
51 | def on_post(self, req):
52 | return "POST/OK"
53 |
54 | def on_head(self, req):
55 | return "HEAD/OK"
56 |
57 | router.add(Resource("/_localstack/health", TestResource()))
58 |
59 | request1 = Request("GET", "/_localstack/health")
60 | request2 = Request("POST", "/_localstack/health")
61 | request3 = Request("HEAD", "/_localstack/health")
62 | assert router.dispatch(request1).get_data(True) == "GET/OK"
63 | assert router.dispatch(request2).get_data(True) == "POST/OK"
64 | assert router.dispatch(request3).get_data(True) == "HEAD/OK"
65 |
66 | def test_dispatch_to_non_existing_method_raises_exception(self):
67 | router = Router(dispatcher=handler_dispatcher())
68 |
69 | @resource("/_localstack/health")
70 | class TestResource:
71 | def on_post(self, request):
72 | return "POST/OK"
73 |
74 | router.add(TestResource())
75 |
76 | with pytest.raises(MethodNotAllowed):
77 | assert router.dispatch(Request("GET", "/_localstack/health"))
78 | assert router.dispatch(Request("POST", "/_localstack/health")).get_data(True) == "POST/OK"
79 |
80 | def test_resource_with_default_dispatcher(self):
81 | router = Router()
82 |
83 | @resource("/_localstack/")
84 | class TestResource:
85 | def on_get(self, req, args):
86 | return Response.for_json({"message": "GET/OK", "path": args["path"]})
87 |
88 | def on_post(self, req, args):
89 | return Response.for_json({"message": "POST/OK", "path": args["path"]})
90 |
91 | router.add(TestResource())
92 | assert router.dispatch(Request("GET", "/_localstack/health")).json == {
93 | "message": "GET/OK",
94 | "path": "health",
95 | }
96 | assert router.dispatch(Request("POST", "/_localstack/foobar")).json == {
97 | "message": "POST/OK",
98 | "path": "foobar",
99 | }
100 |
101 | def test_resource_overwrite_with_resource_wrapper(self):
102 | router = Router(dispatcher=handler_dispatcher())
103 |
104 | @resource("/_localstack/health")
105 | class TestResourceHealth:
106 | def on_get(self, req):
107 | return Response.for_json({"message": "GET/OK", "path": req.path})
108 |
109 | def on_post(self, req):
110 | return Response.for_json({"message": "POST/OK", "path": req.path})
111 |
112 | endpoints = TestResourceHealth()
113 | router.add(endpoints)
114 | router.add(Resource("/health", endpoints))
115 |
116 | assert router.dispatch(Request("GET", "/_localstack/health")).json == {
117 | "message": "GET/OK",
118 | "path": "/_localstack/health",
119 | }
120 | assert router.dispatch(Request("POST", "/_localstack/health")).json == {
121 | "message": "POST/OK",
122 | "path": "/_localstack/health",
123 | }
124 |
125 | assert router.dispatch(Request("GET", "/health")).json == {
126 | "message": "GET/OK",
127 | "path": "/health",
128 | }
129 | assert router.dispatch(Request("POST", "/health")).json == {
130 | "message": "POST/OK",
131 | "path": "/health",
132 | }
133 |
--------------------------------------------------------------------------------
/docs/router.md:
--------------------------------------------------------------------------------
1 | Router
2 | ======
3 |
4 | Routers are based on Werkzeug's [URL Map](https://werkzeug.palletsprojects.com/en/2.3.x/routing/), but dispatch to handler functions directly.
5 | All features from Werkzeug's URL routing are inherited, including the [rule string format](https://werkzeug.palletsprojects.com/en/latest/routing/#rule-format) and [type converters](https://werkzeug.palletsprojects.com/en/latest/routing/#built-in-converters).
6 |
7 | `@route`
8 | --------
9 |
10 | The `@route` decorator works similar to Flask or FastAPI, but they are not tied to an Application object.
11 | Instead, you can define routes on functions or methods, and then add them directly to the router.
12 |
13 | ```python
14 | from rolo import Router, route, Response
15 | from werkzeug import Request
16 | from werkzeug.serving import run_simple
17 |
18 | @route("/users")
19 | def list_users(_request: Request, args):
20 | assert not args
21 | return Response("user")
22 |
23 | @route("/users/")
24 | def get_user_by_id(_request: Request, args):
25 | assert args
26 | return Response(f"{args['user_id']}")
27 |
28 | router = Router()
29 | router.add(list_users)
30 | router.add(get_user_by_id)
31 |
32 | # convert Router to a WSGI app and serve it through werkzeug
33 | run_simple('localhost', 8080, router.wsgi(), use_reloader=True)
34 | ```
35 |
36 | Depending on the _dispatcher_ your Router uses, the signature of your endpoints will look differently.
37 |
38 | Handler dispatcher
39 | ------------------
40 |
41 | Routers use dispatchers to dispatch the request to functions.
42 | In the previous example, the default dispatcher calls the function with the `Request` object and the request arguments as dictionary.
43 | The "handler dispatcher" can transform functions into more Flask or FastAPI-like functions, that also allow you to return values that are automatically transformed.
44 |
45 | ```python
46 | from rolo import Router, route
47 | from rolo.routing import handler_dispatcher
48 |
49 | from werkzeug import Request
50 | from werkzeug.serving import run_simple
51 |
52 | @route("/users")
53 | def list_users(request: Request):
54 | # query from db using the ?q= query string
55 | query = request.args["q"]
56 | # ...
57 | return [{"user_id": ...}, ...]
58 |
59 | @route("/users/")
60 | def get_user_by_id(_request: Request, user_id: int):
61 | return {"user_id": user_id, "name": ...}
62 |
63 | router = Router(dispatcher=handler_dispatcher())
64 | router.add(list_users)
65 | router.add(get_user_by_id)
66 |
67 | # convert Router to a WSGI app and serve it through werkzeug
68 | run_simple('localhost', 8080, router.wsgi(), use_reloader=True)
69 | ```
70 |
71 | Using classes
72 | -------------
73 |
74 | Unlike Flask or FastAPI, Rolo allows you to use classes to organize your routes.
75 | The above example can also be written as follows
76 |
77 | ```python
78 | from rolo import Router, route, Request
79 | from rolo.routing import handler_dispatcher
80 |
81 | class UserResource:
82 |
83 | @route("/users/")
84 | def list_users(self, _request: Request):
85 | return "user"
86 |
87 | @route("/users/")
88 | def get_user_by_id(self, _request: Request, user_id: int):
89 | return f"{user_id}"
90 |
91 | router = Router(dispatcher=handler_dispatcher())
92 | router.add(UserResource())
93 | ```
94 | The router will scan the instantiated `UserResource` for `@route` decorators, and add them automatically.
95 |
96 | Resource classes
97 | ----------------
98 |
99 | If you prefer the RESTful style that `Falcon `_ implements, you can use the `@resource` decorator on a class.
100 | This will automatically create routes for all `on_` methods.
101 | Here is an example
102 |
103 |
104 | ```python
105 | from rolo import Router, resource, Request
106 |
107 | @resource("/users/")
108 | class UserResource:
109 |
110 | def on_get(self, request: Request, user_id: int):
111 | return {"user_id": user_id, "user": ...}
112 |
113 | def on_post(self, request: Request, user_id: int):
114 | data = request.json
115 | # ... do something
116 |
117 | router = Router()
118 | router.add(UserResource())
119 | ```
120 |
121 | Pydantic integration
122 | --------------------
123 |
124 | Here's how the default example from the FastAPI documentation would look like with rolo:
125 |
126 | ```python
127 | import pydantic
128 |
129 | from rolo import Request, Router, route
130 |
131 |
132 | class Item(pydantic.BaseModel):
133 | name: str
134 | price: float
135 | is_offer: bool | None = None
136 |
137 |
138 | @route("/", methods=["GET"])
139 | def read_root(request: Request):
140 | return {"Hello": "World"}
141 |
142 |
143 | @route("/items/", methods=["GET"])
144 | def read_item(request: Request, item_id: int):
145 | return {"item_id": item_id, "q": request.query_string}
146 |
147 |
148 | @route("/items/", methods=["PUT"])
149 | def update_item(request: Request, item_id: int, item: Item):
150 | return {"item_name": item.name, "item_id": item_id}
151 |
152 |
153 | router = Router()
154 | router.add(read_root)
155 | router.add(read_item)
156 | router.add(update_item)
157 | ```
158 |
--------------------------------------------------------------------------------
/docs/handler_chain.md:
--------------------------------------------------------------------------------
1 | Handler Chain
2 | =============
3 |
4 | The rolo handler chain implements a variant of the chain-of-responsibility pattern to process an incoming HTTP request.
5 | It is meant to be used together with a [`Gateway`](gateway.md), which is responsible for creating `HandlerChain` instances.
6 |
7 | Handler chains are a powerful abstraction to create complex HTTP server behavior, while keeping code cleanly encapsulated and the high-level logic easy to understand.
8 | You can find a simple example how to create a handler chain in the [Getting Started](getting_started.md) guide.
9 |
10 | ## Behavior
11 |
12 | A handler chain consists of:
13 | * request handlers: process the request and attempt to create an initial response
14 | * response handlers: process the response
15 | * finalizers: handlers that are always executed at the end of running a handler chain
16 | * exception handlers: run when an exception occurs during the execution of a handler
17 |
18 | Each HTTP request coming into the server has its own `HandlerChain` instance, since the handler chain holds state for the handling of a request.
19 | A handler chain can be in three states that can be controlled by the handlers.
20 |
21 | * Running - the implicit state in which _all_ handlers are executed sequentially
22 | * Stopped - a handler has called `chain.stop()`. This stops the execution of all request handlers, and
23 | proceeds immediately to executing the response handlers. Response handlers and finalizers will be run,
24 | even if the chain has been stopped.
25 | * Terminated - a handler has called `chain.terminate()`. This stops the execution of all request
26 | handlers, and all response handlers, but runs the finalizers at the end.
27 |
28 | If an exception occurs during the execution of request handlers, the chain by default stops the chain,
29 | then runs each exception handler, and finally runs the response handlers.
30 | Exceptions that happen during the execution of response or exception handlers are logged but do not modify the control flow of the chain.
31 |
32 | ## Request Context
33 |
34 | The `RequestContext` object holds the HTTP `Request` object in `context.request`, as well as any arbitrary data you would like to pass down to other handlers.
35 | It's a universal attribute store, so you can simply call: `context.myattr = "foo"` to set a value.
36 | You can add type hints for your request context, see [gateway](gateway.md).
37 |
38 | ## Handlers
39 |
40 | Request handlers, response handlers, and finalizers need to satisfy the `Handler` protocol:
41 |
42 | ```python
43 | from rolo import Response
44 | from rolo.gateway import HandlerChain, RequestContext
45 |
46 | def handle(chain: HandlerChain, context: RequestContext, response: Response):
47 | ...
48 | ```
49 |
50 | * `chain`: the HandlerChain instance currently being executed. The handler implementation can call for example `chain.stop()` to indicate that it should skip all other request handlers.
51 | * `context`: the RequestContext contains the rolo `Request` object, as well as a universal property store. You can simply call `context.myattr = ...` to pass a value down to the next handler
52 | * `response`: Handlers of a handler chain don't return a response, instead the response being populated is handed down from handler to handler, and can thus be enriched
53 |
54 | ### Exception Handlers
55 |
56 | Exception handlers are similar, only they are also passed the `Exception` that was raised in the handler chain.
57 |
58 | ```python
59 | from rolo import Response
60 | from rolo.gateway import HandlerChain, RequestContext
61 |
62 | def handle(chain: HandlerChain, exception: Exception, context: RequestContext, response: Response):
63 | ...
64 | ```
65 |
66 | ## Builtin Handlers
67 |
68 | ### Router handler
69 |
70 | Sometimes you have a `Gateway` but also want to use the [`Router`](router.md).
71 | You can use the `RouterHandler` adapter to make a `Router` look like a handler chain `Handler`, and then pass it as handler to a Gateway.
72 |
73 | ```python
74 | from rolo import Router
75 | from rolo.gateway import Gateway
76 | from rolo.gateway.handlers import RouterHandler
77 |
78 | router: Router = ...
79 | gateway: Gateway = ...
80 |
81 | gateway.request_handlers.append(RouterHandler(router))
82 | ```
83 |
84 | ### Empty response handler
85 |
86 | With the `EmptyResponseHandler` response handler automatically creates a default response if the response in the chain is empty.
87 | By default, it creates an empty 404 response, but it can be customized:
88 |
89 | ```python
90 | from rolo.gateway.handlers import EmptyResponseHandler
91 |
92 | gateway.response_handlers.append(EmptyResponseHandler(status_code=404, body=b'404 Not Found'))
93 | ```
94 |
95 | ### Werkzeug exception handler
96 |
97 | Werkzeug has a very useful [HTTP exception hierarchy](https://werkzeug.palletsprojects.com/en/latest/exceptions/) that can be used to programmatically trigger HTTP errors.
98 | For instance, a request handler may raise a `NotFound` error.
99 | To get the Gateway to automatically handle those exceptions and render them into JSON objects or HTML, you can use the `WerkzeugExceptionHandler`.
100 |
101 | ```python
102 | from rolo.gateway.handlers import WerkzeugExceptionHandler
103 |
104 | gateway.exception_handlers.append(WerkzeugExceptionHandler(output_format="json"))
105 | ```
106 |
107 | In your request handler you can now raise any exception from `werkzeug.exceptions` and it will be rendered accordingly.
108 |
--------------------------------------------------------------------------------
/rolo/websocket/adapter.py:
--------------------------------------------------------------------------------
1 | """Adapter API between high-level rolo.websocket.request and an underlying IO frameworks like ASGI or
2 | twisted."""
3 | import dataclasses
4 | import typing as t
5 |
6 | from werkzeug.datastructures import Headers
7 |
8 | WebSocketEnvironment: t.TypeAlias = t.Dict[str, t.Any]
9 | """Special WSGIEnvironment that has a ``rolo.websocket`` key that stores a `Websocket` instance."""
10 |
11 |
12 | class Event:
13 | """A websocket event (subset of ``wsproto.events``)."""
14 |
15 | pass
16 |
17 |
18 | @dataclasses.dataclass
19 | class Message(Event):
20 | data: bytes | str
21 |
22 |
23 | @dataclasses.dataclass
24 | class TextMessage(Message):
25 | data: str
26 |
27 |
28 | @dataclasses.dataclass
29 | class BytesMessage(Message):
30 | data: bytes
31 |
32 |
33 | @dataclasses.dataclass
34 | class CreateConnection(Event):
35 | """
36 | This indicates the first event of the websocket after a connection upgrade. For example, in wsproto
37 | this corresponds to a ``Request`` event, or ``websocket.connect`` event in ASGI.
38 | """
39 |
40 | pass
41 |
42 |
43 | @dataclasses.dataclass
44 | class AcceptConnection(Event):
45 | subprotocol: t.Optional[str] = None
46 | extensions: list[str] = dataclasses.field(default_factory=list)
47 | extra_headers: list[tuple[bytes, bytes]] = dataclasses.field(default_factory=list)
48 |
49 |
50 | class WebSocketAdapter:
51 | """
52 | Adapter to plug the high-level interfaces ``WebSocket`` and ``WebSocketRequest`` into an IO framework.
53 | It doesn't cover the full websocket protocol API (for instance there are no Ping/Pong events),
54 | under the assumption that the lower-level IO framework will abstract them away.
55 | """
56 |
57 | def accept(
58 | self,
59 | subprotocol: str = None,
60 | extensions: list[str] = None,
61 | extra_headers: Headers = None,
62 | timeout: float = None,
63 | ):
64 | """
65 | Accept the websocket upgrade request and send an accept message back to the client. This or
66 | ``reject`` must be the first things to be called.
67 |
68 | :param subprotocol: the accepted subprotocol
69 | :param extensions: any accepted extensions to use
70 | :param extra_headers: headers to pass to the accept response
71 | :param timeout: optional timeout
72 | """
73 | raise NotImplementedError
74 |
75 | def reject(
76 | self,
77 | status_code: int,
78 | headers: Headers = None,
79 | body: t.Iterable[bytes] = None,
80 | timeout: float = None,
81 | ):
82 | """
83 | Reject the websocket request. This means sending an actual HTTP response back to the client, i.e.,
84 | not upgrading the connection. This only makes sense before any call to ``receive`` was made.
85 |
86 | :param status_code: the HTTP response status code
87 | :param headers: the HTTP response headers
88 | :param body: the body
89 | :param timeout: optional timeout
90 | """
91 | raise NotImplementedError
92 |
93 | def receive(
94 | self,
95 | timeout: float = None,
96 | ) -> CreateConnection | Message:
97 | """Blocking IO method to wait for the next ``Message`` or, if not initialized yet, the first
98 | ``CreateConnection`` event."""
99 | raise NotImplementedError
100 |
101 | def send(
102 | self,
103 | event: Message,
104 | timeout: float = None,
105 | ):
106 | """
107 | Send the given message to the websocket.
108 |
109 | :param event: the message to send
110 | :param timeout: optional timeout
111 | :return:
112 | """
113 | raise NotImplementedError
114 |
115 | def close(self, code: int = 1001, reason: str = None, timeout: float = None):
116 | """
117 | If the underlying websocket connection has already been closed, this call is ignore, so it's safe
118 | to always call.
119 | """
120 | raise NotImplementedError
121 |
122 |
123 | class WebSocketListener(t.Protocol):
124 | """
125 | Similar protocol to a WSGIApplication, only it expects a Websocket instead of a WSGIEnvironment.
126 | """
127 |
128 | def __call__(self, environ: WebSocketEnvironment):
129 | """
130 | Called when a new Websocket connection is established. To initiate the connection, you need to perform
131 | the connect handshake yourself. First, receive the ``websocket.connect`` event, and then send the
132 | ``websocket.accept`` event. Here's a minimal example::
133 |
134 | def accept(self, environ: WebsocketEnvironment):
135 | websocket: WebSocketAdapter = environ['rolo.websocket']
136 | event = websocket.receive()
137 | if isinstance(event, CreateConnection):
138 | websocket.accept()
139 | else:
140 | websocket.close(1002) # protocol error
141 | return
142 |
143 | while True:
144 | event = websocket.receive()
145 | if isinstance(event, CloseConnection):
146 | return
147 | print(event)
148 |
149 | In reality, you wouldn't be using the websocket adapter directly, the server would probably create a
150 | ``rolo.websocket.WebSocketRequest`` and serve it accordingly through a ``Gateway``.
151 |
152 | :param environ: The new Websocket environment
153 | """
154 | raise NotImplementedError
155 |
--------------------------------------------------------------------------------
/rolo/response.py:
--------------------------------------------------------------------------------
1 | import json
2 | import mimetypes
3 | import typing as t
4 | from importlib import resources
5 |
6 | from werkzeug.exceptions import NotFound
7 | from werkzeug.wrappers import Response as WerkzeugResponse
8 |
9 | if t.TYPE_CHECKING:
10 | from types import ModuleType
11 |
12 |
13 | class _StreamIterableWrapper(t.Iterable[bytes]):
14 | """
15 | This can wrap an IO[bytes] stream to return an Iterable with a default chunk size of 65536 bytes
16 | """
17 |
18 | def __init__(self, stream: t.IO[bytes], chunk_size: int = 65536):
19 | self.stream = stream
20 | self._chunk_size = chunk_size
21 |
22 | def __iter__(self) -> t.Iterator[bytes]:
23 | """
24 | When passing a stream back to the WSGI server, it will often iterate only 1 byte at a time. Using this chunking
25 | mechanism allows us to bypass this issue.
26 | The caller needs to call `close()` to properly close the file descriptor
27 | :return:
28 | """
29 | while data := self.stream.read(self._chunk_size):
30 | if not data:
31 | return b""
32 |
33 | yield data
34 |
35 | def close(self):
36 | if hasattr(self.stream, "close"):
37 | self.stream.close()
38 |
39 |
40 | class Response(WerkzeugResponse):
41 | """
42 | An HTTP Response object, which simply extends werkzeug's Response object with a few convenience methods.
43 | """
44 |
45 | def update_from(self, other: WerkzeugResponse):
46 | """
47 | Updates this response object with the data from the given response object. It reads the status code,
48 | the response data, and updates its own headers (overwrites existing headers, but does not remove ones
49 | not present in the given object). Also updates ``call_on_close`` callbacks in the same way.
50 |
51 | :param other: the response object to read from
52 | """
53 | self.status_code = other.status_code
54 | self.response = other.response
55 | self._on_close.extend(other._on_close)
56 | self.headers.update(other.headers)
57 |
58 | def set_json(self, doc: t.Any, cls: t.Type[json.JSONEncoder] = None):
59 | """
60 | Serializes the given dictionary using localstack's ``CustomEncoder`` into a json response, and sets the
61 | mimetype automatically to ``application/json``.
62 |
63 | :param doc: the response dictionary to be serialized as JSON
64 | :param cls: the JSON encoder class to use for serializing the passed document
65 | """
66 | self.data = json.dumps(doc, cls=cls)
67 | self.mimetype = "application/json"
68 |
69 | def set_response(self, response: t.Union[str, bytes, bytearray, t.Iterable[bytes]]):
70 | """
71 | Function to set the low-level ``response`` object. This is copied from the werkzeug Response constructor. The
72 | response attribute always holds an iterable of bytes. Passing a str, bytes or bytearray is equivalent to
73 | calling ``response.data = ``. If None is passed, then it will create an empty list. If anything
74 | else is passed, the value is set directly. This value can be a list of bytes, and iterator that returns bytes
75 | (e.g., a generator), which can be used by the underlying server to stream responses to the client. Anything else
76 | (like passing dicts) will result in errors at lower levels of the server.
77 |
78 | :param response: the response value
79 | """
80 | if response is None:
81 | self.response = []
82 | elif isinstance(response, (str, bytes, bytearray)):
83 | self.data = response
84 | else:
85 | self.response = response
86 |
87 | return self
88 |
89 | def to_readonly_response_dict(self) -> t.Dict:
90 | """
91 | Returns a read-only version of a response dictionary as it is often expected by other libraries like boto.
92 | """
93 | return {
94 | "body": self.stream if self.is_streamed else self.data,
95 | "status_code": self.status_code,
96 | "headers": dict(self.headers),
97 | }
98 |
99 | @classmethod
100 | def for_json(cls, doc: t.Any, *args, **kwargs) -> "Response":
101 | """
102 | Creates a new JSON response from the given document. It automatically sets the mimetype to ``application/json``.
103 |
104 | :param doc: the document to serialize into JSON
105 | :param args: arguments passed to the ``Response`` constructor
106 | :param kwargs: keyword arguments passed to the ``Response`` constructor
107 | :return: a new Response object
108 | """
109 | response = cls(*args, **kwargs)
110 | response.set_json(doc)
111 | return response
112 |
113 | @classmethod
114 | def for_resource(cls, module: "ModuleType", path: str, *args, **kwargs) -> "Response":
115 | """
116 | Looks up the given file in the given module, and creates a new Response object with the contents of that
117 | file. It guesses the mimetype of the file and sets it in the response accordingly. If the file does not exist
118 | ,it raises a ``NotFound`` error.
119 |
120 | :param module: the module to look up the file in
121 | :param path: the path/file name
122 | :return: a new Response object
123 | """
124 | resource = resources.files(module).joinpath(path)
125 | if not resource.is_file():
126 | raise NotFound()
127 | mimetype = mimetypes.guess_type(resource.name)
128 | mimetype = mimetype[0] if mimetype and mimetype[0] else "application/octet-stream"
129 |
130 | return cls(_StreamIterableWrapper(resource.open("rb")), *args, mimetype=mimetype, **kwargs)
131 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Rolo HTTP: A Python framework for building HTTP-based server applications.
6 |
7 |
8 | # Rolo HTTP
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | Rolo is a flexible framework and library to build HTTP-based server applications beyond microservices and REST APIs.
19 | You can build HTTP-based RPC servers, websocket proxies, or other server types that typical web frameworks are not designed for.
20 |
21 | Rolo extends [Werkzeug](https://github.com/pallets/werkzeug/), a flexible Python HTTP server library, for you to use concepts you are familiar with like `Router`, `Request`, `Response`, or `@route`.
22 | It introduces the concept of a `Gateway` and `HandlerChain`, an implementation variant of the [chain-of-responsibility pattern](https://en.wikipedia.org/wiki/Chain-of-responsibility_pattern).
23 |
24 | Rolo is designed for environments that do not use asyncio, but still require asynchronous HTTP features like HTTP2 SSE or Websockets.
25 | To allow asynchronous communication, Rolo introduces an ASGI/WSGI bridge, that allows you to serve Rolo applications through ASGI servers like Hypercorn.
26 |
27 | ## Usage
28 |
29 | ### Default router example
30 |
31 | Routers are based on Werkzeug's [URL Map](https://werkzeug.palletsprojects.com/en/2.3.x/routing/), but dispatch to handler functions directly.
32 | The `@route` decorator works similar to Flask or FastAPI, but they are not tied to an Application object.
33 | Instead, you can define routes on functions or methods, and then add them directly to the router.
34 |
35 | ```python
36 | from rolo import Router, route, Response
37 | from werkzeug import Request
38 | from werkzeug.serving import run_simple
39 |
40 | @route("/users")
41 | def user(_request: Request, args):
42 | assert not args
43 | return Response("user")
44 |
45 | @route("/users/")
46 | def user_id(_request: Request, args):
47 | assert args
48 | return Response(f"{args['user_id']}")
49 |
50 | router = Router()
51 | router.add(user)
52 | router.add(user_id)
53 |
54 | # convert Router to a WSGI app and serve it through werkzeug
55 | run_simple('localhost', 8080, router.wsgi(), use_reloader=True)
56 | ```
57 |
58 | ### Pydantic integration
59 |
60 | Routers use dispatchers to dispatch the request to functions.
61 | In the previous example, the default dispatcher calls the function with the `Request` object and the request arguments as dictionary.
62 | The "handler dispatcher" can transform functions into more Flask or FastAPI-like functions, that also allow you to integrate with Pydantic.
63 | Here's how the default example from the FastAPI documentation would look like with rolo:
64 |
65 | ```python
66 | import pydantic
67 | from werkzeug import Request
68 | from werkzeug.serving import run_simple
69 |
70 | from rolo import Router, route
71 |
72 |
73 | class Item(pydantic.BaseModel):
74 | name: str
75 | price: float
76 | is_offer: bool | None = None
77 |
78 |
79 | @route("/", methods=["GET"])
80 | def read_root(request: Request):
81 | return {"Hello": "World"}
82 |
83 |
84 | @route("/items/", methods=["GET"])
85 | def read_item(request: Request, item_id: int):
86 | return {"item_id": item_id, "q": request.query_string}
87 |
88 |
89 | @route("/items/", methods=["PUT"])
90 | def update_item(request: Request, item_id: int, item: Item):
91 | return {"item_name": item.name, "item_id": item_id}
92 |
93 |
94 | router = Router()
95 | router.add(read_root)
96 | router.add(read_item)
97 | router.add(update_item)
98 |
99 | # convert Router to a WSGI app and serve it through werkzeug
100 | run_simple("localhost", 8080, router.wsgi(), use_reloader=True)
101 | ```
102 |
103 | ### Gateway & Handler Chain
104 |
105 | A rolo `Gateway` holds a set of request, response, and exception handlers, as well as request finalizers.
106 | Gateway instances can then be served as WSGI or ASGI applications by using the appropriate serving adapter.
107 | Here is a simple example of a Gateway with just one handler that returns the URL and method that was invoked.
108 |
109 | ```python
110 | from werkzeug import run_simple
111 |
112 | from rolo import Response
113 | from rolo.gateway import Gateway, HandlerChain, RequestContext
114 | from rolo.gateway.wsgi import WsgiGateway
115 |
116 |
117 | def echo_handler(chain: HandlerChain, context: RequestContext, response: Response):
118 | response.status_code = 200
119 | response.set_json({"url": context.request.url, "method": context.request.method})
120 |
121 |
122 | gateway = Gateway(request_handlers=[echo_handler])
123 |
124 | app = WsgiGateway(gateway)
125 | run_simple("localhost", 8080, app, use_reloader=True)
126 | ```
127 |
128 | Serving this will yield:
129 |
130 | ```console
131 | curl http://localhost:8080/hello-world
132 | {"url": "http://localhost:8080/hello-world", "method": "GET"}
133 | ```
134 |
135 |
136 | ## Develop
137 |
138 | ### Quickstart
139 |
140 | to install the python and other developer requirements into a venv run:
141 |
142 | make install
143 |
144 | ### Format code
145 |
146 | We use black and isort as code style tools.
147 | To execute them, run:
148 |
149 | make format
150 |
151 | ### Build distribution
152 |
153 | To build a wheel and source distribution, simply run
154 |
155 | make dist
156 |
--------------------------------------------------------------------------------
/tests/websocket/test_websockets.py:
--------------------------------------------------------------------------------
1 | import json
2 | import threading
3 | from queue import Queue
4 |
5 | import pytest
6 | import websocket
7 | from _pytest.fixtures import SubRequest
8 | from werkzeug.datastructures import Headers
9 |
10 | from rolo import Response, Router
11 | from rolo.websocket.request import (
12 | WebSocketDisconnectedError,
13 | WebSocketProtocolError,
14 | WebSocketRequest,
15 | )
16 |
17 |
18 | @pytest.fixture(params=["asgi", "twisted"])
19 | def serve_websocket_listener(request: SubRequest):
20 | def _serve(listener):
21 | if request.param == "asgi":
22 | srv = request.getfixturevalue("serve_asgi_adapter")
23 | return srv(wsgi_app=None, websocket_listener=listener)
24 | else:
25 | srv = request.getfixturevalue("serve_twisted_websocket_listener")
26 | return srv(listener)
27 |
28 | yield _serve
29 |
30 |
31 | def test_websocket_basic_interaction(serve_websocket_listener):
32 | raised = threading.Event()
33 |
34 | @WebSocketRequest.listener
35 | def app(request: WebSocketRequest):
36 | with request.accept() as ws:
37 | ws.send("hello")
38 | assert ws.receive() == "foobar"
39 | ws.send("world")
40 |
41 | with pytest.raises(WebSocketDisconnectedError):
42 | ws.receive()
43 |
44 | raised.set()
45 |
46 | server = serve_websocket_listener(app)
47 |
48 | client = websocket.WebSocket()
49 | client.connect(server.url.replace("http://", "ws://"))
50 | assert client.recv() == "hello"
51 | client.send("foobar")
52 | assert client.recv() == "world"
53 | client.close()
54 |
55 | assert raised.wait(timeout=3)
56 |
57 |
58 | def test_websocket_disconnect_while_iter(serve_websocket_listener):
59 | """Makes sure that the ``for line in iter(ws)`` pattern works smoothly when the client disconnects."""
60 | returned = threading.Event()
61 | received = []
62 |
63 | @WebSocketRequest.listener
64 | def app(request: WebSocketRequest):
65 | with request.accept() as ws:
66 | for line in iter(ws):
67 | received.append(line)
68 |
69 | returned.set()
70 |
71 | server = serve_websocket_listener(app)
72 |
73 | client = websocket.WebSocket()
74 | client.connect(server.url.replace("http://", "ws://"))
75 |
76 | client.send("foo")
77 | client.send("bar")
78 | client.close()
79 |
80 | assert returned.wait(timeout=3)
81 | assert received[0] == "foo"
82 | assert received[1] == "bar"
83 |
84 |
85 | def test_websocket_headers(serve_websocket_listener):
86 | @WebSocketRequest.listener
87 | def echo_headers(request: WebSocketRequest):
88 | with request.accept(headers=Headers({"x-foo-bar": "foobar"})) as ws:
89 | ws.send(json.dumps(dict(request.headers)))
90 |
91 | server = serve_websocket_listener(echo_headers)
92 |
93 | client = websocket.WebSocket()
94 | client.connect(
95 | server.url.replace("http://", "ws://"),
96 | header=["Authorization: Basic let-me-in", "CasedHeader: hello"],
97 | )
98 |
99 | assert client.handshake_response.status == 101
100 | assert client.getheaders()["x-foo-bar"] == "foobar"
101 | doc = client.recv()
102 | headers = json.loads(doc)
103 | assert headers["Connection"] == "Upgrade"
104 | assert headers["Authorization"] == "Basic let-me-in"
105 | assert headers["CasedHeader"] == "hello"
106 |
107 |
108 | def test_websocket_reject(serve_websocket_listener):
109 | @WebSocketRequest.listener
110 | def respond(request: WebSocketRequest):
111 | request.reject(Response("nope", 403))
112 |
113 | server = serve_websocket_listener(respond)
114 |
115 | socket = websocket.WebSocket()
116 | with pytest.raises(websocket.WebSocketBadStatusException) as e:
117 | socket.connect(server.url.replace("http://", "ws://"))
118 |
119 | assert e.value.status_code == 403
120 | assert e.value.resp_body == b"nope"
121 |
122 |
123 | def test_binary_and_text_mode(serve_websocket_listener):
124 | received = Queue()
125 |
126 | @WebSocketRequest.listener
127 | def echo_headers(request: WebSocketRequest):
128 | with request.accept() as ws:
129 | ws.send(b"foo")
130 | ws.send("textfoo")
131 | received.put(ws.receive())
132 | received.put(ws.receive())
133 |
134 | server = serve_websocket_listener(echo_headers)
135 |
136 | client = websocket.WebSocket()
137 | client.connect(server.url.replace("http://", "ws://"))
138 |
139 | assert client.handshake_response.status == 101
140 | data = client.recv()
141 | assert data == b"foo"
142 |
143 | data = client.recv()
144 | assert data == "textfoo"
145 |
146 | client.send("textbar")
147 | client.send_binary(b"bar")
148 |
149 | assert received.get(timeout=5) == "textbar"
150 | assert received.get(timeout=5) == b"bar"
151 |
152 |
153 | def test_send_non_confirming_data(serve_websocket_listener):
154 | match = Queue()
155 |
156 | @WebSocketRequest.listener
157 | def echo_headers(request: WebSocketRequest):
158 | with request.accept() as ws:
159 | with pytest.raises(WebSocketProtocolError) as e:
160 | ws.send({"foo": "bar"})
161 | match.put(e)
162 |
163 | server = serve_websocket_listener(echo_headers)
164 |
165 | client = websocket.WebSocket()
166 | client.connect(server.url.replace("http://", "ws://"))
167 |
168 | e = match.get(timeout=5)
169 | assert e.match("Cannot send data type over websocket")
170 |
171 |
172 | def test_router_integration(serve_websocket_listener):
173 | router = Router()
174 |
175 | def _handler(request: WebSocketRequest, request_args: dict):
176 | with request.accept() as ws:
177 | ws.send("foo")
178 | ws.send(f"id={request_args['id']}")
179 | ws.send(json.dumps(dict(request.headers)))
180 |
181 | router.add("/foo/", _handler)
182 |
183 | server = serve_websocket_listener(WebSocketRequest.listener(router.dispatch))
184 | client = websocket.WebSocket()
185 | client.connect(
186 | server.url.replace("http://", "ws://") + "/foo/bar", header=["CasedHeader: hello"]
187 | )
188 | assert client.recv() == "foo"
189 | assert client.recv() == "id=bar"
190 | assert "CasedHeader" in json.loads(client.recv())
191 |
--------------------------------------------------------------------------------
/rolo/client.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from urllib.parse import urlparse
3 |
4 | import requests
5 | import urllib3.util
6 | from werkzeug import Request, Response
7 | from werkzeug.datastructures import Headers
8 |
9 | from .request import get_raw_base_url, get_raw_current_url, get_raw_path, restore_payload
10 |
11 |
12 | class HttpClient(abc.ABC):
13 | """
14 | An HTTP client that can make http requests using werkzeug's request object.
15 | """
16 |
17 | @abc.abstractmethod
18 | def request(self, request: Request, server: str | None = None) -> Response:
19 | """
20 | Make the given HTTP as a client.
21 |
22 | :param request: the request to make
23 | :param server: the URL to send the request to, which defaults to the host component of the original Request.
24 | :return: the response.
25 | """
26 | raise NotImplementedError
27 |
28 | @abc.abstractmethod
29 | def close(self):
30 | """
31 | Close any underlying resources the client may need.
32 | """
33 | pass
34 |
35 | def __enter__(self):
36 | return self
37 |
38 | def __exit__(self, *args):
39 | self.close()
40 |
41 |
42 | class _VerifyRespectingSession(requests.Session):
43 | """
44 | A class which wraps requests.Session to circumvent https://github.com/psf/requests/issues/3829.
45 | This ensures that if `REQUESTS_CA_BUNDLE` or `CURL_CA_BUNDLE` are set, the request does not perform the TLS
46 | verification if `session.verify` is set to `False.
47 | """
48 |
49 | def merge_environment_settings(self, url, proxies, stream, verify, *args, **kwargs):
50 | if self.verify is False:
51 | verify = False
52 |
53 | return super(_VerifyRespectingSession, self).merge_environment_settings(
54 | url, proxies, stream, verify, *args, **kwargs
55 | )
56 |
57 |
58 | class SimpleRequestsClient(HttpClient):
59 | session: requests.Session
60 | follow_redirects: bool
61 |
62 | def __init__(self, session: requests.Session = None, follow_redirects: bool = True):
63 | self.session = session or _VerifyRespectingSession()
64 | self.follow_redirects = follow_redirects
65 |
66 | @staticmethod
67 | def _get_destination_url(request: Request, server: str | None = None) -> str:
68 | if server:
69 | # accepts "http://localhost:5000" or "localhost:5000"
70 | if "://" in server:
71 | parts = urlparse(server)
72 | scheme, server = parts.scheme, parts.netloc
73 | else:
74 | scheme = request.scheme
75 | return get_raw_current_url(scheme, server, request.root_path, get_raw_path(request))
76 |
77 | return get_raw_base_url(request)
78 |
79 | @staticmethod
80 | def _transform_response_headers(response: requests.Response) -> Headers:
81 | """
82 | `requests` by default concatenate headers in response under a single header separated by a comma
83 | This behavior is generally the same as having the same header multiple times with different values in a
84 | response.
85 | However, specific headers like `Set-Cookie` needs to be defined multiple times. By directly using the raw
86 | `urllib3` response that still contains non-concatenate values, we can follow more closely the response.
87 | """
88 | headers = Headers()
89 | for k, v in response.raw.headers.iteritems():
90 | headers.add(k, v)
91 | return headers
92 |
93 | def request(self, request: Request, server: str | None = None) -> Response:
94 | """
95 | Very naive implementation to make the given HTTP request using the requests library, i.e., process the request
96 | as a client.
97 |
98 | :param request: the request to perform
99 | :param server: the URL to send the request to, which defaults to the host component of the original Request.
100 | :param allow_redirects: allow the request to follow redirects
101 | :return: the response.
102 | """
103 |
104 | url = self._get_destination_url(request, server)
105 |
106 | headers = dict(request.headers.items())
107 |
108 | # urllib3 (used by requests) will set an Accept-Encoding header ("gzip,deflate")
109 | # - See urllib3.util.request.ACCEPT_ENCODING
110 | # - The solution to this, provided by urllib3, is to use `urllib3.util.SKIP_HEADER`
111 | # to prevent the header from being added.
112 | if not request.headers.get("accept-encoding"):
113 | headers["accept-encoding"] = urllib3.util.SKIP_HEADER
114 |
115 | response = self.session.request(
116 | method=request.method,
117 | # use raw base url to preserve path url encoding
118 | url=url,
119 | # request.args are only the url parameters
120 | params=list(request.args.items(multi=True)),
121 | headers=headers,
122 | data=restore_payload(request),
123 | stream=True,
124 | allow_redirects=self.follow_redirects,
125 | )
126 |
127 | if request.method == "HEAD":
128 | # for HEAD requests we have to keep the original content-length, but it will be re-calculated when creating
129 | # the final_response object
130 | final_response = Response(
131 | response=response.content,
132 | status=response.status_code,
133 | headers=self._transform_response_headers(response),
134 | )
135 | final_response.content_length = response.headers.get("Content-Length", 0)
136 | return final_response
137 |
138 | response_headers = self._transform_response_headers(response)
139 |
140 | if "chunked" in (transfer_encoding := response_headers.get("Transfer-Encoding", "")):
141 | response_headers.pop("Content-Length", None)
142 | # We should not set `Transfer-Encoding` in a Response, because it is the responsibility of the webserver
143 | # to do so, if there are no Content-Length. However, gzip behavior is more related to the actual content of
144 | # the response, so we keep that one.
145 | transfer_encoding_values = [v.strip() for v in transfer_encoding.split(",")]
146 | transfer_encoding_no_chunked = [
147 | v for v in transfer_encoding_values if v.lower() != "chunked"
148 | ]
149 | response_headers.setlist("Transfer-Encoding", transfer_encoding_no_chunked)
150 |
151 | final_response = Response(
152 | response=(chunk for chunk in response.raw.stream(1024, decode_content=False)),
153 | status=response.status_code,
154 | headers=response_headers,
155 | )
156 |
157 | return final_response
158 |
159 | def close(self):
160 | self.session.close()
161 |
162 |
163 | def make_request(request: Request) -> Response:
164 | """
165 | Convenience method to make the given HTTP as a client.
166 |
167 | :param request: the request to make
168 | :return: the response.
169 | """
170 | with SimpleRequestsClient() as client:
171 | return client.request(request)
172 |
--------------------------------------------------------------------------------
/rolo/proxy.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from typing import Mapping, Union
3 | from urllib.parse import urlparse
4 |
5 | from werkzeug import Request, Response
6 | from werkzeug.datastructures import Headers
7 | from werkzeug.test import EnvironBuilder
8 |
9 | from .client import HttpClient, SimpleRequestsClient
10 | from .request import get_raw_path, restore_payload, set_environment_headers
11 |
12 |
13 | def forward(
14 | request: Request,
15 | forward_base_url: str,
16 | forward_path: str = None,
17 | headers: Union[Headers, Mapping[str, str]] = None,
18 | ) -> Response:
19 | """
20 | Convenience method that creates a new Proxy and immediately calls proxy.forward(...). See ``Proxy`` for more
21 | information.
22 | """
23 | with Proxy(forward_base_url=forward_base_url) as proxy:
24 | return proxy.forward(request, forward_path=forward_path, headers=headers)
25 |
26 |
27 | class Proxy(HttpClient):
28 | preserve_host: bool
29 |
30 | def __init__(
31 | self, forward_base_url: str, client: HttpClient = None, preserve_host: bool = True
32 | ):
33 | """
34 | Creates a new HTTP Proxy which can be used to forward incoming requests according to the configuration.
35 |
36 | :param forward_base_url: the base url (backend) to forward the requests to.
37 | :param client: the HTTP Client used to make the requests
38 | :param preserve_host: True to ensure that the Host header of the incoming request is preserved.
39 | If False, then the Host header will be set to the Host from the perspective of the Proxy.
40 | """
41 | self.forward_base_url = forward_base_url
42 | self.client = client or SimpleRequestsClient()
43 | self.preserve_host = preserve_host
44 |
45 | def request(self, request: Request, server: str | None = None) -> Response:
46 | """
47 | Compatibility with HttpClient interface. A call is equivalent to ``Proxy.forward(request, None, None)``.
48 |
49 | :param request: the request to proxy
50 | :param server: ignored for a proxy, since the server is already set by `forward_base_url`.
51 | :return: the proxied response
52 | """
53 | return self.forward(request)
54 |
55 | def forward(
56 | self,
57 | request: Request,
58 | forward_path: str = None,
59 | headers: Union[Headers, Mapping[str, str]] = None,
60 | ) -> Response:
61 | """
62 | Uses the client to forward the given request according to the proxy's configuration.
63 |
64 | :param request: the base request to forward (with the original URL and path data)
65 | :param forward_path: the path to forward the request to. if set, the original path will be replaced completely,
66 | otherwise the original path will be used
67 | :param headers: additional custom headers to send as part of the proxy request
68 | :return: the proxied response
69 | """
70 | headers = Headers(headers) if headers else Headers()
71 |
72 | if client_ip := request.remote_addr:
73 | if xff := request.headers.get("X-Forwarded-For"):
74 | headers["X-Forwarded-For"] = f"{xff}, {client_ip}"
75 | else:
76 | headers["X-Forwarded-For"] = f"{client_ip}"
77 |
78 | if forward_path is None:
79 | forward_path = get_raw_path(request)
80 | if forward_path:
81 | forward_path = "/" + forward_path.lstrip("/")
82 |
83 | proxy_request = _copy_request(request, self.forward_base_url, forward_path, headers)
84 |
85 | if self.preserve_host and "Host" in request.headers:
86 | proxy_request.headers["Host"] = request.headers["Host"]
87 |
88 | target = urlparse(self.forward_base_url)
89 | return self.client.request(proxy_request, server=f"{target.scheme}://{target.netloc}")
90 |
91 | def close(self):
92 | self.client.close()
93 |
94 |
95 | class ProxyHandler:
96 | """
97 | A dispatcher Handler which can be used in a ``Router[Handler]`` that proxies incoming requests according to the
98 | configuration.
99 |
100 | The Handler is expected to be used together with a route that uses a ``path`` parameter named ``path`` in the URL.
101 | Fir example: if you want to forward all requests from ``/foobar/`` to ``http://localhost:8080/v1/``,
102 | you would do the following::
103 |
104 | router = Router(dispatcher=handler_dispatcher())
105 | router.add("/foobar/", ProxyHandler("http://localhost:8080/v1")
106 |
107 | This is similar to the common nginx configuration where proxy_pass is a URI::
108 |
109 | location /foobar {
110 | proxy_pass http://localhost:8080/v1/;
111 | }
112 | """
113 |
114 | def __init__(self, forward_base_url: str, client: HttpClient = None):
115 | """
116 | Creates a new Proxy with the given ``forward_base_url`` (see ``Proxy``).
117 |
118 | :param forward_base_url: the base url (backend) to forward the requests to.
119 | :param client: the HTTP Client used to make the requests
120 | """
121 | self.proxy = Proxy(forward_base_url=forward_base_url, client=client)
122 |
123 | def __call__(self, request: Request, **kwargs) -> Response:
124 | return self.proxy.forward(request, forward_path=kwargs.get("path", ""))
125 |
126 | def close(self):
127 | self.proxy.close()
128 |
129 |
130 | def _copy_request(
131 | request: Request,
132 | base_url: str = None,
133 | path: str = None,
134 | headers: Union[Headers, Mapping[str, str]] = None,
135 | ) -> Request:
136 | """
137 | Creates a new request from the given one that can be used to perform a proxy call.
138 |
139 | :param request: the original request
140 | :param base_url: the url to forward the request to (e.g., http://localhost:8080)
141 | :param path: the path to forward the request to (e.g., /foobar), if set to None, the original path will be used
142 | :param headers: optional headers to overwrite
143 | :return: a new request with slightly modified underlying environment but the same data stream
144 | """
145 | # ensure that the headers in the env are set on the environment
146 | # FIXME: we should preserve header casing like we do with the `asgi.headers` property in the asgi/wsgi bridge to
147 | # pass through the raw headers.
148 | set_environment_headers(request.environ, request.headers)
149 | builder = EnvironBuilder.from_environ(request.environ)
150 |
151 | if base_url:
152 | builder.base_url = base_url
153 | builder.headers["Host"] = builder.host
154 |
155 | if path is not None:
156 | builder.path = path
157 |
158 | if headers:
159 | builder.headers.update(headers)
160 |
161 | # FIXME: unfortunately, EnvironBuilder expects the input stream to be seekable, but we don't have that when using
162 | # the asgi/wsgi bridge. we need a better way of dealing with IO!
163 | data = restore_payload(request)
164 | builder.input_stream = BytesIO(data)
165 | builder.content_length = len(data)
166 | # Since the payload is completely restored, the proxy forwarding is not streamed.
167 | # Therefore, we need to remove a potential "chunked" Transfer-Encoding
168 | if builder.headers.get("Transfer-Encoding", None) == "chunked":
169 | builder.headers.pop("Transfer-Encoding")
170 |
171 | new_request = builder.get_request()
172 |
173 | # explicitly set the path in the environment and in the newly created request
174 | if path is not None:
175 | new_request.environ["RAW_URI"] = path or "/"
176 |
177 | # copy headers s.t. they are no longer immutable (by default, EnvironHeaders are used)
178 | new_request.headers = Headers(new_request.headers)
179 |
180 | return new_request
181 |
--------------------------------------------------------------------------------
/tests/test_pydantic.py:
--------------------------------------------------------------------------------
1 | from typing import TypedDict
2 |
3 | import pydantic
4 | import pytest
5 | from werkzeug.exceptions import BadRequest
6 |
7 | from rolo import Request, Router, resource
8 | from rolo.routing import handler as routing_handler
9 | from rolo.routing import handler_dispatcher
10 |
11 | pydantic_version = pydantic.version.version_short()
12 |
13 |
14 | class MyItem(pydantic.BaseModel):
15 | name: str
16 | price: float
17 | is_offer: bool = None
18 |
19 |
20 | class TestPydanticHandlerDispatcher:
21 | def test_request_arg(self):
22 | router = Router(dispatcher=handler_dispatcher())
23 |
24 | def handler(_request: Request, item: MyItem) -> dict:
25 | return {"item": item.model_dump()}
26 |
27 | router.add("/items", handler)
28 |
29 | request = Request("POST", "/items", body=b'{"name":"rolo","price":420.69}')
30 | assert router.dispatch(request).get_json(force=True) == {
31 | "item": {
32 | "name": "rolo",
33 | "price": 420.69,
34 | "is_offer": None,
35 | },
36 | }
37 |
38 | def test_request_args(self):
39 | router = Router(dispatcher=handler_dispatcher())
40 |
41 | def handler(_request: Request, item_id: int, item: MyItem) -> dict:
42 | return {"item_id": item_id, "item": item.model_dump()}
43 |
44 | router.add("/items/", handler)
45 |
46 | request = Request("POST", "/items/123", body=b'{"name":"rolo","price":420.69}')
47 | assert router.dispatch(request).get_json(force=True) == {
48 | "item_id": 123,
49 | "item": {
50 | "name": "rolo",
51 | "price": 420.69,
52 | "is_offer": None,
53 | },
54 | }
55 |
56 | def test_request_args_empty_body(self):
57 | router = Router(dispatcher=handler_dispatcher())
58 |
59 | def handler(_request: Request, item_id: int, item: MyItem) -> dict:
60 | return {"item_id": item_id, "item": item.model_dump()}
61 |
62 | router.add("/items/", handler)
63 |
64 | request = Request("POST", "/items/123", body=b"")
65 | assert router.dispatch(request).get_json(force=True) == [
66 | {
67 | "type": "json_invalid",
68 | "loc": [],
69 | "msg": "Invalid JSON: EOF while parsing a value at line 1 column 0",
70 | "ctx": {"error": "EOF while parsing a value at line 1 column 0"},
71 | "input": "",
72 | "url": f"https://errors.pydantic.dev/{pydantic_version}/v/json_invalid",
73 | }
74 | ]
75 |
76 | def test_response(self):
77 | router = Router(dispatcher=handler_dispatcher())
78 |
79 | def handler(_request: Request, item_id: int) -> MyItem:
80 | return MyItem(name="rolo", price=420.69)
81 |
82 | router.add("/items/", handler)
83 |
84 | request = Request("GET", "/items/123")
85 | assert router.dispatch(request).get_json() == {
86 | "name": "rolo",
87 | "price": 420.69,
88 | "is_offer": None,
89 | }
90 |
91 | def test_response_list(self):
92 | router = Router(dispatcher=handler_dispatcher())
93 |
94 | def handler(_request: Request) -> list[MyItem]:
95 | return [
96 | MyItem(name="rolo", price=420.69),
97 | MyItem(name="twiks", price=1.23, is_offer=True),
98 | ]
99 |
100 | router.add("/items", handler)
101 |
102 | request = Request("GET", "/items")
103 | assert router.dispatch(request).get_json() == [
104 | {
105 | "name": "rolo",
106 | "price": 420.69,
107 | "is_offer": None,
108 | },
109 | {
110 | "name": "twiks",
111 | "price": 1.23,
112 | "is_offer": True,
113 | },
114 | ]
115 |
116 | def test_request_arg_validation_error(self):
117 | router = Router(dispatcher=handler_dispatcher())
118 |
119 | def handler(_request: Request, item_id: int, item: MyItem) -> str:
120 | return item.model_dump_json()
121 |
122 | router.add("/items/", handler)
123 |
124 | request = Request("POST", "/items/123", body=b'{"name":"rolo"}')
125 | assert router.dispatch(request).get_json() == [
126 | {
127 | "type": "missing",
128 | "loc": ["price"],
129 | "msg": "Field required",
130 | "input": {"name": "rolo"},
131 | "url": f"https://errors.pydantic.dev/{pydantic_version}/v/missing",
132 | }
133 | ]
134 |
135 | def test_request_arg_invalid_json(self):
136 | router = Router(dispatcher=handler_dispatcher())
137 |
138 | def handler(_request: Request, item_id: int, item: MyItem) -> str:
139 | return item.model_dump_json()
140 |
141 | router.add("/items/", handler)
142 |
143 | request = Request("POST", "/items/123", body=b'{"}')
144 | with pytest.raises(BadRequest):
145 | assert router.dispatch(request)
146 |
147 | def test_missing_annotation(self):
148 | router = Router(dispatcher=handler_dispatcher())
149 |
150 | # without an annotation, we cannot be sure what type to serialize into, so the dispatcher doesn't pass
151 | # anything into ``item``.
152 | def handler(_request: Request, item=None) -> dict:
153 | return {"item": item}
154 |
155 | router.add("/items", handler)
156 |
157 | request = Request("POST", "/items", body=b'{"name":"rolo","price":420.69}')
158 | assert router.dispatch(request).get_json(force=True) == {"item": None}
159 |
160 | def test_with_pydantic_disabled(self, monkeypatch):
161 | monkeypatch.setattr(routing_handler, "ENABLE_PYDANTIC", False)
162 | router = Router(dispatcher=handler_dispatcher())
163 |
164 | def handler(_request: Request, item: MyItem) -> dict:
165 | return {"item": item.model_dump()}
166 |
167 | router.add("/items", handler)
168 |
169 | request = Request("POST", "/items", body=b'{"name":"rolo","price":420.69}')
170 | with pytest.raises(TypeError):
171 | # "missing 1 required positional argument: 'item'"
172 | assert router.dispatch(request)
173 |
174 | def test_with_resource(self):
175 | router = Router(dispatcher=handler_dispatcher())
176 |
177 | @resource("/items/")
178 | class MyResource:
179 | def on_get(self, request: Request, item_id: int):
180 | return MyItem(name="rolo", price=420.69)
181 |
182 | def on_post(self, request: Request, item_id: int, item: MyItem):
183 | return {"item_id": item_id, "item": item.model_dump()}
184 |
185 | router.add(MyResource())
186 |
187 | response = router.dispatch(Request("GET", "/items/123"))
188 | assert response.get_json() == {
189 | "name": "rolo",
190 | "price": 420.69,
191 | "is_offer": None,
192 | }
193 |
194 | response = router.dispatch(
195 | Request("POST", "/items/123", body=b'{"name":"rolo","price":420.69}')
196 | )
197 | assert response.get_json() == {
198 | "item": {"is_offer": None, "name": "rolo", "price": 420.69},
199 | "item_id": 123,
200 | }
201 |
202 | def test_with_generic_type_alias(self):
203 | router = Router(dispatcher=handler_dispatcher())
204 |
205 | def handler(request: Request, matrix: dict[str, str] = None):
206 | return "ok"
207 |
208 | router.add("/", endpoint=handler)
209 |
210 | request = Request("GET", "/")
211 | assert router.dispatch(request).data == b"ok"
212 |
213 | def test_with_typed_dict(self):
214 | try:
215 | from typing import Unpack
216 | except ImportError:
217 | pytest.skip("This test only works with Python >=3.11")
218 |
219 | router = Router(dispatcher=handler_dispatcher())
220 |
221 | class Test(TypedDict, total=False):
222 | path: str
223 | random_value: str
224 |
225 | def func(request: Request, **kwargs: Unpack[Test]):
226 | return f"path={kwargs.get('path')},random_value={kwargs.get('random_value')}"
227 |
228 | router.add(
229 | "/",
230 | endpoint=func,
231 | defaults={"path": "", "random_value": "dev"},
232 | )
233 |
234 | request = Request("GET", "/")
235 | assert router.dispatch(request).data == b"path=,random_value=dev"
236 |
--------------------------------------------------------------------------------
/tests/gateway/test_chain.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | from werkzeug.datastructures import Headers
4 |
5 | from rolo.gateway import CompositeFinalizer, CompositeHandler, HandlerChain, RequestContext
6 | from rolo.response import Response
7 |
8 |
9 | def test_response_handler_exception():
10 | def _raise(*args, **kwargs):
11 | raise ValueError("oh noes")
12 |
13 | response1 = mock.MagicMock()
14 | response2 = _raise
15 | response3 = mock.MagicMock()
16 | exception = mock.MagicMock()
17 | finalizer = mock.MagicMock()
18 |
19 | chain = HandlerChain(
20 | response_handlers=[response1, response2, response3],
21 | exception_handlers=[exception],
22 | finalizers=[finalizer],
23 | )
24 | chain.handle(RequestContext(), Response())
25 |
26 | response1.assert_called_once()
27 | response3.assert_called_once() # all response handlers should be called
28 | exception.assert_not_called() # response handlers don't trigger exception handlers
29 | finalizer.assert_called_once()
30 |
31 | assert chain.error is None
32 |
33 |
34 | def test_finalizer_handler_exception():
35 | def _raise(*args, **kwargs):
36 | raise ValueError("oh noes")
37 |
38 | response = mock.MagicMock()
39 | exception = mock.MagicMock()
40 | finalizer1 = mock.MagicMock()
41 | finalizer2 = _raise
42 | finalizer3 = mock.MagicMock()
43 |
44 | chain = HandlerChain(
45 | response_handlers=[response],
46 | exception_handlers=[exception],
47 | finalizers=[finalizer1, finalizer2, finalizer3],
48 | )
49 | chain.handle(RequestContext(), Response())
50 |
51 | response.assert_called_once()
52 | exception.assert_not_called() # response handlers don't trigger exception handlers
53 | finalizer1.assert_called_once()
54 | finalizer3.assert_called_once()
55 |
56 | assert chain.error is None
57 |
58 |
59 | def test_composite_finalizer_handler_exception():
60 | def _raise(*args, **kwargs):
61 | raise ValueError("oh noes")
62 |
63 | response = mock.MagicMock()
64 | exception = mock.MagicMock()
65 | finalizer1 = mock.MagicMock()
66 | finalizer2 = _raise
67 | finalizer3 = mock.MagicMock()
68 |
69 | finalizer = CompositeFinalizer()
70 | finalizer.append(finalizer1)
71 | finalizer.append(finalizer2)
72 | finalizer.append(finalizer3)
73 |
74 | chain = HandlerChain(
75 | response_handlers=[response],
76 | exception_handlers=[exception],
77 | finalizers=[finalizer],
78 | )
79 | chain.handle(RequestContext(), Response())
80 |
81 | response.assert_called_once()
82 | exception.assert_not_called() # response handlers don't trigger exception handlers
83 | finalizer1.assert_called_once()
84 | finalizer3.assert_called_once()
85 |
86 | assert chain.error is None
87 |
88 |
89 | def test_respond_with_json_response():
90 | def handle(chain_: HandlerChain, _context, _response):
91 | chain_.respond(202, {"foo": "bar"}, headers={"X-Foo": "Bar"})
92 |
93 | chain = HandlerChain(request_handlers=[handle])
94 | chain.handle(RequestContext(), Response())
95 |
96 | assert chain.response.status_code == 202
97 | assert chain.response.json == {"foo": "bar"}
98 | assert chain.response.headers.get("x-foo") == "Bar"
99 | assert chain.response.mimetype == "application/json"
100 |
101 |
102 | def test_respond_with_string_response():
103 | def handle(chain_: HandlerChain, _context, _response):
104 | chain_.respond(200, "foobar", Headers({"X-Foo": "Bar"}))
105 |
106 | chain = HandlerChain(request_handlers=[handle])
107 | chain.handle(RequestContext(), Response())
108 |
109 | assert chain.response.status_code == 200
110 | assert chain.response.data == b"foobar"
111 | assert chain.response.headers.get("x-foo") == "Bar"
112 | assert chain.response.mimetype == "text/plain"
113 |
114 |
115 | class TestCompositeHandler:
116 | def test_composite_handler_stops_handler_chain(self):
117 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response):
118 | _chain.stop()
119 |
120 | inner2 = mock.MagicMock()
121 | outer1 = mock.MagicMock()
122 | outer2 = mock.MagicMock()
123 | response1 = mock.MagicMock()
124 | finalizer = mock.MagicMock()
125 |
126 | chain = HandlerChain()
127 |
128 | composite = CompositeHandler()
129 | composite.handlers.append(inner1)
130 | composite.handlers.append(inner2)
131 |
132 | chain.request_handlers.append(outer1)
133 | chain.request_handlers.append(composite)
134 | chain.request_handlers.append(outer2)
135 | chain.response_handlers.append(response1)
136 | chain.finalizers.append(finalizer)
137 |
138 | chain.handle(RequestContext(), Response())
139 | outer1.assert_called_once()
140 | outer2.assert_not_called()
141 | inner2.assert_not_called()
142 | response1.assert_called_once()
143 | finalizer.assert_called_once()
144 |
145 | def test_composite_handler_terminates_handler_chain(self):
146 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response):
147 | _chain.terminate()
148 |
149 | inner2 = mock.MagicMock()
150 | outer1 = mock.MagicMock()
151 | outer2 = mock.MagicMock()
152 | response1 = mock.MagicMock()
153 | finalizer = mock.MagicMock()
154 |
155 | chain = HandlerChain()
156 |
157 | composite = CompositeHandler()
158 | composite.handlers.append(inner1)
159 | composite.handlers.append(inner2)
160 |
161 | chain.request_handlers.append(outer1)
162 | chain.request_handlers.append(composite)
163 | chain.request_handlers.append(outer2)
164 | chain.response_handlers.append(response1)
165 | chain.finalizers.append(finalizer)
166 |
167 | chain.handle(RequestContext(), Response())
168 | outer1.assert_called_once()
169 | outer2.assert_not_called()
170 | inner2.assert_not_called()
171 | response1.assert_not_called()
172 | finalizer.assert_called_once()
173 |
174 | def test_composite_handler_with_not_return_on_stop(self):
175 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response):
176 | _chain.stop()
177 |
178 | inner2 = mock.MagicMock()
179 | outer1 = mock.MagicMock()
180 | outer2 = mock.MagicMock()
181 | response1 = mock.MagicMock()
182 | finalizer = mock.MagicMock()
183 |
184 | chain = HandlerChain()
185 |
186 | composite = CompositeHandler(return_on_stop=False)
187 | composite.handlers.append(inner1)
188 | composite.handlers.append(inner2)
189 |
190 | chain.request_handlers.append(outer1)
191 | chain.request_handlers.append(composite)
192 | chain.request_handlers.append(outer2)
193 | chain.response_handlers.append(response1)
194 | chain.finalizers.append(finalizer)
195 |
196 | chain.handle(RequestContext(), Response())
197 | outer1.assert_called_once()
198 | outer2.assert_not_called()
199 | inner2.assert_called_once()
200 | response1.assert_called_once()
201 | finalizer.assert_called_once()
202 |
203 | def test_composite_handler_continues_handler_chain(self):
204 | inner1 = mock.MagicMock()
205 | inner2 = mock.MagicMock()
206 | outer1 = mock.MagicMock()
207 | outer2 = mock.MagicMock()
208 | response1 = mock.MagicMock()
209 | finalizer = mock.MagicMock()
210 |
211 | chain = HandlerChain()
212 |
213 | composite = CompositeHandler()
214 | composite.handlers.append(inner1)
215 | composite.handlers.append(inner2)
216 |
217 | chain.request_handlers.append(outer1)
218 | chain.request_handlers.append(composite)
219 | chain.request_handlers.append(outer2)
220 | chain.response_handlers.append(response1)
221 | chain.finalizers.append(finalizer)
222 |
223 | chain.handle(RequestContext(), Response())
224 | outer1.assert_called_once()
225 | outer2.assert_called_once()
226 | inner1.assert_called_once()
227 | inner2.assert_called_once()
228 | response1.assert_called_once()
229 | finalizer.assert_called_once()
230 |
231 | def test_composite_handler_exception_calls_outer_exception_handlers(self):
232 | def inner1(_chain: HandlerChain, request: RequestContext, response: Response):
233 | raise ValueError()
234 |
235 | inner2 = mock.MagicMock()
236 | outer1 = mock.MagicMock()
237 | outer2 = mock.MagicMock()
238 | exception_handler = mock.MagicMock()
239 | response1 = mock.MagicMock()
240 | finalizer = mock.MagicMock()
241 |
242 | chain = HandlerChain()
243 |
244 | composite = CompositeHandler()
245 | composite.handlers.append(inner1)
246 | composite.handlers.append(inner2)
247 |
248 | chain.request_handlers.append(outer1)
249 | chain.request_handlers.append(composite)
250 | chain.request_handlers.append(outer2)
251 | chain.exception_handlers.append(exception_handler)
252 | chain.response_handlers.append(response1)
253 | chain.finalizers.append(finalizer)
254 |
255 | chain.handle(RequestContext(), Response())
256 | outer1.assert_called_once()
257 | outer2.assert_not_called()
258 | inner2.assert_not_called()
259 | exception_handler.assert_called_once()
260 | response1.assert_called_once()
261 | finalizer.assert_called_once()
262 |
--------------------------------------------------------------------------------
/rolo/routing/rules.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import typing as t
3 |
4 | from werkzeug.routing import Map, Rule, RuleFactory
5 |
6 |
7 | class _RuleAttributes(t.NamedTuple):
8 | path: str
9 | host: t.Optional[str] = (None,)
10 | methods: t.Optional[t.Iterable[str]] = None
11 | kwargs: t.Optional[dict[str, t.Any]] = {}
12 |
13 |
14 | class _RouteEndpoint(t.Protocol):
15 | """
16 | An endpoint that encapsulates ``_RuleAttributes`` for the creation of a ``Rule`` inside a ``Router``.
17 | """
18 |
19 | rule_attributes: list[_RuleAttributes]
20 |
21 | def __call__(self, *args, **kwargs):
22 | raise NotImplementedError
23 |
24 |
25 | class WithHost(RuleFactory):
26 | def __init__(self, host: str, rules: t.Iterable[RuleFactory]) -> None:
27 | self.host = host
28 | self.rules = rules
29 |
30 | def get_rules(self, map: Map) -> t.Iterator[Rule]:
31 | for rulefactory in self.rules:
32 | for rule in rulefactory.get_rules(map):
33 | rule = rule.empty()
34 | rule.host = self.host
35 | yield rule
36 |
37 |
38 | class RuleGroup(RuleFactory):
39 | def __init__(self, rules: t.Iterable[RuleFactory]):
40 | self.rules = rules
41 |
42 | def get_rules(self, map: Map) -> t.Iterable[Rule]:
43 | for rule in self.rules:
44 | yield from rule.get_rules(map)
45 |
46 |
47 | class RuleAdapter(RuleFactory):
48 | """
49 | Takes something that can also be passed to ``Router.add``, and exposes it as a ``RuleFactory`` that generates the
50 | appropriate Werkzeug rules. This can be used in combination with other rule factories like ``Submount``,
51 | and creates general compatibility with werkzeug rules. Here's an example::
52 |
53 | @route("/my_api", methods=["GET"])
54 | def do_get(request: Request, _args):
55 | # should be inherited
56 | return Response(f"{request.path}/do-get")
57 |
58 | def hello(request: Request, _args):
59 | return Response(f"hello world")
60 |
61 | router = Router()
62 |
63 | # base endpoints
64 | endpoints = RuleAdapter([
65 | do_get,
66 | RuleAdapter("/hello", hello)
67 | ])
68 |
69 | router.add([
70 | endpoints,
71 | Submount("/foo", [endpoints])
72 | ])
73 |
74 | """
75 |
76 | factory: RuleFactory
77 | """The underlying real rule factory."""
78 |
79 | @t.overload
80 | def __init__(
81 | self,
82 | path: str,
83 | endpoint: t.Callable,
84 | host: t.Optional[str] = None,
85 | methods: t.Optional[t.Iterable[str]] = None,
86 | **kwargs,
87 | ):
88 | """
89 | Basically a ``Rule``.
90 |
91 | :param path: the path pattern to match. This path rule, in contrast to the default behavior of Werkzeug, will be
92 | matched against the raw / original (potentially URL-encoded) path.
93 | :param endpoint: the endpoint to invoke
94 | :param host: an optional host matching pattern. if not pattern is given, the rule matches any host
95 | :param methods: the allowed HTTP verbs for this rule
96 | :param kwargs: any other argument that can be passed to ``werkzeug.routing.Rule``
97 | """
98 | ...
99 |
100 | @t.overload
101 | def __init__(self, fn: _RouteEndpoint):
102 | """
103 | Takes a route endpoint (typically a function decorated with ``@route``) and adds it as ``EndpointRule``.
104 |
105 | :param fn: the RouteEndpoint function
106 | """
107 | ...
108 |
109 | @t.overload
110 | def __init__(self, rule_factory: RuleFactory):
111 | """
112 | Adds a ``Rule`` or the rules created by a ``RuleFactory`` to the given router. It passes the rules down to
113 | the underlying Werkzeug ``Map``, but also returns the created Rules.
114 |
115 | :param rule_factory: a `Rule` or ``RuleFactory`
116 | """
117 | ...
118 |
119 | @t.overload
120 | def __init__(self, obj: t.Any):
121 | """
122 | Scans the given object for members that can be used as a `RouteEndpoint` and adds them to the router.
123 |
124 | :param obj: the object to scan
125 | """
126 | ...
127 |
128 | @t.overload
129 | def __init__(self, rules: list[t.Union[_RouteEndpoint, RuleFactory, t.Any]]):
130 | """Add multiple rules at once"""
131 | ...
132 |
133 | def __init__(self, *args, **kwargs):
134 | """
135 | Dispatcher for overloaded ``__init__`` methods.
136 | """
137 | if "path" in kwargs or isinstance(args[0], str):
138 | self.factory = _EndpointRule(*args, **kwargs)
139 | elif "fn" in kwargs or callable(args[0]):
140 | self.factory = _EndpointFunction(*args, **kwargs)
141 | elif "rule_factory" in kwargs:
142 | self.factory = kwargs["rule_factory"]
143 | elif isinstance(args[0], RuleFactory):
144 | self.factory = args[0]
145 | elif isinstance(args[0], list):
146 | self.factory = RuleGroup([RuleAdapter(rule) for rule in args[0]])
147 | else:
148 | self.factory = _EndpointsObject(*args, **kwargs)
149 |
150 | def get_rules(self, map: Map) -> t.Iterable[Rule]:
151 | yield from self.factory.get_rules(map)
152 |
153 |
154 | class _EndpointRule(RuleFactory):
155 | """
156 | Generates default werkzeug ``Rule`` object with the given attributes. Additionally, it makes sure that
157 | the generated rule always has a default host value, if the map has host matching enabled. Specifically,
158 | it adds the well-known placeholder ``<__host__>``, which is later stripped out of the request arguments
159 | when dispatching to the endpoint. This ensures compatibility of rule definitions across routers that
160 | have host matching enabled or not.
161 | """
162 |
163 | def __init__(
164 | self,
165 | path: str,
166 | endpoint: t.Callable,
167 | host: t.Optional[str] = None,
168 | methods: t.Optional[t.Iterable[str]] = None,
169 | **kwargs,
170 | ):
171 | self.path = path
172 | self.endpoint = endpoint
173 | self.host = host
174 | self.methods = methods
175 | self.kwargs = kwargs
176 |
177 | def get_rules(self, map: Map) -> t.Iterable[Rule]:
178 | host = self.host
179 |
180 | if host is None and map.host_matching:
181 | # this creates a "match any" rule, and will put the value of the host
182 | # into the variable "__host__"
183 | host = "<__host__>"
184 |
185 | # the typing for endpoint is a str, but the doc states it can be any value,
186 | # however then the redirection URL building will not work
187 | rule = Rule(
188 | self.path, endpoint=self.endpoint, methods=self.methods, host=host, **self.kwargs
189 | )
190 | yield rule
191 |
192 |
193 | class _EndpointFunction(RuleFactory):
194 | """
195 | Internal rule factory that generates router Rules from ``@route`` annotated functions, or anything else
196 | that can be interpreted as a ``_RouteEndpoint``. It extracts the rule attributes from the
197 | ``_RuleAttributes`` attribute defined by ``_RouteEndpoint``. Example::
198 |
199 | @route("/my_api", methods=["GET"])
200 | def do_get(request: Request, _args):
201 | # should be inherited
202 | return Response(f"{request.path}/do-get")
203 |
204 | router.add(do_get) # <- will use an _EndpointFunction RuleFactory.
205 | """
206 |
207 | def __init__(self, fn: _RouteEndpoint):
208 | self.fn = fn
209 |
210 | def get_rules(self, map: Map) -> t.Iterable[Rule]:
211 | attrs: list[_RuleAttributes] = self.fn.rule_attributes
212 | for attr in attrs:
213 | yield from _EndpointRule(
214 | path=attr.path,
215 | endpoint=self.fn,
216 | host=attr.host,
217 | methods=attr.methods,
218 | **attr.kwargs,
219 | ).get_rules(map)
220 |
221 |
222 | class _EndpointsObject(RuleFactory):
223 | """
224 | Scans the given object for members that can be used as a `RouteEndpoint` and yields them as rules.
225 | """
226 |
227 | def __init__(self, obj: object):
228 | self.obj = obj
229 |
230 | def get_rules(self, map: Map) -> t.Iterable[Rule]:
231 | endpoints: list[_RouteEndpoint] = []
232 |
233 | members = inspect.getmembers(self.obj)
234 | for _, member in members:
235 | if hasattr(member, "rule_attributes"):
236 | endpoints.append(member)
237 |
238 | # make sure rules with "HEAD" are added first, otherwise werkzeug would let any "GET" rule would overwrite them.
239 | for endpoint in endpoints:
240 | for attr in endpoint.rule_attributes:
241 | if attr.methods and "HEAD" in attr.methods:
242 | yield from _EndpointRule(
243 | path=attr.path,
244 | endpoint=endpoint,
245 | host=attr.host,
246 | methods=attr.methods,
247 | **attr.kwargs,
248 | ).get_rules(map)
249 |
250 | for endpoint in endpoints:
251 | for attr in endpoint.rule_attributes:
252 | if not attr.methods or "HEAD" not in attr.methods:
253 | yield from _EndpointRule(
254 | path=attr.path,
255 | endpoint=endpoint,
256 | host=attr.host,
257 | methods=attr.methods,
258 | **attr.kwargs,
259 | ).get_rules(map)
260 |
--------------------------------------------------------------------------------
/rolo/testing/pytest.py:
--------------------------------------------------------------------------------
1 | """rolo pytest plugin used both for internal testing and as testing library."""
2 | import asyncio
3 | import dataclasses
4 | import socket
5 | import threading
6 | import time
7 | import typing
8 | from typing import Protocol
9 |
10 | import pytest
11 | from werkzeug import Request as WerkzeugRequest
12 | from werkzeug import serving
13 |
14 | from rolo import Router
15 | from rolo.asgi import ASGIAdapter, ASGILifespanListener
16 | from rolo.gateway import Gateway
17 | from rolo.gateway.asgi import AsgiGateway
18 | from rolo.gateway.wsgi import WsgiGateway
19 | from rolo.routing import handler_dispatcher
20 | from rolo.serving.twisted import HeaderPreservingHTTPChannel, TwistedGateway
21 | from rolo.websocket.adapter import WebSocketListener
22 |
23 | if typing.TYPE_CHECKING:
24 | from hypercorn.typing import ASGIFramework
25 |
26 |
27 | class ServerInfo(Protocol):
28 | url: str
29 | host: str
30 | port: int
31 |
32 |
33 | class Server(ServerInfo):
34 | def shutdown(self):
35 | ...
36 |
37 |
38 | @dataclasses.dataclass
39 | class _ServerInfo:
40 | host: str
41 | port: int
42 | url: str
43 |
44 |
45 | @pytest.fixture
46 | def serve_wsgi_app():
47 | servers: list[serving.BaseWSGIServer] = []
48 |
49 | def _serve(app, host: str = "localhost", port: int = None) -> serving.BaseWSGIServer | Server:
50 | srv = serving.make_server(host, port or 0, app, threaded=True)
51 | name = threading._newname("test-server-%d")
52 | threading.Thread(target=srv.serve_forever, name=name, daemon=True).start()
53 | servers.append(srv)
54 | srv.url = f"http://{srv.host}:{srv.port}"
55 | return srv
56 |
57 | yield _serve
58 |
59 | for server in servers:
60 | server.shutdown()
61 |
62 |
63 | @pytest.fixture
64 | def wsgi_router_server(serve_wsgi_app) -> tuple[Router, serving.BaseWSGIServer | Server]:
65 | """Creates a new Router with a handler dispatcher, serves it through a newly created ASGI server, and returns
66 | both the router and the server.
67 | """
68 | router = Router(dispatcher=handler_dispatcher())
69 | app = WerkzeugRequest.application(router.dispatch)
70 | return router, serve_wsgi_app(app)
71 |
72 |
73 | @pytest.fixture()
74 | def serve_asgi_app():
75 | import hypercorn
76 | import hypercorn.asyncio
77 |
78 | _server_shutdown = []
79 |
80 | def _create(
81 | app: "ASGIFramework",
82 | config: hypercorn.Config = None,
83 | event_loop: asyncio.AbstractEventLoop = None,
84 | ) -> Server:
85 | host = "localhost"
86 | port = get_random_tcp_port()
87 | bind = f"localhost:{port}"
88 |
89 | if not config:
90 | config = hypercorn.Config()
91 | config.h11_pass_raw_headers = True
92 | config.bind = [bind]
93 |
94 | event_loop = event_loop or asyncio.new_event_loop()
95 | close = asyncio.Event()
96 | closed = threading.Event()
97 |
98 | async def _set_close():
99 | close.set()
100 |
101 | def _run():
102 | event_loop.run_until_complete(
103 | hypercorn.asyncio.serve(app, config, shutdown_trigger=close.wait)
104 | )
105 | closed.set()
106 |
107 | def _shutdown():
108 | if close.is_set():
109 | return
110 | asyncio.run_coroutine_threadsafe(_set_close(), event_loop)
111 | closed.wait(timeout=10)
112 | try:
113 | app.close()
114 | except AttributeError:
115 | pass
116 | asyncio.run_coroutine_threadsafe(event_loop.shutdown_asyncgens(), event_loop)
117 | event_loop.shutdown_default_executor()
118 | event_loop.stop()
119 | event_loop.close()
120 |
121 | _server_shutdown.append(_shutdown)
122 | threading.Thread(
123 | target=_run, name=threading._newname("asgi-server-%d"), daemon=True
124 | ).start()
125 |
126 | srv = _ServerInfo(host, port, f"http://{host}:{port}")
127 | srv.shutdown = _shutdown
128 |
129 | assert wait_server_is_up(srv), f"gave up waiting for server {srv}"
130 |
131 | return srv
132 |
133 | yield _create
134 |
135 | for server_shutdown in _server_shutdown:
136 | server_shutdown()
137 |
138 |
139 | @pytest.fixture()
140 | def serve_asgi_adapter(serve_asgi_app):
141 | def _create(
142 | wsgi_app,
143 | lifespan_listener: ASGILifespanListener = None,
144 | websocket_listener: WebSocketListener = None,
145 | ):
146 | loop = asyncio.new_event_loop()
147 | return serve_asgi_app(
148 | ASGIAdapter(
149 | wsgi_app,
150 | event_loop=loop,
151 | lifespan_listener=lifespan_listener,
152 | websocket_listener=websocket_listener,
153 | ),
154 | event_loop=loop,
155 | )
156 |
157 | yield _create
158 |
159 |
160 | @pytest.fixture
161 | def serve_wsgi_gateway(serve_wsgi_app):
162 | def _serve(gateway: Gateway) -> Server:
163 | return serve_wsgi_app(WsgiGateway(gateway))
164 |
165 | return _serve
166 |
167 |
168 | @pytest.fixture
169 | def serve_asgi_gateway(serve_asgi_app):
170 | def _serve(gateway: Gateway) -> Server:
171 | loop = asyncio.new_event_loop()
172 | return serve_asgi_app(AsgiGateway(gateway, event_loop=loop), event_loop=loop)
173 |
174 | return _serve
175 |
176 |
177 | @pytest.fixture(scope="session")
178 | def twisted_reactor():
179 | """Session fixture that controls the lifecycle of the main twisted reactor."""
180 | from twisted.internet import reactor
181 | from twisted.internet.error import ReactorAlreadyRunning
182 | from twisted.web.http import HTTPFactory
183 |
184 | def _run():
185 | if reactor.running:
186 | return
187 |
188 | try:
189 | # for some reason, when using a `SelectReactor` (like you do by default on MacOS), whatever
190 | # protocols are added to the reactor via `listenTCP` _after_ `run` has been called,
191 | # are not served properly. We see this because the request calls in tests block forever. If we
192 | # add any listener here before calling `run`, then for some reason it works. 🤷
193 | reactor.listenTCP(get_random_tcp_port(), HTTPFactory())
194 | reactor.run(installSignalHandlers=False)
195 | except ReactorAlreadyRunning:
196 | pass
197 |
198 | threading.Thread(target=_run, daemon=True).start()
199 |
200 | assert poll_condition(
201 | lambda: reactor.running, timeout=5
202 | ), f"gave up waiting for {reactor} to start"
203 |
204 | yield reactor
205 |
206 | reactor.stop()
207 |
208 |
209 | @pytest.fixture
210 | def serve_twisted_tcp_server(twisted_reactor):
211 | """Factory ficture for serving a twisted protocol factory (like ``Site``) through the twisted reactor."""
212 | from twisted.internet.tcp import Port
213 |
214 | ports: list[Port] = []
215 |
216 | def _create(protocol_factory):
217 | port = get_random_tcp_port()
218 | host = "localhost"
219 | ports.append(twisted_reactor.listenTCP(port, protocol_factory))
220 | srv = _ServerInfo(host, port, f"http://{host}:{port}")
221 | assert wait_server_is_up(srv), f"gave up waiting for {srv}"
222 | return srv
223 |
224 | yield _create
225 |
226 | for _port in ports:
227 | _port.stopListening()
228 |
229 |
230 | @pytest.fixture
231 | def serve_twisted_gateway(serve_twisted_tcp_server):
232 | def _create(gateway):
233 | return serve_twisted_tcp_server(TwistedGateway(gateway))
234 |
235 | yield _create
236 |
237 |
238 | @pytest.fixture
239 | def serve_twisted_websocket_listener(twisted_reactor, serve_twisted_tcp_server):
240 | """
241 | This fixture creates a Twisted Site, without the need to serve a fully-fledged rolo Gateway.
242 | This is inspired by `rolo.serving.twisted.TwistedGateway`, but directly uses `WebsocketResourceDecorator` to
243 | pass the `WebSocketListener` instead of `gateway.accept` to the `websocketListener` parameter.
244 | It allows us to test the low-level behavior of WebSockets without being dependent on the Gateway implementation.
245 | """
246 | from twisted.web.server import Site
247 |
248 | from rolo.serving.twisted import HeaderPreservingWSGIResource, WebsocketResourceDecorator
249 |
250 | def _create(websocket_listener: WebSocketListener):
251 | site = Site(
252 | WebsocketResourceDecorator(
253 | original=HeaderPreservingWSGIResource(
254 | twisted_reactor, twisted_reactor.getThreadPool(), None
255 | ),
256 | websocketListener=websocket_listener,
257 | )
258 | )
259 | site.protocol = HeaderPreservingHTTPChannel.protocol_factory
260 | return serve_twisted_tcp_server(site)
261 |
262 | return _create
263 |
264 |
265 | def is_server_up(srv: ServerInfo):
266 | args = socket.getaddrinfo(srv.host, srv.port, socket.AF_INET, socket.SOCK_STREAM)
267 | for family, socktype, proto, _canonname, sockaddr in args:
268 | s = socket.socket(family, socktype, proto)
269 | try:
270 | s.connect(sockaddr)
271 | except socket.error:
272 | return False
273 | else:
274 | s.close()
275 | return True
276 |
277 |
278 | def wait_server_is_up(srv: ServerInfo, timeout: float = 10, interval: float = 0.1) -> bool:
279 | return poll_condition(lambda: is_server_up(srv), timeout=timeout, interval=interval)
280 |
281 |
282 | def get_random_tcp_port() -> int:
283 | import socket
284 |
285 | sock = socket.socket()
286 | sock.bind(("", 0))
287 | return sock.getsockname()[1]
288 |
289 |
290 | def poll_condition(
291 | condition: typing.Callable[[], bool],
292 | timeout: float = None,
293 | interval: float = 0.5,
294 | ) -> bool:
295 | """
296 | Poll evaluates the given condition until a truthy value is returned. It does this every `interval` seconds
297 | (0.5 by default), until the timeout (in seconds, if any) is reached.
298 |
299 | Poll returns True once `condition()` returns a truthy value, or False if the timeout is reached.
300 | """
301 | remaining = 0
302 | if timeout is not None:
303 | remaining = timeout
304 |
305 | while not condition():
306 | if timeout is not None:
307 | remaining -= interval
308 |
309 | if remaining <= 0:
310 | return False
311 |
312 | time.sleep(interval)
313 |
314 | return True
315 |
--------------------------------------------------------------------------------
/tests/test_request.py:
--------------------------------------------------------------------------------
1 | import wsgiref.validate
2 |
3 | import pytest
4 | from werkzeug.exceptions import BadRequest
5 |
6 | from rolo.request import Request, dummy_wsgi_environment, get_raw_path, restore_payload
7 |
8 |
9 | def test_get_json():
10 | r = Request(
11 | "POST",
12 | "/",
13 | headers={"Content-Type": "application/json"},
14 | body=b'{"foo": "bar", "baz": 420}',
15 | )
16 | assert r.json == {"foo": "bar", "baz": 420}
17 | assert r.content_type == "application/json"
18 |
19 |
20 | def test_get_json_force():
21 | r = Request("POST", "/", body=b'{"foo": "bar", "baz": 420}')
22 | assert r.get_json(force=True) == {"foo": "bar", "baz": 420}
23 |
24 |
25 | def test_get_json_invalid():
26 | r = Request("POST", "/", body=b'{"foo": "')
27 |
28 | with pytest.raises(BadRequest):
29 | assert r.get_json(force=True)
30 |
31 | assert r.get_json(force=True, silent=True) is None
32 |
33 |
34 | def test_get_data():
35 | r = Request("GET", "/", body="foobar")
36 | assert r.data == b"foobar"
37 |
38 |
39 | def test_get_data_as_text():
40 | r = Request("GET", "/", body="foobar")
41 | assert r.get_data(as_text=True) == "foobar"
42 |
43 |
44 | def test_get_stream():
45 | r = Request("GET", "/", body=b"foobar")
46 | assert r.stream.read(3) == b"foo"
47 | assert r.stream.read(3) == b"bar"
48 |
49 |
50 | def test_args():
51 | r = Request("GET", "/", query_string="foo=420&bar=69")
52 | assert len(r.args) == 2
53 | assert r.args["foo"] == "420"
54 | assert r.args["bar"] == "69"
55 |
56 |
57 | def test_values():
58 | r = Request("GET", "/", query_string="foo=420&bar=69")
59 | assert len(r.values) == 2
60 | assert r.values["foo"] == "420"
61 | assert r.values["bar"] == "69"
62 |
63 |
64 | def test_form_empty():
65 | r = Request("POST", "/")
66 | assert len(r.form) == 0
67 |
68 |
69 | def test_post_form_urlencoded_and_query():
70 | # see https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/POST#example
71 | r = Request(
72 | "POST",
73 | "/form",
74 | query_string="query1=foo&query2=bar",
75 | body=b"field1=value1&field2=value2",
76 | headers={"Content-Type": "application/x-www-form-urlencoded"},
77 | )
78 |
79 | assert len(r.form) == 2
80 | assert r.form["field1"] == "value1"
81 | assert r.form["field2"] == "value2"
82 |
83 | assert len(r.args) == 2
84 | assert r.args["query1"] == "foo"
85 | assert r.args["query2"] == "bar"
86 |
87 | assert len(r.values) == 4
88 | assert r.values["field1"] == "value1"
89 | assert r.values["field2"] == "value2"
90 | assert r.args["query1"] == "foo"
91 | assert r.args["query2"] == "bar"
92 |
93 |
94 | def test_validate_dummy_environment():
95 | def validate(*args, **kwargs):
96 | assert wsgiref.validate.check_environ(dummy_wsgi_environment(*args, **kwargs)) is None
97 |
98 | validate(path="/foo/bar", body="foo")
99 | validate(path="/foo/bar", query_string="foo=420&bar=69")
100 | validate(server=("localstack.cloud", 4566))
101 | validate(server=("localstack.cloud", None))
102 | validate(remote_addr="127.0.0.1")
103 | validate(headers={"Content-Type": "text/xml"}, body=b"")
104 | validate(headers={"Content-Type": "text/xml", "x-amz-target": "foobar"}, body=b"")
105 |
106 |
107 | def test_content_length_is_set_automatically():
108 | # checking that the value is calculated automatically
109 | request = Request("GET", "/", body="foobar")
110 | assert request.content_length == 6
111 |
112 |
113 | def test_content_length_is_overwritten():
114 | # checking that the value passed from headers take precedence
115 | request = Request("GET", "/", body="foobar", headers={"Content-Length": "7"})
116 | assert request.content_length == 7
117 |
118 |
119 | def test_get_custom_headers():
120 | request = Request("GET", "/", body="foobar", headers={"x-amz-target": "foobar"})
121 | assert request.headers["x-amz-target"] == "foobar"
122 |
123 |
124 | def test_get_raw_path():
125 | request = Request("GET", "/foo/bar/ed", raw_path="/foo%2Fbar/ed")
126 |
127 | assert request.path == "/foo/bar/ed"
128 | assert request.environ["RAW_URI"] == "/foo%2Fbar/ed"
129 | assert get_raw_path(request) == "/foo%2Fbar/ed"
130 |
131 |
132 | def test_get_raw_path_with_query():
133 | request = Request("GET", "/foo/bar/ed", raw_path="/foo%2Fbar/ed?fizz=buzz")
134 |
135 | assert request.path == "/foo/bar/ed"
136 | assert request.environ["RAW_URI"] == "/foo%2Fbar/ed?fizz=buzz"
137 | assert get_raw_path(request) == "/foo%2Fbar/ed"
138 |
139 |
140 | def test_get_raw_path_with_prefix_slashes():
141 | request = Request("GET", "/foo/bar/ed", raw_path="//foo%2Fbar/ed?fizz=buzz")
142 |
143 | assert request.path == "/foo/bar/ed"
144 | assert request.environ["RAW_URI"] == "//foo%2Fbar/ed?fizz=buzz"
145 | assert get_raw_path(request) == "//foo%2Fbar/ed"
146 |
147 |
148 | def test_get_raw_path_with_full_uri():
149 | # raw_path is actually raw_uri in the WSGI environment
150 | # it can be a full URL
151 | request = Request("GET", "/foo/bar/ed", raw_path="http://localhost:4566/foo%2Fbar/ed")
152 |
153 | assert request.path == "/foo/bar/ed"
154 | assert request.environ["RAW_URI"] == "http://localhost:4566/foo%2Fbar/ed"
155 | assert get_raw_path(request) == "/foo%2Fbar/ed"
156 |
157 |
158 | def test_headers_retain_dashes():
159 | request = Request("GET", "/foo/bar/ed", {"X-Amz-Meta--foo_bar-ed": "foobar"})
160 | assert "x-amz-meta--foo_bar-ed" in request.headers
161 | assert request.headers["x-amz-meta--foo_bar-ed"] == "foobar"
162 |
163 |
164 | def test_headers_retain_case():
165 | request = Request("GET", "/foo/bar/ed", {"X-Amz-Meta--FOO_BaR-ed": "foobar"})
166 | keys = list(request.headers.keys())
167 | for k in keys:
168 | if k.lower().startswith("x-amz-meta"):
169 | assert k == "X-Amz-Meta--FOO_BaR-ed"
170 | return
171 | pytest.fail(f"key not in header keys {keys}")
172 |
173 |
174 | def test_multipart_parsing():
175 | body = (
176 | b"--4efd159eae0c4f4e125a5a509e073d85"
177 | b"\r\n"
178 | b'Content-Disposition: form-data; name="foo"; filename="foo"'
179 | b"\r\n\r\n"
180 | b"bar"
181 | b"\r\n"
182 | b"--4efd159eae0c4f4e125a5a509e073d85"
183 | b"\r\n"
184 | b'Content-Disposition: form-data; name="baz"; filename="baz"'
185 | b"\r\n\r\n"
186 | b"ed"
187 | b"\r\n--4efd159eae0c4f4e125a5a509e073d85--"
188 | b"\r\n"
189 | )
190 |
191 | request = Request(
192 | "POST",
193 | path="/",
194 | body=body,
195 | headers={"Content-Type": "multipart/form-data; boundary=4efd159eae0c4f4e125a5a509e073d85"},
196 | )
197 | result = {}
198 | for k, file_storage in request.files.items():
199 | result[k] = file_storage.stream.read().decode("utf-8")
200 |
201 | assert result == {"foo": "bar", "baz": "ed"}
202 |
203 |
204 | def test_utf8_path():
205 | r = Request("GET", "/foo/Ā0Ä")
206 |
207 | assert r.path == "/foo/Ā0Ä"
208 | assert r.environ["PATH_INFO"] == "/foo/Ä\x800Ã\x84" # quoted and latin-1 encoded
209 |
210 |
211 | def test_restore_payload_multipart_parsing():
212 | body = (
213 | b"\r\n"
214 | b"--4efd159eae0c4f4e125a5a509e073d85"
215 | b"\r\n"
216 | b'Content-Disposition: form-data; name="formfield"'
217 | b"\r\n\r\n"
218 | b"not a file, just a field"
219 | b"\r\n"
220 | b"--4efd159eae0c4f4e125a5a509e073d85"
221 | b"\r\n"
222 | b'Content-Disposition: form-data; name="foo"; filename="foo"'
223 | b"\r\n"
224 | b"Content-Type: text/plain;"
225 | b"\r\n\r\n"
226 | b"bar"
227 | b"\r\n"
228 | b"--4efd159eae0c4f4e125a5a509e073d85"
229 | b"\r\n"
230 | b'Content-Disposition: form-data; name="baz"; filename="baz"'
231 | b"\r\n"
232 | b"Content-Type: text/plain;"
233 | b"\r\n\r\n"
234 | b"ed"
235 | b"\r\n"
236 | b"\r\n--4efd159eae0c4f4e125a5a509e073d85--"
237 | b"\r\n"
238 | )
239 |
240 | request = Request(
241 | "POST",
242 | path="/",
243 | body=body,
244 | headers={"Content-Type": "multipart/form-data; boundary=4efd159eae0c4f4e125a5a509e073d85"},
245 | )
246 |
247 | form = {}
248 | for k, field in request.form.items():
249 | form[k] = field
250 |
251 | assert form == {"formfield": "not a file, just a field"}
252 |
253 | files = []
254 | for k, file_storage in request.files.items():
255 | assert file_storage.stream
256 | # we do not want to consume the file storage stream, because we can't restore the payload then
257 | files.append(k)
258 |
259 | assert files == ["foo", "baz"]
260 | restored_data = restore_payload(request)
261 |
262 | assert restored_data == body
263 |
264 |
265 | def test_request_mixed_multipart():
266 | # this is almost how we previously restored a form that would have both `form` fields and `files`
267 | # we would URL encode the form first then add multipart, which does not work, the first part should be ignored
268 | # and make certain strict multipart parser fail (Starlette), because it finds data before the first boundary
269 |
270 | # this test does something a bit different to prove it is ignored (add an URL encoded part in the beginning)
271 | body = (
272 | b"formfield=not+a+file%2C+just+a+field\r\n"
273 | b"--4efd159eae0c4f4e125a5a509e073d85"
274 | b"\r\n"
275 | b'Content-Disposition: form-data; name="foo"; filename="foo"'
276 | b"\r\n"
277 | b"Content-Type: text/plain;"
278 | b"\r\n\r\n"
279 | b"bar"
280 | b"\r\n"
281 | b"--4efd159eae0c4f4e125a5a509e073d85"
282 | b"\r\n"
283 | b'Content-Disposition: form-data; name="baz"; filename="baz"'
284 | b"\r\n"
285 | b"Content-Type: text/plain;"
286 | b"\r\n\r\n"
287 | b"ed"
288 | b"\r\n"
289 | b"\r\n--4efd159eae0c4f4e125a5a509e073d85--"
290 | b"\r\n"
291 | )
292 |
293 | request = Request(
294 | "POST",
295 | path="/",
296 | body=body,
297 | headers={"Content-Type": "multipart/form-data; boundary=4efd159eae0c4f4e125a5a509e073d85"},
298 | )
299 |
300 | form = {}
301 | for k, field in request.form.items():
302 | form[k] = field
303 |
304 | assert form == {}
305 |
306 | files = []
307 | for k, file_storage in request.files.items():
308 | assert file_storage.stream
309 | # we do not want to consume the file storage stream, because we can't restore the payload then
310 | files.append(k)
311 |
312 | assert files == ["foo", "baz"]
313 |
314 | restored_data = restore_payload(request)
315 | assert b"formfield" not in restored_data
316 |
317 |
318 | def test_restore_payload_form_urlencoded():
319 | body = b"formfield=not+a+file%2C+just+a+field"
320 |
321 | request = Request(
322 | "POST",
323 | path="/",
324 | body=body,
325 | headers={"Content-Type": "application/x-www-form-urlencoded"},
326 | )
327 |
328 | form = {}
329 | for k, field in request.form.items():
330 | form[k] = field
331 |
332 | assert form == {"formfield": "not a file, just a field"}
333 |
334 | assert not request.files
335 |
336 | restored_data = restore_payload(request)
337 |
338 | assert restored_data == body
339 |
--------------------------------------------------------------------------------
/rolo/websocket/request.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import typing as t
3 |
4 | from werkzeug import Response
5 | from werkzeug._internal import _wsgi_decoding_dance
6 | from werkzeug.datastructures import EnvironHeaders, Headers, MultiDict
7 | from werkzeug.sansio.request import Request as _SansIORequest
8 | from werkzeug.wsgi import _get_server
9 |
10 | from .adapter import (
11 | BytesMessage,
12 | CreateConnection,
13 | Message,
14 | TextMessage,
15 | WebSocketAdapter,
16 | WebSocketEnvironment,
17 | WebSocketListener,
18 | )
19 | from .errors import WebSocketDisconnectedError, WebSocketProtocolError
20 |
21 |
22 | class WebSocket:
23 | """
24 | High-level interface to interact with a websocket after a handshake has been completed with
25 | `WebsocketRequest.accept()`.
26 | """
27 |
28 | request: "WebSocketRequest"
29 | socket: WebSocketAdapter
30 |
31 | def __init__(self, request: "WebSocketRequest", socket: WebSocketAdapter):
32 | self.request = request
33 | self.socket = socket
34 |
35 | def __enter__(self):
36 | return self
37 |
38 | def __exit__(self, exc_type, exc_val, exc_tb):
39 | self.close()
40 |
41 | def __iter__(self):
42 | while True:
43 | try:
44 | yield self.receive()
45 | except WebSocketDisconnectedError:
46 | break
47 |
48 | def send(self, text_or_bytes: str | bytes, timeout: float = None):
49 | """
50 | Send data to the websocket connection.
51 |
52 | :param text_or_bytes: the data to send. Use strings for text-mode sockets (default).
53 | :param timeout: the timeout in seconds to wait before raising a timeout error
54 | """
55 | if text_or_bytes is None:
56 | raise ValueError("text_or_bytes cannot be None")
57 |
58 | if isinstance(text_or_bytes, str):
59 | event = TextMessage(text_or_bytes)
60 | elif isinstance(text_or_bytes, bytes):
61 | event = BytesMessage(text_or_bytes)
62 | else:
63 | raise WebSocketProtocolError(
64 | f"Cannot send data type {type(text_or_bytes)} over websocket"
65 | )
66 | self.socket.send(event, timeout)
67 |
68 | def receive(self) -> str | bytes:
69 | """
70 | Receive the next data package from the websocket. Will be string or byte data and set the
71 | underlying binary for the frame automatically.
72 |
73 | :raise WebSocketDisconnectedError: if the websocket was closed in the meantime
74 | :raise WebSocketProtocolError: error in the interaction between the app and the webserver
75 | :return: the next data package from the websocket
76 | """
77 | event = self.socket.receive()
78 | if isinstance(event, Message):
79 | data = event.data
80 | if data is None:
81 | raise WebSocketProtocolError("No data returned by the websocket.")
82 | return data
83 | else:
84 | raise WebSocketProtocolError(
85 | f"Unexpected websocket event type {event.__class__.__name__}."
86 | )
87 |
88 | def close(self, code: int = 1000, reason: t.Optional[str] = None, timeout: float = None):
89 | """
90 | Closes the websocket connection with specific code.
91 |
92 | :param code: the websocket close code.
93 | :param reason: optional reason
94 | :param timeout: connection timeout
95 | """
96 | self.socket.close(code, reason, timeout)
97 |
98 |
99 | class WebSocketRequest(_SansIORequest):
100 | """
101 | A websocket request represents the incoming HTTP request to upgrade the connection to a WebSocket
102 | connection. The request method is an artificial ``WEBSOCKET`` method that can also be used in the Router:
103 | ``@route("/path", method=["WEBSOCKET"])``.
104 |
105 | The websocket connection needs to be either accepted or rejected. When calling
106 | ``WebSocketRequest.accept``, an upgrade response will be sent to the client, and the protocol will be
107 | switched to the bidirectional WebSocket protocol. If ``WebSocketRequest.reject`` is called, the server
108 | immediately returns an HTTP response and closes the connection.
109 | """
110 |
111 | def __init__(self, environ: WebSocketEnvironment):
112 | """
113 | Creates a new request from the given WebSocketEnvironment. This is like a sans-IO WSGI Environment,
114 | with an additional field ``rolo.websocket`` that contains a ``WebSocketAdapter`` interface.
115 |
116 | :param environ: the WebSocketEnvironment
117 | """
118 | raw_headers = environ.get("rolo.headers")
119 | if raw_headers:
120 | # restores raw headers from the server scope, to have proper casing or dashes. This can depend on server
121 | # behavior, but we want a unified way to keep header casing/formatting.
122 | # This is similar to what we do in wsgi.py
123 | headers = Headers(
124 | MultiDict([(k.decode("latin-1"), v.decode("latin-1")) for (k, v) in raw_headers])
125 | )
126 | else:
127 | headers = Headers(EnvironHeaders(environ))
128 |
129 | # copied from werkzeug.wrappers.request
130 | super().__init__(
131 | method=environ.get("REQUEST_METHOD", "WEBSOCKET"),
132 | scheme=environ.get("wsgi.url_scheme", "ws"),
133 | server=_get_server(environ),
134 | root_path=_wsgi_decoding_dance(environ.get("SCRIPT_NAME") or ""),
135 | path=_wsgi_decoding_dance(environ.get("PATH_INFO") or ""),
136 | query_string=environ.get("QUERY_STRING", "").encode("latin1"),
137 | headers=headers,
138 | remote_addr=environ.get("REMOTE_ADDR"),
139 | )
140 | self.environ = environ
141 |
142 | self.shallow = True # compatibility with werkzeug.Request
143 |
144 | self._upgraded = False
145 | self._rejected = False
146 |
147 | @property
148 | def socket(self) -> WebSocketAdapter:
149 | """
150 | Returns the underlying WebSocketAdapter from the environment. This is analogous to ``Request.stream``
151 | in the default werkzeug HTTP request object.
152 |
153 | :return: the WebSocketAdapter from the environment
154 | """
155 | return self.environ.get("rolo.websocket") or self.environ.get("asgi.websocket")
156 |
157 | def is_upgraded(self) -> bool:
158 | """Returns true if ``accept`` was called."""
159 | return self._upgraded
160 |
161 | def is_rejected(self) -> bool:
162 | """Returns true if ``reject`` was called."""
163 | return self._rejected
164 |
165 | def reject(self, response: Response):
166 | """
167 | Reject the websocket upgrade and return the given response. Will raise a ``ValueError`` if the
168 | request has already been accepted or rejected before.
169 |
170 | :param response: the HTTP response to return to the client.
171 | """
172 | if self._upgraded:
173 | raise ValueError("Websocket connection already upgraded")
174 | if self._rejected:
175 | raise ValueError("Websocket connection already rejected")
176 |
177 | self.socket.reject(
178 | response.status_code,
179 | response.headers,
180 | response.iter_encoded(),
181 | )
182 | self._rejected = True
183 |
184 | def accept(
185 | self, subprotocol: str = None, headers: Headers = None, timeout: float = None
186 | ) -> WebSocket:
187 | """
188 | Performs the websocket connection upgrade handshake. After calling ``accept``, a new ``Websocket``
189 | instance is returned that represents the bidirectional communication channel, which you should
190 | continue operating on. Example::
191 |
192 | def app(request: WebsocketRequest):
193 | # example: do authorization first
194 | auth = request.headers.get("Authorization")
195 | if not is_authorized(auth):
196 | request.reject(Response("no dice", 403))
197 | return
198 |
199 | # then continue working with the websocket
200 | with request.accept() as websocket:
201 | websocket.send("hello world!")
202 | data = websocket.receive()
203 | # ...
204 |
205 | The handshake using the WebSocketAdapter works as follows: receive the ``CreateConnection`` event
206 | from the websocket and then call the ``socket.accept(...)``. If the handshake failed because
207 | the websocket sent an unexpected exception, the connection is closed and the method raises an error.
208 |
209 | :param subprotocol: The subprotocol the server wishes to accept. Optional
210 | :param headers: Response headers
211 | :param timeout: connection timeout
212 | :return: a websocket
213 | :raises ProtocolError: if unexpected events were received from the websocket server
214 | """
215 | if self._upgraded:
216 | raise ValueError("Websocket connection already upgraded")
217 | if self._rejected:
218 | raise ValueError("Websocket connection already rejected")
219 |
220 | event = self.socket.receive(timeout)
221 | if isinstance(event, CreateConnection):
222 | self.socket.accept(subprotocol, [], headers, timeout)
223 | self._upgraded = True
224 | return WebSocket(self, self.socket)
225 | else:
226 | reason = f"Unexpected event {event.__class__.__name__}"
227 | self.socket.close(1003, reason)
228 | raise WebSocketProtocolError(reason)
229 |
230 | def close(self):
231 | """
232 | Explicitly close the websocket. If this is called after ``reject(...)`` or ``accept(...)`` has been
233 | called, this will have no effect. Calling ``reject`` inherently closes the websocket connection
234 | since it immediately returns an HTTP response. After calling ``accept`` you should call
235 | ``WebSocket.close`` instead.
236 | """
237 | if self._rejected or self._upgraded:
238 | return
239 | self.socket.close(1000)
240 |
241 | @classmethod
242 | def listener(cls, fn: t.Callable[["WebSocketRequest"], None]) -> WebSocketListener:
243 | """
244 | Convenience function inspired by ``werkzeug.Request.application`` that transforms a function into a
245 | ``WebsocketListener`` for the use in server code that support ``WebsocketListeners``. Example::
246 |
247 | @WebsocketRequest.listener
248 | def app(request: WebSocketRequest):
249 | with request.accept() as ws:
250 | ws.send("hello world")
251 |
252 | adapter = ASGIAdapter(wsgi_app=..., websocket_listener=app)
253 | # ... serve adapter
254 |
255 |
256 | :param fn: the function to wrap
257 | :return: a WebsocketListener compatible interface
258 | """
259 | from werkzeug.exceptions import HTTPException
260 |
261 | @functools.wraps(fn)
262 | def application(*args):
263 | request = cls(args[-1])
264 | try:
265 | fn(*args[:-1] + (request,))
266 | except HTTPException as e:
267 | resp = e.get_response(args[-1])
268 | request.reject(resp)
269 | finally:
270 | request.close()
271 |
272 | return application
273 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/rolo/request.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from typing import IO, TYPE_CHECKING, Dict, Mapping, Optional, Tuple, Union
3 | from urllib.parse import quote, unquote, urlencode, urlparse
4 |
5 | if TYPE_CHECKING:
6 | from _typeshed.wsgi import WSGIEnvironment
7 |
8 | from werkzeug.datastructures import Headers, MultiDict
9 | from werkzeug.test import encode_multipart
10 | from werkzeug.wrappers.request import Request as WerkzeugRequest
11 |
12 |
13 | def dummy_wsgi_environment(
14 | method: str = "GET",
15 | path: str = "",
16 | headers: Optional[Union[Dict, Headers]] = None,
17 | body: Optional[Union[bytes, str, IO[bytes]]] = None,
18 | scheme: str = "http",
19 | root_path: str = "/",
20 | query_string: Optional[str] = None,
21 | remote_addr: Optional[str] = None,
22 | server: Optional[Tuple[str, Optional[int]]] = None,
23 | raw_uri: Optional[str] = None,
24 | ) -> "WSGIEnvironment":
25 | """
26 | Creates a dummy WSGIEnvironment that represents a standalone sans-IO HTTP requests.
27 |
28 | See https://wsgi.readthedocs.io/en/latest/definitions.html#standard-environ-keys
29 |
30 | :param method: The HTTP request method (such as GET or POST)
31 | :param path: The remainder of the request URL's path. This may be an empty string, if the
32 | request URL targets the application root and does not have a trailing slash.
33 | :param headers: optional HTTP headers
34 | :param body: the body of the request
35 | :param scheme: the scheme (http or https)
36 | :param root_path: The initial portion of the request URL's path that corresponds to the
37 | application object.
38 | :param query_string: The portion of the request URL that follows the “?”, if any. May be
39 | empty or absent.
40 | :param remote_addr: The address making the request
41 | :param server: The server (tuple of server name and port)
42 | :param raw_uri: The original path that may contain url encoded path elements.
43 | :return: A WSGIEnvironment dictionary
44 | """
45 |
46 | # Standard environ keys
47 | environ = {
48 | "REQUEST_METHOD": method,
49 | # prepare the paths for the "WSGI decoding dance" done by werkzeug
50 | "SCRIPT_NAME": unquote(quote(root_path.rstrip("/")), "latin-1"),
51 | "PATH_INFO": unquote(quote(path), "latin-1"),
52 | "SERVER_PROTOCOL": "HTTP/1.1",
53 | "QUERY_STRING": query_string or "",
54 | }
55 |
56 | if raw_uri:
57 | if query_string:
58 | raw_uri += "?" + query_string
59 | environ["RAW_URI"] = raw_uri
60 | environ["REQUEST_URI"] = environ["RAW_URI"]
61 |
62 | if server:
63 | environ["SERVER_NAME"] = server[0]
64 | if server[1]:
65 | environ["SERVER_PORT"] = str(server[1])
66 | else:
67 | environ["SERVER_PORT"] = "80"
68 | else:
69 | environ["SERVER_NAME"] = "127.0.0.1"
70 | environ["SERVER_PORT"] = "80"
71 |
72 | if remote_addr:
73 | environ["REMOTE_ADDR"] = remote_addr
74 |
75 | if headers:
76 | set_environment_headers(environ, headers)
77 |
78 | if not body:
79 | body = b""
80 |
81 | if isinstance(body, (str, bytes)):
82 | data = body.encode("utf-8") if isinstance(body, str) else body
83 |
84 | wsgi_input = BytesIO(data)
85 | if "CONTENT_LENGTH" not in environ:
86 | # try to determine content length from body
87 | environ["CONTENT_LENGTH"] = str(len(data))
88 | else:
89 | wsgi_input = body
90 |
91 | # WSGI environ keys
92 | environ["wsgi.version"] = (1, 0)
93 | environ["wsgi.url_scheme"] = scheme
94 | environ["wsgi.input"] = wsgi_input
95 | environ["wsgi.input_terminated"] = True
96 | environ["wsgi.errors"] = BytesIO()
97 | environ["wsgi.multithread"] = True
98 | environ["wsgi.multiprocess"] = False
99 | environ["wsgi.run_once"] = False
100 |
101 | return environ
102 |
103 |
104 | def set_environment_headers(environ: "WSGIEnvironment", headers: Union[Dict, Headers]):
105 | # Collect all the headers to set
106 | # (this might be is accessing the environment, this needs to be done before removing the items from the env)
107 | new_headers = {}
108 | for k, v in headers.items():
109 | name = k.upper().replace("-", "_")
110 |
111 | if name not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
112 | name = f"HTTP_{name}"
113 |
114 | val = v
115 | if name in new_headers:
116 | val = new_headers[name] + "," + val
117 |
118 | new_headers[name] = val
119 |
120 | # Clear the HTTP headers in the env
121 | header_keys = [name for name in environ if name.startswith("HTTP_")]
122 | for name in header_keys:
123 | environ.pop(name, None)
124 |
125 | # Set the new headers in the env
126 | for k, v in new_headers.items():
127 | environ[k] = v
128 |
129 |
130 | class Request(WerkzeugRequest):
131 | """
132 | An HTTP request object. This is (and should remain) a drop-in replacement for werkzeug's WSGI
133 | compliant Request objects. It allows simple sans-IO requests outside a web server environment.
134 |
135 | DO NOT add methods that are not also part of werkzeug.wrappers.request.Request object.
136 | """
137 |
138 | def __init__(
139 | self,
140 | method: str = "GET",
141 | path: str = "",
142 | headers: Union[Mapping, Headers] = None,
143 | body: Union[bytes, str] = None,
144 | scheme: str = "http",
145 | root_path: str = "/",
146 | query_string: Union[bytes, str] = b"",
147 | remote_addr: str = None,
148 | server: Optional[Tuple[str, Optional[int]]] = None,
149 | raw_path: str = None,
150 | ):
151 | # decode query string if necessary (latin-1 is what werkzeug would expect)
152 | if isinstance(query_string, bytes):
153 | query_string = query_string.decode("latin-1")
154 |
155 | # create the WSGIEnvironment dictionary that represents this request
156 | environ = dummy_wsgi_environment(
157 | method=method,
158 | path=path,
159 | headers=headers,
160 | body=body,
161 | scheme=scheme,
162 | root_path=root_path,
163 | query_string=query_string,
164 | remote_addr=remote_addr,
165 | server=server,
166 | raw_uri=raw_path,
167 | )
168 |
169 | super(Request, self).__init__(environ)
170 |
171 | # restore originally passed headers:
172 | # werkzeug normally provides read-only access to headers set in the WSGIEnvironment through the EnvironHeaders
173 | # class, here we make them mutable again. moreover, WSGI header encoding conflicts with RFC2616. see this github
174 | # issue for a discussion: https://github.com/pallets/werkzeug/issues/940
175 | headers = Headers(headers)
176 | # these two headers are treated separately in the WSGI environment, so we extract them if necessary
177 | for h in ["content-length", "content-type"]:
178 | if h not in headers and h in self.headers:
179 | headers[h] = self.headers[h]
180 | self.headers = headers
181 |
182 | @classmethod
183 | def application(cls, *args):
184 | # werkzeug's application decorator assumes its Request constructor signature, which our Request doesn't support.
185 | # using ``application`` from our request therefore creates runtime errors. this makes sure no one runs into
186 | # these problems. if we want to support it, we need to create compatibility with werkzeug's Request constructor
187 | raise NotImplementedError
188 |
189 |
190 | def get_raw_path(request) -> str:
191 | """
192 | Returns the raw_path inside the request without the query string. The request can either be a Quart Request
193 | object (that encodes the raw path in request.scope['raw_path']) or a Werkzeug WSGI request (that encodes the raw
194 | URI in request.environ['RAW_URI']).
195 |
196 | :param request: the request object
197 | :return: the raw path if any
198 | """
199 | if hasattr(request, "environ"):
200 | # werkzeug/flask request (already a string, and contains the query part)
201 | # we need to parse it, because the RAW_URI can contain a full URL if it is specified in the HTTP request
202 | raw_uri: str = request.environ.get("RAW_URI", "")
203 | if raw_uri.startswith("//"):
204 | # if the RAW_URI starts with double slashes, `urlparse` will fail to decode it as path only
205 | # it also means that we already only have the path, so we just need to remove the query string
206 | return raw_uri.split("?")[0]
207 | return urlparse(raw_uri or request.path).path
208 |
209 | if hasattr(request, "scope"):
210 | # quart request raw_path comes as bytes, and without the query part
211 | return request.scope.get("raw_path", request.path).decode("utf-8")
212 |
213 | raise ValueError("cannot extract raw path from request object %s" % request)
214 |
215 |
216 | def get_full_raw_path(request: WerkzeugRequest) -> str:
217 | """
218 | Returns the full raw request path (with original URL encoding), including the query string.
219 | This is _not_ equal to request.url, since there the path section would be url-encoded while the query part will be
220 | (partly) url-decoded.
221 | """
222 | query_str = f"?{request.query_string.decode('latin1')}" if request.query_string else ""
223 | raw_path = f"{get_raw_path(request)}{query_str}"
224 | return raw_path
225 |
226 |
227 | def get_raw_base_url(request: Request) -> str:
228 | """
229 | Returns the base URL (with original URL encoding). This does not include the query string.
230 | This is the encoding-preserving equivalent to `request.base_url`.
231 | """
232 | return get_raw_current_url(
233 | request.scheme, request.host, request.root_path, get_raw_path(request)
234 | )
235 |
236 |
237 | def get_raw_current_url(
238 | scheme: str,
239 | host: str,
240 | root_path: Optional[str] = None,
241 | path: Optional[str] = None,
242 | query_string: Optional[bytes] = None,
243 | ) -> str:
244 | """
245 | `werkzeug.sansio.utils.get_current_url` implementation without
246 | any encoding dances.
247 | The given paths and query string are directly used without any encodings
248 | (to avoid any double encodings).
249 | It can be used to recreate the raw URL for a request.
250 |
251 | :param scheme: The protocol the request used, like ``"https"``.
252 | :param host: The host the request was made to. See :func:`get_host`.
253 | :param root_path: Prefix that the application is mounted under. This
254 | is prepended to ``path``.
255 | :param path: The path part of the URL after ``root_path``.
256 | :param query_string: The portion of the URL after the "?".
257 | """
258 | url = [scheme, "://", host]
259 |
260 | if root_path is None:
261 | url.append("/")
262 | return "".join(url)
263 |
264 | url.append(root_path.rstrip("/"))
265 | url.append("/")
266 |
267 | if path is None:
268 | return "".join(url)
269 |
270 | url.append(path.lstrip("/"))
271 |
272 | if query_string:
273 | url.append("?")
274 | url.append(query_string)
275 |
276 | return "".join(url)
277 |
278 |
279 | def restore_payload(request: Request) -> bytes:
280 | """
281 | This method takes a request and restores the original payload from it even after it has been consumed. A werkzeug
282 | request consumes form/multipart data from the stream, and here we are serializing it back to a request a client
283 | could make . This is a useful method to have when we are proxying requests, i.e., create outgoing requests from
284 | incoming requests that were subject to previous parsing.
285 |
286 | TODO: this construct is not great and will definitely be a source of trouble. but something like this will
287 | inevitably become the the basis of proxying werkzeug requests. The alternative is to build our own request object
288 | that memoizes the original payload before parsing.
289 | """
290 | if request.shallow:
291 | return b""
292 |
293 | data = request.data
294 |
295 | if request.method != "POST":
296 | return data
297 |
298 | if request.mimetype == "multipart/form-data":
299 | boundary = request.content_type.split("=")[1]
300 |
301 | fields = MultiDict()
302 | fields.update(request.form)
303 | fields.update(request.files)
304 |
305 | _, data_files = encode_multipart(fields, boundary)
306 | data += data_files
307 |
308 | elif request.mimetype == "application/x-www-form-urlencoded":
309 | data += urlencode(list(request.form.items(multi=True))).encode("utf-8")
310 |
311 | return data
312 |
--------------------------------------------------------------------------------