├── tests ├── __init__.py ├── tools │ ├── __init__.py │ ├── rsocket.graphqls │ ├── fixtures_aiohttp.py │ ├── helpers.py │ ├── fixtures_graphql.py │ ├── helpers_aiohttp.py │ ├── http3_client.py │ ├── http_app.py │ ├── fixtures_quart.py │ ├── fixtures_websockets.py │ ├── fixtures_http3.py │ └── fixtures_shared.py ├── rsocket │ ├── __init__.py │ ├── cloudevents │ │ ├── __init__.py │ │ └── test_route_cloud_events.py │ ├── test_stream_data_mimetype.py │ ├── test_stream_helpers.py │ ├── misbehaving_rsocket.py │ ├── test_helpers.py │ ├── test_mimetype.py │ ├── test_without_server.py │ ├── test_payload.py │ ├── test_request_router.py │ ├── test_frame_decode.py │ ├── test_request_routing_decode_payload.py │ ├── test_setup.py │ ├── test_unimplemented_handler.py │ ├── test_extentions.py │ ├── test_metadata_push.py │ └── test_internal.py ├── rx_support │ ├── __init__.py │ └── test_rx_handler.py ├── reactivestreams │ ├── __init__.py │ └── test_reactivestreams.py ├── test_reactivex │ ├── __init__.py │ ├── test_reactivex_handler.py │ ├── test_helper.py │ └── test_concurrency.py └── test_integrations │ ├── __init__.py │ └── test_cloudevents.py ├── examples ├── __init__.py ├── bugs │ ├── __init__.py │ ├── issue_165 │ │ ├── __init__.py │ │ ├── 165_client.py │ │ └── 165_server.py │ ├── issue_290 │ │ └── __init__.py │ └── streaming_file │ │ ├── __init__.py │ │ ├── my_file │ │ ├── client.py │ │ └── server.py ├── graphql │ ├── __init__.py │ ├── java │ │ ├── src │ │ │ ├── main │ │ │ │ ├── resources │ │ │ │ │ ├── application.properties │ │ │ │ │ └── graphql │ │ │ │ │ │ └── rsocket.graphqls │ │ │ │ └── java │ │ │ │ │ └── io │ │ │ │ │ └── rsocket │ │ │ │ │ └── pythontest │ │ │ │ │ └── ServerWithGraphQL.java │ │ │ └── test │ │ │ │ └── java │ │ │ │ └── io │ │ │ │ └── rsocket │ │ │ │ └── pythontest │ │ │ │ └── ServerWithGraphQLTests.java │ │ ├── .gitignore │ │ └── pom.xml │ ├── rsocket.graphqls │ └── server_graphql.py ├── cloudevents │ ├── __init__.py │ ├── server_cloudevents.py │ └── client_cloudevents.py ├── tutorial │ ├── __init__.py │ ├── step1 │ │ ├── __init__.py │ │ ├── readme.md │ │ ├── chat_server.py │ │ └── chat_client.py │ ├── step2 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── chat_client.py │ ├── step3 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── step4 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── step5 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── step6 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── step7 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── step8 │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── reactivex │ │ ├── __init__.py │ │ ├── readme.md │ │ └── shared.py │ ├── step0 │ │ ├── readme.md │ │ ├── chat_client.py │ │ └── chat_server.py │ └── test_tutorials.py ├── django_channels │ ├── __init__.py │ └── django_rsocket │ │ ├── __init__.py │ │ ├── .gitignore │ │ ├── django_rsocket │ │ ├── __init__.py │ │ ├── routing.py │ │ ├── wsgi.py │ │ ├── asgi.py │ │ ├── urls.py │ │ └── consumers.py │ │ └── manage.py ├── example_fixtures.py ├── java │ ├── src │ │ └── main │ │ │ ├── java │ │ │ └── io │ │ │ │ └── rsocket │ │ │ │ └── pythontest │ │ │ │ ├── Fixtures.java │ │ │ │ ├── Server.java │ │ │ │ ├── ServerWithFragmentation.java │ │ │ │ ├── RoutingRSocket.java │ │ │ │ ├── ClientWebsocket.java │ │ │ │ ├── ClientWebsocketHandler.java │ │ │ │ ├── SimpleRSocketAcceptor.java │ │ │ │ └── ClientChannelHandler.java │ │ │ └── resources │ │ │ └── logback.xml │ └── pom.xml ├── client_website_example.py ├── server_quart_websocket.py ├── server_fastapi_websocket.py ├── server_website_example.py ├── shared_tests.py ├── client.py ├── client_springboot.py ├── client_quic.py ├── server.py ├── client_websocket.py ├── server_with_lease.py ├── server_websockets.py ├── rsocket_in_aiohttp.py ├── response_stream.py ├── server_quic.py ├── server_aiohttp_websocket.py ├── client_reconnect.py ├── response_channel.py ├── certificates │ └── ssl_cert.pem └── cli_demo_server │ └── server.py ├── performance ├── __init__.py └── conftest.py ├── rsocket ├── cli │ └── __init__.py ├── graphql │ └── __init__.py ├── handlers │ ├── __init__.py │ ├── interfaces.py │ ├── request_cahnnel_responder.py │ ├── request_response_responder.py │ ├── request_channel_requester.py │ ├── request_response_requester.py │ └── request_stream_requester.py ├── routing │ └── __init__.py ├── streams │ ├── __init__.py │ ├── exceptions.py │ ├── backpressureapi.py │ ├── empty_stream.py │ ├── error_stream.py │ ├── helpers.py │ ├── null_subscrier.py │ ├── stream_from_async_generator.py │ └── stream_handler.py ├── awaitable │ ├── __init__.py │ ├── collector_subscriber.py │ └── awaitable_rsocket.py ├── cloudevents │ ├── __init__.py │ └── serialize.py ├── extensions │ ├── __init__.py │ ├── routing.py │ ├── mimetype.py │ ├── authentication_types.py │ ├── composite_metadata_item.py │ ├── tagging.py │ ├── authentication_content.py │ └── helpers.py ├── load_balancer │ ├── __init__.py │ ├── load_balancer_strategy.py │ ├── random_client.py │ ├── round_robin.py │ └── load_balancer_rsocket.py ├── reactivex │ ├── __init__.py │ ├── reactivex_channel.py │ └── subscriber_adapter.py ├── rx_support │ ├── __init__.py │ ├── rx_channel.py │ └── subscriber_adapter.py ├── transports │ ├── __init__.py │ ├── abstract_messaging.py │ ├── websockets_transport.py │ ├── transport.py │ ├── quart_websocket.py │ └── tcp.py ├── __init__.py ├── logger.py ├── disposable.py ├── async_helpers.py ├── datetime_helpers.py ├── local_typing.py ├── error_codes.py ├── fragment.py ├── payload.py ├── rsocket_internal.py ├── frame_parser.py ├── rsocket.py ├── queue_peekable.py └── exceptions.py ├── docs ├── changelog.rst ├── requirements.txt ├── _static │ └── theme_override.css ├── guide.rst ├── quickstart.rst ├── index.rst ├── Makefile ├── make.bat ├── api.rst └── extensions.rst ├── clean.sh ├── pyproject.toml ├── reactivestreams ├── __init__.py ├── publisher.py ├── subscription.py └── subscriber.py ├── .coveragerc ├── .readthedocs.yaml ├── .github ├── dependabot.yml └── workflows │ ├── python-publish.yml │ └── python-package.yml ├── tox.ini ├── LICENSE ├── .gitignore └── requirements.txt /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /performance/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/bugs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/graphql/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/graphql/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/routing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/streams/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/rsocket/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/rx_support/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/cloudevents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/awaitable/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/cloudevents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/load_balancer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/reactivex/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/rx_support/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/transports/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/reactivestreams/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_reactivex/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/bugs/issue_165/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/bugs/issue_290/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/django_channels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step4/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step5/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step6/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step7/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step8/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/rsocket/cloudevents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_integrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/bugs/streaming_file/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/reactivex/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rsocket/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.21' 2 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CHANGELOG.rst 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme==2.0.0 2 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step3/readme.md: -------------------------------------------------------------------------------- 1 | Add private messages. -------------------------------------------------------------------------------- /examples/tutorial/step4/readme.md: -------------------------------------------------------------------------------- 1 | Add channel messages. -------------------------------------------------------------------------------- /examples/tutorial/step5/readme.md: -------------------------------------------------------------------------------- 1 | Add file upload/download. -------------------------------------------------------------------------------- /examples/tutorial/step6/readme.md: -------------------------------------------------------------------------------- 1 | Add server/client statistics. -------------------------------------------------------------------------------- /examples/tutorial/step8/readme.md: -------------------------------------------------------------------------------- 1 | Websocket client/server setup -------------------------------------------------------------------------------- /clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf dist build rsocket.egg-info 4 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/.gitignore: -------------------------------------------------------------------------------- 1 | db.sqlite3 2 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/django_rsocket/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/tutorial/step7/readme.md: -------------------------------------------------------------------------------- 1 | Add client side request handler. -------------------------------------------------------------------------------- /examples/tutorial/step1/readme.md: -------------------------------------------------------------------------------- 1 | Add request routing for login endpoint. -------------------------------------------------------------------------------- /examples/tutorial/step2/readme.md: -------------------------------------------------------------------------------- 1 | Add server side session for logged in user. -------------------------------------------------------------------------------- /examples/tutorial/reactivex/readme.md: -------------------------------------------------------------------------------- 1 | Add server/client reactivex implementation -------------------------------------------------------------------------------- /rsocket/streams/exceptions.py: -------------------------------------------------------------------------------- 1 | class FinishedIterator(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /examples/tutorial/step0/readme.md: -------------------------------------------------------------------------------- 1 | Basic server/client setup, single response endpoint. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /rsocket/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def logger(): 5 | return logging.getLogger('pyrsocket') 6 | -------------------------------------------------------------------------------- /docs/_static/theme_override.css: -------------------------------------------------------------------------------- 1 | .descclassname { 2 | opacity: 30%; 3 | } 4 | 5 | .descclassname:hover { 6 | opacity: 100%; 7 | } 8 | -------------------------------------------------------------------------------- /docs/guide.rst: -------------------------------------------------------------------------------- 1 | Guide 2 | ===== 3 | 4 | .. autosummary:: 5 | :toctree: generated 6 | 7 | A guide is available at https://rsocket.io/guides/rsocket-py 8 | 9 | -------------------------------------------------------------------------------- /examples/graphql/java/src/main/resources/application.properties: -------------------------------------------------------------------------------- 1 | spring.rsocket.server.port=9191 2 | spring.graphql.rsocket.mapping=graphql 3 | server.port=0 4 | -------------------------------------------------------------------------------- /examples/example_fixtures.py: -------------------------------------------------------------------------------- 1 | from rsocket.frame_helpers import ensure_bytes 2 | 3 | large_data1 = b''.join(ensure_bytes(str(i)) + b'123456789' for i in range(50)) 4 | -------------------------------------------------------------------------------- /rsocket/disposable.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Disposable(metaclass=abc.ABCMeta): 5 | 6 | @abc.abstractmethod 7 | def dispose(self): 8 | ... 9 | -------------------------------------------------------------------------------- /rsocket/async_helpers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | async def async_range(count: int): 5 | for i in range(count): 6 | yield i 7 | await asyncio.sleep(0) 8 | -------------------------------------------------------------------------------- /tests/reactivestreams/test_reactivestreams.py: -------------------------------------------------------------------------------- 1 | from reactivestreams import reactivestreams 2 | 3 | 4 | def test_reactivestreams(): 5 | assert reactivestreams() == 'reactivestreams-0.1' 6 | -------------------------------------------------------------------------------- /docs/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quick start 2 | =========== 3 | 4 | .. autosummary:: 5 | :toctree: generated 6 | 7 | A quick getting started guide is available at https://rsocket.io/guides/rsocket-py/simple -------------------------------------------------------------------------------- /rsocket/streams/backpressureapi.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class BackpressureApi(metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def initial_request_n(self, n: int): 7 | ... 8 | -------------------------------------------------------------------------------- /rsocket/datetime_helpers.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | 4 | def to_milliseconds(period: timedelta) -> int: 5 | return round(period.total_seconds() * 1000) + round(period.microseconds / 1000) 6 | -------------------------------------------------------------------------------- /rsocket/streams/empty_stream.py: -------------------------------------------------------------------------------- 1 | from rsocket.helpers import DefaultPublisherSubscription 2 | 3 | 4 | class EmptyStream(DefaultPublisherSubscription): 5 | def request(self, n: int): 6 | self._subscriber.on_complete() 7 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/django_rsocket/routing.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | 3 | from .consumers import RSocketConsumer 4 | 5 | websocket_urlpatterns = [ 6 | path('rsocket', RSocketConsumer.as_asgi()), 7 | ] 8 | -------------------------------------------------------------------------------- /examples/graphql/java/src/main/resources/graphql/rsocket.graphqls: -------------------------------------------------------------------------------- 1 | type Query { 2 | greeting: Greeting 3 | } 4 | 5 | type Subscription { 6 | greetings: Greeting 7 | } 8 | 9 | type Greeting { 10 | message: String 11 | } -------------------------------------------------------------------------------- /rsocket/handlers/interfaces.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rsocket.frame import Frame 4 | 5 | 6 | class Requester(metaclass=abc.ABCMeta): 7 | 8 | @abc.abstractmethod 9 | def frame_received(self, frame: Frame): 10 | ... 11 | -------------------------------------------------------------------------------- /reactivestreams/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ReactiveStreams 3 | ~~~~~~~~~~~~~~~ 4 | 5 | Abstract base class definitions for ReactiveStreams Publisher/Subsciber. 6 | """ 7 | 8 | 9 | def reactivestreams(): 10 | return 'reactivestreams-0.1' 11 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch=True 3 | source=rsocket 4 | 5 | [coverage:report] 6 | skip_empty = true 7 | exclude_lines = 8 | pragma: no cover 9 | pass 10 | @(abc\.)?abstractmethod 11 | 12 | [html] 13 | directory = coverage_report_html 14 | -------------------------------------------------------------------------------- /tests/tools/rsocket.graphqls: -------------------------------------------------------------------------------- 1 | type Query { 2 | greeting: Greeting 3 | getMessage: String 4 | } 5 | 6 | type Subscription { 7 | greetings: Greeting 8 | } 9 | 10 | type Greeting { 11 | message: String 12 | } 13 | 14 | type Mutation { 15 | setMessage(message: String): Greeting 16 | } 17 | -------------------------------------------------------------------------------- /examples/graphql/rsocket.graphqls: -------------------------------------------------------------------------------- 1 | type Query { 2 | greeting: Greeting 3 | getMessage: String 4 | } 5 | 6 | type Subscription { 7 | greetings: Greeting 8 | } 9 | 10 | type Greeting { 11 | message: String 12 | } 13 | 14 | type Mutation { 15 | setMessage(message: String): Greeting 16 | } 17 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | - requirements: requirements.txt 12 | - method: pip 13 | path: . 14 | 15 | 16 | sphinx: 17 | configuration: docs/conf.py 18 | 19 | -------------------------------------------------------------------------------- /rsocket/streams/error_stream.py: -------------------------------------------------------------------------------- 1 | from rsocket.helpers import DefaultPublisherSubscription 2 | 3 | 4 | class ErrorStream(DefaultPublisherSubscription): 5 | 6 | def __init__(self, exception: Exception): 7 | self._exception = exception 8 | 9 | def request(self, n: int): 10 | self._subscriber.on_error(self._exception) 11 | -------------------------------------------------------------------------------- /rsocket/streams/helpers.py: -------------------------------------------------------------------------------- 1 | from asyncio import Queue 2 | 3 | 4 | async def async_generator_from_queue(queue: Queue, stop_value=None): 5 | while True: 6 | value = await queue.get() 7 | 8 | if value is stop_value: 9 | return 10 | else: 11 | yield value 12 | queue.task_done() 13 | -------------------------------------------------------------------------------- /examples/graphql/java/src/test/java/io/rsocket/pythontest/ServerWithGraphQLTests.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import org.junit.jupiter.api.Test; 4 | import org.springframework.boot.test.context.SpringBootTest; 5 | 6 | @SpringBootTest 7 | class ServerWithGraphQLTests { 8 | 9 | @Test 10 | void contextLoads() { 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /rsocket/local_typing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from typing import Union 4 | 5 | if sys.version_info < (3, 9): # here to prevent deprecation warnings on cross version python compatible code. 6 | from typing import Awaitable 7 | else: 8 | from collections.abc import Awaitable 9 | 10 | ByteTypes = Union[bytes, bytearray] 11 | 12 | __all__ = [ 13 | 'Awaitable', 14 | 'ByteTypes' 15 | ] 16 | -------------------------------------------------------------------------------- /rsocket/extensions/routing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional 2 | 3 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 4 | from rsocket.extensions.tagging import TaggingMetadata 5 | 6 | 7 | class RoutingMetadata(TaggingMetadata): 8 | 9 | def __init__(self, tags: Optional[List[Union[bytes, str]]] = None): 10 | super().__init__(WellKnownMimeTypes.MESSAGE_RSOCKET_ROUTING.value.name, tags) 11 | -------------------------------------------------------------------------------- /rsocket/error_codes.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum, unique 2 | 3 | 4 | @unique 5 | class ErrorCode(IntEnum): 6 | INVALID_SETUP = 0x001, 7 | UNSUPPORTED_SETUP = 0x002, 8 | REJECTED_SETUP = 0x003, 9 | REJECTED_RESUME = 0x004, 10 | CONNECTION_ERROR = 0x101, 11 | CONNECTION_ERROR_NO_RETRY = 0x102, 12 | APPLICATION_ERROR = 0x201, 13 | REJECTED = 0x202, 14 | CANCELED = 0x203, 15 | INVALID = 0x204, 16 | RESERVED = 0xFFFFFFFF 17 | -------------------------------------------------------------------------------- /rsocket/rx_support/rx_channel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union, Callable 3 | 4 | from rx.core.typing import Observable, Observer, Subject 5 | 6 | from rsocket.frame import MAX_REQUEST_N 7 | 8 | 9 | @dataclass(frozen=True) 10 | class RxChannel: 11 | observable: Optional[Union[Observable, Callable[[Subject], Observable]]] = None 12 | observer: Optional[Observer] = None 13 | limit_rate: int = MAX_REQUEST_N 14 | -------------------------------------------------------------------------------- /rsocket/reactivex/reactivex_channel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union, Callable 3 | 4 | from reactivex import Observable, Observer, Subject 5 | 6 | from rsocket.frame import MAX_REQUEST_N 7 | 8 | 9 | @dataclass(frozen=True) 10 | class ReactivexChannel: 11 | observable: Optional[Union[Observable, Callable[[Subject], Observable]]] = None 12 | observer: Optional[Observer] = None 13 | limit_rate: int = MAX_REQUEST_N 14 | -------------------------------------------------------------------------------- /rsocket/streams/null_subscrier.py: -------------------------------------------------------------------------------- 1 | from reactivestreams.subscriber import Subscriber 2 | from reactivestreams.subscription import Subscription 3 | 4 | 5 | class NullSubscriber(Subscriber): 6 | def on_next(self, value, is_complete=False): 7 | pass 8 | 9 | def on_error(self, exception: Exception): 10 | pass 11 | 12 | def on_complete(self): 13 | pass 14 | 15 | def on_subscribe(self, subscription: Subscription): 16 | pass 17 | -------------------------------------------------------------------------------- /reactivestreams/publisher.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from reactivestreams.subscriber import Subscriber 4 | 5 | 6 | class Publisher(metaclass=abc.ABCMeta): 7 | """ 8 | Handles event for subscription to a subscriber 9 | """ 10 | 11 | @abc.abstractmethod 12 | def subscribe(self, subscriber: Subscriber): 13 | ... 14 | 15 | 16 | class DefaultPublisher(Publisher): 17 | def subscribe(self, subscriber: Subscriber): 18 | self._subscriber = subscriber 19 | -------------------------------------------------------------------------------- /rsocket/extensions/mimetype.py: -------------------------------------------------------------------------------- 1 | class WellKnownType: 2 | __slots__ = ( 3 | 'name', 4 | 'id' 5 | ) 6 | 7 | def __init__(self, name: bytes, id_: int): 8 | self.name = name 9 | self.id = id_ 10 | 11 | def __eq__(self, other): 12 | return self.name == other.name and self.id == other.id 13 | 14 | def __hash__(self): 15 | return hash((self.id, self.name)) 16 | 17 | 18 | class WellKnownMimeType(WellKnownType): 19 | pass 20 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/Fixtures.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import java.util.StringJoiner; 4 | import java.util.stream.IntStream; 5 | 6 | public class Fixtures { 7 | public static String largeData() { 8 | final var joiner = new StringJoiner(""); 9 | 10 | IntStream.range(0, 50) 11 | .mapToObj(i -> i + "123456789") 12 | .forEach(joiner::add); 13 | 14 | return joiner.toString(); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /rsocket/load_balancer/load_balancer_strategy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rsocket.rsocket import RSocket 4 | 5 | 6 | class LoadBalancerStrategy(metaclass=abc.ABCMeta): 7 | """ 8 | Base class for load balancer strategies. 9 | """ 10 | 11 | @abc.abstractmethod 12 | def select(self) -> RSocket: 13 | ... 14 | 15 | @abc.abstractmethod 16 | async def connect(self): 17 | ... 18 | 19 | @abc.abstractmethod 20 | async def close(self): 21 | ... 22 | -------------------------------------------------------------------------------- /reactivestreams/subscription.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Subscription(metaclass=ABCMeta): 5 | """ 6 | Backpressure stream control. 7 | """ 8 | 9 | @abstractmethod 10 | def request(self, n: int): 11 | ... 12 | 13 | @abstractmethod 14 | def cancel(self): 15 | ... 16 | 17 | 18 | class DefaultSubscription(Subscription): 19 | def request(self, n: int): 20 | pass 21 | 22 | def cancel(self): 23 | pass 24 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/django_rsocket/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for django_rsocket project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/5.2/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_rsocket.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /rsocket/rx_support/subscriber_adapter.py: -------------------------------------------------------------------------------- 1 | from rx.core.typing import Observer 2 | 3 | from reactivestreams.subscriber import Subscriber 4 | 5 | 6 | class SubscriberAdapter(Observer): 7 | def __init__(self, subscriber: Subscriber): 8 | self._subscriber = subscriber 9 | 10 | def on_next(self, value): 11 | self._subscriber.on_next(value) 12 | 13 | def on_error(self, error): 14 | self._subscriber.on_error(error) 15 | 16 | def on_completed(self): 17 | self._subscriber.on_complete() 18 | -------------------------------------------------------------------------------- /tests/rx_support/test_rx_handler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rsocket.payload import Payload 4 | from rsocket.rx_support.rx_handler import BaseRxHandler 5 | 6 | 7 | async def test_rx_handler(): 8 | handler = BaseRxHandler() 9 | 10 | with pytest.raises(Exception): 11 | await handler.request_channel(Payload()) 12 | 13 | with pytest.raises(Exception): 14 | await handler.request_response(Payload()) 15 | 16 | with pytest.raises(Exception): 17 | await handler.request_stream(Payload()) 18 | -------------------------------------------------------------------------------- /rsocket/reactivex/subscriber_adapter.py: -------------------------------------------------------------------------------- 1 | from reactivex.abc import ObserverBase 2 | 3 | from reactivestreams.subscriber import Subscriber 4 | 5 | 6 | class SubscriberAdapter(ObserverBase): 7 | def __init__(self, subscriber: Subscriber): 8 | self._subscriber = subscriber 9 | 10 | def on_next(self, value): 11 | self._subscriber.on_next(value) 12 | 13 | def on_error(self, error): 14 | self._subscriber.on_error(error) 15 | 16 | def on_completed(self): 17 | self._subscriber.on_complete() 18 | -------------------------------------------------------------------------------- /tests/test_reactivex/test_reactivex_handler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rsocket.payload import Payload 4 | from rsocket.reactivex.reactivex_handler import BaseReactivexHandler 5 | 6 | 7 | async def test_reactivex_handler(): 8 | handler = BaseReactivexHandler() 9 | 10 | with pytest.raises(Exception): 11 | await handler.request_channel(Payload()) 12 | 13 | with pytest.raises(Exception): 14 | await handler.request_response(Payload()) 15 | 16 | with pytest.raises(Exception): 17 | await handler.request_stream(Payload()) 18 | -------------------------------------------------------------------------------- /examples/graphql/java/.gitignore: -------------------------------------------------------------------------------- 1 | HELP.md 2 | target/ 3 | !.mvn/wrapper/maven-wrapper.jar 4 | !**/src/main/**/target/ 5 | !**/src/test/**/target/ 6 | 7 | ### STS ### 8 | .apt_generated 9 | .classpath 10 | .factorypath 11 | .project 12 | .settings 13 | .springBeans 14 | .sts4-cache 15 | 16 | ### IntelliJ IDEA ### 17 | .idea 18 | *.iws 19 | *.iml 20 | *.ipr 21 | 22 | ### NetBeans ### 23 | /nbproject/private/ 24 | /nbbuild/ 25 | /dist/ 26 | /nbdist/ 27 | /.nb-gradle/ 28 | build/ 29 | !**/src/main/**/build/ 30 | !**/src/test/**/build/ 31 | 32 | ### VS Code ### 33 | .vscode/ 34 | -------------------------------------------------------------------------------- /rsocket/cloudevents/serialize.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from cloudevents.conversion import to_json, from_json 4 | from cloudevents.pydantic import CloudEvent 5 | 6 | from rsocket.payload import Payload 7 | 8 | 9 | def cloud_event_deserialize(cls, payload: Payload) -> Any: 10 | if cls == CloudEvent: 11 | return from_json(CloudEvent, payload.data) 12 | 13 | return payload 14 | 15 | 16 | def cloud_event_serialize(cls, value: Any) -> Payload: 17 | if cls == CloudEvent: 18 | return Payload(to_json(value)) 19 | 20 | return value 21 | -------------------------------------------------------------------------------- /tests/rsocket/test_stream_data_mimetype.py: -------------------------------------------------------------------------------- 1 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 2 | from rsocket.extensions.stream_data_mimetype import StreamDataMimetypes 3 | 4 | 5 | def test_stream_data_mimetypes_equality(): 6 | assert StreamDataMimetypes() == StreamDataMimetypes([]) 7 | assert StreamDataMimetypes([WellKnownMimeTypes.APPLICATION_JSON]) == StreamDataMimetypes( 8 | [WellKnownMimeTypes.APPLICATION_JSON]) 9 | assert StreamDataMimetypes([WellKnownMimeTypes.APPLICATION_JSON]) != StreamDataMimetypes( 10 | [WellKnownMimeTypes.TEXT_PLAIN]) 11 | -------------------------------------------------------------------------------- /tests/rsocket/test_stream_helpers.py: -------------------------------------------------------------------------------- 1 | from asyncio import Queue 2 | 3 | from rsocket.streams.helpers import async_generator_from_queue 4 | 5 | 6 | async def test_async_generator_from_queue(): 7 | queue = Queue() 8 | 9 | for i in range(10): 10 | queue.put_nowait(i) 11 | 12 | queue.put_nowait(None) 13 | 14 | async def collect(): 15 | results = [] 16 | async for i in async_generator_from_queue(queue): 17 | results.append(i) 18 | 19 | return results 20 | 21 | r = await collect() 22 | 23 | assert r == list(range(10)) 24 | -------------------------------------------------------------------------------- /rsocket/transports/abstract_messaging.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import asyncio 3 | 4 | from rsocket.transports.transport import Transport 5 | 6 | 7 | class AbstractMessagingTransport(Transport, metaclass=abc.ABCMeta): 8 | def __init__(self): 9 | super().__init__() 10 | self._incoming_frame_queue = asyncio.Queue() 11 | 12 | async def next_frame_generator(self): 13 | frame = await self._incoming_frame_queue.get() 14 | 15 | if isinstance(frame, Exception): 16 | raise frame 17 | 18 | async def frame_generator(): 19 | yield frame 20 | 21 | return frame_generator() 22 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. rsocket documentation master file 2 | 3 | RSocket 4 | ======= 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | quickstart 10 | guide 11 | api 12 | extensions 13 | changelog 14 | 15 | .. autosummary:: 16 | :toctree: generated 17 | 18 | 19 | The python rsocket package implements the 1.0 version of the RSocket protocol (excluding "resume" functionality) 20 | and is designed for use in python >= 3.8 using asyncio. 21 | 22 | .. note:: 23 | The python package API is not stable. There may be changes until version 1.0.0. 24 | 25 | 26 | Indices and tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /rsocket/fragment.py: -------------------------------------------------------------------------------- 1 | from asyncio import Future 2 | from typing import Optional 3 | 4 | from rsocket.payload import Payload 5 | 6 | 7 | class Fragment(Payload): 8 | __slots__ = ('is_first', 'is_last', 'sent_future') 9 | 10 | def __init__(self, 11 | data: Optional[bytes] = None, 12 | metadata: Optional[bytes] = None, 13 | is_last: Optional[bool] = True, 14 | is_first: Optional[bool] = True, 15 | sent_future: Optional[Future] = None): 16 | super().__init__(data, metadata) 17 | self.is_first = is_first 18 | self.is_last = is_last 19 | self.sent_future = sent_future 20 | -------------------------------------------------------------------------------- /examples/bugs/streaming_file/my_file: -------------------------------------------------------------------------------- 1 | gsdfgsdfg 2 | sdfgsdfgsdfggsdfgsdfg 3 | sdfgsdfgsdfggsdfgsdfg 4 | sdfgsdfgsdfggsdfgsdfg 5 | sdfgsdfgsdfggsdfgsdfg 6 | sdfgsdfgsdfggsdfgsdfg 7 | sdfgsdfgsdfggsdfgsdfg 8 | sdfgsdfgsdfggsdfgsdfg 9 | sdfgsdfgsdfggsdfgsdfg 10 | sdfgsdfgsdfggsdfgsdfg 11 | sdfgsdfgsdfggsdfgsdfg 12 | sdfgsdfgsdfggsdfgsdfg 13 | sdfgsdfgsdfggsdfgsdfg 14 | sdfgsdfgsdfggsdfgsdfg 15 | sdfgsdfgsdfggsdfgsdfg 16 | sdfgsdfgsdfggsdfgsdfg 17 | sdfgsdfgsdfggsdfgsdfg 18 | sdfgsdfgsdfggsdfgsdfg 19 | sdfgsdfgsdfggsdfgsdfg 20 | sdfgsdfgsdfggsdfgsdfg 21 | sdfgsdfgsdfggsdfgsdfg 22 | sdfgsdfgsdfggsdfgsdfg 23 | sdfgsdfgsdfggsdfgsdfg 24 | sdfgsdfgsdfggsdfgsdfg 25 | sdfgsdfgsdfggsdfgsdfg 26 | sdfgsdfgsdfgvvv -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/rsocket/misbehaving_rsocket.py: -------------------------------------------------------------------------------- 1 | from rsocket.frame import Frame 2 | from rsocket.transports.transport import Transport 3 | 4 | 5 | class MisbehavingRSocket: 6 | def __init__(self, transport: Transport): 7 | self._transport = transport 8 | 9 | async def send_frame(self, frame: Frame): 10 | await self._transport.send_frame(frame) 11 | 12 | 13 | class BrokenFrame: 14 | def __init__(self, content: bytes): 15 | self._content = content 16 | 17 | def serialize(self) -> bytes: 18 | return self._content 19 | 20 | 21 | class UnknownFrame(Frame): 22 | def __init__(self): 23 | super().__init__(34) 24 | 25 | def parse(self, buffer: bytes, offset: int): 26 | pass 27 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/django_rsocket/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for django_rsocket project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/5.2/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | from channels.routing import ProtocolTypeRouter, URLRouter 15 | from .routing import websocket_urlpatterns 16 | 17 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_rsocket.settings') 18 | 19 | application = ProtocolTypeRouter({ 20 | "http": get_asgi_application(), 21 | "websocket": URLRouter(websocket_urlpatterns), 22 | }) 23 | -------------------------------------------------------------------------------- /examples/tutorial/step3/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from typing import Optional, TypeVar, Type 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | 15 | 16 | def encode_dataclass(obj): 17 | return ensure_bytes(json.dumps(obj.__dict__)) 18 | 19 | 20 | def dataclass_to_payload(obj) -> Payload: 21 | return Payload(encode_dataclass(obj)) 22 | 23 | 24 | T = TypeVar('T') 25 | 26 | 27 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 28 | return cls(**json.loads(utf8_decode(data))) 29 | -------------------------------------------------------------------------------- /examples/tutorial/step0/chat_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from rsocket.helpers import single_transport_provider, utf8_decode 5 | from rsocket.payload import Payload 6 | from rsocket.rsocket_client import RSocketClient 7 | from rsocket.transports.tcp import TransportTCP 8 | 9 | 10 | async def main(): 11 | connection = await asyncio.open_connection('localhost', 6565) 12 | 13 | async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client: 14 | response = await client.request_response(Payload(data=b'George')) 15 | 16 | print(f"Server response: {utf8_decode(response.data)}") 17 | 18 | 19 | if __name__ == '__main__': 20 | logging.basicConfig(level=logging.INFO) 21 | asyncio.run(main()) 22 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py38,py39,py310,py311 3 | 4 | [testenv] 5 | deps = Rx==3.2.0 6 | aiohttp==3.8.4 7 | aioquic==0.9.20 8 | asyncclick==8.1.3.4 9 | asyncstdlib==3.10.6 10 | coverage==6.5.0 11 | coveralls==3.3.1 12 | decoy==2.0.1 13 | flake8==5.0.4 14 | pytest-asyncio==0.21.0 15 | pytest-cov==4.0.0 16 | pytest-profiling==1.7.0 17 | pytest-rerunfailures==11.1.2 18 | pytest-timeout==2.1.0 19 | pytest-xdist==3.2.1 20 | pytest==7.3.1 21 | quart==0.18.4 22 | reactivex==4.0.4 23 | starlette==0.26.1 24 | cbitstruct==1.0.9 25 | cloudevents==1.9.0 26 | pydantic==1.10.7 27 | 28 | commands = pytest --cov-report=html --cov --ignore=examples 29 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | """Run administrative tasks.""" 9 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_rsocket.settings') 10 | try: 11 | from django.core.management import execute_from_command_line 12 | except ImportError as exc: 13 | raise ImportError( 14 | "Couldn't import Django. Are you sure it's installed and " 15 | "available on your PYTHONPATH environment variable? Did you " 16 | "forget to activate a virtual environment?" 17 | ) from exc 18 | execute_from_command_line(sys.argv) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /tests/tools/fixtures_aiohttp.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def aiohttp_raw_server(event_loop: asyncio.BaseEventLoop, unused_tcp_port): 8 | try: 9 | from aiohttp.test_utils import RawTestServer 10 | except ModuleNotFoundError: 11 | yield None 12 | return 13 | 14 | servers = [] 15 | 16 | try: 17 | async def go(handler): 18 | server = RawTestServer(handler, port=unused_tcp_port) 19 | await server.start_server() 20 | servers.append(server) 21 | return server 22 | 23 | yield go 24 | finally: 25 | async def finalize() -> None: 26 | while servers: 27 | await servers.pop().close() 28 | 29 | event_loop.run_until_complete(finalize()) 30 | -------------------------------------------------------------------------------- /performance/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def pytest_configure(config): 5 | config.addinivalue_line("markers", "performance: marks performance tests") 6 | 7 | 8 | def setup_logging(level=logging.DEBUG, use_file: bool = False): 9 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | 11 | console_handler = logging.StreamHandler() 12 | console_handler.setFormatter(formatter) 13 | console_handler.setLevel(level) 14 | 15 | handlers = [console_handler] 16 | 17 | if use_file: 18 | file_handler = logging.FileHandler('tests.log') 19 | file_handler.setFormatter(formatter) 20 | file_handler.setLevel(level) 21 | handlers.append(file_handler) 22 | 23 | logging.basicConfig(level=level, handlers=handlers) 24 | 25 | 26 | setup_logging(logging.ERROR) 27 | -------------------------------------------------------------------------------- /examples/client_website_example.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from rsocket.helpers import single_transport_provider 5 | from rsocket.payload import Payload 6 | from rsocket.rsocket_client import RSocketClient 7 | from rsocket.rx_support.rx_rsocket import RxRSocket 8 | from rsocket.transports.tcp import TransportTCP 9 | 10 | 11 | async def main(): 12 | connection = await asyncio.open_connection('localhost', 7878) 13 | 14 | async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client: 15 | 16 | rx_client = RxRSocket(client) 17 | payload = Payload(b'Hello World') 18 | 19 | result = await rx_client.request_response(payload).pipe() 20 | 21 | logging.info(result.data) 22 | 23 | 24 | if __name__ == '__main__': 25 | logging.basicConfig(level=logging.INFO) 26 | asyncio.run(main()) 27 | -------------------------------------------------------------------------------- /examples/server_quart_websocket.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from quart import Quart 5 | 6 | from rsocket.helpers import create_future 7 | from rsocket.local_typing import Awaitable 8 | from rsocket.payload import Payload 9 | from rsocket.request_handler import BaseRequestHandler 10 | from rsocket.transports.quart_websocket import websocket_handler 11 | 12 | app = Quart(__name__) 13 | 14 | 15 | class Handler(BaseRequestHandler): 16 | 17 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 18 | return create_future(Payload(b'pong')) 19 | 20 | 21 | @app.websocket("/") 22 | async def ws(): 23 | await websocket_handler(handler_factory=Handler) 24 | 25 | 26 | if __name__ == "__main__": 27 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 28 | logging.basicConfig(level=logging.DEBUG) 29 | app.run(port=port) 30 | -------------------------------------------------------------------------------- /rsocket/handlers/request_cahnnel_responder.py: -------------------------------------------------------------------------------- 1 | from rsocket.frame import Frame, RequestChannelFrame 2 | from rsocket.handlers.request_cahnnel_common import RequestChannelCommon 3 | 4 | 5 | class RequestChannelResponder(RequestChannelCommon): 6 | 7 | def setup(self): 8 | super().setup() 9 | 10 | def frame_received(self, frame: Frame): 11 | if isinstance(frame, RequestChannelFrame): 12 | self.setup() 13 | 14 | if self.subscriber.subscription is None: 15 | self.socket.send_complete(self.stream_id) 16 | self.mark_completed_and_finish(sent=True) 17 | else: 18 | self.subscriber.subscription.request(frame.initial_request_n) 19 | 20 | if frame.flags_complete: 21 | self._complete_remote_subscriber() 22 | 23 | else: 24 | super().frame_received(frame) 25 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/django_rsocket/urls.py: -------------------------------------------------------------------------------- 1 | """ 2 | URL configuration for django_rsocket project. 3 | 4 | The `urlpatterns` list routes URLs to views. For more information please see: 5 | https://docs.djangoproject.com/en/5.2/topics/http/urls/ 6 | Examples: 7 | Function views 8 | 1. Add an import: from my_app import views 9 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 10 | Class-based views 11 | 1. Add an import: from other_app.views import Home 12 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 13 | Including another URLconf 14 | 1. Import the include() function: from django.urls import include, path 15 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 16 | """ 17 | from django.contrib import admin 18 | from django.urls import path 19 | 20 | urlpatterns = [ 21 | path('admin/', admin.site.urls), 22 | ] 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/tools/helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import datetime 3 | from typing import Awaitable, Any 4 | 5 | from cryptography.hazmat.primitives import serialization 6 | 7 | 8 | def quic_client_configuration(certificate, **kwargs): 9 | from aioquic.quic.configuration import QuicConfiguration 10 | 11 | client_configuration = QuicConfiguration( 12 | is_client=True, 13 | **kwargs 14 | ) 15 | ca_data = certificate.public_bytes(serialization.Encoding.PEM) 16 | client_configuration.load_verify_locations(cadata=ca_data, cafile=None) 17 | return client_configuration 18 | 19 | 20 | @dataclass 21 | class MeasureTime: 22 | result: Any 23 | delta: float 24 | 25 | 26 | async def measure_time(coroutine: Awaitable) -> MeasureTime: 27 | start = datetime.now() 28 | result = await coroutine 29 | return MeasureTime(result, (datetime.now() - start).total_seconds()) 30 | -------------------------------------------------------------------------------- /rsocket/payload.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from rsocket.frame_helpers import safe_len 4 | from rsocket.local_typing import ByteTypes 5 | 6 | 7 | class Payload: 8 | """ 9 | A response/stream message (upstream or downstream). Contains data and metadata, both `bytes`. 10 | 11 | :param data: data segment of payload 12 | :param metadata: metadata segment of payload 13 | """ 14 | 15 | __slots__ = ('data', 'metadata') 16 | 17 | def __init__(self, data: Optional[ByteTypes] = None, metadata: Optional[ByteTypes] = None): 18 | self.data = data 19 | self.metadata = metadata 20 | 21 | def __str__(self): 22 | return f"" 23 | 24 | def __eq__(self, other): 25 | return self.data == other.data and self.metadata == other.metadata 26 | 27 | def __repr__(self): 28 | return f"Payload({repr(self.data)}, {repr(self.metadata)})" 29 | -------------------------------------------------------------------------------- /examples/django_channels/django_rsocket/django_rsocket/consumers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from channels.routing import ProtocolTypeRouter, URLRouter 4 | from django.urls import path 5 | 6 | from rsocket.helpers import create_future 7 | from rsocket.payload import Payload 8 | from rsocket.request_handler import BaseRequestHandler 9 | from rsocket.transports.channels_transport import rsocket_consumer_factory 10 | 11 | 12 | # Define a request handler for RSocket 13 | class Handler(BaseRequestHandler): 14 | async def request_response(self, payload: Payload): 15 | logging.info(payload.data) 16 | 17 | return create_future(Payload(b'Echo: ' + payload.data)) 18 | 19 | 20 | # Create a consumer using the factory 21 | RSocketConsumer = rsocket_consumer_factory(handler_factory=Handler) 22 | 23 | # Django Channels routing configuration 24 | application = ProtocolTypeRouter({ 25 | 'websocket': URLRouter([ 26 | path('rsocket', RSocketConsumer.as_asgi()), 27 | ]), 28 | }) 29 | -------------------------------------------------------------------------------- /examples/server_fastapi_websocket.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from fastapi import FastAPI, WebSocket 3 | 4 | from rsocket.helpers import create_future 5 | from rsocket.local_typing import Awaitable 6 | from rsocket.payload import Payload 7 | from rsocket.request_handler import BaseRequestHandler 8 | from rsocket.rsocket_server import RSocketServer 9 | from rsocket.transports.http3_transport import Http3TransportWebsocket 10 | 11 | app = FastAPI() 12 | 13 | 14 | class Handler(BaseRequestHandler): 15 | 16 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 17 | return create_future(Payload(b'pong')) 18 | 19 | 20 | @app.websocket("/") 21 | async def endpoint(websocket: WebSocket): 22 | await websocket.accept() 23 | transport = Http3TransportWebsocket(websocket) 24 | RSocketServer(transport=transport, handler_factory=Handler) 25 | await transport.wait_for_disconnect() 26 | 27 | 28 | if __name__ == "__main__": 29 | uvicorn.run(app, host="0.0.0.0", port=6565) 30 | -------------------------------------------------------------------------------- /examples/server_website_example.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from rsocket.helpers import create_future 5 | from rsocket.local_typing import Awaitable 6 | from rsocket.payload import Payload 7 | from rsocket.request_handler import BaseRequestHandler 8 | from rsocket.rsocket_server import RSocketServer 9 | from rsocket.transports.tcp import TransportTCP 10 | 11 | 12 | class Handler(BaseRequestHandler): 13 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 14 | logging.info(payload.data) 15 | 16 | return create_future(Payload(b'Echo: ' + payload.data)) 17 | 18 | 19 | async def run_server(): 20 | def session(*connection): 21 | RSocketServer(TransportTCP(*connection), handler_factory=Handler) 22 | 23 | server = await asyncio.start_server(session, 'localhost', 7878) 24 | 25 | async with server: 26 | await server.serve_forever() 27 | 28 | 29 | if __name__ == '__main__': 30 | logging.basicConfig(level=logging.INFO) 31 | asyncio.run(run_server()) 32 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/Server.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.rsocket.core.RSocketServer; 4 | import io.rsocket.frame.decoder.PayloadDecoder; 5 | import io.rsocket.transport.netty.server.TcpServerTransport; 6 | 7 | import java.util.Objects; 8 | 9 | public class Server { 10 | 11 | public static void main(String[] args) { 12 | int port = getPort(args); 13 | System.out.println("Port: " + port); 14 | RSocketServer rSocketServer = RSocketServer.create(); 15 | rSocketServer.acceptor(new SimpleRSocketAcceptor()); 16 | rSocketServer.payloadDecoder(PayloadDecoder.ZERO_COPY); 17 | Objects.requireNonNull(rSocketServer.bind(TcpServerTransport.create(port)) 18 | .block()) 19 | .onClose() 20 | .block(); 21 | } 22 | 23 | private static int getPort(String[] args) { 24 | if (args.length > 0) { 25 | return Integer.parseInt(args[0]); 26 | } else { 27 | return 6565; 28 | } 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /examples/shared_tests.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import logging 4 | 5 | from rsocket.payload import Payload 6 | from rsocket.rsocket_client import RSocketClient 7 | 8 | 9 | async def simple_client_server_test(client: RSocketClient): 10 | date_format = b'%Y-%m-%d %H:%M:%S' 11 | payload = Payload(date_format) 12 | 13 | async def run_request_response(): 14 | try: 15 | while True: 16 | result = await client.request_response(payload) 17 | data = result.data 18 | 19 | time_received = datetime.datetime.strptime(data.decode(), date_format.decode()) 20 | 21 | logging.info('Response: {}'.format(time_received)) 22 | await asyncio.sleep(1) 23 | except asyncio.CancelledError: 24 | pass 25 | 26 | task = asyncio.create_task(run_request_response()) 27 | await asyncio.sleep(5) 28 | task.cancel() 29 | await task 30 | 31 | 32 | def assert_result_data(result: Payload, expected: bytes): 33 | if result.data != expected: 34 | raise Exception 35 | -------------------------------------------------------------------------------- /examples/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | 5 | from examples.shared_tests import simple_client_server_test 6 | from reactivestreams.subscriber import DefaultSubscriber 7 | from rsocket.helpers import single_transport_provider 8 | from rsocket.rsocket_client import RSocketClient 9 | from rsocket.transports.tcp import TransportTCP 10 | 11 | 12 | class StreamSubscriber(DefaultSubscriber): 13 | 14 | def on_next(self, value, is_complete=False): 15 | logging.info('RS: {}'.format(value)) 16 | self.subscription.request(1) 17 | 18 | 19 | async def main(server_port): 20 | logging.info('Connecting to server at localhost:%s', server_port) 21 | 22 | connection = await asyncio.open_connection('localhost', server_port) 23 | 24 | async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client: 25 | await simple_client_server_test(client) 26 | 27 | 28 | if __name__ == '__main__': 29 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 30 | logging.basicConfig(level=logging.DEBUG) 31 | asyncio.run(main(port)) 32 | -------------------------------------------------------------------------------- /examples/tutorial/step4/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from typing import Optional, TypeVar, Type 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | channel: Optional[str] = None 15 | 16 | 17 | chat_filename_mimetype = b'chat/file-name' 18 | 19 | 20 | def encode_dataclass(obj): 21 | return ensure_bytes(json.dumps(obj.__dict__)) 22 | 23 | 24 | def dataclass_to_payload(obj) -> Payload: 25 | return Payload(encode_dataclass(obj)) 26 | 27 | 28 | T = TypeVar('T') 29 | 30 | 31 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 32 | return cls(**json.loads(utf8_decode(data))) 33 | 34 | 35 | def decode_payload(cls, payload: Payload): 36 | data = payload.data 37 | 38 | if cls is bytes: 39 | return data 40 | if cls is str: 41 | return utf8_decode(data) 42 | 43 | return decode_dataclass(data, cls) 44 | -------------------------------------------------------------------------------- /examples/tutorial/step5/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from typing import Optional, TypeVar, Type 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | channel: Optional[str] = None 15 | 16 | 17 | chat_filename_mimetype = b'chat/file-name' 18 | 19 | 20 | def encode_dataclass(obj): 21 | return ensure_bytes(json.dumps(obj.__dict__)) 22 | 23 | 24 | def dataclass_to_payload(obj) -> Payload: 25 | return Payload(encode_dataclass(obj)) 26 | 27 | 28 | T = TypeVar('T') 29 | 30 | 31 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 32 | return cls(**json.loads(utf8_decode(data))) 33 | 34 | 35 | def decode_payload(cls, payload: Payload): 36 | data = payload.data 37 | 38 | if cls is bytes: 39 | return data 40 | if cls is str: 41 | return utf8_decode(data) 42 | 43 | return decode_dataclass(data, cls) 44 | -------------------------------------------------------------------------------- /examples/tutorial/test_tutorials.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import subprocess 4 | from time import sleep 5 | 6 | import pytest 7 | 8 | 9 | @pytest.mark.timeout(20) 10 | @pytest.mark.parametrize('step', 11 | [ 12 | 'step0', 13 | 'step1', 14 | 'step2', 15 | 'step3', 16 | 'step4', 17 | 'step5', 18 | 'step6', 19 | 'step7', 20 | 'step8', 21 | 'reactivex'] 22 | 23 | ) 24 | def test_client_server_combinations(step): 25 | pid = os.spawnlp(os.P_NOWAIT, 'python3', 'python3', f'./{step}/chat_server.py') 26 | 27 | try: 28 | sleep(2) 29 | client = subprocess.Popen(['python3', f'./{step}/chat_client.py']) 30 | client.wait(timeout=20) 31 | 32 | assert client.returncode == 0 33 | finally: 34 | os.kill(pid, signal.SIGTERM) 35 | -------------------------------------------------------------------------------- /rsocket/extensions/authentication_types.py: -------------------------------------------------------------------------------- 1 | from enum import unique, Enum 2 | from typing import Optional 3 | 4 | from rsocket.exceptions import RSocketUnknownAuthType 5 | from rsocket.helpers import WellKnownType, map_type_names_by_id, \ 6 | map_type_ids_by_name 7 | 8 | 9 | class WellKnownAuthenticationType(WellKnownType): 10 | pass 11 | 12 | 13 | @unique 14 | class WellKnownAuthenticationTypes(Enum): 15 | SIMPLE = WellKnownAuthenticationType(b'simple', 0x00) 16 | BEARER = WellKnownAuthenticationType(b'bearer', 0x01) 17 | 18 | @classmethod 19 | def require_by_id(cls, numeric_id: int) -> WellKnownAuthenticationType: 20 | try: 21 | return type_by_id[numeric_id] 22 | except KeyError: 23 | raise RSocketUnknownAuthType(numeric_id) 24 | 25 | @classmethod 26 | def get_by_name(cls, metadata_name: str) -> Optional[WellKnownAuthenticationType]: 27 | return type_by_name.get(metadata_name) 28 | 29 | 30 | type_by_id = map_type_names_by_id(WellKnownAuthenticationTypes) 31 | type_by_name = map_type_ids_by_name(WellKnownAuthenticationTypes) 32 | -------------------------------------------------------------------------------- /tests/rsocket/test_helpers.py: -------------------------------------------------------------------------------- 1 | from rsocket.extensions.composite_metadata_item import CompositeMetadataItem 2 | from rsocket.extensions.helpers import metadata_item 3 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 4 | from rsocket.helpers import is_empty_payload, is_non_empty_payload 5 | from rsocket.payload import Payload 6 | 7 | 8 | def test_metadata_item(): 9 | result = metadata_item(b'123', WellKnownMimeTypes.TEXT_PLAIN) 10 | 11 | assert result == CompositeMetadataItem(WellKnownMimeTypes.TEXT_PLAIN, b'123') 12 | 13 | 14 | def test_is_empty_payload(): 15 | assert is_empty_payload(Payload()) 16 | assert not is_empty_payload(Payload(data=b'abc')) 17 | assert not is_empty_payload(Payload(metadata=b'abc')) 18 | assert not is_empty_payload(Payload(data=b'abc', metadata=b'abc')) 19 | 20 | 21 | def test_is_non_empty_payload(): 22 | assert not is_non_empty_payload(Payload()) 23 | assert is_non_empty_payload(Payload(data=b'abc')) 24 | assert is_non_empty_payload(Payload(metadata=b'abc')) 25 | assert is_non_empty_payload(Payload(data=b'abc', metadata=b'abc')) 26 | -------------------------------------------------------------------------------- /examples/tutorial/step0/chat_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from rsocket.frame_helpers import ensure_bytes 5 | from rsocket.helpers import create_future, utf8_decode 6 | from rsocket.local_typing import Awaitable 7 | from rsocket.payload import Payload 8 | from rsocket.request_handler import BaseRequestHandler 9 | from rsocket.rsocket_server import RSocketServer 10 | from rsocket.transports.tcp import TransportTCP 11 | 12 | 13 | class Handler(BaseRequestHandler): 14 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 15 | username = utf8_decode(payload.data) 16 | return create_future(Payload(ensure_bytes(f'Welcome to chat, {username}'))) 17 | 18 | 19 | async def run_server(): 20 | def session(*connection): 21 | RSocketServer(TransportTCP(*connection), handler_factory=Handler) 22 | 23 | async with await asyncio.start_server(session, 'localhost', 6565) as server: 24 | await server.serve_forever() 25 | 26 | 27 | if __name__ == '__main__': 28 | logging.basicConfig(level=logging.INFO) 29 | asyncio.run(run_server()) 30 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Core API Reference 2 | ================== 3 | 4 | Controls 5 | -------- 6 | 7 | Server 8 | ~~~~~~ 9 | 10 | .. automodule:: rsocket.rsocket_server 11 | :members: 12 | :inherited-members: 13 | 14 | Client 15 | ~~~~~~ 16 | 17 | .. automodule:: rsocket.rsocket_client 18 | :members: 19 | :inherited-members: 20 | 21 | Handler 22 | ~~~~~~~ 23 | 24 | .. automodule:: rsocket.request_handler 25 | :members: 26 | 27 | 28 | Enums 29 | ----- 30 | 31 | .. automodule:: rsocket.extensions.mimetypes 32 | :members: 33 | 34 | 35 | Models 36 | ------ 37 | 38 | .. automodule:: rsocket.payload 39 | :members: 40 | 41 | 42 | Interfaces 43 | ---------- 44 | 45 | Publisher 46 | ~~~~~~~~~ 47 | 48 | .. automodule:: reactivestreams.publisher 49 | :members: 50 | 51 | Subscriber 52 | ~~~~~~~~~~ 53 | 54 | .. automodule:: reactivestreams.subscriber 55 | :members: 56 | 57 | Subscription 58 | ~~~~~~~~~~~~ 59 | 60 | .. automodule:: reactivestreams.subscription 61 | :members: 62 | 63 | Transports 64 | ---------- 65 | 66 | .. automodule:: rsocket.transports.transport 67 | :members: 68 | 69 | -------------------------------------------------------------------------------- /rsocket/load_balancer/random_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | from typing import List 4 | 5 | from rsocket.load_balancer.load_balancer_strategy import LoadBalancerStrategy 6 | from rsocket.rsocket import RSocket 7 | 8 | 9 | class LoadBalancerRandom(LoadBalancerStrategy): 10 | """ 11 | Random Load Balancer 12 | """ 13 | 14 | def __init__(self, 15 | pool: List[RSocket], 16 | auto_connect=True, 17 | auto_close=True): 18 | self._auto_close = auto_close 19 | self._auto_connect = auto_connect 20 | self._pool = pool 21 | 22 | def select(self) -> RSocket: 23 | random_client_id = random.randint(0, len(self._pool) - 1) 24 | return self._pool[random_client_id] 25 | 26 | async def connect(self): 27 | if self._auto_connect: 28 | [await client.connect() for client in self._pool] 29 | 30 | async def close(self): 31 | if self._auto_close: 32 | await asyncio.gather(*[client.close() for client in self._pool], 33 | return_exceptions=True) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Gabriel Shaar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/bugs/issue_165/165_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from datetime import timedelta 4 | 5 | from rsocket.helpers import single_transport_provider 6 | from rsocket.payload import Payload 7 | from rsocket.rsocket_client import RSocketClient 8 | from rsocket.transports.tcp import TransportTCP 9 | 10 | 11 | async def main(server_port): 12 | logging.info("Connecting to server at localhost:%s", server_port) 13 | 14 | connection = await asyncio.open_connection("localhost", server_port) 15 | 16 | async with RSocketClient( 17 | single_transport_provider(TransportTCP(*connection)), 18 | fragment_size_bytes=10240, 19 | keep_alive_period=timedelta(seconds=10), 20 | ) as client: 21 | # huge_array = bytearray(16777209) # Works 22 | huge_array = bytearray(16777210) # rsocket.exceptions.ParseError: Frame too short: 0 bytes 23 | payload = Payload(huge_array) 24 | await client.fire_and_forget(payload) 25 | await asyncio.sleep(1) 26 | 27 | 28 | if __name__ == "__main__": 29 | logging.basicConfig(level=logging.DEBUG) 30 | asyncio.run(main(10000)) 31 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/ServerWithFragmentation.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.rsocket.core.RSocketServer; 4 | import io.rsocket.frame.decoder.PayloadDecoder; 5 | import io.rsocket.transport.netty.server.TcpServerTransport; 6 | 7 | import java.util.Objects; 8 | 9 | public class ServerWithFragmentation { 10 | 11 | public static void main(String[] args) { 12 | int port = getPort(args); 13 | System.out.println("Port: " + port); 14 | RSocketServer rSocketServer = RSocketServer.create() 15 | .fragment(64); 16 | rSocketServer.acceptor(new SimpleRoutingAcceptor()); 17 | rSocketServer.payloadDecoder(PayloadDecoder.ZERO_COPY); 18 | Objects.requireNonNull(rSocketServer.bind(TcpServerTransport.create(port)) 19 | .block()) 20 | .onClose() 21 | .block(); 22 | } 23 | 24 | private static int getPort(String[] args) { 25 | if (args.length > 0) { 26 | return Integer.parseInt(args[0]); 27 | } else { 28 | return 6565; 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /rsocket/rsocket_internal.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Optional 3 | 4 | from rsocket.error_codes import ErrorCode 5 | from rsocket.frame import Frame, RequestFrame 6 | from rsocket.payload import Payload 7 | 8 | 9 | class RSocketInternal(metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def send_frame(self, frame: Frame): 12 | ... 13 | 14 | @abc.abstractmethod 15 | def send_complete(self, stream_id: int): 16 | ... 17 | 18 | @abc.abstractmethod 19 | def finish_stream(self, stream_id: int): 20 | ... 21 | 22 | @abc.abstractmethod 23 | def send_request(self, frame: RequestFrame): 24 | ... 25 | 26 | @abc.abstractmethod 27 | def send_payload(self, stream_id: int, payload: Payload, complete=False, is_next=True): 28 | ... 29 | 30 | @abc.abstractmethod 31 | def send_error(self, stream_id: int, exception: Exception): 32 | ... 33 | 34 | @abc.abstractmethod 35 | def stop_all_streams(self, error_code=ErrorCode.CANCELED, data=b''): 36 | ... 37 | 38 | @abc.abstractmethod 39 | def get_fragment_size_bytes(self) -> Optional[int]: 40 | ... 41 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/RoutingRSocket.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.rsocket.Payload; 4 | import io.rsocket.RSocket; 5 | import org.reactivestreams.Publisher; 6 | import reactor.core.publisher.Flux; 7 | import reactor.core.publisher.Mono; 8 | 9 | public interface RoutingRSocket { 10 | default Mono fireAndForget(String route, Payload payload) { 11 | return new RSocket() { 12 | }.fireAndForget(payload); 13 | } 14 | 15 | default Mono requestResponse(String route, Payload payload) { 16 | return new RSocket() { 17 | }.requestResponse(payload); 18 | } 19 | 20 | default Flux requestStream(String route, Payload payload) { 21 | return new RSocket() { 22 | }.requestStream(payload); 23 | } 24 | 25 | default Flux requestChannel(String route, Publisher payloads) { 26 | return new RSocket() { 27 | }.requestChannel(payloads); 28 | } 29 | 30 | default Mono metadataPush(String route, Payload payload) { 31 | return new RSocket() { 32 | }.metadataPush(payload); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /tests/test_reactivex/test_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from reactivex import operators, Subject 3 | 4 | from rsocket.reactivex.back_pressure_publisher import observable_from_async_generator, task_from_awaitable 5 | 6 | 7 | @pytest.mark.parametrize('request_n, generate_n, expected_n', ( 8 | (10, 5, 5), 9 | (5, 10, 5), 10 | # (10, 10, 10), # fixme: failing on python 3.10 11 | # (0, 10, 0), # operators.take(0) is problematic 12 | )) 13 | @pytest.mark.skip 14 | async def test_helper(request_n, generate_n, expected_n): 15 | async def generator(): 16 | for i in range(generate_n): 17 | yield i 18 | 19 | feedback = Subject() 20 | 21 | task = await task_from_awaitable( 22 | observable_from_async_generator(generator().__aiter__(), feedback).pipe(operators.take(request_n), 23 | operators.to_list()) 24 | ) 25 | 26 | feedback.on_next(request_n) 27 | 28 | result = await task 29 | feedback.on_completed() 30 | 31 | assert len(result) == expected_n 32 | 33 | # await asyncio.sleep(1) # wait for task to finish 34 | -------------------------------------------------------------------------------- /rsocket/extensions/composite_metadata_item.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | from rsocket.extensions.mimetypes import WellKnownMimeTypes, WellKnownMimeType, ensure_well_known_encoding_enum_value 4 | 5 | _default = object() 6 | 7 | 8 | def default_or_value(value, default=None): 9 | if value is _default: 10 | return default 11 | return value 12 | 13 | 14 | class CompositeMetadataItem: 15 | __slots__ = ( 16 | 'encoding', 17 | 'content' 18 | ) 19 | 20 | def __init__(self, 21 | encoding: Union[bytes, WellKnownMimeTypes, WellKnownMimeType] = _default, 22 | body: Optional[bytes] = _default): 23 | self.encoding = ensure_well_known_encoding_enum_value(default_or_value(encoding)) 24 | self.content = default_or_value(body) 25 | 26 | def parse(self, buffer: bytes): 27 | self.content = buffer 28 | 29 | def serialize(self) -> bytes: 30 | return self.content 31 | 32 | def __eq__(self, other): 33 | if isinstance(other, self.__class__): 34 | return self.content == other.content and self.encoding == other.encoding 35 | 36 | return False 37 | -------------------------------------------------------------------------------- /rsocket/load_balancer/round_robin.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import List 3 | 4 | from rsocket.load_balancer.load_balancer_strategy import LoadBalancerStrategy 5 | from rsocket.rsocket import RSocket 6 | 7 | 8 | class LoadBalancerRoundRobin(LoadBalancerStrategy): 9 | """ 10 | Round Robin Load Balancer 11 | """ 12 | def __init__(self, 13 | pool: List[RSocket], 14 | auto_connect=True, 15 | auto_close=True): 16 | self._auto_close = auto_close 17 | self._auto_connect = auto_connect 18 | self._pool = pool 19 | self._current_index = 0 20 | 21 | def select(self) -> RSocket: 22 | client = self._pool[self._current_index] 23 | self._current_index = (self._current_index + 1) % len(self._pool) 24 | return client 25 | 26 | async def connect(self): 27 | if self._auto_connect: 28 | [await client.connect() for client in self._pool] 29 | 30 | async def close(self): 31 | if self._auto_close: 32 | await asyncio.gather(*[client.close() for client in self._pool], 33 | return_exceptions=True) 34 | -------------------------------------------------------------------------------- /rsocket/handlers/request_response_responder.py: -------------------------------------------------------------------------------- 1 | from asyncio import Future 2 | 3 | from rsocket.disposable import Disposable 4 | from rsocket.frame import CancelFrame, Frame 5 | from rsocket.rsocket import RSocket 6 | from rsocket.streams.stream_handler import StreamHandler 7 | 8 | 9 | class RequestResponseResponder(StreamHandler, Disposable): 10 | def __init__(self, socket: RSocket, future: Future): 11 | super().__init__(socket) 12 | self.future = future 13 | 14 | def setup(self): 15 | self.future.add_done_callback(self.future_done) 16 | 17 | def future_done(self, future): 18 | if self.future.cancelled(): 19 | pass 20 | elif not future.exception(): 21 | self.socket.send_payload( 22 | self.stream_id, future.result(), complete=True) 23 | else: 24 | self.socket.send_error(self.stream_id, future.exception()) 25 | 26 | self._finish_stream() 27 | 28 | def dispose(self): 29 | self.future.cancel() 30 | 31 | def frame_received(self, frame: Frame): 32 | if isinstance(frame, CancelFrame): 33 | self.future.cancel() 34 | self._finish_stream() 35 | -------------------------------------------------------------------------------- /examples/bugs/issue_165/165_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from rsocket.helpers import create_future 5 | from rsocket.local_typing import Awaitable 6 | from rsocket.payload import Payload 7 | from rsocket.request_handler import BaseRequestHandler 8 | from rsocket.rsocket_server import RSocketServer 9 | from rsocket.transports.tcp import TransportTCP 10 | 11 | 12 | class Handler(BaseRequestHandler): 13 | async def request_fire_and_forget(self, payload: Payload) -> Awaitable[Payload]: 14 | print(f"Receiving {len(payload.data)} bytes") 15 | return create_future(Payload(b"OK")) 16 | 17 | 18 | async def run_server(server_port): 19 | logging.info("Starting server at localhost:%s", server_port) 20 | 21 | def session(*connection): 22 | RSocketServer(TransportTCP(*connection), handler_factory=Handler, fragment_size_bytes=10240) 23 | 24 | server = await asyncio.start_server(session, "localhost", server_port) 25 | 26 | async with server: 27 | await server.serve_forever() 28 | 29 | 30 | if __name__ == "__main__": 31 | logging.basicConfig(filename="rsocket.log", format="%(asctime)s %(message)s", filemode="w") 32 | asyncio.run(run_server(10000)) 33 | -------------------------------------------------------------------------------- /examples/client_springboot.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | from uuid import uuid4 5 | 6 | from rsocket.extensions.helpers import composite, route, authenticate_simple 7 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 8 | from rsocket.helpers import single_transport_provider 9 | from rsocket.payload import Payload 10 | from rsocket.rsocket_client import RSocketClient 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | async def main(): 15 | connection = await asyncio.open_connection('localhost', 7000) 16 | 17 | setup_payload = Payload( 18 | data=str(uuid4()).encode(), 19 | metadata=composite(route('shell-client'), authenticate_simple('user', 'pass'))) 20 | 21 | async with RSocketClient(single_transport_provider(TransportTCP(*connection)), 22 | setup_payload=setup_payload, 23 | metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA): 24 | await asyncio.sleep(5) 25 | 26 | 27 | def serialize(message) -> bytes: 28 | return json.dumps(message).encode() 29 | 30 | 31 | if __name__ == '__main__': 32 | logging.basicConfig(level=logging.DEBUG) 33 | asyncio.run(main()) 34 | -------------------------------------------------------------------------------- /tests/rsocket/test_mimetype.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rsocket.exceptions import RSocketUnknownMimetype, RSocketMimetypeTooLong 4 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 5 | from rsocket.helpers import serialize_well_known_encoding 6 | 7 | 8 | def test_mimetype_raise_exception_on_unknown_type(): 9 | with pytest.raises(RSocketUnknownMimetype) as exc_info: 10 | WellKnownMimeTypes.require_by_id(99999) 11 | 12 | assert exc_info.value.mimetype_id == 99999 13 | 14 | 15 | def test_serialize_well_known_encoding_too_long(): 16 | with pytest.raises(RSocketMimetypeTooLong): 17 | serialize_well_known_encoding(b'1' * 1000, WellKnownMimeTypes.get_by_name) 18 | 19 | 20 | def test_mimetype_require_by_id(): 21 | mimetype = WellKnownMimeTypes.require_by_id(0x05) 22 | 23 | assert mimetype is WellKnownMimeTypes.APPLICATION_JSON.value.name 24 | 25 | 26 | def test_mimetype_get_by_name(): 27 | mimetype = WellKnownMimeTypes.get_by_name(b'application/json') 28 | 29 | assert mimetype is WellKnownMimeTypes.APPLICATION_JSON.value.id 30 | 31 | 32 | def test_mimetype_get_by_unknown_name(): 33 | mimetype = WellKnownMimeTypes.get_by_name(b'non_existing/type') 34 | 35 | assert mimetype is None 36 | -------------------------------------------------------------------------------- /examples/client_quic.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from pathlib import Path 5 | 6 | from aioquic.quic.configuration import QuicConfiguration 7 | 8 | from examples.shared_tests import simple_client_server_test 9 | from rsocket.helpers import single_transport_provider 10 | from rsocket.rsocket_client import RSocketClient 11 | from rsocket.transports.aioquic_transport import rsocket_connect 12 | 13 | 14 | async def main(server_port): 15 | logging.info('Connecting to server at localhost:%s', server_port) 16 | 17 | client_configuration = QuicConfiguration( 18 | is_client=True 19 | ) 20 | ca_file_path = Path(__file__).parent / 'certificates' / 'pycacert.pem' 21 | client_configuration.load_verify_locations(cafile=str(ca_file_path)) 22 | 23 | async with rsocket_connect('localhost', server_port, 24 | configuration=client_configuration) as transport: 25 | async with RSocketClient(single_transport_provider(transport)) as client: 26 | await simple_client_server_test(client) 27 | 28 | 29 | if __name__ == '__main__': 30 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 31 | logging.basicConfig(level=logging.DEBUG) 32 | asyncio.run(main(port)) 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | if: startsWith(github.ref, 'refs/tags') 34 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 35 | with: 36 | user: __token__ 37 | password: ${{ secrets.PYPI_TOKEN }} 38 | -------------------------------------------------------------------------------- /tests/rsocket/test_without_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | import pytest 5 | 6 | from rsocket.exceptions import RSocketTransportError 7 | from rsocket.logger import logger 8 | from rsocket.request_handler import BaseRequestHandler 9 | from rsocket.rsocket_client import RSocketClient 10 | from rsocket.transports.tcp import TransportTCP 11 | 12 | 13 | @pytest.mark.allow_error_log() 14 | async def test_connection_never_established(unused_tcp_port: int): 15 | class ClientHandler(BaseRequestHandler): 16 | async def on_close(self, rsocket, exception: Optional[Exception] = None): 17 | logger().info('Test Reconnecting (closed)') 18 | await rsocket.reconnect() 19 | 20 | async def transport_provider(): 21 | try: 22 | for i in range(3): 23 | client_connection = await asyncio.open_connection('localhost', unused_tcp_port) 24 | yield TransportTCP(*client_connection) 25 | 26 | except Exception: 27 | logger().error('Client connection error', exc_info=True) 28 | raise 29 | 30 | with pytest.raises(RSocketTransportError): 31 | async with RSocketClient(transport_provider(), handler_factory=ClientHandler): 32 | await asyncio.sleep(1) 33 | -------------------------------------------------------------------------------- /examples/graphql/java/src/main/java/io/rsocket/pythontest/ServerWithGraphQL.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | import org.springframework.graphql.data.method.annotation.QueryMapping; 6 | import org.springframework.graphql.data.method.annotation.SubscriptionMapping; 7 | import org.springframework.stereotype.Controller; 8 | import reactor.core.publisher.Flux; 9 | 10 | import java.time.Duration; 11 | import java.time.Instant; 12 | import java.util.stream.Stream; 13 | 14 | @SpringBootApplication 15 | public class ServerWithGraphQL { 16 | 17 | public static void main(String[] args) { 18 | SpringApplication.run(ServerWithGraphQL.class, args); 19 | } 20 | 21 | } 22 | 23 | 24 | @Controller 25 | class GreetingsController { 26 | @SubscriptionMapping 27 | Flux greetings() { 28 | return Flux.fromStream(Stream.generate(() -> new Greeting("Hello world @" + Instant.now() + "!"))) 29 | .delayElements(Duration.ofSeconds(1)).take(10); 30 | } 31 | 32 | @QueryMapping 33 | Greeting greeting() { 34 | return new Greeting("Hello world"); 35 | } 36 | } 37 | 38 | record Greeting(String message) { 39 | 40 | } 41 | -------------------------------------------------------------------------------- /tests/rsocket/test_payload.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rsocket.payload import Payload 4 | 5 | 6 | @pytest.mark.parametrize('payload, expected_str', ( 7 | (Payload(), ""), 8 | (Payload(b'some data'), ""), 9 | (Payload(metadata=b'some metadata'), ""), 10 | (Payload(b'some data', b'some metadata'), ""), 11 | 12 | )) 13 | def test_payload_to_str(payload, expected_str): 14 | assert str(payload) == expected_str 15 | 16 | 17 | @pytest.mark.parametrize('payload, expected_str', ( 18 | (Payload(), "Payload(None, None)"), 19 | (Payload(b'some data'), "Payload(b'some data', None)"), 20 | (Payload(metadata=b'some metadata'), "Payload(None, b'some metadata')"), 21 | (Payload(b'some data', b'some metadata'), "Payload(b'some data', b'some metadata')"), 22 | 23 | )) 24 | def test_payload_repr(payload, expected_str): 25 | assert repr(payload) == expected_str 26 | 27 | 28 | def test_payload_support_bytearray(): 29 | payload = Payload(bytearray([1, 5, 10]), bytearray([4, 6, 7])) 30 | 31 | assert payload.data == b'\x01\x05\x0a' 32 | assert payload.metadata == b'\x04\x06\x07' 33 | -------------------------------------------------------------------------------- /examples/tutorial/step1/chat_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Awaitable 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode, create_response 7 | from rsocket.payload import Payload 8 | from rsocket.routing.request_router import RequestRouter 9 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 10 | from rsocket.rsocket_server import RSocketServer 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | def handler_factory(): 15 | router = RequestRouter() 16 | 17 | @router.response('login') 18 | async def login(payload: Payload) -> Awaitable[Payload]: 19 | username = utf8_decode(payload.data) 20 | 21 | logging.info(f'New user: {username}') 22 | 23 | return create_response(ensure_bytes(f'Hello {username}')) 24 | 25 | return RoutingRequestHandler(router) 26 | 27 | 28 | async def run_server(): 29 | def session(*connection): 30 | RSocketServer(TransportTCP(*connection), handler_factory=handler_factory) 31 | 32 | async with await asyncio.start_server(session, 'localhost', 6565) as server: 33 | await server.serve_forever() 34 | 35 | 36 | if __name__ == '__main__': 37 | logging.basicConfig(level=logging.INFO) 38 | asyncio.run(run_server()) 39 | -------------------------------------------------------------------------------- /examples/java/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 17 | 18 | 19 | 20 | 21 | 22 | %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /tests/tools/fixtures_graphql.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pathlib import Path 3 | from typing import Dict 4 | 5 | import pytest 6 | from graphql import build_schema 7 | 8 | 9 | @pytest.fixture 10 | def graphql_schema(): 11 | stored_message = "" 12 | 13 | async def greeting(*args) -> Dict: 14 | return { 15 | 'message': "Hello world" 16 | } 17 | 18 | async def get_message(*args) -> str: 19 | return stored_message 20 | 21 | async def set_message(root, _info, message) -> Dict: 22 | nonlocal stored_message 23 | stored_message = message 24 | return { 25 | "message": message 26 | } 27 | 28 | def greetings(*args): 29 | async def results(): 30 | for i in range(10): 31 | yield {'greetings': {'message': f"Hello world {i}"}} 32 | await asyncio.sleep(0.01) 33 | 34 | return results() 35 | 36 | with (Path(__file__).parent / 'rsocket.graphqls').open() as fd: 37 | schema = build_schema(fd.read()) 38 | 39 | schema.query_type.fields['greeting'].resolve = greeting 40 | schema.query_type.fields['getMessage'].resolve = get_message 41 | schema.mutation_type.fields['setMessage'].resolve = set_message 42 | schema.subscription_type.fields['greetings'].subscribe = greetings 43 | 44 | return schema 45 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/ClientWebsocket.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.netty.channel.ChannelHandler; 4 | import io.rsocket.core.RSocketConnector; 5 | import io.rsocket.transport.netty.client.WebsocketClientTransport; 6 | import io.rsocket.util.DefaultPayload; 7 | import reactor.core.scheduler.Schedulers; 8 | import reactor.netty.http.client.HttpClient; 9 | 10 | public class ClientWebsocket { 11 | private static final String host = "localhost"; 12 | private static final int port = 6565; 13 | 14 | public static void main(String[] args) { 15 | 16 | ChannelHandler pingSender = new ClientWebsocketHandler(); 17 | 18 | HttpClient httpClient = HttpClient.create() 19 | .doOnConnected(b -> b.addHandlerLast(pingSender)) 20 | .host(host) 21 | .port(port); 22 | 23 | RSocketConnector.connectWith(WebsocketClientTransport.create(httpClient, "/")) 24 | .publishOn(Schedulers.boundedElastic()) 25 | .doOnNext(rSocket -> 26 | rSocket.requestResponse(DefaultPayload.create("ping")) 27 | .doOnNext(response -> System.out.println("Response from server :: " + response.getDataUtf8())) 28 | .block()) 29 | .block(); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /tests/tools/helpers_aiohttp.py: -------------------------------------------------------------------------------- 1 | from asyncio import Event 2 | from contextlib import asynccontextmanager 3 | from typing import Optional 4 | 5 | from rsocket.rsocket_base import RSocketBase 6 | from tests.rsocket.helpers import assert_no_open_streams 7 | 8 | 9 | @asynccontextmanager 10 | async def pipe_factory_aiohttp_websocket(aiohttp_raw_server, unused_tcp_port, client_arguments=None, 11 | server_arguments=None): 12 | from rsocket.transports.aiohttp_websocket import websocket_client, websocket_handler_factory 13 | 14 | server: Optional[RSocketBase] = None 15 | wait_for_server = Event() 16 | 17 | def store_server(new_server): 18 | nonlocal server 19 | server = new_server 20 | wait_for_server.set() 21 | 22 | await aiohttp_raw_server(websocket_handler_factory(on_server_create=store_server, **(server_arguments or {}))) 23 | 24 | # test_overrides = {'keep_alive_period': timedelta(minutes=20)} 25 | client_arguments = client_arguments or {} 26 | # client_arguments.update(test_overrides) 27 | 28 | async with websocket_client('http://localhost:{}'.format(unused_tcp_port), 29 | **client_arguments) as client: 30 | await wait_for_server.wait() 31 | yield server, client 32 | 33 | await server.close() 34 | assert_no_open_streams(client, server) 35 | -------------------------------------------------------------------------------- /tests/rsocket/test_request_router.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rsocket.exceptions import RSocketEmptyRoute 4 | from rsocket.helpers import create_response 5 | from rsocket.local_typing import Awaitable 6 | from rsocket.payload import Payload 7 | from rsocket.routing.request_router import RequestRouter 8 | 9 | 10 | async def test_request_router_exception_on_duplicate_route_with_same_type(): 11 | router = RequestRouter() 12 | 13 | with pytest.raises(KeyError): 14 | @router.response('path1') 15 | async def request_response(payload, composite_metadata) -> Awaitable[Payload]: 16 | return create_response() 17 | 18 | @router.response('path1') 19 | async def request_response2(payload, composite_metadata) -> Awaitable[Payload]: 20 | return create_response() 21 | 22 | 23 | async def test_request_router_disallow_empty_routes(): 24 | router = RequestRouter() 25 | 26 | with pytest.raises(RSocketEmptyRoute): 27 | @router.response('') 28 | async def request_response(payload, composite_metadata) -> Awaitable[Payload]: 29 | return create_response() 30 | 31 | with pytest.raises(RSocketEmptyRoute): 32 | # noinspection PyTypeChecker 33 | @router.response(None) 34 | async def request_response2(payload, composite_metadata) -> Awaitable[Payload]: 35 | return create_response() 36 | -------------------------------------------------------------------------------- /examples/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from datetime import datetime 5 | 6 | from rsocket.helpers import create_future 7 | from rsocket.local_typing import Awaitable 8 | from rsocket.payload import Payload 9 | from rsocket.request_handler import BaseRequestHandler 10 | from rsocket.rsocket_server import RSocketServer 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | class Handler(BaseRequestHandler): 15 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 16 | await asyncio.sleep(0.1) # Simulate not immediate process 17 | date_time_format = payload.data.decode('utf-8') 18 | formatted_date_time = datetime.now().strftime(date_time_format) 19 | return create_future(Payload(formatted_date_time.encode('utf-8'))) 20 | 21 | 22 | async def run_server(server_port): 23 | logging.info('Starting server at localhost:%s', server_port) 24 | 25 | def session(*connection): 26 | RSocketServer(TransportTCP(*connection), handler_factory=Handler) 27 | 28 | server = await asyncio.start_server(session, 'localhost', server_port) 29 | 30 | async with server: 31 | await server.serve_forever() 32 | 33 | 34 | if __name__ == '__main__': 35 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 36 | logging.basicConfig(level=logging.DEBUG) 37 | asyncio.run(run_server(port)) 38 | -------------------------------------------------------------------------------- /rsocket/frame_parser.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from typing import AsyncGenerator 3 | 4 | from rsocket.logger import logger 5 | 6 | __all__ = ['FrameParser'] 7 | 8 | from rsocket.frame import Frame, InvalidFrame, parse_or_ignore 9 | 10 | 11 | class FrameParser: 12 | def __init__(self): 13 | self._buffer = bytearray() 14 | 15 | async def receive_data(self, data: bytes, header_length=3) -> AsyncGenerator[Frame, None]: 16 | self._buffer.extend(data) 17 | total = len(self._buffer) 18 | 19 | frame_length_byte_count = header_length 20 | 21 | while total >= frame_length_byte_count: 22 | if header_length > 0: 23 | length = struct.unpack('>I', b'\x00' + self._buffer[:frame_length_byte_count])[0] 24 | else: 25 | length = len(data) 26 | 27 | if total < length + frame_length_byte_count: 28 | return 29 | 30 | try: 31 | new_frame = parse_or_ignore(self._buffer[frame_length_byte_count:length + frame_length_byte_count]) 32 | 33 | if new_frame is not None: 34 | yield new_frame 35 | except Exception: 36 | logger().error('Error parsing frame', exc_info=True) 37 | yield InvalidFrame() 38 | 39 | self._buffer = self._buffer[length + frame_length_byte_count:] 40 | total -= length + frame_length_byte_count 41 | -------------------------------------------------------------------------------- /tests/tools/http3_client.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from typing import cast 3 | from urllib.parse import urlparse 4 | 5 | from aioquic.asyncio.client import connect 6 | from aioquic.h3.connection import H3_ALPN, ErrorCode 7 | 8 | from rsocket.transports.http3_transport import Http3TransportWebsocket, RSocketHttp3ClientProtocol 9 | from tests.tools.helpers import quic_client_configuration 10 | 11 | 12 | @asynccontextmanager 13 | async def http3_ws_transport( 14 | certificate, 15 | url: str 16 | ): 17 | parsed = urlparse(url) 18 | assert parsed.scheme in ( 19 | "https", 20 | "wss", 21 | ), "Only https:// or wss:// URLs are supported." 22 | 23 | host = parsed.hostname 24 | 25 | if parsed.port is not None: 26 | port = parsed.port 27 | else: 28 | port = 443 29 | 30 | async with connect( 31 | host, 32 | port, 33 | configuration=quic_client_configuration(certificate, alpn_protocols=H3_ALPN), 34 | create_protocol=RSocketHttp3ClientProtocol 35 | ) as client: 36 | client = cast(RSocketHttp3ClientProtocol, client) 37 | 38 | if parsed.scheme == "wss": 39 | ws = await client.websocket(url) 40 | transport = Http3TransportWebsocket(ws) 41 | yield transport 42 | await ws.close() 43 | 44 | client._quic.close(error_code=ErrorCode.H3_NO_ERROR) 45 | -------------------------------------------------------------------------------- /tests/rsocket/test_frame_decode.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import asyncstdlib 4 | 5 | from rsocket.extensions.authentication import AuthenticationSimple 6 | from rsocket.extensions.authentication_content import AuthenticationContent 7 | from rsocket.extensions.composite_metadata import CompositeMetadata 8 | 9 | 10 | async def test_decode_spring_demo_auth(): 11 | metadata = bytearray( 12 | b'\xfe\x00\x00\r\x0cshell-client"message/x.rsocket.authentication.v0\x00\x00\x0b\x80\x00\x04userpass') 13 | 14 | composite_metadata = CompositeMetadata() 15 | composite_metadata.parse(metadata) 16 | 17 | assert len(composite_metadata.items) == 2 18 | 19 | composite_item = cast(AuthenticationContent, composite_metadata.items[1]) 20 | authentication = cast(AuthenticationSimple, composite_item.authentication) 21 | assert authentication.username == b'user' 22 | assert authentication.password == b'pass' 23 | 24 | 25 | async def test_multiple_frames(frame_parser): 26 | data = b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' 27 | data += b'\x00\x00\x13\x00\x00\x26\x6a\x2c\x00\x00\x00\x02\x04\x77\x65\x69' 28 | data += b'\x72\x64\x6e\x65\x73\x73' 29 | data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' 30 | data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' 31 | data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' 32 | frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) 33 | assert len(frames) == 5 34 | -------------------------------------------------------------------------------- /rsocket/transports/websockets_transport.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import Queue 3 | 4 | from rsocket.frame import Frame 5 | from rsocket.transports.abstract_messaging import AbstractMessagingTransport 6 | 7 | 8 | class WebsocketsTransport(AbstractMessagingTransport): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self._outgoing_frame_queue = Queue() 13 | 14 | async def send_frame(self, frame: Frame): 15 | await self._outgoing_frame_queue.put(frame) 16 | 17 | async def close(self): 18 | pass 19 | 20 | async def consumer_handler(self, websocket): 21 | async for message in websocket: 22 | async for frame in self._frame_parser.receive_data(message, header_length=0): 23 | await self._incoming_frame_queue.put(frame) 24 | 25 | async def producer_handler(self, websocket): 26 | while True: 27 | frame = await self._outgoing_frame_queue.get() 28 | await websocket.send(frame.serialize()) 29 | 30 | async def handler(self, websocket): 31 | consumer_task = asyncio.create_task(self.consumer_handler(websocket)) 32 | producer_task = asyncio.create_task(self.producer_handler(websocket)) 33 | 34 | done, pending = await asyncio.wait( 35 | [consumer_task, producer_task], 36 | return_when=asyncio.FIRST_COMPLETED, 37 | ) 38 | 39 | for task in pending: 40 | task.cancel() 41 | -------------------------------------------------------------------------------- /rsocket/transports/transport.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from rsocket.frame import Frame 4 | from rsocket.frame_parser import FrameParser 5 | 6 | 7 | class Transport(metaclass=abc.ABCMeta): 8 | """ 9 | Base class for all transports: 10 | 11 | - tcp: :class:`TransportTCP ` 12 | - websocket: :class:`TransportAsyncWebsocketsClient ` 13 | - http3: :class:`Http3TransportWebsocket ` 14 | - aioquic: :class:`RSocketQuicProtocol ` 15 | - aiohttp: :class:`TransportAioHttpWebsocket ` 16 | - quart: :class:`TransportQuartWebsocket ` 17 | """ 18 | 19 | def __init__(self): 20 | self._frame_parser = FrameParser() 21 | 22 | async def connect(self): 23 | pass 24 | 25 | @abc.abstractmethod 26 | async def send_frame(self, frame: Frame): 27 | ... 28 | 29 | @abc.abstractmethod 30 | async def next_frame_generator(self): 31 | ... 32 | 33 | @abc.abstractmethod 34 | async def close(self): 35 | ... 36 | 37 | def requires_length_header(self) -> bool: 38 | return False 39 | 40 | async def on_send_queue_empty(self): 41 | pass 42 | -------------------------------------------------------------------------------- /examples/tutorial/step1/chat_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Optional 4 | 5 | from rsocket.extensions.helpers import composite, route 6 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 7 | from rsocket.frame_helpers import ensure_bytes 8 | from rsocket.helpers import single_transport_provider, utf8_decode 9 | from rsocket.payload import Payload 10 | from rsocket.rsocket_client import RSocketClient 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | class ChatClient: 15 | def __init__(self, rsocket: RSocketClient): 16 | self._rsocket = rsocket 17 | self._username: Optional[str] = None 18 | 19 | async def login(self, username: str): 20 | self._username = username 21 | payload = Payload(ensure_bytes(username), composite(route('login'))) 22 | response = await self._rsocket.request_response(payload) 23 | print(f'Server response: {utf8_decode(response.data)}') 24 | 25 | 26 | async def main(): 27 | connection = await asyncio.open_connection('localhost', 6565) 28 | 29 | async with RSocketClient(single_transport_provider(TransportTCP(*connection)), 30 | metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) as client1: 31 | user = ChatClient(client1) 32 | 33 | await user.login('user1') 34 | 35 | 36 | if __name__ == '__main__': 37 | logging.basicConfig(level=logging.INFO) 38 | asyncio.run(main()) 39 | -------------------------------------------------------------------------------- /examples/client_websocket.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import aiohttp 4 | import asyncclick as click 5 | 6 | from rsocket.helpers import single_transport_provider 7 | from rsocket.payload import Payload 8 | from rsocket.rsocket_client import RSocketClient 9 | from rsocket.transports.aiohttp_websocket import TransportAioHttpClient 10 | from rsocket.transports.aiohttp_websocket import websocket_client 11 | 12 | 13 | async def application(with_ssl: bool, serve_port: int): 14 | if with_ssl: 15 | async with aiohttp.ClientSession() as session: 16 | async with session.ws_connect('wss://localhost:%s' % serve_port, verify_ssl=False) as websocket: 17 | async with RSocketClient( 18 | single_transport_provider(TransportAioHttpClient(websocket=websocket))) as client: 19 | result = await client.request_response(Payload(b'ping')) 20 | print(result) 21 | 22 | else: 23 | async with websocket_client('http://localhost:%s' % serve_port) as client: 24 | result = await client.request_response(Payload(b'ping')) 25 | print(result.data) 26 | 27 | 28 | @click.command() 29 | @click.option('--with-ssl', is_flag=False, default=False) 30 | @click.option('--port', is_flag=False, default=6565) 31 | async def command(with_ssl, port: int): 32 | logging.basicConfig(level=logging.DEBUG) 33 | await application(with_ssl, port) 34 | 35 | 36 | if __name__ == '__main__': 37 | command() 38 | -------------------------------------------------------------------------------- /examples/tutorial/step2/chat_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Optional 4 | 5 | from rsocket.extensions.helpers import composite, route 6 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 7 | from rsocket.frame_helpers import ensure_bytes 8 | from rsocket.helpers import single_transport_provider 9 | from rsocket.payload import Payload 10 | from rsocket.rsocket_client import RSocketClient 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | class ChatClient: 15 | def __init__(self, rsocket: RSocketClient): 16 | self._rsocket = rsocket 17 | self._session_id: Optional[str] = None 18 | self._username: Optional[str] = None 19 | 20 | async def login(self, username: str): 21 | self._username = username 22 | payload = Payload(ensure_bytes(username), composite(route('login'))) 23 | self._session_id = (await self._rsocket.request_response(payload)).data 24 | return self 25 | 26 | 27 | async def main(): 28 | connection = await asyncio.open_connection('localhost', 6565) 29 | 30 | async with RSocketClient(single_transport_provider(TransportTCP(*connection)), 31 | metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) as client1: 32 | user = ChatClient(client1) 33 | 34 | await user.login('George') 35 | 36 | 37 | if __name__ == '__main__': 38 | logging.basicConfig(level=logging.INFO) 39 | asyncio.run(main()) 40 | -------------------------------------------------------------------------------- /examples/server_with_lease.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from datetime import timedelta 5 | 6 | from rsocket.helpers import create_future 7 | from rsocket.lease import SingleLeasePublisher 8 | from rsocket.payload import Payload 9 | from rsocket.routing.request_router import RequestRouter 10 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 11 | from rsocket.rsocket_server import RSocketServer 12 | from rsocket.transports.tcp import TransportTCP 13 | 14 | router = RequestRouter() 15 | 16 | 17 | @router.response('single_request') 18 | async def single_request_response(payload, composite_metadata): 19 | logging.info('Got single request') 20 | return create_future(Payload(b'single_response')) 21 | 22 | 23 | def handler_factory(): 24 | return RoutingRequestHandler(router) 25 | 26 | 27 | def handle_client(reader, writer): 28 | RSocketServer(TransportTCP(reader, writer), handler_factory=handler_factory, lease_publisher=SingleLeasePublisher( 29 | maximum_request_count=5, 30 | maximum_lease_time=timedelta(seconds=2) 31 | )) 32 | 33 | 34 | async def run_server(server_port): 35 | server = await asyncio.start_server(handle_client, 'localhost', server_port) 36 | 37 | async with server: 38 | await server.serve_forever() 39 | 40 | 41 | if __name__ == '__main__': 42 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 43 | logging.basicConfig(level=logging.DEBUG) 44 | asyncio.run(run_server(port)) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # PyCharm 92 | *.iml 93 | .idea/ 94 | python311/ -------------------------------------------------------------------------------- /rsocket/rsocket.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import asyncio 3 | from typing import Union, Optional, Any 4 | 5 | from reactivestreams.publisher import Publisher 6 | from rsocket.local_typing import Awaitable 7 | from rsocket.payload import Payload 8 | from rsocket.streams.backpressureapi import BackpressureApi 9 | 10 | 11 | class RSocket(metaclass=abc.ABCMeta): 12 | 13 | @abc.abstractmethod 14 | def request_channel( 15 | self, 16 | payload: Payload, 17 | publisher: Optional[Publisher] = None, 18 | sending_done: Optional[asyncio.Event] = None) -> Union[Any, Publisher]: 19 | ... 20 | 21 | @abc.abstractmethod 22 | def request_response(self, payload: Payload) -> Awaitable[Payload]: 23 | ... 24 | 25 | @abc.abstractmethod 26 | def fire_and_forget(self, payload: Payload) -> Awaitable[None]: 27 | ... 28 | 29 | @abc.abstractmethod 30 | def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]: 31 | ... 32 | 33 | @abc.abstractmethod 34 | def metadata_push(self, metadata: bytes) -> Awaitable[None]: 35 | ... 36 | 37 | @abc.abstractmethod 38 | async def connect(self): 39 | ... 40 | 41 | @abc.abstractmethod 42 | async def close(self): 43 | ... 44 | 45 | @abc.abstractmethod 46 | async def __aenter__(self) -> 'RSocket': 47 | ... 48 | 49 | @abc.abstractmethod 50 | async def __aexit__(self, exc_type, exc_val, exc_tb): 51 | ... 52 | -------------------------------------------------------------------------------- /tests/rsocket/test_request_routing_decode_payload.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | 4 | from rsocket.extensions.helpers import route, composite 5 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 6 | from rsocket.frame_helpers import ensure_bytes 7 | from rsocket.helpers import utf8_decode, create_response 8 | from rsocket.payload import Payload 9 | from rsocket.routing.request_router import RequestRouter 10 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 11 | 12 | 13 | async def test_request_response_type_hinted_payload(lazy_pipe): 14 | @dataclass 15 | class Message: 16 | user: str 17 | message: str 18 | 19 | router = RequestRouter(lambda cls, payload: cls(**json.loads(utf8_decode(payload.data)))) 20 | 21 | def handler_factory(): 22 | return RoutingRequestHandler(router) 23 | 24 | @router.response('test.path') 25 | async def response(message: Message): 26 | return create_response(ensure_bytes(message.message)) 27 | 28 | async with lazy_pipe( 29 | client_arguments={'metadata_encoding': WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA}, 30 | server_arguments={'handler_factory': handler_factory}) as (server, client): 31 | result = await client.request_response(Payload( 32 | data=ensure_bytes(json.dumps(Message('George', 'hello').__dict__)), 33 | metadata=composite(route('test.path')))) 34 | 35 | assert result.data == b'hello' 36 | -------------------------------------------------------------------------------- /examples/server_websockets.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from datetime import datetime 5 | 6 | import websockets 7 | 8 | from rsocket.helpers import create_future 9 | from rsocket.local_typing import Awaitable 10 | from rsocket.payload import Payload 11 | from rsocket.request_handler import BaseRequestHandler 12 | from rsocket.rsocket_server import RSocketServer 13 | from rsocket.transports.websockets_transport import WebsocketsTransport 14 | 15 | 16 | class Handler(BaseRequestHandler): 17 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 18 | await asyncio.sleep(0.1) # Simulate not immediate process 19 | date_time_format = payload.data.decode('utf-8') 20 | formatted_date_time = datetime.now().strftime(date_time_format) 21 | return create_future(Payload(formatted_date_time.encode('utf-8'))) 22 | 23 | 24 | async def endpoint(websocket): 25 | transport = WebsocketsTransport() 26 | 27 | async with RSocketServer(transport, handler_factory=Handler): 28 | await transport.handler(websocket) 29 | 30 | 31 | async def run_server(server_port): 32 | logging.info('Starting server at localhost:%s', server_port) 33 | 34 | async with websockets.serve(endpoint, "localhost", server_port): 35 | await asyncio.Future() # run forever 36 | 37 | 38 | if __name__ == '__main__': 39 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 40 | logging.basicConfig(level=logging.DEBUG) 41 | asyncio.run(run_server(port)) 42 | -------------------------------------------------------------------------------- /tests/tools/http_app.py: -------------------------------------------------------------------------------- 1 | from starlette.applications import Starlette 2 | from starlette.routing import WebSocketRoute 3 | from starlette.types import Receive, Scope, Send 4 | from starlette.websockets import WebSocket 5 | 6 | from rsocket.helpers import noop 7 | from rsocket.rsocket_server import RSocketServer 8 | from rsocket.transports.http3_transport import Http3TransportWebsocket 9 | 10 | 11 | async def web_socket(server_settings: dict, websocket: WebSocket, on_server_create=noop): 12 | await websocket.accept() 13 | transport = Http3TransportWebsocket(websocket) 14 | server = RSocketServer(transport=transport, **server_settings) 15 | on_server_create(server) 16 | await transport.wait_for_disconnect() 17 | 18 | 19 | async def starlette_factory(server_settings, scope, receive, send, on_server_create=noop): 20 | async def web_socket_factory(websocket): 21 | return await web_socket(server_settings, websocket, on_server_create) 22 | 23 | starlette = Starlette( 24 | routes=[WebSocketRoute("/ws", web_socket_factory)] 25 | ) 26 | 27 | await starlette(scope, receive, send) 28 | 29 | 30 | class ApplicationFactory: 31 | 32 | def __init__(self, server_settings, on_server_create=noop): 33 | self._on_server_create = on_server_create 34 | self._server_settings = server_settings 35 | 36 | async def app(self, scope: Scope, receive: Receive, send: Send) -> None: 37 | await starlette_factory(self._server_settings, scope, receive, send, self._on_server_create) 38 | -------------------------------------------------------------------------------- /examples/cloudevents/server_cloudevents.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import sys 5 | 6 | from cloudevents.pydantic import CloudEvent 7 | 8 | from rsocket.cloudevents.serialize import cloud_event_deserialize, cloud_event_serialize 9 | from rsocket.routing.request_router import RequestRouter 10 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 11 | from rsocket.rsocket_server import RSocketServer 12 | from rsocket.transports.tcp import TransportTCP 13 | 14 | router = RequestRouter(cloud_event_deserialize, 15 | cloud_event_serialize) 16 | 17 | 18 | @router.response('event') 19 | async def event_response(event: CloudEvent) -> CloudEvent: 20 | return CloudEvent.create(attributes={ 21 | 'type': 'io.spring.event.Foo', 22 | 'source': 'https://spring.io/foos' 23 | }, data=json.dumps(json.loads(event.data))) 24 | 25 | 26 | def handler_factory(): 27 | return RoutingRequestHandler(router) 28 | 29 | 30 | async def run_server(server_port): 31 | logging.info('Starting server at localhost:%s', server_port) 32 | 33 | def session(*connection): 34 | RSocketServer(TransportTCP(*connection), handler_factory=handler_factory) 35 | 36 | async with await asyncio.start_server(session, 'localhost', server_port) as server: 37 | await server.serve_forever() 38 | 39 | 40 | if __name__ == '__main__': 41 | port = sys.argv[1] if len(sys.argv) > 1 else 7000 42 | logging.basicConfig(level=logging.DEBUG) 43 | asyncio.run(run_server(port)) 44 | -------------------------------------------------------------------------------- /rsocket/streams/stream_from_async_generator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import timedelta 3 | from typing import AsyncGenerator, Tuple, Callable 4 | 5 | from rsocket.async_helpers import async_range 6 | from rsocket.payload import Payload 7 | from rsocket.streams.exceptions import FinishedIterator 8 | from rsocket.streams.stream_from_generator import StreamFromGenerator 9 | 10 | 11 | class StreamFromAsyncGenerator(StreamFromGenerator): 12 | 13 | def __init__(self, 14 | generator: Callable[[], AsyncGenerator[Tuple[Payload, bool], None]], 15 | delay_between_messages=timedelta(0), 16 | on_cancel=None, 17 | on_complete=None): 18 | super().__init__(generator, delay_between_messages, on_cancel, on_complete) 19 | 20 | async def _start_generator(self): 21 | self._generator: AsyncGenerator = self._generator_factory() 22 | self._iteration = self._generator.__aiter__() 23 | 24 | async def _generate_next_n(self, n: int) -> AsyncGenerator[Tuple[Payload, bool], None]: 25 | is_complete_sent = False 26 | async for i in async_range(n): 27 | try: 28 | next_value = await self._iteration.__anext__() 29 | is_complete_sent = next_value[1] 30 | yield next_value 31 | except StopAsyncIteration: 32 | if not is_complete_sent: 33 | raise FinishedIterator() 34 | return 35 | 36 | def _cancel_generator(self): 37 | asyncio.create_task(self._generator.aclose()) 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Rx==3.2.0 2 | 3 | aiohttp==3.12.14; python_version > "3.8" 4 | aiohttp==3.10.11; python_version == "3.8" 5 | 6 | aioquic==1.2.0 7 | asyncstdlib==3.13.1 8 | 9 | asyncclick==8.1.8; python_version > "3.8" 10 | asyncclick==8.1.7.2; python_version == "3.8" 11 | 12 | coverage==6.5.0 13 | coveralls==3.3.1 14 | decoy==2.2.0 15 | 16 | flake8==7.3.0; python_version > "3.8" 17 | flake8==7.1.2; python_version == "3.8" 18 | 19 | pytest-asyncio==0.25.3; python_version > "3.8" 20 | pytest-asyncio==0.23.4; python_version == "3.8" 21 | 22 | pytest-cov==4.1.0 23 | 24 | pytest-profiling==1.8.1; python_version > "3.8" 25 | pytest-profiling==1.7.0; python_version == "3.8" 26 | 27 | pytest-rerunfailures==15.1; python_version > "3.8" 28 | pytest-rerunfailures==14.0; python_version == "3.8" 29 | 30 | pytest-timeout==2.3.1 31 | pytest-xdist==3.6.1 32 | 33 | pytest==8.3.4; python_version > "3.8" 34 | pytest==7.4.4; python_version == "3.8" 35 | 36 | quart==0.20.0; python_version > "3.8" 37 | quart==0.19.9; python_version == "3.8" 38 | 39 | reactivex==4.0.4 40 | 41 | starlette==0.47.2; python_version > "3.8" 42 | starlette==0.44.0; python_version == "3.8" 43 | 44 | cbitstruct==1.1.0; python_version <= "3.12" 45 | cloudevents==1.11.0 46 | pydantic==1.10.22 47 | 48 | Werkzeug==3.1.3; python_version > "3.8" 49 | Werkzeug==3.0.6; python_version == "3.8" 50 | 51 | graphql-core==3.2.3 52 | gql==3.5.2 53 | 54 | websockets==15.0.1; python_version > "3.8" 55 | websockets==13.1; python_version == "3.8" 56 | 57 | asyncwebsockets==0.9.4 58 | 59 | fastapi==0.116.1 60 | 61 | channels==4.2.2 62 | daphne==4.2.1 -------------------------------------------------------------------------------- /examples/rsocket_in_aiohttp.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | 4 | from aiohttp import web 5 | 6 | from rsocket.helpers import create_future 7 | from rsocket.local_typing import Awaitable 8 | from rsocket.payload import Payload 9 | from rsocket.request_handler import BaseRequestHandler 10 | from rsocket.rsocket_server import RSocketServer 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | class Handler(BaseRequestHandler): 15 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 16 | await asyncio.sleep(0.1) # Simulate not immediate process 17 | date_time_format = payload.data.decode('utf-8') 18 | formatted_date_time = datetime.now().strftime(date_time_format) 19 | return create_future(Payload(formatted_date_time.encode('utf-8'))) 20 | 21 | 22 | async def run_server(server_port): 23 | def session(*connection): 24 | RSocketServer(TransportTCP(*connection), handler_factory=Handler) 25 | 26 | print('Listening for rsocket on {}'.format(server_port)) 27 | server = await asyncio.start_server(session, 'localhost', server_port) 28 | 29 | async with server: 30 | await server.serve_forever() 31 | 32 | 33 | async def start_background_tasks(app): 34 | app['rsocket'] = asyncio.create_task(run_server(6565)) 35 | 36 | 37 | async def cleanup_background_tasks(app): 38 | app['rsocket'].cancel() 39 | await app['rsocket'] 40 | 41 | 42 | app = web.Application() 43 | app.on_startup.append(start_background_tasks) 44 | app.on_cleanup.append(cleanup_background_tasks) 45 | web.run_app(app) 46 | -------------------------------------------------------------------------------- /examples/tutorial/step6/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, TypeVar, Type 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | channel: Optional[str] = None 15 | 16 | 17 | @dataclass(frozen=True) 18 | class ServerStatistics: 19 | user_count: Optional[int] = None 20 | channel_count: Optional[int] = None 21 | 22 | 23 | @dataclass() 24 | class ServerStatisticsRequest: 25 | ids: Optional[List[str]] = field(default_factory=lambda: ['users', 'channels']) 26 | period_seconds: Optional[int] = field(default_factory=lambda: 1) 27 | 28 | 29 | @dataclass(frozen=True) 30 | class ClientStatistics: 31 | memory_usage: Optional[float] = None 32 | 33 | 34 | chat_filename_mimetype = b'chat/file-name' 35 | 36 | 37 | def encode_dataclass(obj) -> bytes: 38 | return ensure_bytes(json.dumps(obj.__dict__)) 39 | 40 | 41 | def dataclass_to_payload(obj) -> Payload: 42 | return Payload(encode_dataclass(obj)) 43 | 44 | 45 | T = TypeVar('T') 46 | 47 | 48 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 49 | return cls(**json.loads(utf8_decode(data))) 50 | 51 | 52 | def decode_payload(cls, payload: Payload): 53 | data = payload.data 54 | 55 | if cls is bytes: 56 | return data 57 | if cls is str: 58 | return utf8_decode(data) 59 | 60 | return decode_dataclass(data, cls) 61 | -------------------------------------------------------------------------------- /examples/tutorial/step7/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, TypeVar, Type 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | channel: Optional[str] = None 15 | 16 | 17 | @dataclass(frozen=True) 18 | class ServerStatistics: 19 | user_count: Optional[int] = None 20 | channel_count: Optional[int] = None 21 | 22 | 23 | @dataclass() 24 | class ServerStatisticsRequest: 25 | ids: Optional[List[str]] = field(default_factory=lambda: ['users', 'channels']) 26 | period_seconds: Optional[int] = field(default_factory=lambda: 1) 27 | 28 | 29 | @dataclass(frozen=True) 30 | class ClientStatistics: 31 | memory_usage: Optional[float] = None 32 | 33 | 34 | chat_filename_mimetype = b'chat/file-name' 35 | 36 | 37 | def encode_dataclass(obj) -> bytes: 38 | return ensure_bytes(json.dumps(obj.__dict__)) 39 | 40 | 41 | def dataclass_to_payload(obj) -> Payload: 42 | return Payload(encode_dataclass(obj)) 43 | 44 | 45 | T = TypeVar('T') 46 | 47 | 48 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 49 | return cls(**json.loads(utf8_decode(data))) 50 | 51 | 52 | def decode_payload(cls, payload: Payload): 53 | data = payload.data 54 | 55 | if cls is bytes: 56 | return data 57 | if cls is str: 58 | return utf8_decode(data) 59 | 60 | return decode_dataclass(data, cls) 61 | -------------------------------------------------------------------------------- /examples/tutorial/step8/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, TypeVar, Type 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | channel: Optional[str] = None 15 | 16 | 17 | @dataclass(frozen=True) 18 | class ServerStatistics: 19 | user_count: Optional[int] = None 20 | channel_count: Optional[int] = None 21 | 22 | 23 | @dataclass() 24 | class ServerStatisticsRequest: 25 | ids: Optional[List[str]] = field(default_factory=lambda: ['users', 'channels']) 26 | period_seconds: Optional[int] = field(default_factory=lambda: 1) 27 | 28 | 29 | @dataclass(frozen=True) 30 | class ClientStatistics: 31 | memory_usage: Optional[float] = None 32 | 33 | 34 | chat_filename_mimetype = b'chat/file-name' 35 | 36 | 37 | def encode_dataclass(obj) -> bytes: 38 | return ensure_bytes(json.dumps(obj.__dict__)) 39 | 40 | 41 | def dataclass_to_payload(obj) -> Payload: 42 | return Payload(encode_dataclass(obj)) 43 | 44 | 45 | T = TypeVar('T') 46 | 47 | 48 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 49 | return cls(**json.loads(utf8_decode(data))) 50 | 51 | 52 | def decode_payload(cls, payload: Payload): 53 | data = payload.data 54 | 55 | if cls is bytes: 56 | return data 57 | if cls is str: 58 | return utf8_decode(data) 59 | 60 | return decode_dataclass(data, cls) 61 | -------------------------------------------------------------------------------- /examples/response_stream.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from datetime import timedelta 4 | 5 | from rsocket.payload import Payload 6 | from rsocket.streams.stream_from_generator import StreamFromGenerator 7 | 8 | 9 | def sample_sync_response_stream(response_count: int = 3, 10 | delay_between_messages=timedelta(0), 11 | is_infinite_stream: bool = False): 12 | def generator(): 13 | try: 14 | current_response = 0 15 | 16 | def range_counter(): 17 | return range(response_count) 18 | 19 | if not is_infinite_stream: 20 | counter = range_counter 21 | else: 22 | counter = itertools.count 23 | 24 | for i in counter(): 25 | if is_infinite_stream: 26 | is_complete = False 27 | else: 28 | is_complete = (current_response + 1) == response_count 29 | 30 | if delay_between_messages.total_seconds() > 0: 31 | message = 'Slow Item' 32 | else: 33 | message = 'Item' 34 | 35 | message = '%s: %s' % (message, current_response) 36 | yield Payload(message.encode('utf-8'), b'metadata'), is_complete 37 | 38 | if is_complete: 39 | break 40 | 41 | current_response += 1 42 | finally: 43 | logging.info('Closing sync stream generator') 44 | 45 | return StreamFromGenerator(generator, delay_between_messages) 46 | -------------------------------------------------------------------------------- /examples/tutorial/reactivex/shared.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, Type, TypeVar 4 | 5 | from rsocket.frame_helpers import ensure_bytes 6 | from rsocket.helpers import utf8_decode 7 | from rsocket.payload import Payload 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Message: 12 | user: Optional[str] = None 13 | content: Optional[str] = None 14 | channel: Optional[str] = None 15 | 16 | 17 | @dataclass(frozen=True) 18 | class ServerStatistics: 19 | user_count: Optional[int] = None 20 | channel_count: Optional[int] = None 21 | 22 | 23 | @dataclass() 24 | class ServerStatisticsRequest: 25 | ids: Optional[List[str]] = field(default_factory=lambda: ['users', 'channels']) 26 | period_seconds: Optional[int] = field(default_factory=lambda: 2) 27 | 28 | 29 | @dataclass(frozen=True) 30 | class ClientStatistics: 31 | memory_usage: Optional[float] = None 32 | 33 | 34 | chat_filename_mimetype = b'chat/file-name' 35 | 36 | 37 | def encode_dataclass(obj) -> bytes: 38 | return ensure_bytes(json.dumps(obj.__dict__)) 39 | 40 | 41 | def dataclass_to_payload(obj) -> Payload: 42 | return Payload(encode_dataclass(obj)) 43 | 44 | 45 | T = TypeVar('T') 46 | 47 | 48 | def decode_dataclass(data: bytes, cls: Type[T]) -> T: 49 | return cls(**json.loads(utf8_decode(data))) 50 | 51 | 52 | def decode_payload(cls, payload: Payload): 53 | data = payload.data 54 | 55 | if cls is bytes: 56 | return data 57 | if cls is str: 58 | return utf8_decode(data) 59 | 60 | return decode_dataclass(data, cls) 61 | -------------------------------------------------------------------------------- /tests/rsocket/test_setup.py: -------------------------------------------------------------------------------- 1 | from asyncio import Event 2 | from typing import Optional 3 | 4 | import pytest 5 | 6 | from rsocket.helpers import create_response 7 | from rsocket.local_typing import Awaitable 8 | from rsocket.payload import Payload 9 | from rsocket.request_handler import BaseRequestHandler 10 | 11 | 12 | @pytest.mark.parametrize('data_mimetype', ( 13 | 'application/json', 14 | 'custom_defined/custom_type' 15 | )) 16 | async def test_setup_with_explicit_data_encoding(lazy_pipe, data_mimetype): 17 | received_data_encoding_event = Event() 18 | received_data_encoding: Optional[bytes] = None 19 | 20 | class ServerHandler(BaseRequestHandler): 21 | def __init__(self): 22 | self._authenticated = False 23 | 24 | async def on_setup(self, 25 | data_encoding: bytes, 26 | metadata_encoding: bytes, 27 | payload: Payload): 28 | nonlocal received_data_encoding 29 | received_data_encoding = data_encoding 30 | received_data_encoding_event.set() 31 | 32 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 33 | return create_response(b'response') 34 | 35 | async with lazy_pipe( 36 | client_arguments={ 37 | 'data_encoding': data_mimetype 38 | }, 39 | server_arguments={ 40 | 'handler_factory': ServerHandler 41 | }): 42 | await received_data_encoding_event.wait() 43 | 44 | assert received_data_encoding == data_mimetype.encode() 45 | -------------------------------------------------------------------------------- /examples/bugs/streaming_file/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from _codecs import utf_8_decode 4 | from typing import List 5 | 6 | from reactivex import operators 7 | 8 | from rsocket.extensions.helpers import composite, route 9 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 10 | from rsocket.helpers import single_transport_provider 11 | from rsocket.payload import Payload 12 | from rsocket.reactivex.reactivex_client import ReactiveXClient 13 | from rsocket.rsocket_client import RSocketClient 14 | from rsocket.transports.tcp import TransportTCP 15 | 16 | 17 | def got_data(payload): 18 | logging.info('From server: ' + payload.data.decode('utf-8')) 19 | return utf_8_decode(payload.data) 20 | 21 | 22 | class Client: 23 | 24 | def __init__(self, rs: RSocketClient): 25 | self._rs = rs 26 | 27 | async def stream_file(self) -> List[str]: 28 | request = Payload(metadata=composite(route('audio'))) 29 | return await ReactiveXClient(self._rs).request_stream(request).pipe( 30 | operators.map(got_data), 31 | operators.to_list() 32 | ) 33 | 34 | 35 | async def main(): 36 | connection = await asyncio.open_connection('localhost', 8000) 37 | async with RSocketClient(single_transport_provider(TransportTCP(*connection)), 38 | metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA, 39 | fragment_size_bytes=1_000_000) as client: 40 | c = Client(client) 41 | r = await c.stream_file() 42 | print(r) 43 | 44 | if __name__ == '__main__': 45 | logging.basicConfig(level=logging.DEBUG) 46 | asyncio.run(main()) -------------------------------------------------------------------------------- /rsocket/extensions/tagging.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from typing import Union, List, Optional 3 | 4 | from rsocket.exceptions import RSocketError 5 | from rsocket.extensions.composite_metadata_item import CompositeMetadataItem 6 | from rsocket.frame_helpers import ensure_bytes 7 | 8 | 9 | class TaggingMetadata(CompositeMetadataItem): 10 | __slots__ = ( 11 | 'tags' 12 | ) 13 | 14 | def __init__(self, encoding: bytes, tags: Optional[List[Union[bytes, str]]] = None): 15 | self.tags = tags 16 | self.encoding = encoding 17 | 18 | super().__init__(encoding, None) 19 | 20 | def serialize(self) -> bytes: 21 | self.content = self._serialize_tags() 22 | return super().serialize() 23 | 24 | def _serialize_tags(self) -> bytes: 25 | serialized = b'' 26 | 27 | for tag in list(map(ensure_bytes, self.tags)): 28 | if len(tag) > 255: 29 | raise RSocketError('Tag length longer than 255 characters: "%s"' % tag) 30 | 31 | serialized += struct.pack('>B', len(tag)) 32 | serialized += tag 33 | 34 | return serialized 35 | 36 | def parse(self, buffer: bytes): 37 | self.tags = [] 38 | offset = 0 39 | 40 | while offset < len(buffer): 41 | tag_length = struct.unpack('>B', buffer[offset:offset + 1])[0] 42 | offset += 1 43 | self.tags.append(buffer[offset:offset + tag_length]) 44 | offset += tag_length 45 | 46 | def __eq__(self, other): 47 | if isinstance(other, self.__class__): 48 | return self.tags == other.tags and self.encoding == other.encoding 49 | 50 | return False 51 | -------------------------------------------------------------------------------- /examples/cloudevents/client_cloudevents.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import sys 5 | 6 | from cloudevents.conversion import to_json, from_json 7 | from cloudevents.pydantic import CloudEvent 8 | 9 | from rsocket.extensions.helpers import composite, route 10 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 11 | from rsocket.helpers import single_transport_provider 12 | from rsocket.payload import Payload 13 | from rsocket.rsocket_client import RSocketClient 14 | from rsocket.transports.tcp import TransportTCP 15 | 16 | 17 | async def main(server_port: int): 18 | connection = await asyncio.open_connection('localhost', server_port) 19 | 20 | async with RSocketClient(single_transport_provider(TransportTCP(*connection)), 21 | metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA, 22 | data_encoding=b'application/cloudevents+json') as client: 23 | event = CloudEvent.create(attributes={ 24 | 'type': 'io.spring.event.Foo', 25 | 'source': 'https://spring.io/foos' 26 | }, data=json.dumps({'value': 'Dave'})) 27 | 28 | response = await client.request_response(Payload(data=to_json(event), metadata=composite(route('event')))) 29 | 30 | response_event = from_json(CloudEvent, response.data) 31 | response_data = json.loads(response_event.data) 32 | 33 | assert response_data['value'] == 'Dave' 34 | 35 | print(response_data['value']) 36 | 37 | 38 | if __name__ == '__main__': 39 | port = sys.argv[1] if len(sys.argv) > 1 else 7000 40 | logging.basicConfig(level=logging.DEBUG) 41 | asyncio.run(main(port)) 42 | -------------------------------------------------------------------------------- /reactivestreams/subscriber.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Optional 3 | 4 | from reactivestreams.subscription import Subscription 5 | 6 | 7 | class Subscriber(metaclass=ABCMeta): 8 | """ 9 | Handles stream events. 10 | """ 11 | 12 | @abstractmethod 13 | def on_subscribe(self, subscription: Subscription): 14 | ... 15 | 16 | @abstractmethod 17 | def on_next(self, value, is_complete=False): 18 | ... 19 | 20 | @abstractmethod 21 | def on_error(self, exception: Exception): 22 | ... 23 | 24 | @abstractmethod 25 | def on_complete(self): 26 | ... 27 | 28 | 29 | class DefaultSubscriber(Subscriber): 30 | def __init__(self, on_next=None, 31 | on_error=None, 32 | on_complete=None, 33 | on_subscribe=None): 34 | self._on_subscribe = on_subscribe 35 | self._on_complete = on_complete 36 | self._on_error = on_error 37 | self._on_next = on_next 38 | self.subscription: Optional[Subscription] = None 39 | 40 | def on_next(self, value, is_complete=False): 41 | if self._on_next is not None: 42 | self._on_next(value, is_complete) 43 | 44 | def on_error(self, exception: Exception): 45 | if self._on_error is not None: 46 | self._on_error(exception) 47 | 48 | def on_subscribe(self, subscription: Subscription): 49 | self.subscription = subscription 50 | 51 | if self._on_subscribe is not None: 52 | self._on_subscribe(subscription) 53 | 54 | def on_complete(self): 55 | if self._on_complete is not None: 56 | self._on_complete() 57 | -------------------------------------------------------------------------------- /rsocket/streams/stream_handler.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABCMeta 2 | from typing import Optional, Union 3 | 4 | from rsocket.exceptions import RSocketValueError 5 | from rsocket.frame import Frame, MAX_REQUEST_N 6 | from rsocket.frame_builders import to_cancel_frame, to_request_n_frame 7 | from rsocket.rsocket import RSocket 8 | from rsocket.rsocket_internal import RSocketInternal 9 | from rsocket.streams.backpressureapi import BackpressureApi 10 | 11 | 12 | class StreamHandler(BackpressureApi, metaclass=ABCMeta): 13 | def __init__(self, socket: Union[RSocket, RSocketInternal]): 14 | super().__init__() 15 | self.stream_id: Optional[int] = None 16 | self.socket = socket 17 | self._initial_request_n = MAX_REQUEST_N 18 | 19 | @abstractmethod 20 | def setup(self): 21 | ... 22 | 23 | def initial_request_n(self, n: int): 24 | if n <= 0: 25 | self.socket.finish_stream(self.stream_id) 26 | raise RSocketValueError('Initial request N must be > 0') 27 | 28 | self._initial_request_n = n 29 | return self 30 | 31 | def frame_sent(self, frame: Frame): 32 | """Not being marked abstract, since most handlers won't override.""" 33 | 34 | @abstractmethod 35 | def frame_received(self, frame: Frame): 36 | ... 37 | 38 | def send_cancel(self): 39 | """Convenience method for use by requester subclasses.""" 40 | 41 | self.socket.send_frame(to_cancel_frame(self.stream_id)) 42 | 43 | def send_request_n(self, n: int): 44 | self.socket.send_frame(to_request_n_frame(self.stream_id, n)) 45 | 46 | def _finish_stream(self): 47 | self.socket.finish_stream(self.stream_id) 48 | -------------------------------------------------------------------------------- /rsocket/handlers/request_channel_requester.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | from reactivestreams.publisher import Publisher 5 | from reactivestreams.subscriber import Subscriber 6 | from rsocket.frame_builders import to_request_channel_frame 7 | from rsocket.handlers.interfaces import Requester 8 | from rsocket.handlers.request_cahnnel_common import RequestChannelCommon 9 | from rsocket.payload import Payload 10 | from rsocket.rsocket import RSocket 11 | 12 | 13 | class RequestChannelRequester(RequestChannelCommon, Requester): 14 | 15 | def __init__(self, 16 | socket: RSocket, 17 | payload: Payload, 18 | publisher: Optional[Publisher] = None, 19 | sending_done: Optional[asyncio.Event] = None): 20 | super().__init__(socket, publisher, sending_done) 21 | self._payload = payload 22 | 23 | def setup(self): 24 | super().setup() 25 | 26 | def _send_channel_request(self, payload: Payload): 27 | self.socket.send_request( 28 | to_request_channel_frame(stream_id=self.stream_id, 29 | payload=payload, 30 | initial_request_n=self._initial_request_n, 31 | complete=self._publisher is None, 32 | fragment_size_bytes=self.socket.get_fragment_size_bytes()) 33 | ) 34 | 35 | def subscribe(self, subscriber: Subscriber): 36 | self.setup() 37 | super().subscribe(subscriber) 38 | self._send_channel_request(self._payload) 39 | 40 | if self._publisher is None: 41 | self.mark_completed_and_finish(sent=True) 42 | -------------------------------------------------------------------------------- /tests/tools/fixtures_quart.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import Event 3 | from contextlib import asynccontextmanager 4 | from typing import Optional 5 | 6 | from rsocket.rsocket_base import RSocketBase 7 | from tests.rsocket.helpers import assert_no_open_streams 8 | 9 | 10 | @asynccontextmanager 11 | async def pipe_factory_quart_websocket(unused_tcp_port, client_arguments=None, server_arguments=None): 12 | from quart import Quart 13 | from rsocket.transports.quart_websocket import websocket_handler 14 | from rsocket.transports.asyncwebsockets_transport import websocket_client 15 | 16 | app = Quart(__name__) 17 | server: Optional[RSocketBase] = None 18 | wait_for_server = Event() 19 | 20 | def store_server(new_server): 21 | nonlocal server 22 | server = new_server 23 | wait_for_server.set() 24 | 25 | @app.websocket("/") 26 | async def ws(): 27 | await websocket_handler(on_server_create=store_server, **(server_arguments or {})) 28 | # test_overrides = {'keep_alive_period': timedelta(minutes=20)} 29 | 30 | client_arguments = client_arguments or {} 31 | # client_arguments.update(test_overrides) 32 | server_task = asyncio.create_task(app.run_task(port=unused_tcp_port)) 33 | await asyncio.sleep(0) 34 | 35 | async with websocket_client('ws://localhost:{}'.format(unused_tcp_port), 36 | **client_arguments) as client: 37 | await wait_for_server.wait() 38 | yield server, client 39 | 40 | await server.close() 41 | assert_no_open_streams(client, server) 42 | 43 | try: 44 | server_task.cancel() 45 | await server_task 46 | except asyncio.CancelledError: 47 | pass 48 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/ClientWebsocketHandler.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.netty.channel.Channel; 4 | import io.netty.channel.ChannelHandlerContext; 5 | import io.netty.channel.ChannelInboundHandlerAdapter; 6 | import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; 7 | import io.netty.util.ReferenceCountUtil; 8 | import reactor.core.publisher.MonoProcessor; 9 | 10 | import java.nio.charset.StandardCharsets; 11 | 12 | public class ClientWebsocketHandler extends ChannelInboundHandlerAdapter { 13 | private final MonoProcessor channel = MonoProcessor.create(); 14 | private final MonoProcessor pong = MonoProcessor.create(); 15 | 16 | @Override 17 | public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { 18 | if (msg instanceof PongWebSocketFrame) { 19 | pong.onNext(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); 20 | ReferenceCountUtil.safeRelease(msg); 21 | ctx.read(); 22 | } else { 23 | super.channelRead(ctx, msg); 24 | } 25 | } 26 | 27 | @Override 28 | public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { 29 | Channel ch = ctx.channel(); 30 | if (!channel.isTerminated() && ch.isWritable()) { 31 | channel.onNext(ctx.channel()); 32 | } 33 | super.channelWritabilityChanged(ctx); 34 | } 35 | 36 | @Override 37 | public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 38 | Channel ch = ctx.channel(); 39 | if (ch.isWritable()) { 40 | channel.onNext(ch); 41 | } 42 | super.handlerAdded(ctx); 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /rsocket/extensions/authentication_content.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | from rsocket.extensions.authentication import Authentication, AuthenticationSimple, AuthenticationBearer 4 | from rsocket.extensions.authentication_types import WellKnownAuthenticationTypes 5 | from rsocket.extensions.composite_metadata_item import CompositeMetadataItem 6 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 7 | from rsocket.helpers import serialize_well_known_encoding, parse_well_known_encoding 8 | 9 | 10 | class AuthenticationContent(CompositeMetadataItem): 11 | __slots__ = 'authentication' 12 | 13 | def __init__(self, authentication: Optional[Authentication] = None): 14 | super().__init__(WellKnownMimeTypes.MESSAGE_RSOCKET_AUTHENTICATION.value.name, None) 15 | self.authentication = authentication 16 | 17 | def serialize(self) -> bytes: 18 | serialized = serialize_well_known_encoding(self.authentication.type, WellKnownAuthenticationTypes.get_by_name) 19 | 20 | serialized += self.authentication.serialize() 21 | 22 | return serialized 23 | 24 | def parse(self, buffer: bytes): 25 | authentication_type, offset = parse_well_known_encoding(buffer, WellKnownAuthenticationTypes.require_by_id) 26 | self.authentication = authentication_item_factory(authentication_type)() 27 | self.authentication.parse(buffer[offset:]) 28 | 29 | 30 | metadata_item_factory_by_type = { 31 | WellKnownAuthenticationTypes.SIMPLE.value.name: AuthenticationSimple, 32 | WellKnownAuthenticationTypes.BEARER.value.name: AuthenticationBearer, 33 | } 34 | 35 | 36 | def authentication_item_factory(metadata_encoding: bytes) -> Type[Authentication]: 37 | return metadata_item_factory_by_type[metadata_encoding] 38 | -------------------------------------------------------------------------------- /tests/rsocket/test_unimplemented_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | 5 | from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket 6 | from rsocket.payload import Payload 7 | from rsocket.rsocket_client import RSocketClient 8 | from rsocket.rsocket_server import RSocketServer 9 | from tests.rsocket.helpers import get_components 10 | 11 | 12 | @pytest.mark.allow_error_log(regex_filter='(Protocol|Setup|Unknown) error') 13 | async def test_request_response_not_implemented_by_server_by_default(pipe: Tuple[RSocketServer, RSocketClient]): 14 | payload = Payload(b'abc', b'def') 15 | server, client = get_components(pipe) 16 | 17 | with pytest.raises(RuntimeError) as exc_info: 18 | await client.request_response(payload) 19 | 20 | assert str(exc_info.value) == 'Not implemented' 21 | 22 | 23 | @pytest.mark.allow_error_log(regex_filter='(Protocol|Setup|Unknown) error') 24 | async def test_request_stream_not_implemented_by_server_by_default(pipe: Tuple[RSocketServer, RSocketClient]): 25 | payload = Payload(b'abc', b'def') 26 | server, client = get_components(pipe) 27 | 28 | with pytest.raises(RuntimeError) as exc_info: 29 | await AwaitableRSocket(client).request_stream(payload) 30 | 31 | assert str(exc_info.value) == 'Not implemented' 32 | 33 | 34 | @pytest.mark.allow_error_log(regex_filter='(Protocol|Setup|Unknown) error') 35 | async def test_request_channel_not_implemented_by_server_by_default(pipe: Tuple[RSocketServer, RSocketClient]): 36 | payload = Payload(b'abc', b'def') 37 | server, client = get_components(pipe) 38 | 39 | with pytest.raises(RuntimeError) as exc_info: 40 | await AwaitableRSocket(client).request_channel(payload) 41 | 42 | assert str(exc_info.value) == 'Not implemented' 43 | -------------------------------------------------------------------------------- /examples/server_quic.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from datetime import datetime 5 | from pathlib import Path 6 | 7 | from aioquic.quic.configuration import QuicConfiguration 8 | 9 | from rsocket.helpers import create_future 10 | from rsocket.local_typing import Awaitable 11 | from rsocket.payload import Payload 12 | from rsocket.request_handler import BaseRequestHandler 13 | from rsocket.transports.aioquic_transport import rsocket_serve 14 | 15 | 16 | class Handler(BaseRequestHandler): 17 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 18 | await asyncio.sleep(0.1) # Simulate not immediate process 19 | date_time_format = payload.data.decode('utf-8') 20 | formatted_date_time = datetime.now().strftime(date_time_format) 21 | return create_future(Payload(formatted_date_time.encode('utf-8'))) 22 | 23 | 24 | def run_server(server_port): 25 | logging.info('Starting server at localhost:%s', server_port) 26 | 27 | configuration = QuicConfiguration( 28 | is_client=False 29 | ) 30 | 31 | certificates_path = Path(__file__).parent / 'certificates' 32 | configuration.load_cert_chain(certificates_path / 'ssl_cert.pem', certificates_path / 'ssl_key.pem') 33 | 34 | return rsocket_serve(host='localhost', 35 | port=server_port, 36 | configuration=configuration, 37 | handler_factory=Handler) 38 | 39 | 40 | if __name__ == '__main__': 41 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 42 | logging.basicConfig(level=logging.DEBUG) 43 | 44 | loop = asyncio.get_event_loop() 45 | loop.run_until_complete(run_server(port)) 46 | try: 47 | loop.run_forever() 48 | except KeyboardInterrupt: 49 | pass 50 | -------------------------------------------------------------------------------- /rsocket/handlers/request_response_requester.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from rsocket.frame import ErrorFrame, PayloadFrame, Frame, error_frame_to_exception 4 | from rsocket.frame_builders import to_request_response_frame 5 | from rsocket.handlers.interfaces import Requester 6 | from rsocket.helpers import create_future, payload_from_frame 7 | from rsocket.local_typing import Awaitable 8 | from rsocket.payload import Payload 9 | from rsocket.rsocket import RSocket 10 | from rsocket.streams.stream_handler import StreamHandler 11 | 12 | 13 | class RequestResponseRequester(StreamHandler, Requester): 14 | def __init__(self, socket: RSocket, payload: Payload): 15 | super().__init__(socket) 16 | self._payload = payload 17 | self._future = create_future() 18 | 19 | def setup(self): 20 | self._future.add_done_callback(self._on_future_complete) 21 | 22 | def run(self) -> Awaitable[Payload]: 23 | request = to_request_response_frame(self.stream_id, 24 | self._payload, 25 | self.socket.get_fragment_size_bytes()) 26 | self.socket.send_request(request) 27 | return self._future 28 | 29 | def frame_received(self, frame: Frame): 30 | if isinstance(frame, PayloadFrame): 31 | self._future.set_result(payload_from_frame(frame)) 32 | self._finish_stream() 33 | elif isinstance(frame, ErrorFrame): 34 | self._future.set_exception(error_frame_to_exception(frame)) 35 | self._finish_stream() 36 | 37 | def _on_future_complete(self, future: asyncio.Future): 38 | if future.cancelled(): 39 | self.cancel() 40 | 41 | def cancel(self): 42 | self.send_cancel() 43 | self._finish_stream() 44 | -------------------------------------------------------------------------------- /rsocket/queue_peekable.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from asyncio import Queue, QueueEmpty 3 | 4 | 5 | class QueuePeekable(Queue): 6 | 7 | async def peek(self): 8 | """Peek the next item in the queue. 9 | 10 | If queue is empty, wait until an item is available. 11 | """ 12 | while self.empty(): 13 | getter = self._get_loop().create_future() 14 | self._getters.append(getter) 15 | try: 16 | await getter 17 | except Exception: 18 | getter.cancel() # Just in case getter is not done yet. 19 | try: 20 | # Clean self._getters from canceled getters. 21 | self._getters.remove(getter) 22 | except ValueError: 23 | # The getter could be removed from self._getters by a 24 | # previous put_nowait call. 25 | pass 26 | if not self.empty() and not getter.cancelled(): 27 | # We were woken up by put_nowait(), but can't take 28 | # the call. Wake up the next in line. 29 | self._wakeup_next(self._getters) 30 | raise 31 | return self.peek_nowait() 32 | 33 | def peek_nowait(self): 34 | """Peek the next item in the queue. 35 | 36 | Return an item if one is immediately available, else raise QueueEmpty. 37 | """ 38 | if self.empty(): 39 | raise QueueEmpty 40 | 41 | item = self._queue[0] 42 | self._wakeup_next(self._putters) 43 | return item 44 | 45 | 46 | if sys.version_info < (3, 10): 47 | class QueuePeekableBackwardCompatible(QueuePeekable): 48 | def _get_loop(self): 49 | return self._loop 50 | 51 | 52 | QueuePeekable = QueuePeekableBackwardCompatible 53 | -------------------------------------------------------------------------------- /rsocket/transports/quart_websocket.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from quart import websocket 4 | 5 | from rsocket.frame import Frame 6 | from rsocket.helpers import wrap_transport_exception 7 | from rsocket.logger import logger 8 | from rsocket.rsocket_server import RSocketServer 9 | from rsocket.transports.abstract_messaging import AbstractMessagingTransport 10 | 11 | 12 | async def websocket_handler(on_server_create=None, **kwargs): 13 | """ 14 | Helper method to instantiate an RSocket server using a quart websocket connection. 15 | 16 | :param on_server_create: callback to be called when the server is created 17 | :param kwargs: parameters passed to the server 18 | """ 19 | 20 | transport = TransportQuartWebsocket() 21 | server = RSocketServer(transport, **kwargs) 22 | 23 | if on_server_create is not None: 24 | on_server_create(server) 25 | 26 | await transport.handle_incoming_ws_messages() 27 | 28 | 29 | class TransportQuartWebsocket(AbstractMessagingTransport): 30 | """ 31 | RSocket transport over server side quart websocket. Use the `websocket_handler ` helper method to instantiate. 32 | """ 33 | 34 | async def handle_incoming_ws_messages(self): 35 | try: 36 | while True: 37 | data = await websocket.receive() 38 | 39 | async for frame in self._frame_parser.receive_data(data, 0): 40 | self._incoming_frame_queue.put_nowait(frame) 41 | except asyncio.CancelledError: 42 | logger().debug('Asyncio task canceled: quart_handle_incoming_ws_messages') 43 | 44 | async def send_frame(self, frame: Frame): 45 | with wrap_transport_exception(): 46 | await websocket.send(frame.serialize()) 47 | 48 | async def close(self): 49 | pass 50 | -------------------------------------------------------------------------------- /rsocket/transports/tcp.py: -------------------------------------------------------------------------------- 1 | from asyncio import StreamReader, StreamWriter 2 | 3 | from rsocket.frame import Frame, serialize_prefix_with_frame_size_header 4 | from rsocket.helpers import wrap_transport_exception 5 | from rsocket.transports.transport import Transport 6 | 7 | 8 | class TransportTCP(Transport): 9 | """ 10 | RSocket transport over asyncio TCP connection. 11 | 12 | :param reader: asyncio connection reader stream 13 | :param writer: asyncio connection writer stream 14 | """ 15 | 16 | def __init__(self, 17 | reader: StreamReader, 18 | writer: StreamWriter, 19 | read_buffer_size=1024): 20 | super().__init__() 21 | self._read_buffer_size = read_buffer_size 22 | self._writer = writer 23 | self._reader = reader 24 | 25 | async def send_frame(self, frame: Frame): 26 | await self.serialize_partial(frame) 27 | 28 | async def serialize_partial(self, frame: Frame): 29 | with wrap_transport_exception(): 30 | self._writer.write(serialize_prefix_with_frame_size_header(frame)) 31 | frame.write_data_metadata(self._writer.write) 32 | await self._writer.drain() 33 | 34 | async def on_send_queue_empty(self): 35 | with wrap_transport_exception(): 36 | await self._writer.drain() 37 | 38 | async def close(self): 39 | self._writer.close() 40 | await self._writer.wait_closed() 41 | 42 | async def next_frame_generator(self): 43 | with wrap_transport_exception(): 44 | data = await self._reader.read(self._read_buffer_size) 45 | 46 | if not data: 47 | self._writer.close() 48 | return 49 | 50 | return self._frame_parser.receive_data(data) 51 | 52 | def requires_length_header(self) -> bool: 53 | return True 54 | -------------------------------------------------------------------------------- /tests/tools/fixtures_websockets.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import Event 3 | from contextlib import asynccontextmanager 4 | from typing import Optional 5 | 6 | import websockets 7 | 8 | from rsocket.rsocket_base import RSocketBase 9 | from rsocket.rsocket_server import RSocketServer 10 | from tests.rsocket.helpers import assert_no_open_streams 11 | 12 | 13 | @asynccontextmanager 14 | async def pipe_factory_websockets(unused_tcp_port, client_arguments=None, server_arguments=None): 15 | from rsocket.transports.websockets_transport import WebsocketsTransport 16 | from rsocket.transports.aiohttp_websocket import websocket_client 17 | 18 | server: Optional[RSocketBase] = None 19 | wait_for_server = Event() 20 | stop_websocket_server = Event() 21 | 22 | async def endpoint(websocket): 23 | nonlocal server 24 | transport = WebsocketsTransport() 25 | server = RSocketServer(transport, **(server_arguments or {})) 26 | wait_for_server.set() 27 | await transport.handler(websocket) 28 | 29 | async def server_app(): 30 | async with websockets.serve(endpoint, "localhost", unused_tcp_port): 31 | await stop_websocket_server.wait() 32 | 33 | server_task = asyncio.create_task(server_app()) 34 | 35 | try: 36 | async with websocket_client('http://localhost:{}'.format(unused_tcp_port), 37 | **(client_arguments or {})) as client: 38 | await wait_for_server.wait() 39 | yield server, client 40 | 41 | finally: 42 | stop_websocket_server.set() 43 | if server is not None: 44 | await server.close() 45 | 46 | assert_no_open_streams(client, server) 47 | 48 | try: 49 | server_task.cancel() 50 | await server_task 51 | except asyncio.CancelledError: 52 | pass 53 | -------------------------------------------------------------------------------- /tests/rsocket/cloudevents/test_route_cloud_events.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from cloudevents.conversion import to_json, from_json 4 | from cloudevents.pydantic import CloudEvent 5 | 6 | from rsocket.cloudevents.serialize import cloud_event_deserialize, cloud_event_serialize 7 | from rsocket.extensions.helpers import composite, route 8 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 9 | from rsocket.payload import Payload 10 | from rsocket.routing.request_router import RequestRouter 11 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 12 | 13 | 14 | async def test_routed_cloudevents(lazy_pipe): 15 | router = RequestRouter(cloud_event_deserialize, 16 | cloud_event_serialize) 17 | 18 | def handler_factory(): 19 | return RoutingRequestHandler(router) 20 | 21 | @router.response('cloud_event') 22 | async def response_request(value: CloudEvent) -> CloudEvent: 23 | return CloudEvent.create(attributes={ 24 | 'type': 'io.spring.event.Foo', 25 | 'source': 'https://spring.io/foos' 26 | }, data=json.dumps(json.loads(value.data))) 27 | 28 | async with lazy_pipe( 29 | client_arguments={'metadata_encoding': WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA}, 30 | server_arguments={'handler_factory': handler_factory}) as (server, client): 31 | event = CloudEvent.create(attributes={ 32 | 'type': 'io.spring.event.Foo', 33 | 'source': 'https://spring.io/foos' 34 | }, data=json.dumps({'value': 'Dave'})) 35 | 36 | response = await client.request_response(Payload(data=to_json(event), metadata=composite(route('cloud_event')))) 37 | 38 | response_event = from_json(CloudEvent, response.data) 39 | response_data = json.loads(response_event.data) 40 | 41 | assert response_data['value'] == 'Dave' 42 | -------------------------------------------------------------------------------- /examples/graphql/java/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 4.0.0 5 | 6 | org.springframework.boot 7 | spring-boot-starter-parent 8 | 3.1.4 9 | 10 | 11 | com.example 12 | demo 13 | 0.0.1-SNAPSHOT 14 | demo 15 | Demo project for Spring Boot 16 | 17 | 17 18 | 19 | 20 | 21 | org.springframework.boot 22 | spring-boot-starter-graphql 23 | 24 | 25 | org.springframework.boot 26 | spring-boot-starter-rsocket 27 | 28 | 29 | 30 | org.springframework.boot 31 | spring-boot-starter-test 32 | test 33 | 34 | 35 | io.projectreactor 36 | reactor-test 37 | test 38 | 39 | 40 | org.springframework.graphql 41 | spring-graphql-test 42 | test 43 | 44 | 45 | 46 | 47 | 48 | 49 | org.springframework.boot 50 | spring-boot-maven-plugin 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | - name: Lint with flake8 31 | run: | 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 34 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 35 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 36 | - name: Test with pytest 37 | run: | 38 | pytest -n 1 --reruns 5 --cov-report=html --cov --ignore=examples tests 39 | - name: Archive code coverage html report 40 | uses: actions/upload-artifact@v4 41 | with: 42 | name: code-coverage-report 43 | path: coverage_report_html 44 | overwrite: true 45 | - name: Publish coveralls 46 | continue-on-error: true 47 | env: 48 | GITHUB_TOKEN: ${{ secrets.COVERALLS_TOKEN }} 49 | run: | 50 | coveralls 51 | -------------------------------------------------------------------------------- /examples/server_aiohttp_websocket.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import ssl 3 | 4 | import asyncclick as click 5 | from aiohttp import web 6 | 7 | from fixtures import generate_certificate_and_key 8 | from rsocket.helpers import create_future 9 | from rsocket.local_typing import Awaitable 10 | from rsocket.payload import Payload 11 | from rsocket.request_handler import BaseRequestHandler 12 | from rsocket.rsocket_server import RSocketServer 13 | from rsocket.transports.aiohttp_websocket import TransportAioHttpWebsocket 14 | 15 | 16 | class Handler(BaseRequestHandler): 17 | 18 | async def request_response(self, payload: Payload) -> Awaitable[Payload]: 19 | return create_future(Payload(b'pong')) 20 | 21 | 22 | def websocket_handler_factory(**kwargs): 23 | async def websocket_handler(request): 24 | ws = web.WebSocketResponse() 25 | await ws.prepare(request) 26 | transport = TransportAioHttpWebsocket(ws) 27 | RSocketServer(transport, **kwargs) 28 | await transport.handle_incoming_ws_messages() 29 | return ws 30 | 31 | return websocket_handler 32 | 33 | 34 | @click.command() 35 | @click.option('--port', help='Port to listen on', default=6565, type=int) 36 | @click.option('--with-ssl', is_flag=True, help='Enable SSL mode') 37 | async def start_server(with_ssl: bool, port: int): 38 | logging.basicConfig(level=logging.DEBUG) 39 | app = web.Application() 40 | app.add_routes([web.get('/', websocket_handler_factory(handler_factory=Handler))]) 41 | 42 | if with_ssl: 43 | ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 44 | 45 | with generate_certificate_and_key() as (certificate, key): 46 | ssl_context.load_cert_chain(certificate, key) 47 | else: 48 | ssl_context = None 49 | 50 | await web._run_app(app, port=port, ssl_context=ssl_context) 51 | 52 | 53 | if __name__ == '__main__': 54 | start_server() 55 | -------------------------------------------------------------------------------- /tests/test_integrations/test_cloudevents.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from cloudevents.conversion import to_json, from_json 4 | from cloudevents.pydantic import CloudEvent 5 | 6 | from rsocket.extensions.helpers import route, composite 7 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 8 | from rsocket.helpers import create_response 9 | from rsocket.payload import Payload 10 | from rsocket.routing.request_router import RequestRouter 11 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 12 | 13 | 14 | async def test_routed_request_cloud_event(lazy_pipe): 15 | router = RequestRouter() 16 | 17 | def handler_factory(): 18 | return RoutingRequestHandler(router) 19 | 20 | @router.response('event') 21 | async def single_request_response(payload): 22 | received_event = from_json(CloudEvent, payload.data) 23 | received_data = json.loads(received_event.data) 24 | 25 | event = CloudEvent.create(attributes={ 26 | 'type': 'io.spring.event.Foo', 27 | 'source': 'https://spring.io/foos' 28 | }, data=json.dumps(received_data)) 29 | 30 | return create_response(to_json(event)) 31 | 32 | async with lazy_pipe( 33 | client_arguments={'metadata_encoding': WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA}, 34 | server_arguments={'handler_factory': handler_factory}) as (server, client): 35 | event = CloudEvent.create(attributes={ 36 | 'type': 'io.spring.event.Foo', 37 | 'source': 'https://spring.io/foos' 38 | }, data=json.dumps({'value': 'Dave'})) 39 | 40 | data = to_json(event) 41 | response = await client.request_response(Payload(data=data, metadata=composite(route('event')))) 42 | 43 | event = from_json(CloudEvent, response.data) 44 | response_data = json.loads(event.data) 45 | 46 | assert response_data['value'] == 'Dave' 47 | -------------------------------------------------------------------------------- /rsocket/awaitable/collector_subscriber.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from reactivestreams.subscriber import Subscriber 4 | from reactivestreams.subscription import DefaultSubscription, Subscription 5 | from rsocket.frame import MAX_REQUEST_N 6 | 7 | 8 | class CollectorSubscriber(Subscriber, Subscription): 9 | 10 | def __init__(self, limit_rate=MAX_REQUEST_N, limit_count=None) -> None: 11 | self._limit_count = limit_count 12 | self._limit_rate = limit_rate 13 | self._received_count = 0 14 | self._total_received_count = 0 15 | self.is_done = asyncio.Event() 16 | self.error = None 17 | self.values = [] 18 | self.subscription = None 19 | 20 | def on_complete(self): 21 | self.is_done.set() 22 | 23 | def on_subscribe(self, subscription: DefaultSubscription): 24 | self.subscription = subscription 25 | 26 | def cancel(self): 27 | self.subscription.cancel() 28 | 29 | def request(self, n: int): 30 | self.subscription.request(n) 31 | 32 | def on_next(self, value, is_complete=False): 33 | self.values.append(value) 34 | 35 | self._received_count += 1 36 | self._total_received_count += 1 37 | 38 | if is_complete: 39 | self.is_done.set() 40 | elif self._limit_count is not None and self._limit_count == self._total_received_count: 41 | self.subscription.cancel() 42 | self.is_done.set() 43 | else: 44 | if self._received_count == self._limit_rate: 45 | self._received_count = 0 46 | self.subscription.request(self._limit_rate) 47 | 48 | def on_error(self, exception: Exception): 49 | self.error = exception 50 | self.is_done.set() 51 | 52 | async def run(self): 53 | await self.is_done.wait() 54 | 55 | if self.error: 56 | raise self.error 57 | 58 | return self.values 59 | -------------------------------------------------------------------------------- /rsocket/extensions/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from rsocket.extensions.authentication import AuthenticationBearer, AuthenticationSimple 4 | from rsocket.extensions.authentication_content import AuthenticationContent 5 | from rsocket.extensions.composite_metadata import CompositeMetadata, CompositeMetadataItem 6 | from rsocket.extensions.mimetypes import WellKnownMimeType, WellKnownMimeTypes 7 | from rsocket.extensions.routing import RoutingMetadata 8 | from rsocket.extensions.stream_data_mimetype import StreamDataMimetype, StreamDataMimetypes 9 | 10 | 11 | def composite(*items) -> bytes: 12 | metadata = CompositeMetadata() 13 | metadata.extend(*items) 14 | return metadata.serialize() 15 | 16 | 17 | def metadata_item(data: bytes, 18 | encoding: Union[bytes, WellKnownMimeTypes, WellKnownMimeType]) -> CompositeMetadataItem: 19 | return CompositeMetadataItem(encoding, data) 20 | 21 | 22 | def authenticate_simple(username: str, password: str) -> CompositeMetadataItem: 23 | return AuthenticationContent(AuthenticationSimple(username, password)) 24 | 25 | 26 | def authenticate_bearer(token: str) -> CompositeMetadataItem: 27 | return AuthenticationContent(AuthenticationBearer(token)) 28 | 29 | 30 | def route(*paths: str) -> CompositeMetadataItem: 31 | return RoutingMetadata(list(paths)) 32 | 33 | 34 | def data_mime_type(metadata_mime_type: Union[bytes, WellKnownMimeType]) -> StreamDataMimetype: 35 | return StreamDataMimetype(metadata_mime_type) 36 | 37 | 38 | def data_mime_types(*metadata_mime_types: Union[bytes, WellKnownMimeType]) -> StreamDataMimetypes: 39 | return StreamDataMimetypes(list(metadata_mime_types)) 40 | 41 | 42 | def require_route(composite_metadata: CompositeMetadata) -> str: 43 | for item in composite_metadata.items: 44 | if isinstance(item, RoutingMetadata): 45 | return item.tags[0].decode() 46 | 47 | raise Exception('No route found in request') 48 | -------------------------------------------------------------------------------- /rsocket/load_balancer/load_balancer_rsocket.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Union, Optional, Any 3 | 4 | from reactivestreams.publisher import Publisher 5 | from rsocket.load_balancer.load_balancer_strategy import LoadBalancerStrategy 6 | from rsocket.local_typing import Awaitable 7 | from rsocket.payload import Payload 8 | from rsocket.rsocket import RSocket 9 | from rsocket.streams.backpressureapi import BackpressureApi 10 | 11 | 12 | class LoadBalancerRSocket(RSocket): 13 | 14 | def __init__(self, strategy: LoadBalancerStrategy): 15 | self._strategy = strategy 16 | 17 | def request_channel(self, 18 | payload: Payload, 19 | publisher: Optional[Publisher] = None, 20 | sending_done: Optional[asyncio.Event] = None) -> Union[Any, Publisher]: 21 | return self._select_client().request_channel( 22 | payload, publisher, sending_done 23 | ) 24 | 25 | def request_response(self, payload: Payload) -> Awaitable[Payload]: 26 | return self._select_client().request_response(payload) 27 | 28 | def fire_and_forget(self, payload: Payload) -> Awaitable[None]: 29 | return self._select_client().fire_and_forget(payload) 30 | 31 | def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]: 32 | return self._select_client().request_stream(payload) 33 | 34 | def metadata_push(self, metadata: bytes) -> Awaitable[None]: 35 | return self._select_client().metadata_push(metadata) 36 | 37 | async def connect(self): 38 | await self._strategy.connect() 39 | 40 | async def close(self): 41 | await self._strategy.close() 42 | 43 | async def __aenter__(self) -> RSocket: 44 | await self._strategy.connect() 45 | return self 46 | 47 | async def __aexit__(self, exc_type, exc_val, exc_tb): 48 | await self._strategy.close() 49 | 50 | def _select_client(self): 51 | return self._strategy.select() 52 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/SimpleRSocketAcceptor.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.netty.buffer.Unpooled; 4 | import io.netty.util.CharsetUtil; 5 | import io.rsocket.ConnectionSetupPayload; 6 | import io.rsocket.Payload; 7 | import io.rsocket.RSocket; 8 | import io.rsocket.SocketAcceptor; 9 | import io.rsocket.metadata.CompositeMetadata; 10 | import io.rsocket.util.DefaultPayload; 11 | import reactor.core.publisher.Flux; 12 | import reactor.core.publisher.Mono; 13 | 14 | import java.util.ArrayList; 15 | 16 | public class SimpleRSocketAcceptor implements SocketAcceptor { 17 | 18 | @Override 19 | public Mono accept(ConnectionSetupPayload connectionSetupPayload, RSocket rSocket) { 20 | return Mono.just(new RSocket() { 21 | public Mono fireAndForget(Payload payload) { 22 | var str = payload.getDataUtf8(); 23 | System.out.println("Received :: " + str); 24 | return Mono.empty(); 25 | } 26 | 27 | public Mono requestResponse(Payload payload) { 28 | var str = payload.getDataUtf8(); 29 | return Mono.just(DefaultPayload.create(str.toUpperCase())); 30 | } 31 | 32 | public Flux requestStream(Payload payload) { 33 | var metadata = Unpooled.wrappedBuffer(payload.getMetadata()); 34 | 35 | var route = new ArrayList(); 36 | 37 | new CompositeMetadata(metadata, true).forEach(entry -> 38 | route.add(entry.getContent().toString(CharsetUtil.US_ASCII)) 39 | ); 40 | 41 | var data = payload.getDataUtf8(); 42 | 43 | return Flux.concat(Flux.fromStream(route.stream()), Flux.range(1, 10) 44 | .map(index -> data + "-" + index) 45 | ) 46 | .map(DefaultPayload::create); 47 | } 48 | }); 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /tests/test_reactivex/test_concurrency.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Tuple, Optional 3 | 4 | import reactivex 5 | from reactivex import operators 6 | 7 | from rsocket.frame_helpers import ensure_bytes 8 | from rsocket.helpers import utf8_decode 9 | from rsocket.payload import Payload 10 | from rsocket.reactivex.reactivex_client import ReactiveXClient 11 | from rsocket.reactivex.reactivex_handler import BaseReactivexHandler 12 | from rsocket.reactivex.reactivex_handler_adapter import reactivex_handler_factory 13 | from rsocket.rsocket_client import RSocketClient 14 | from rsocket.rsocket_server import RSocketServer 15 | from tests.tools.helpers import measure_time 16 | 17 | 18 | class Handler(BaseReactivexHandler): 19 | 20 | def __init__(self, server_done: Optional[asyncio.Event] = None): 21 | self._server_done = server_done 22 | 23 | async def request_stream(self, payload: Payload): 24 | count = int(utf8_decode(payload.data)) 25 | return reactivex.from_iterable( 26 | (Payload(ensure_bytes('Feed Item: {}/{}'.format(index, count))) for index in range(count))) 27 | 28 | 29 | async def test_concurrent_streams(pipe: Tuple[RSocketServer, RSocketClient]): 30 | server, client = pipe 31 | 32 | server.set_handler_using_factory(reactivex_handler_factory(Handler)) 33 | 34 | request_1 = asyncio.create_task(measure_time(ReactiveXClient(client).request_stream(Payload(b'2000')).pipe( 35 | operators.map(lambda payload: payload.data), 36 | operators.do_action(on_next=lambda x: print(x)), 37 | operators.to_list() 38 | ))) 39 | 40 | request_2 = asyncio.create_task(measure_time(ReactiveXClient(client).request_stream(Payload(b'10')).pipe( 41 | operators.map(lambda payload: payload.data), 42 | operators.do_action(on_next=lambda x: print(x)), 43 | operators.to_list() 44 | ))) 45 | 46 | results = (await request_1, await request_2) 47 | 48 | delta = abs(results[0].delta - results[1].delta) 49 | 50 | assert delta > 0.2 51 | -------------------------------------------------------------------------------- /rsocket/handlers/request_stream_requester.py: -------------------------------------------------------------------------------- 1 | from reactivestreams.subscriber import Subscriber 2 | from rsocket.frame import ErrorFrame, PayloadFrame, Frame, error_frame_to_exception 3 | from rsocket.frame_builders import to_request_stream_frame 4 | from rsocket.handlers.interfaces import Requester 5 | from rsocket.helpers import payload_from_frame, DefaultPublisherSubscription 6 | from rsocket.payload import Payload 7 | from rsocket.rsocket import RSocket 8 | from rsocket.streams.stream_handler import StreamHandler 9 | 10 | 11 | class RequestStreamRequester(StreamHandler, DefaultPublisherSubscription, Requester): 12 | def __init__(self, socket: RSocket, payload: Payload): 13 | super().__init__(socket) 14 | self.payload = payload 15 | 16 | def setup(self): 17 | pass 18 | 19 | def subscribe(self, subscriber: Subscriber): 20 | super().subscribe(subscriber) 21 | self._send_stream_request(self.payload) 22 | 23 | def cancel(self): 24 | self.send_cancel() 25 | self._finish_stream() 26 | 27 | def request(self, n: int): 28 | self.send_request_n(n) 29 | 30 | def frame_received(self, frame: Frame): 31 | if isinstance(frame, PayloadFrame): 32 | if frame.flags_next: 33 | self._subscriber.on_next(payload_from_frame(frame), 34 | is_complete=frame.flags_complete) 35 | elif frame.flags_complete: 36 | self._subscriber.on_complete() 37 | 38 | if frame.flags_complete: 39 | self._finish_stream() 40 | elif isinstance(frame, ErrorFrame): 41 | self._subscriber.on_error(error_frame_to_exception(frame)) 42 | self._finish_stream() 43 | 44 | def _send_stream_request(self, payload: Payload): 45 | self.socket.send_request(to_request_stream_frame( 46 | stream_id=self.stream_id, 47 | payload=payload, 48 | initial_request_n=self._initial_request_n, 49 | fragment_size_bytes=self.socket.get_fragment_size_bytes() 50 | )) 51 | -------------------------------------------------------------------------------- /examples/client_reconnect.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from typing import Optional 5 | 6 | from rsocket.extensions.helpers import route, composite, authenticate_simple 7 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 8 | from rsocket.payload import Payload 9 | from rsocket.request_handler import BaseRequestHandler 10 | from rsocket.rsocket_client import RSocketClient 11 | from rsocket.transports.tcp import TransportTCP 12 | 13 | 14 | async def request_response(client: RSocketClient) -> Payload: 15 | payload = Payload(b'The quick brown fox', composite( 16 | route('single_request'), 17 | authenticate_simple('user', '12345') 18 | )) 19 | 20 | return await client.request_response(payload) 21 | 22 | 23 | class Handler(BaseRequestHandler): 24 | 25 | async def on_close(self, rsocket, exception: Optional[Exception] = None): 26 | await asyncio.sleep(5) 27 | await rsocket.reconnect() 28 | 29 | 30 | async def main(server_port): 31 | logging.info('Connecting to server at localhost:%s', server_port) 32 | 33 | async def transport_provider(max_reconnect): 34 | for i in range(max_reconnect): 35 | connection = await asyncio.open_connection('localhost', server_port) 36 | yield TransportTCP(*connection) 37 | 38 | async with RSocketClient(transport_provider(3), 39 | metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA, 40 | handler_factory=Handler) as client: 41 | result1 = await request_response(client) 42 | assert result1.data == b'single_response' 43 | 44 | await asyncio.sleep(10) 45 | 46 | result2 = await request_response(client) 47 | assert result2.data == b'single_response' 48 | 49 | result3 = await request_response(client) 50 | assert result3.data == b'single_response' 51 | 52 | 53 | if __name__ == '__main__': 54 | port = sys.argv[1] if len(sys.argv) > 1 else 6565 55 | logging.basicConfig(level=logging.DEBUG) 56 | asyncio.run(main(port)) 57 | -------------------------------------------------------------------------------- /examples/graphql/server_graphql.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | from graphql import build_schema 8 | 9 | from rsocket.graphql.server_helper import graphql_handler 10 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 11 | from rsocket.rsocket_server import RSocketServer 12 | from rsocket.transports.tcp import TransportTCP 13 | 14 | stored_message = "" 15 | 16 | 17 | async def greeting(*args) -> Dict: 18 | return { 19 | 'message': "Hello world" 20 | } 21 | 22 | 23 | async def get_message(*args) -> str: 24 | return stored_message 25 | 26 | 27 | async def set_message(root, _info, message) -> Dict: 28 | global stored_message 29 | stored_message = message 30 | return { 31 | "message": message 32 | } 33 | 34 | 35 | def greetings(*args): 36 | async def results(): 37 | for i in range(10): 38 | yield {'greetings': {'message': f"Hello world {i}"}} 39 | await asyncio.sleep(1) 40 | 41 | return results() 42 | 43 | 44 | with (Path(__file__).parent / 'rsocket.graphqls').open() as fd: 45 | schema = build_schema(fd.read()) 46 | 47 | schema.query_type.fields['greeting'].resolve = greeting 48 | schema.query_type.fields['getMessage'].resolve = get_message 49 | schema.mutation_type.fields['setMessage'].resolve = set_message 50 | schema.subscription_type.fields['greetings'].subscribe = greetings 51 | 52 | 53 | def handler_factory(): 54 | return RoutingRequestHandler(graphql_handler(schema, 'graphql')) 55 | 56 | 57 | async def run_server(server_port): 58 | logging.info('Starting server at localhost:%s', server_port) 59 | 60 | def session(*connection): 61 | RSocketServer(TransportTCP(*connection), handler_factory=handler_factory) 62 | 63 | server = await asyncio.start_server(session, 'localhost', server_port) 64 | 65 | async with server: 66 | await server.serve_forever() 67 | 68 | 69 | if __name__ == '__main__': 70 | port = sys.argv[1] if len(sys.argv) > 1 else 9191 71 | logging.basicConfig(level=logging.DEBUG) 72 | asyncio.run(run_server(port)) 73 | -------------------------------------------------------------------------------- /examples/java/src/main/java/io/rsocket/pythontest/ClientChannelHandler.java: -------------------------------------------------------------------------------- 1 | package io.rsocket.pythontest; 2 | 3 | import io.netty.buffer.Unpooled; 4 | import io.rsocket.Payload; 5 | import io.rsocket.util.DefaultPayload; 6 | import org.reactivestreams.Publisher; 7 | import org.reactivestreams.Subscriber; 8 | import org.reactivestreams.Subscription; 9 | 10 | public class ClientChannelHandler implements Publisher, Subscription, Subscriber { 11 | 12 | 13 | private Subscriber subscriber; 14 | 15 | private final Payload request; 16 | 17 | private String lastMessage; 18 | 19 | private boolean routingRequestSent = false; 20 | 21 | public ClientChannelHandler(Payload request) { 22 | 23 | this.request = request; 24 | } 25 | 26 | @Override 27 | public void subscribe(Subscriber subscriber) { 28 | this.subscriber = subscriber; 29 | subscriber.onSubscribe(this); 30 | } 31 | 32 | @Override 33 | public void request(long l) { 34 | System.out.println("Received request for " + l + " Frames"); 35 | if (!routingRequestSent) { 36 | subscriber.onNext(request); 37 | routingRequestSent = true; 38 | } else { 39 | subscriber.onNext(DefaultPayload.create(Unpooled.wrappedBuffer(("from client: " + lastMessage).getBytes()))); 40 | } 41 | } 42 | 43 | @Override 44 | public void cancel() { 45 | System.out.println("Canceled"); 46 | } 47 | 48 | @Override 49 | public void onSubscribe(Subscription s) { 50 | 51 | } 52 | 53 | @Override 54 | public void onNext(Payload payload) { 55 | lastMessage = payload.getDataUtf8(); 56 | System.out.println("Response from server stream :: " + lastMessage); 57 | 58 | if (lastMessage.endsWith("2")) { 59 | subscriber.onComplete(); 60 | } 61 | } 62 | 63 | @Override 64 | public void onError(Throwable t) { 65 | System.out.println("Error from server" + t.getMessage()); 66 | } 67 | 68 | @Override 69 | public void onComplete() { 70 | System.out.println("Complete from server"); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /examples/response_channel.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from typing import AsyncGenerator, Tuple, Optional 4 | 5 | from reactivestreams.subscriber import Subscriber 6 | from reactivestreams.subscription import Subscription 7 | from rsocket.payload import Payload 8 | from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator 9 | 10 | 11 | def sample_async_response_stream(response_count: int = 3, 12 | local_subscriber: Optional[Subscriber] = None, 13 | is_infinite_stream: bool = False): 14 | async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: 15 | try: 16 | current_response = 0 17 | 18 | def range_counter(): 19 | return range(response_count) 20 | 21 | if not is_infinite_stream: 22 | counter = range_counter 23 | else: 24 | counter = itertools.count 25 | 26 | for i in counter(): 27 | if is_infinite_stream: 28 | is_complete = False 29 | else: 30 | is_complete = (current_response + 1) == response_count 31 | 32 | message = 'Item on channel: %s' % current_response 33 | yield Payload(message.encode('utf-8')), is_complete 34 | 35 | if local_subscriber is not None: 36 | local_subscriber.subscription.request(2) 37 | 38 | if is_complete: 39 | break 40 | 41 | current_response += 1 42 | finally: 43 | logging.info('Closing async stream generator') 44 | 45 | return StreamFromAsyncGenerator(generator) 46 | 47 | 48 | class LoggingSubscriber(Subscriber): 49 | def on_subscribe(self, subscription: Subscription): 50 | self.subscription = subscription 51 | 52 | def on_next(self, value: Payload, is_complete=False): 53 | logging.info('From client on channel: ' + value.data.decode('utf-8')) 54 | 55 | def on_error(self, exception: Exception): 56 | logging.error('Error on channel ' + str(exception)) 57 | 58 | def on_complete(self): 59 | logging.info('Completed on channel') 60 | -------------------------------------------------------------------------------- /examples/java/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4.0.0 3 | 4 | com.mycompany.app 5 | rsocket-examples 6 | 1 7 | 8 | 9 | 10 | org.apache.maven.plugins 11 | maven-compiler-plugin 12 | 13 | 11 14 | 11 15 | 16 | 17 | 18 | org.apache.maven.plugins 19 | maven-shade-plugin 20 | 3.2.2 21 | 22 | 23 | package 24 | 25 | shade 26 | 27 | 28 | 29 | 30 | *:* 31 | 32 | 33 | false 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | ch.qos.logback 43 | logback-classic 44 | 1.3.12 45 | 46 | 47 | io.rsocket 48 | rsocket-core 49 | 1.1.3 50 | 51 | 52 | io.rsocket 53 | rsocket-transport-netty 54 | 1.1.3 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /docs/extensions.rst: -------------------------------------------------------------------------------- 1 | Extensions 2 | ========== 3 | 4 | Transports 5 | ---------- 6 | 7 | TCP 8 | ~~~ 9 | 10 | .. automodule:: rsocket.transports.tcp 11 | :members: 12 | 13 | Websocket 14 | ~~~~~~~~~ 15 | 16 | aiohttp 17 | +++++++ 18 | 19 | .. automodule:: rsocket.transports.aiohttp_websocket 20 | :members: 21 | 22 | quart 23 | +++++ 24 | 25 | .. automodule:: rsocket.transports.quart_websocket 26 | :members: 27 | 28 | websockets 29 | ++++++++++ 30 | 31 | .. automodule:: rsocket.transports.websockets_transport 32 | :members: 33 | 34 | asyncwebsockets 35 | +++++++++++++++ 36 | 37 | .. automodule:: rsocket.transports.asyncwebsockets_transport 38 | :members: 39 | 40 | quic 41 | ~~~~ 42 | 43 | .. automodule:: rsocket.transports.aioquic_transport 44 | :members: 45 | 46 | http3 47 | ~~~~~ 48 | 49 | .. automodule:: rsocket.transports.http3_transport 50 | :members: 51 | 52 | Routing 53 | ------- 54 | 55 | RequestRouter 56 | ~~~~~~~~~~~~~ 57 | 58 | .. automodule:: rsocket.routing.request_router 59 | :members: 60 | 61 | RoutingRequestHandler 62 | ~~~~~~~~~~~~~~~~~~~~~ 63 | 64 | .. automodule:: rsocket.routing.routing_request_handler 65 | :members: 66 | 67 | 68 | Load Balancer 69 | ------------- 70 | 71 | .. automodule:: rsocket.load_balancer.load_balancer_rsocket 72 | :members: 73 | 74 | Strategies 75 | ~~~~~~~~~~ 76 | 77 | .. automodule:: rsocket.load_balancer.round_robin 78 | :members: 79 | :inherited-members: 80 | 81 | .. automodule:: rsocket.load_balancer.random_client 82 | :members: 83 | :inherited-members: 84 | 85 | 86 | ReactiveX 87 | --------- 88 | 89 | ReactiveX 4 90 | ~~~~~~~~~~~ 91 | 92 | .. automodule:: rsocket.reactivex.reactivex_handler 93 | :members: 94 | :inherited-members: 95 | 96 | 97 | .. automodule:: rsocket.reactivex.reactivex_handler_adapter 98 | :members: 99 | :inherited-members: 100 | 101 | 102 | ReactiveX 3 103 | ~~~~~~~~~~~ 104 | 105 | .. automodule:: rsocket.rx_support.rx_handler 106 | :members: 107 | :inherited-members: 108 | 109 | 110 | .. automodule:: rsocket.rx_support.rx_handler_adapter 111 | :members: 112 | :inherited-members: 113 | -------------------------------------------------------------------------------- /examples/certificates/ssl_cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV 3 | BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW 4 | MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx 5 | NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj 6 | MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv 7 | Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k 8 | YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA 9 | 3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug 10 | U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2 11 | pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA 12 | hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC 13 | WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU 14 | NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3 15 | EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB 16 | wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV 17 | HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E 18 | FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK 19 | b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m 20 | dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst 21 | gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0 22 | Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw 23 | AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD 24 | VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0 25 | Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/ 26 | uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY 27 | oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb 28 | iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0 29 | KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP 30 | IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr 31 | +UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI 32 | AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv 33 | StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw== 34 | -----END CERTIFICATE----- 35 | -------------------------------------------------------------------------------- /tests/tools/fixtures_http3.py: -------------------------------------------------------------------------------- 1 | from asyncio import Event 2 | from contextlib import asynccontextmanager 3 | from typing import Optional 4 | 5 | from rsocket.helpers import single_transport_provider 6 | from rsocket.rsocket_base import RSocketBase 7 | from rsocket.rsocket_client import RSocketClient 8 | from tests.rsocket.helpers import assert_no_open_streams 9 | 10 | 11 | @asynccontextmanager 12 | async def pipe_factory_http3(generate_test_certificates, 13 | unused_tcp_port, 14 | client_arguments=None, 15 | server_arguments=None): 16 | from tests.tools.http3_client import http3_ws_transport 17 | from tests.tools.http3_server import start_http_server 18 | 19 | certificate, private_key = generate_test_certificates 20 | 21 | server: Optional[RSocketBase] = None 22 | client: Optional[RSocketBase] = None 23 | wait_for_server = Event() 24 | 25 | def store_server(new_server): 26 | nonlocal server 27 | server = new_server 28 | wait_for_server.set() 29 | 30 | http3_server = await start_http_server(host='localhost', 31 | port=unused_tcp_port, 32 | certificate=certificate, 33 | private_key=private_key, 34 | on_server_create=store_server, 35 | **(server_arguments or {})) 36 | try: 37 | # from datetime import timedelta 38 | # test_overrides = {'keep_alive_period': timedelta(minutes=20)} 39 | client_arguments = client_arguments or {} 40 | # client_arguments.update(test_overrides) 41 | async with http3_ws_transport(certificate, f'wss://localhost:{unused_tcp_port}/ws') as transport: 42 | async with RSocketClient(single_transport_provider(transport), 43 | **client_arguments) as client: 44 | await wait_for_server.wait() 45 | yield server, client 46 | finally: 47 | if server is not None: 48 | await server.close() 49 | 50 | assert_no_open_streams(client, server) 51 | 52 | http3_server.close() 53 | -------------------------------------------------------------------------------- /rsocket/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from rsocket.error_codes import ErrorCode 4 | 5 | 6 | class ParseError(ValueError): 7 | pass 8 | 9 | 10 | class RSocketError(Exception): 11 | pass 12 | 13 | 14 | class RSocketUnknownMimetype(RSocketError): 15 | def __init__(self, mimetype_id): 16 | self.mimetype_id = mimetype_id 17 | 18 | 19 | class RSocketUnknownAuthType(RSocketError): 20 | def __init__(self, auth_type_id): 21 | self.auth_type_id = auth_type_id 22 | 23 | 24 | class RSocketMimetypeTooLong(RSocketError): 25 | def __init__(self, mimetype): 26 | self.mimetype = mimetype 27 | 28 | 29 | class RSocketUnknownFrameType(RSocketError): 30 | def __init__(self, frame_type_id): 31 | self.frame_type_id = frame_type_id 32 | 33 | 34 | class RSocketApplicationError(RSocketError): 35 | pass 36 | 37 | 38 | class RSocketEmptyRoute(RSocketApplicationError): 39 | def __init__(self, method_name: str): 40 | self.method_name = method_name 41 | 42 | def __str__(self) -> str: 43 | return f'Empty route set on method {self.method_name}' 44 | 45 | 46 | class RSocketUnknownRoute(RSocketApplicationError): 47 | def __init__(self, route_id: str): 48 | self.route_id = route_id 49 | 50 | 51 | class RSocketStreamAllocationFailure(RSocketError): 52 | pass 53 | 54 | 55 | class RSocketValueError(RSocketError): 56 | pass 57 | 58 | 59 | class RSocketProtocolError(RSocketError): 60 | def __init__(self, error_code: ErrorCode, data: Optional[str] = None): 61 | self.error_code = error_code 62 | self.data = data 63 | 64 | def __str__(self) -> str: 65 | return 'RSocket error %s(%s): "%s"' % (self.error_code.name, self.error_code.value, self.data or '') 66 | 67 | 68 | class RSocketStreamIdInUse(RSocketProtocolError): 69 | 70 | def __init__(self, stream_id: int): 71 | super().__init__(ErrorCode.REJECTED) 72 | self.stream_id = stream_id 73 | 74 | 75 | class RSocketFrameFragmentDifferentType(RSocketError): 76 | pass 77 | 78 | 79 | class RSocketTransportError(RSocketError): 80 | pass 81 | 82 | 83 | class RSocketTransportClosed(RSocketError): 84 | pass 85 | 86 | 87 | class RSocketNoAvailableTransport(RSocketError): 88 | pass 89 | -------------------------------------------------------------------------------- /tests/rsocket/test_extentions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import struct 3 | 4 | from rsocket.exceptions import RSocketError 5 | from rsocket.extensions.authentication import AuthenticationBearer, AuthenticationSimple 6 | from rsocket.extensions.mimetypes import WellKnownMimeTypes 7 | from rsocket.extensions.routing import RoutingMetadata 8 | from rsocket.extensions.tagging import TaggingMetadata 9 | 10 | 11 | def test_authentication_bearer(): 12 | data = b'1234' 13 | 14 | authentication = AuthenticationBearer(b'1234') 15 | 16 | assert authentication.serialize() == data 17 | 18 | parsed = AuthenticationBearer() 19 | parsed.parse(data) 20 | 21 | assert parsed == authentication 22 | 23 | 24 | def test_authentication_simple(): 25 | data = b'\x00\x041234abcd' 26 | 27 | authentication = AuthenticationSimple(b'1234', b'abcd') 28 | 29 | assert authentication.serialize() == data 30 | 31 | parsed = AuthenticationSimple() 32 | parsed.parse(data) 33 | 34 | assert parsed == authentication 35 | 36 | 37 | def test_routing(): 38 | data = b'\nroute.path\x0bother.route' 39 | 40 | routing = RoutingMetadata([b'route.path', b'other.route']) 41 | 42 | assert routing.serialize() == data 43 | 44 | parsed = RoutingMetadata() 45 | parsed.parse(data) 46 | 47 | assert parsed == routing 48 | 49 | 50 | def test_tagging_metadata_serialize_max_length(): 51 | tag = 's' * 255 52 | meta = TaggingMetadata(WellKnownMimeTypes.MESSAGE_RSOCKET_ROUTING, [tag]) 53 | 54 | serialized = meta.serialize() 55 | 56 | length = struct.pack('>B', len(tag)) 57 | assert length + bytes(tag, 'utf-8') == serialized 58 | 59 | 60 | def test_tagging_metadata_serialize_exception_length(): 61 | tag = 's' * 256 62 | meta = TaggingMetadata(WellKnownMimeTypes.MESSAGE_RSOCKET_ROUTING, [tag]) 63 | 64 | with pytest.raises(RSocketError) as e_info: 65 | meta.serialize() 66 | 67 | assert e_info.match(f'Tag length longer than 255 characters: "b\'{tag}\'"') 68 | 69 | 70 | def test_tagging_metadata_parse(): 71 | meta = TaggingMetadata(WellKnownMimeTypes.MESSAGE_RSOCKET_ROUTING) 72 | tag = 's' * 255 73 | length = struct.pack('>B', len(tag)) 74 | 75 | meta.parse(length + bytes(tag, 'utf-8')) 76 | assert tag == meta.tags[0].decode() 77 | -------------------------------------------------------------------------------- /examples/cli_demo_server/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import asyncclick as click 5 | import reactivex 6 | from aiohttp import web 7 | from reactivex import Observable 8 | 9 | from rsocket.payload import Payload 10 | from rsocket.routing.request_router import RequestRouter 11 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 12 | from rsocket.rsocket_server import RSocketServer 13 | from rsocket.transports.aiohttp_websocket import TransportAioHttpWebsocket 14 | from rsocket.transports.tcp import TransportTCP 15 | 16 | router = RequestRouter() 17 | 18 | 19 | @router.response('echo') 20 | async def echo(payload: Payload) -> Observable: 21 | return reactivex.just(Payload(payload.data)) 22 | 23 | 24 | def websocket_handler_factory(**kwargs): 25 | async def websocket_handler(request): 26 | ws = web.WebSocketResponse() 27 | await ws.prepare(request) 28 | transport = TransportAioHttpWebsocket(ws) 29 | RSocketServer(transport, **kwargs) 30 | await transport.handle_incoming_ws_messages() 31 | return ws 32 | 33 | return websocket_handler 34 | 35 | 36 | @click.command() 37 | @click.option('--port', help='Port to listen on', default=6565, type=int) 38 | @click.option('--transport', is_flag=False, default='tcp') 39 | async def start_server(port: int, transport: str): 40 | logging.basicConfig(level=logging.DEBUG) 41 | 42 | logging.info(f'Starting {transport} server at localhost:{port}') 43 | 44 | if transport in ['ws']: 45 | app = web.Application() 46 | app.add_routes([web.get('/', websocket_handler_factory( 47 | handler_factory=lambda: RoutingRequestHandler(router) 48 | ))]) 49 | 50 | await web._run_app(app, port=port) 51 | elif transport == 'tcp': 52 | def handle_client(reader, writer): 53 | RSocketServer(TransportTCP(reader, writer), 54 | handler_factory=lambda: RoutingRequestHandler(router)) 55 | 56 | server = await asyncio.start_server(handle_client, 'localhost', port) 57 | 58 | async with server: 59 | await server.serve_forever() 60 | else: 61 | raise Exception(f'Unsupported transport {transport}') 62 | 63 | 64 | if __name__ == '__main__': 65 | start_server() 66 | -------------------------------------------------------------------------------- /tests/rsocket/test_metadata_push.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket 5 | from rsocket.payload import Payload 6 | from rsocket.request_handler import BaseRequestHandler 7 | from rsocket.rsocket_server import RSocketServer 8 | from tests.rsocket.helpers import get_components 9 | 10 | 11 | class MetadataPushHandler(BaseRequestHandler): 12 | def __init__(self): 13 | self.received = asyncio.Event() 14 | self.received_payload: Optional[Payload] = None 15 | 16 | async def on_metadata_push(self, payload: Payload): 17 | self.received_payload = payload 18 | self.received.set() 19 | 20 | 21 | async def test_metadata_push(pipe): 22 | handler: Optional[MetadataPushHandler] = None 23 | 24 | def handler_factory(): 25 | nonlocal handler 26 | handler = MetadataPushHandler() 27 | return handler 28 | 29 | server, client = get_components(pipe) 30 | server.set_handler_using_factory(handler_factory) 31 | 32 | await client.metadata_push(b'cat') 33 | 34 | await handler.received.wait() 35 | 36 | assert handler.received_payload.data is None 37 | assert handler.received_payload.metadata == b'cat' 38 | 39 | 40 | async def test_metadata_push_await(pipe): 41 | handler: Optional[MetadataPushHandler] = None 42 | 43 | def handler_factory(): 44 | nonlocal handler 45 | handler = MetadataPushHandler() 46 | return handler 47 | 48 | server, client = get_components(pipe) 49 | server.set_handler_using_factory(handler_factory) 50 | 51 | await client.metadata_push(b'cat') 52 | 53 | await handler.received.wait() 54 | 55 | 56 | async def test_metadata_push_awaitable_client(pipe): 57 | handler: Optional[MetadataPushHandler] = None 58 | 59 | def handler_factory(): 60 | nonlocal handler 61 | handler = MetadataPushHandler() 62 | return handler 63 | 64 | server: RSocketServer = pipe[0] 65 | client = AwaitableRSocket(pipe[1]) 66 | server.set_handler_using_factory(handler_factory) 67 | 68 | await client.metadata_push(b'cat') 69 | 70 | await handler.received.wait() 71 | 72 | assert handler.received_payload.data is None 73 | assert handler.received_payload.metadata == b'cat' 74 | -------------------------------------------------------------------------------- /examples/bugs/streaming_file/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import logging 4 | from typing import Callable 5 | 6 | import aiofiles 7 | from reactivex import Subject, Observable, operators 8 | 9 | from rsocket.payload import Payload 10 | from rsocket.reactivex.back_pressure_publisher import from_observable_with_backpressure, observable_from_queue 11 | from rsocket.reactivex.reactivex_handler_adapter import reactivex_handler_factory 12 | from rsocket.routing.request_router import RequestRouter 13 | from rsocket.routing.routing_request_handler import RoutingRequestHandler 14 | from rsocket.rsocket_server import RSocketServer 15 | from rsocket.transports.tcp import TransportTCP 16 | 17 | 18 | async def read_file(o: asyncio.Queue): 19 | async with aiofiles.open("my_file", 'rb') as f: 20 | while True: 21 | buffer = await f.read(1024) 22 | logging.info("Reading buffer") 23 | if buffer == b'': 24 | o.put_nowait(None) 25 | break 26 | await o.put(base64.b64encode(buffer)) 27 | 28 | 29 | def handler_factory() -> RoutingRequestHandler: 30 | router = RequestRouter() 31 | 32 | @router.stream('audio') 33 | async def audio() -> Callable[[Subject], Observable]: 34 | q = asyncio.Queue() 35 | # await q.put(b'sf') # testing whether a simpler queuing scheme fixes the issue 36 | asyncio.create_task(read_file(q)) 37 | return from_observable_with_backpressure( 38 | lambda backpressure: observable_from_queue( 39 | q, backpressure=backpressure).pipe( 40 | operators.map(lambda buffer: Payload(buffer)) 41 | ) 42 | ) 43 | 44 | return RoutingRequestHandler(router) 45 | 46 | 47 | async def run_server(): 48 | def session(*connection): 49 | RSocketServer(TransportTCP(*connection), 50 | handler_factory=reactivex_handler_factory(handler_factory), 51 | fragment_size_bytes=1_000_000) 52 | 53 | async with await asyncio.start_server(session, 'localhost', 8000) as server: 54 | await server.serve_forever() 55 | 56 | 57 | def main(): 58 | logging.basicConfig(level=logging.DEBUG) 59 | logging.info("Starting server...") 60 | asyncio.run(run_server()) 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /rsocket/awaitable/awaitable_rsocket.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import List, Optional 3 | 4 | from reactivestreams.publisher import Publisher 5 | from rsocket.awaitable.collector_subscriber import CollectorSubscriber 6 | from rsocket.frame import MAX_REQUEST_N 7 | from rsocket.local_typing import Awaitable 8 | from rsocket.payload import Payload 9 | from rsocket.rsocket import RSocket 10 | 11 | 12 | class AwaitableRSocket: 13 | 14 | def __init__(self, rsocket: RSocket): 15 | self._rsocket = rsocket 16 | 17 | def fire_and_forget(self, payload: Payload) -> Awaitable[None]: 18 | return self._rsocket.fire_and_forget(payload) 19 | 20 | def metadata_push(self, metadata: bytes) -> Awaitable[None]: 21 | return self._rsocket.metadata_push(metadata) 22 | 23 | async def request_response(self, payload: Payload) -> Payload: 24 | return await self._rsocket.request_response(payload) 25 | 26 | async def request_stream(self, 27 | payload: Payload, 28 | limit_rate=MAX_REQUEST_N) -> List[Payload]: 29 | subscriber = CollectorSubscriber(limit_rate) 30 | 31 | self._rsocket.request_stream(payload).initial_request_n(limit_rate).subscribe(subscriber) 32 | 33 | return await subscriber.run() 34 | 35 | async def request_channel(self, 36 | payload: Payload, 37 | publisher: Optional[Publisher] = None, 38 | limit_rate=MAX_REQUEST_N, 39 | sending_done: Optional[asyncio.Event] = None) -> List[Payload]: 40 | subscriber = CollectorSubscriber(limit_rate) 41 | 42 | self._rsocket.request_channel(payload, 43 | publisher=publisher, 44 | sending_done=sending_done).initial_request_n(limit_rate).subscribe(subscriber) 45 | 46 | return await subscriber.run() 47 | 48 | async def __aenter__(self): 49 | await self._rsocket.__aenter__() 50 | return self 51 | 52 | async def __aexit__(self, exc_type, exc_val, exc_tb): 53 | await self._rsocket.__aexit__(exc_type, exc_val, exc_tb) 54 | 55 | async def connect(self): 56 | return await self._rsocket.connect() 57 | 58 | def close(self): 59 | self._rsocket.close() 60 | -------------------------------------------------------------------------------- /tests/tools/fixtures_shared.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import ipaddress 3 | from typing import Optional 4 | 5 | import pytest 6 | from cryptography import x509 7 | from cryptography.hazmat.primitives import hashes 8 | from cryptography.hazmat.primitives.asymmetric.ec import generate_private_key, SECP256R1 9 | 10 | 11 | def dns_name_or_ip_address(name): 12 | try: 13 | ip = ipaddress.ip_address(name) 14 | except ValueError: 15 | return x509.DNSName(name) 16 | else: 17 | return x509.IPAddress(ip) 18 | 19 | 20 | def generate_certificate(*, alternative_names: Optional[list], common_name: str, hash_algorithm, key): 21 | subject = issuer = x509.Name( 22 | [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] 23 | ) 24 | 25 | builder = (x509.CertificateBuilder() 26 | .subject_name(subject) 27 | .issuer_name(issuer) 28 | .public_key(key.public_key()) 29 | .serial_number(x509.random_serial_number()) 30 | .not_valid_before(datetime.datetime.utcnow()) 31 | .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=10)) 32 | ) 33 | 34 | builder = builder.add_extension( 35 | x509.SubjectAlternativeName([ 36 | x509.DNSName(u"localhost") 37 | ]), 38 | critical=False 39 | ) 40 | 41 | if alternative_names: 42 | builder = builder.add_extension( 43 | x509.SubjectAlternativeName( 44 | [dns_name_or_ip_address(name) for name in alternative_names] 45 | ), 46 | critical=False, 47 | ) 48 | cert = builder.sign(key, hash_algorithm) 49 | return cert, key 50 | 51 | 52 | def generate_ec_certificate(common_name: str, alternative_names: Optional[list] = None, curve=None): 53 | if alternative_names is None: 54 | alternative_names = [] 55 | 56 | if curve is None: 57 | curve = SECP256R1() 58 | 59 | key = generate_private_key(curve=curve) 60 | return generate_certificate( 61 | alternative_names=alternative_names, 62 | common_name=common_name, 63 | hash_algorithm=hashes.SHA256(), 64 | key=key, 65 | ) 66 | 67 | 68 | @pytest.fixture(scope="session") 69 | def generate_test_certificates(): 70 | return generate_ec_certificate(common_name='localhost') 71 | -------------------------------------------------------------------------------- /tests/rsocket/test_internal.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from collections import namedtuple 4 | from weakref import WeakKeyDictionary 5 | 6 | import pytest 7 | 8 | from tests.rsocket.helpers import measure_runtime 9 | 10 | 11 | async def test_reader(event_loop: asyncio.AbstractEventLoop): 12 | stream = asyncio.StreamReader(loop=event_loop) 13 | stream.feed_data(b'data') 14 | stream.feed_eof() 15 | data = await stream.read() 16 | assert data == b'data' 17 | 18 | 19 | @pytest.mark.xfail( 20 | reason='This is testing the fixture which should cause the test to fail if there is an error log') 21 | async def test_fail_on_error_log(fail_on_error_log): 22 | logging.error("this should not happen") 23 | 24 | 25 | def test_weak_ref(): 26 | class S(str): 27 | pass 28 | 29 | d = WeakKeyDictionary() 30 | a = S('abc') 31 | d[a] = 1 32 | assert len(d) == 1 33 | 34 | del a 35 | 36 | assert len(d) == 0 37 | 38 | 39 | async def test_range(): 40 | async def loop(ii): 41 | for i in range(100): 42 | await asyncio.sleep(0) 43 | print(ii + str(i)) 44 | 45 | await asyncio.gather(loop('a'), loop('b')) 46 | 47 | 48 | def test_instantiate_named_tuple(): 49 | routing = namedtuple('routing', ['tags', 'content']) 50 | count = 100000 51 | with measure_runtime() as result: 52 | for i in range(count): 53 | r = routing(b'abc', b'abcdefg') 54 | 55 | print(result.time.total_seconds() / count * 1000) 56 | 57 | 58 | def test_instantiate_slotted_class(): 59 | class routing: 60 | __slots__ = ['tags', 'content'] 61 | 62 | def __init__(self, tags, content): 63 | self.tags = tags 64 | self.content = content 65 | 66 | count = 100000 67 | with measure_runtime() as result: 68 | for i in range(count): 69 | r = routing(b'abc', b'abcdefg') 70 | 71 | print(result.time.total_seconds() / count * 1000) 72 | 73 | 74 | def test_instantiate_class(): 75 | class routing: 76 | 77 | def __init__(self, tags, content): 78 | self.tags = tags 79 | self.content = content 80 | 81 | count = 100000 82 | with measure_runtime() as result: 83 | for i in range(count): 84 | r = routing(b'abc', b'abcdefg') 85 | 86 | print(result.time.total_seconds() / count * 1000) 87 | --------------------------------------------------------------------------------