├── .editorconfig ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── pubsub_broadcaster_server_example.py ├── pubsub_client_example.py └── pubsub_server_example.py ├── fastapi_websocket_pubsub ├── __init__.py ├── event_broadcaster.py ├── event_notifier.py ├── exceptions.py ├── logger.py ├── pub_sub_client.py ├── pub_sub_server.py ├── rpc_event_methods.py └── websocket_rpc_event_notifier.py ├── requirements.txt ├── scripts └── publish.sh ├── setup.cfg ├── setup.py └── tests ├── basic_test.py ├── multiprocess_test.py ├── reconnect_test.py ├── requirements.txt ├── server_subscribe_events_test.py ├── server_subscriber_test.py └── server_with_remote_id_test.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root=true 2 | 3 | [*] 4 | charset=utf-8 5 | end_of_line=lf 6 | insert_final_newline=false 7 | indent_style=space 8 | indent_size=2 9 | trim_trailing_whitespace=true 10 | 11 | [*.py] 12 | indent_size=4 -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Tests 5 | 6 | on: 7 | workflow_dispatch: 8 | push: 9 | branches: [ master ] 10 | pull_request: 11 | branches: [ master ] 12 | 13 | jobs: 14 | build: 15 | 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: [3.7, 3.8, 3.9] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | if [ -f ./tests/requirements.txt ]; then pip install -r ./tests/requirements.txt; fi 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest 40 | run: | 41 | pytest 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # editors 132 | .vscode/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) <2020> 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software 4 | and associated documentation files (the "Software"), to deal in the Software without restriction, 5 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 6 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 7 | subject to the following conditions: 8 | The above copyright notice and this permission notice 9 | shall be included in all copies or substantial portions of the Software. 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 11 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 12 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 13 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 14 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 15 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # This is a fork 3 | 4 | This repo is fork, but we didn't go anywhere :) 5 | 6 | To the official repo, go here: [https://github.com/permitio/fastapi_websocket_pubsub](https://github.com/permitio/fastapi_websocket_pubsub) 7 | 8 | 9 | ---- 10 | 11 | 12 |

13 | pubsub 14 |

15 | 16 | # 17 | 18 | # ⚡🗞️ FastAPI Websocket Pub/Sub 19 | 20 | 21 | Tests 22 | 23 | 24 | 25 | Package 26 | 27 | 28 | Downloads 29 | 30 | 31 | 32 | A fast and durable Pub/Sub channel over Websockets. 33 | The easiest way to create a live publish / subscribe multi-cast over the web. 34 | 35 | Supports and tested on Python >= 3.7 36 | 37 | As seen at PyCon IL 2021 and EuroPython 2021 38 | 39 | 40 | ## Installation 🛠️ 41 | ``` 42 | pip install fastapi_websocket_pubsub 43 | ``` 44 | 45 | 46 | ## Intro 47 | The classic pub/sub pattern made easily accessible and scalable over the web and across your cloud in realtime; while enjoying the benefits of FastAPI (e.g. dependency injection). 48 | 49 | FastAPI + WebSockets + PubSub == ⚡💪 ❤️ 50 | 51 | 52 | - Subscribe 53 | - Clients subscribe to topics (arbitrary strings) and receive relevant events along with structured data (serialized with Pydantic). 54 | ```python 55 | # Callback to be called upon event being published on server 56 | async def on_event(data): 57 | print("We got an event! with data- ", data) 58 | # Subscribe for the event 59 | client.subscribe("my event", on_event) 60 | ``` 61 | 62 | - Publish 63 | - Directly from server code to connected clients. 64 | ```python 65 | app = FastAPI() 66 | endpoint = PubSubEndpoint() 67 | endpoint.register_route(app, "/pubsub") 68 | endpoint.publish(["my_event_topic"], data=["my", "data", 1]) 69 | ``` 70 | - From client to client (through the servers) 71 | ```python 72 | async with PubSubClient(server_uri="ws://localhost/pubsub") as client: 73 | endpoint.publish(["my_event_topic"], data=["my", "data", 1]) 74 | ``` 75 | - Across server instances (using [broadcaster](https://pypi.org/project/broadcaster/) and a backend medium (e.g. Redis, Kafka, ...)) 76 | - No matter which server a client connects to - it will get the messages it subscribes to 77 | ```python 78 | app = FastAPI() 79 | endpoint = PubSubEndpoint(broadcaster="postgres://localhost:5432/") 80 | 81 | @app.websocket("/pubsub") 82 | async def websocket_rpc_endpoint(websocket: WebSocket): 83 | async with endpoint.broadcaster: 84 | await endpoint.main_loop(websocket) 85 | ``` 86 | see [examples/pubsub_broadcaster_server_example.py](examples/pubsub_broadcaster_server_example.py) for full usage example 87 | 88 | 89 | 90 | ## Usage example (server publishing following HTTP trigger): 91 | In the code below, a client connects to the server and subscribes to a topic named "triggered". 92 | Aside from PubSub websocket, the server also exposes a regular http route, which triggers publication of the event. 93 | 94 | ### Server: 95 | ```python 96 | import asyncio 97 | import uvicorn 98 | from fastapi import FastAPI 99 | from fastapi.routing import APIRouter 100 | 101 | from fastapi_websocket_pubsub import PubSubEndpoint 102 | app = FastAPI() 103 | # Init endpoint 104 | endpoint = PubSubEndpoint() 105 | # register the endpoint on the app 106 | endpoint.register_route(app, "/pubsub") 107 | # Register a regular HTTP route 108 | @app.get("/trigger") 109 | async def trigger_events(): 110 | # Upon request trigger an event 111 | endpoint.publish(["triggered"]) 112 | ``` 113 | ### Client: 114 | ```python 115 | from fastapi_websocket_pubsub import PubSubClient 116 | # Callback to be called upon event being published on server 117 | async def on_trigger(data): 118 | print("Trigger URL was accessed") 119 | 120 | async with PubSubClient(server_uri="ws://localhost/pubsub") as client: 121 | # Subscribe for the event 122 | client.subscribe("triggered", on_event) 123 | 124 | ``` 125 | 126 | ## More Examples 127 | - See the [examples](/examples) and [tests](/tests) folders for more server and client examples. 128 | - See [fastapi-websocket-rpc depends example](https://github.com/authorizon/fastapi_websocket_rpc/blob/master/tests/fast_api_depends_test.py) to see how to combine with FASTAPI dependency injections 129 | 130 | ## What can I do with this? 131 | The combination of Websockets, and bi-directional Pub/Sub is ideal to create realtime data propagation solution at scale over the web. 132 | - Update mechanism 133 | - Remote control mechanism 134 | - Data processing 135 | - Distributed computing 136 | - Realtime communications over the web 137 | 138 | 139 | ## Foundations: 140 | 141 | - Based on [fastapi-websocket-rpc](https://github.com/authorizon/fastapi_websocket_rpc) for a robust realtime bidirectional channel 142 | 143 | - Based on [broadcaster](https://pypi.org/project/broadcaster/) for syncing server instances 144 | 145 | - Server Endpoint: 146 | 147 | - Based on [FastAPI](https://github.com/tiangolo/fastapi): enjoy all the benefits of a full ASGI platform, including Async-io and dependency injections (for example to authenticate connections) 148 | 149 | - Based on [Pydantic](https://pydantic-docs.helpmanual.io/): easily serialize structured data as part of RPC requests and responses. Simply Pass Pydantic data models as PubSub published data to have it available as part of an event. 150 | 151 | - Client : 152 | - Based on [Tenacity](https://tenacity.readthedocs.io/en/latest/index.html): allowing configurable retries to keep to connection alive 153 | - see WebSocketRpcClient.__init__'s retry_config 154 | 155 | - Based on python [websockets](https://websockets.readthedocs.io/en/stable/intro.html) - a more comprehensive client than the one offered by FastAPI 156 | 157 | ## Logging 158 | fastapi-websocket-pubsub uses fastapi-websocket-rpc for logging config. 159 | It provides a helper logging module to control how it produces logs for you. 160 | See [fastapi_websocket_rpc/logger.py](fastapi_websocket_rpc/logger.py). 161 | Use ```logging_config.set_mode``` or the 'WS_RPC_LOGGING' environment variable to choose the logging method you prefer. 162 | Or override completely via default logging config (e.g. 'logging.config.dictConfig'), all logger name start with: 'fastapi.ws_rpc.pubsub' 163 | 164 | example: 165 | ```python 166 | # set RPC to log like UVICORN 167 | from fastapi_websocket_rpc.logger import logging_config, LoggingModes 168 | logging_config.set_mode(LoggingModes.UVICORN) 169 | ``` 170 | 171 | ## Pull requests - welcome! 172 | - Please include tests for new features 173 | 174 | 175 | -------------------------------------------------------------------------------- /examples/pubsub_broadcaster_server_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiple Servers linked via broadcaster example. 3 | 4 | To run this example. 5 | - 0. Setup a broadcast medium and pass its configuration to the endpoint (e.g. postgres on 'postgres://localhost:5432/' ) 6 | - 1. run this script for the servers (as many instances as you'd like) - use the PORT env-variable to run them on different ports 7 | - 2. once the servers are up, run notifier_client_test.py and connect to one of them 8 | - 3. send get request to one server on: '/trigger' 9 | - 4. See that the client recives the event -no matter which server you connected it to, or which server got the initial trigger to publish 10 | """ 11 | import sys 12 | import os 13 | sys.path.append(os.path.abspath(os.path.join(os.path.basename(__file__), ".."))) 14 | 15 | 16 | from fastapi_websocket_pubsub import PubSubEndpoint 17 | import asyncio 18 | import os 19 | from starlette.websockets import WebSocket 20 | import uvicorn 21 | from fastapi import FastAPI 22 | from fastapi.routing import APIRouter 23 | 24 | PORT = int(os.environ.get("PORT") or "8000") 25 | 26 | 27 | app = FastAPI() 28 | router = APIRouter() 29 | endpoint = PubSubEndpoint(broadcaster="postgres://localhost:5432/") 30 | 31 | @router.websocket("/pubsub") 32 | async def websocket_rpc_endpoint(websocket: WebSocket): 33 | async with endpoint.broadcaster: 34 | await endpoint.main_loop(websocket) 35 | 36 | app.include_router(router) 37 | 38 | 39 | async def events(): 40 | await asyncio.sleep(1) 41 | await endpoint.publish(["guns", "germs"]) 42 | await asyncio.sleep(1) 43 | await endpoint.publish(["germs"]) 44 | await asyncio.sleep(1) 45 | await endpoint.publish(["steel"]) 46 | 47 | 48 | @app.get("/trigger") 49 | async def trigger_events(): 50 | asyncio.create_task(events()) 51 | 52 | 53 | uvicorn.run(app, host="0.0.0.0", port=PORT) 54 | -------------------------------------------------------------------------------- /examples/pubsub_client_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | See pubsub_sever_example.py for running instructions 3 | 4 | A very simple client 5 | """ 6 | import asyncio 7 | import logging 8 | import os 9 | import sys 10 | 11 | sys.path.append(os.path.abspath(os.path.join(os.path.basename(__file__), ".."))) 12 | from fastapi_websocket_pubsub import PubSubClient 13 | 14 | 15 | PORT = int(os.environ.get("PORT") or "8000") 16 | 17 | 18 | async def on_events(data, topic): 19 | print(f"running callback for {topic}!") 20 | 21 | 22 | async def main(): 23 | # Create a client and subscribe to topics 24 | client = PubSubClient(["guns", "germs"], callback=on_events) 25 | 26 | async def on_steel(data, topic): 27 | print("running callback steel!") 28 | print("Got data", data) 29 | asyncio.create_task(client.disconnect()) 30 | 31 | client.subscribe("steel", on_steel) 32 | client.start_client(f"ws://localhost:{PORT}/pubsub") 33 | await client.wait_until_done() 34 | 35 | 36 | asyncio.run(main()) 37 | -------------------------------------------------------------------------------- /examples/pubsub_server_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | To run this test. 3 | - 1. run this script for the server 4 | - 2. once the server is up, run pubsub_client_example.py 5 | - 3. send get request to server on: 'http://localhost:8000/trigger' 6 | """ 7 | 8 | import asyncio 9 | 10 | import uvicorn 11 | from fastapi import FastAPI 12 | from fastapi.routing import APIRouter 13 | 14 | from fastapi_websocket_pubsub import PubSubEndpoint 15 | 16 | app = FastAPI() 17 | router = APIRouter() 18 | endpoint = PubSubEndpoint() 19 | endpoint.register_route(router) 20 | app.include_router(router) 21 | 22 | async def events(): 23 | await asyncio.sleep(1) 24 | # Publish multiple topics (without data) 25 | await endpoint.publish(["guns", "germs"]) 26 | await asyncio.sleep(1) 27 | # Publish single topic (without data) 28 | await endpoint.publish(["germs"]) 29 | await asyncio.sleep(1) 30 | # Publish single topic (with data) 31 | await endpoint.publish(["steel"], data={"author": "Jared Diamond"}) 32 | 33 | @app.get("/trigger") 34 | async def trigger_events(): 35 | asyncio.create_task(events()) 36 | 37 | uvicorn.run(app, host="0.0.0.0", port=8000) 38 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/__init__.py: -------------------------------------------------------------------------------- 1 | from .pub_sub_server import PubSubEndpoint 2 | from .pub_sub_client import PubSubClient 3 | from .event_broadcaster import EventBroadcaster, EventBroadcasterException 4 | from .event_notifier import EventNotifier, SubscriberId , SubscriptionId, Topic, TopicList, Subscription, ALL_TOPICS -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/event_broadcaster.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Union 3 | from pydantic.main import BaseModel 4 | from .event_notifier import EventNotifier, Subscription, TopicList, ALL_TOPICS 5 | from broadcaster import Broadcast 6 | 7 | from .logger import get_logger 8 | from fastapi_websocket_rpc.utils import gen_uid 9 | 10 | 11 | logger = get_logger('EventBroadcaster') 12 | 13 | 14 | # Cross service broadcast consts 15 | NotifierId = str 16 | 17 | 18 | class BroadcastNotification(BaseModel): 19 | notifier_id: NotifierId 20 | topics: TopicList 21 | data: Any 22 | 23 | 24 | class EventBroadcasterException(Exception): 25 | pass 26 | 27 | 28 | class BroadcasterAlreadyStarted(EventBroadcasterException): 29 | pass 30 | 31 | 32 | class EventBroadcasterContextManager: 33 | """ 34 | Manages the context for the EventBroadcaster 35 | Friend-like class of EventBroadcaster (accessing "protected" members ) 36 | """ 37 | 38 | def __init__(self, event_broadcaster: "EventBroadcaster", listen: bool = True, share: bool = True) -> None: 39 | """ 40 | Provide a context manager for an EventBroadcaster, managing if it listens to events coming from the broadcaster 41 | and if it subscribes to the internal notifier to share its events with the broadcaster 42 | 43 | Args: 44 | event_broadcaster (EventBroadcaster): the broadcaster we manage the context for. 45 | share (bool, optional): Should we share events with the broadcaster. Defaults to True. 46 | listen (bool, optional): Should we listen for incoming events from the broadcaster. Defaults to True. 47 | """ 48 | self._event_broadcaster = event_broadcaster 49 | self._share: bool = share 50 | self._listen: bool = listen 51 | self._lock = asyncio.Lock() 52 | 53 | async def __aenter__(self): 54 | async with self._lock: 55 | if self._listen: 56 | self._event_broadcaster._listen_count += 1 57 | if self._event_broadcaster._listen_count == 1: 58 | # We have our first listener start the read-task for it (And all those who'd follow) 59 | logger.info("Listening for incoming events from broadcast channel (first listener started)") 60 | # Start task listening on incoming broadcasts 61 | self._event_broadcaster.start_reader_task() 62 | 63 | if self._share: 64 | self._event_broadcaster._share_count += 1 65 | if self._event_broadcaster._share_count == 1: 66 | # We have our first publisher 67 | # Init the broadcast used for sharing (reading has its own) 68 | self._event_broadcaster._acquire_sharing_broadcast_channel() 69 | logger.debug("Subscribing to ALL_TOPICS, and sharing messages with broadcast channel") 70 | # Subscribe to internal events form our own event notifier and broadcast them 71 | await self._event_broadcaster._subscribe_to_all_topics() 72 | else: 73 | logger.debug(f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}") 74 | return self 75 | 76 | async def __aexit__(self, exc_type, exc, tb): 77 | async with self._lock: 78 | try: 79 | if self._listen: 80 | self._event_broadcaster._listen_count -= 1 81 | # if this was last listener - we can stop the reading task 82 | if self._event_broadcaster._listen_count == 0: 83 | # Cancel task reading broadcast subscriptions 84 | if self._event_broadcaster._subscription_task is not None: 85 | logger.info("Cancelling broadcast listen task") 86 | self._event_broadcaster._subscription_task.cancel() 87 | self._event_broadcaster._subscription_task = None 88 | 89 | if self._share: 90 | self._event_broadcaster._share_count -= 1 91 | # if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore 92 | if self._event_broadcaster._share_count == 0: 93 | # Unsubscribe from internal events 94 | logger.debug("Unsubscribing from ALL TOPICS") 95 | await self._event_broadcaster._unsubscribe_from_topics() 96 | 97 | except: 98 | logger.exception("Failed to exit EventBroadcaster context") 99 | 100 | 101 | class EventBroadcaster: 102 | """ 103 | Bridge EventNotifier to work across processes and machines by sharing their events through a broadcasting channel 104 | 105 | Usage: 106 | uri = "postgres://localhost:5432/db_name" #postgres example (also supports REDIS, Kafka, ...) 107 | # start litsening for broadcast publications notifying the internal event-notifier, and subscribing to the internal notifier, broadcasting its notes 108 | broadcaster = EventBroadcaster(uri, notifier): 109 | async with broadcaster.get_context(): 110 | 111 | """ 112 | 113 | def __init__(self, broadcast_url: str, notifier: EventNotifier, channel="EventNotifier", 114 | broadcast_type=None, is_publish_only=False) -> None: 115 | """ 116 | 117 | Args: 118 | broadcast_url (str): the URL of the broadcasting service 119 | notifier (EventNotifier): the event notifier managing our internal events - which will be bridge via the broadcaster 120 | channel (str, optional): Channel name. Defaults to "EventNotifier". 121 | broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast. 122 | is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False 123 | """ 124 | # Broadcast init params 125 | self._broadcast_url = broadcast_url 126 | self._broadcast_type = broadcast_type or Broadcast 127 | # Publish broadcast (initialized within async with statement) 128 | self._sharing_broadcast_channel = None 129 | # channel to operate on 130 | self._channel = channel 131 | # Async-io task for reading broadcasts (initialized within async with statement) 132 | self._subscription_task = None 133 | # Uniqueue instance id (used to avoid reading own notifications sent in broadcast) 134 | self._id = gen_uid() 135 | # The internal events notifier 136 | self._notifier = notifier 137 | self._is_publish_only = is_publish_only 138 | self._publish_lock = None 139 | # used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share) 140 | self._listen_count: int = 0 141 | self._share_count: int = 0 142 | # If we opt to manage the context directly (i.e. call async with on the event broadcaster itself) 143 | self._context_manager = None 144 | 145 | 146 | async def __broadcast_notifications__(self, subscription: Subscription, data): 147 | """ 148 | Share incoming internal notifications with the entire broadcast channel 149 | 150 | Args: 151 | subscription (Subscription): the subscription that got triggered 152 | data: the event data 153 | """ 154 | logger.info("Broadcasting incoming event: {}".format({'topic': subscription.topic, 'notifier_id': self._id})) 155 | note = BroadcastNotification(notifier_id=self._id, topics=[ 156 | subscription.topic], data=data) 157 | # Publish event to broadcast 158 | async with self._publish_lock: 159 | async with self._sharing_broadcast_channel: 160 | await self._sharing_broadcast_channel.publish(self._channel, note.json()) 161 | 162 | def _acquire_sharing_broadcast_channel(self): 163 | """ 164 | Initialize the elements needed for sharing events with the broadcast channel 165 | """ 166 | self._publish_lock = asyncio.Lock() 167 | self._sharing_broadcast_channel = self._broadcast_type(self._broadcast_url) 168 | 169 | async def _subscribe_to_all_topics(self): 170 | return await self._notifier.subscribe(self._id, 171 | ALL_TOPICS, 172 | self.__broadcast_notifications__) 173 | 174 | async def _unsubscribe_from_topics(self): 175 | return await self._notifier.unsubscribe(self._id) 176 | 177 | def get_context(self, listen=True, share=True): 178 | """ 179 | Create a new context manager you can call 'async with' on, configuring the broadcaster for listening, sharing, or both. 180 | 181 | Args: 182 | listen (bool, optional): Should we listen for events incoming from the broadcast channel. Defaults to True. 183 | share (bool, optional): Should we share events with the broadcast channel. Defaults to True. 184 | 185 | Returns: 186 | EventBroadcasterContextManager: the context 187 | """ 188 | return EventBroadcasterContextManager(self, listen=listen, share=share) 189 | 190 | def get_listening_context(self): 191 | return EventBroadcasterContextManager(self, listen=True, share=False) 192 | 193 | def get_sharing_context(self): 194 | return EventBroadcasterContextManager(self, listen=False, share=True) 195 | 196 | async def __aenter__(self): 197 | """ 198 | Convince caller (also backward compaltability) 199 | """ 200 | if self._context_manager is None: 201 | self._context_manager = self.get_context(listen=not self._is_publish_only) 202 | return await self._context_manager.__aenter__() 203 | 204 | 205 | async def __aexit__(self, exc_type, exc, tb): 206 | await self._context_manager.__aexit__(exc_type, exc, tb) 207 | 208 | def start_reader_task(self): 209 | """Spawn a task reading incoming broadcasts and posting them to the intreal notifier 210 | Raises: 211 | BroadcasterAlreadyStarted: if called more than once per context 212 | Returns: 213 | the spawned task 214 | """ 215 | # Make sure a task wasn't started already 216 | if self._subscription_task is not None: 217 | # we already started a task for this worker process 218 | logger.debug("No need for listen task, already started broadcast listen task for this notifier") 219 | return 220 | # Trigger the task 221 | logger.debug("Spawning broadcast listen task") 222 | self._subscription_task = asyncio.create_task( 223 | self.__read_notifications__()) 224 | return self._subscription_task 225 | 226 | async def __read_notifications__(self): 227 | """ 228 | read incoming broadcasts and posting them to the intreal notifier 229 | """ 230 | logger.info("Starting broadcaster listener") 231 | # Init new broadcast channel for reading 232 | listening_broadcast_channel = self._broadcast_type(self._broadcast_url) 233 | async with listening_broadcast_channel: 234 | # Subscribe to our channel 235 | async with listening_broadcast_channel.subscribe(channel=self._channel) as subscriber: 236 | async for event in subscriber: 237 | try: 238 | notification = BroadcastNotification.parse_raw( 239 | event.message) 240 | # Avoid re-publishing our own broadcasts 241 | if notification.notifier_id != self._id: 242 | logger.info("Handling incoming broadcast event: {}".format({'topics': notification.topics, 'src': notification.notifier_id})) 243 | # Notify subscribers of message received from broadcast 244 | await self._notifier.notify(notification.topics, notification.data, notifier_id=self._id) 245 | except: 246 | logger.exception("Failed handling incoming broadcast") 247 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/event_notifier.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import copy 3 | from typing import Any, Callable, Coroutine, Dict, List, Optional, Union 4 | 5 | from fastapi_websocket_rpc.utils import gen_uid 6 | from pydantic import BaseModel # pylint: disable=no-name-in-module 7 | 8 | from .logger import get_logger 9 | 10 | logger = get_logger('EventNotifier') 11 | 12 | # Magic topic - meaning subscribe to all topics 13 | ALL_TOPICS = "__EventNotifier_ALL_TOPICS__" 14 | 15 | 16 | # Basic Pub/Sub consts 17 | SubscriberId = str 18 | SubscriptionId = str 19 | Topic = str 20 | TopicList = List[Topic] 21 | 22 | 23 | class Subscription(BaseModel): 24 | """ 25 | Data model to be stored per subscription, and sent to each subscriber via the callback 26 | This allows for serializing the data down the line and sending to potential remote subscribers (via the callback), 27 | in which case the callback field itself should be removed first. 28 | """ 29 | id: SubscriptionId 30 | subscriber_id: SubscriberId 31 | topic: Topic 32 | callback: Callable = None 33 | notifier_id: Optional[str] = None 34 | 35 | # Publish event callback signature 36 | def EventCallback(subscription: Subscription, data: Any): 37 | pass 38 | 39 | class EventNotifier: 40 | """ 41 | A Basic Pub/Sub class using callback functions as the 42 | Subscribers subscribe using self.subscribe, choosing topics to subscribe to 43 | and passing a callback that will be called on a publish/notify event (with the topic and data) 44 | 45 | 46 | Usage example: 47 | notifier = EventNotifier() 48 | 49 | #subscriber 50 | notifier.subscribe( notifier.gen_subscriber_id(), ["dinner is served", "breakfast is served"], 51 | lambda topic, data: print(f"{topic}, let's eat. We have: {data}") ) 52 | 53 | #publisher 54 | notifier.notify(["breakfast is served"], "Pancakes!") 55 | """ 56 | 57 | def __init__(self): 58 | # Topics->subscribers->subscription mapping 59 | self._topics: Dict[Topic, Dict[SubscriberId, List[Subscription]]] = {} 60 | # Lock used to sync access to mapped subscriptions 61 | # Initialized JIT to be sure to grab the right asyncio-loop 62 | self._lock: asyncio.Lock = None 63 | # List of events to call when client subscribed 64 | self._on_subscribe_events = [] 65 | # List of events to call when client unsubscribed 66 | self._on_unsubscribe_events = [] 67 | 68 | 69 | def gen_subscriber_id(self): 70 | return gen_uid() 71 | 72 | def gen_subscription_id(self): 73 | return gen_uid() 74 | 75 | def _get_subscribers_lock(self): 76 | """ 77 | Init lock once - on current loop 78 | """ 79 | if self._lock is None: 80 | self._lock = asyncio.Lock() 81 | return self._lock 82 | 83 | 84 | async def subscribe(self, subscriber_id: SubscriberId, topics: Union[TopicList, ALL_TOPICS], callback: EventCallback) -> List[Subscription]: 85 | """ 86 | Subscribe to a set of topics. 87 | Once a notification (i.e. publish) of a topic is received the provided callback function will be called (with topic and data) 88 | 89 | 90 | Args: 91 | subscriber_id (SubscriberID): A UUID identifying the subscriber 92 | topics (TopicList, ALL_TOPICS): A list of topic to subscribe to (Each topic is saved in a separate subscription) 93 | ALL_TOPICS can be passed to subscribe to everything (all current and future topics) 94 | callback (Callable): the callback function to call upon a publish event 95 | """ 96 | new_subscriptions = [] 97 | async with self._get_subscribers_lock(): 98 | if topics == ALL_TOPICS: 99 | topics = [ALL_TOPICS] 100 | for topic in topics: 101 | subscribers = self._topics[topic] = self._topics.get(topic, {}) 102 | subscriptions = subscribers[subscriber_id] = subscribers.get( 103 | subscriber_id, []) 104 | # Create new subscription for each Topic x Subscriber x Callback combo 105 | new_subscription = Subscription(id=self.gen_subscription_id(), 106 | subscriber_id=subscriber_id, 107 | topic=topic, 108 | callback=callback) 109 | subscriptions.append(new_subscription) 110 | new_subscriptions.append(new_subscription) 111 | logger.info(f"New subscription {new_subscription.dict()}") 112 | await EventNotifier.trigger_events(self._on_subscribe_events, subscriber_id, topics) 113 | return new_subscriptions 114 | 115 | async def unsubscribe(self, subscriber_id: SubscriberId, topics: Union[TopicList, None] = None): 116 | """ 117 | Unsubscribe from given topics. 118 | Pass topics=None to unsubscribe the given subscriber from all topics 119 | 120 | Args: 121 | subscriber_id (SubscriberID): A UUID identifying the subscriber 122 | topics (Union[TopicList, None]): Topics to unsubscribe from 123 | """ 124 | async with self._get_subscribers_lock(): 125 | # if no topics are given then unsubscribe from all topics 126 | if topics is None: 127 | topics = list(self._topics.keys()) 128 | for topic in topics: 129 | subscribers = self._topics[topic] 130 | if subscriber_id in subscribers: 131 | logger.info(f"Removing Subscription of topic='{topic}' for subscriber={subscriber_id}") 132 | del subscribers[subscriber_id] 133 | await EventNotifier.trigger_events(self._on_unsubscribe_events, subscriber_id, topics) 134 | 135 | @staticmethod 136 | async def trigger_events(event_callbacks: List[Coroutine], *args): 137 | callbacks_with_params = [] 138 | for callback in event_callbacks: 139 | callbacks_with_params.append(callback(*args)) 140 | await asyncio.gather(*callbacks_with_params) 141 | 142 | 143 | async def trigger_callback(self, data, topic: Topic, subscriber_id: SubscriberId, subscription: Subscription): 144 | await subscription.callback(subscription, data) 145 | 146 | async def callback_subscribers(self, subscribers: Dict[SubscriberId, List[Subscription]], 147 | topic: Topic, 148 | data, notifier_id: SubscriberId = None, override_topic=False): 149 | """ 150 | Trigger callbacks for given subscribers 151 | Args: 152 | subscribers (Dict[SubscriberId,Subscription]): the subscribers to notify of the event 153 | topic (Topic): the topic of the event 154 | data: event data 155 | notifier_id (SubscriberId, optional): id of the event sender. Defaults to None. 156 | override_topic (bool, optional): Should the event/subscription topic be updated to match the given topic. Defaults to False. 157 | """ 158 | for subscriber_id, subscriptions in subscribers.items(): 159 | try: 160 | # Don't notify the notifier 161 | if subscriber_id != notifier_id: 162 | for subscription in subscriptions: 163 | if override_topic: 164 | # Report actual topic instead of ALL_TOPICS (or whatever is saved in the subscription) 165 | event = subscription.copy() 166 | event.topic = topic 167 | original_topic = 'ALL_TOPICS' if (subscription.topic == ALL_TOPICS) else subscription.topic 168 | logger.info(f"calling subscription callbacks: topic={topic} ({original_topic}), subscription_id={subscription.id}, subscriber_id={subscriber_id}") 169 | else: 170 | event = subscription 171 | logger.info(f"calling subscription callbacks: topic={topic}, subscription_id={subscription.id}, subscriber_id={subscriber_id}") 172 | # call callback with subscription-info and provided data 173 | await self.trigger_callback(data, topic, subscriber_id, event) 174 | except: 175 | logger.exception(f"Failed to notify subscriber sub_id={subscriber_id} with topic={topic}") 176 | 177 | 178 | async def notify(self, topics: Union[TopicList, Topic], data=None, notifier_id=None): 179 | """ 180 | Notify subscribers of a new event per topic. (i.e. Publish events) 181 | 182 | Args: 183 | topics (Union[TopicList, Topic]): Topics to trigger a publish event for (Calling the callbacks of all their subscribers) 184 | data ([type], optional): Arbitrary data to pass each callback. Defaults to None. 185 | notifier_id (str): an id of the entity sending the notification, use the same id as subscriber id to avoid getting your own notifications 186 | """ 187 | # allow caller to pass a single topic without a list 188 | if isinstance(topics, Topic): 189 | topics = [topics] 190 | 191 | # get ALL_TOPICS subscribers 192 | subscribers_to_all = self._topics.get(ALL_TOPICS, {}) 193 | 194 | callbacks = [] 195 | # TODO improve with reader/writer lock pattern - so multiple notifications can happen at once 196 | async with self._get_subscribers_lock(): 197 | for topic in topics: 198 | subscribers = self._topics.get(topic, {}) 199 | # handle direct topic subscribers (work on copy to avoid changes after we got the callbacks running) 200 | callbacks.append(self.callback_subscribers(copy.copy(subscribers), topic, data, notifier_id)) 201 | # handle ALL_TOPICS subscribers (work on copy to avoid changes after we got the callbacks running) 202 | # Use actual topic instead of ALL_TOPICS 203 | callbacks.append(self.callback_subscribers(copy.copy(subscribers_to_all), topic, data, notifier_id, override_topic=True)) 204 | # call the subscribers outside of the lock - if they disconnect in the middle of the handling the with statement may fail 205 | # -- (issue with interrupts https://bugs.python.org/issue29988) 206 | await asyncio.gather(*callbacks) 207 | 208 | 209 | def register_subscribe_event(self, callback: Coroutine): 210 | """ 211 | Add a callback function to be triggered when new subscriber joins. 212 | 213 | Args: 214 | callback (Callable): the callback function to call upon a new subscription 215 | """ 216 | self._on_subscribe_events.append(callback) 217 | 218 | def register_unsubscribe_event(self, callback: Coroutine): 219 | """ 220 | Add a callback function to be triggered when a subscriber disconnects. 221 | 222 | Args: 223 | callback (Callable): the callback function to call upon a client unsubscribe 224 | """ 225 | self._on_unsubscribe_events.append(callback) 226 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/exceptions.py: -------------------------------------------------------------------------------- 1 | class PubSubException(Exception): 2 | pass 3 | 4 | 5 | class PubSubClientException(Exception): 6 | pass 7 | 8 | 9 | class PubSubClientInvalidStateException(PubSubClientException): 10 | """ 11 | Raised when an operation is attempted on an PubSubClient which isn't in the right state. 12 | Common examples - trying to publish events before conencting; trying to subscribe after connection is established. 13 | """ 14 | pass 15 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/logger.py: -------------------------------------------------------------------------------- 1 | from fastapi_websocket_rpc.logger import get_logger as rpc_get_logger 2 | 3 | def get_logger(name): 4 | return rpc_get_logger(f"pubsub.{name}") 5 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/pub_sub_client.py: -------------------------------------------------------------------------------- 1 | from types import coroutine 2 | from fastapi_websocket_rpc.rpc_channel import RpcChannel, OnDisconnectCallback, OnConnectCallback 3 | from fastapi_websocket_pubsub.exceptions import PubSubClientInvalidStateException 4 | 5 | import websockets 6 | import asyncio 7 | from typing import Coroutine, Dict, List 8 | 9 | from .logger import get_logger 10 | from .event_notifier import ALL_TOPICS, Topic, TopicList 11 | from fastapi_websocket_rpc import RpcMethodsBase 12 | from fastapi_websocket_rpc import WebSocketRpcClient 13 | from .event_notifier import Topic 14 | from .rpc_event_methods import RpcEventClientMethods 15 | 16 | logger = get_logger('PubSubClient') 17 | 18 | # Callback signature 19 | async def PubSubOnConnectCallback(pubsub_client, channel:RpcChannel): 20 | pass 21 | 22 | class PubSubClient: 23 | """ 24 | pub/sub client (RPC based) 25 | 26 | Usage as subscriber: 27 | Simple usage example (init class with subscription topics): 28 | client = PubSubClient(["guns", "germs", "steel"], callback_coroutine) 29 | client.start_client("ws://localhost:8000/pubsub") 30 | 31 | If you want to register separate callbacks per topic: 32 | client = PubSubClient() 33 | # guns_coroutine will be awaited on when event arrives on "guns" topic 34 | client.subscribe("guns", guns_coroutine) 35 | client.subscribe("germs", germs_coroutine) 36 | 37 | When you are done registering callbacks (once you do, you cannot subscribe to more topics) call: 38 | client.start_client("ws://localhost:8000/pubsub") 39 | 40 | Another more compact option - using async with - 41 | async with PubSubClient(["guns","germs"], both_events_coroutine, server_uri="ws://localhost:8000/pubsub") as client: 42 | # will not end until client.disconnect() is called (by another task / callback) 43 | await client.wait_until_done() 44 | 45 | 46 | Usage as publisher: 47 | client = PubSubClient() 48 | client.start_client("ws://localhost:8000/pubsub") 49 | # Channel must be ready before we can publish on it 50 | await client.wait_until_ready() 51 | await client.publish(["Breakfast Options"], data=["spam", "eggs and spam", {"no spam": "egg bacon spam and sausage"} ]) 52 | 53 | """ 54 | 55 | def __init__(self, topics: List[Topic] = None, 56 | callback=None, 57 | methods_class: RpcMethodsBase = None, 58 | retry_config=None, 59 | keep_alive: float = 0, 60 | on_connect: List[PubSubOnConnectCallback] = None, 61 | on_disconnect: List[OnDisconnectCallback] = None, 62 | server_uri = None, 63 | **kwargs) -> None: 64 | """ 65 | Args: 66 | topics (List[Topic]): topics client should subscribe to, Defaults to None. 67 | methods_class ([RpcMethodsBase], optional): RPC Methods exposed by client. Defaults to RpcEventClientMethods. 68 | retry_config (Dict, optional): Tenacity (https://tenacity.readthedocs.io/) retry kwargs. Defaults to {'wait': wait.wait_random_exponential(max=45)} 69 | retry_config is used both for initial connection failures and reconnects upon connection loss 70 | keep_alive(float): interval in seconds to send a keep-alive ping over the underlying RPC channel, Defaults to 0, which means keep alive is disabled. 71 | on_connect (List[Coroutine]): callbacks on connection being established (each callback is called with the PubSub-client and the rpc-channel) 72 | @note exceptions thrown in on_connect callbacks propagate to the client and will cause connection restart! 73 | on_disconnect (List[Coroutine]): callbacks on connection termination (each callback is called with the rpc-channel) 74 | """ 75 | # Should async with start the client and connect to the server, and on which address 76 | self._server_uri = server_uri 77 | # init our methods with access to the client object (i.e. self) so they can trigger our callbacks 78 | self._methods = methods_class(self) if methods_class is not None else RpcEventClientMethods(self) 79 | # Subscription topics 80 | self._topics = set() 81 | # Subscription callbacks 82 | self._callbacks:Dict[Topic,List[Coroutine]] = {} 83 | self._connect_kwargs = kwargs 84 | # Tenacity retry configuration 85 | self._retry_config = retry_config 86 | # Keep alive config 87 | self._keep_alive = keep_alive 88 | # Core event handlers 89 | self._on_connect = on_connect 90 | self._on_disconnect = on_disconnect if on_disconnect is not None else [] 91 | # internal asyncio tasks 92 | self._run_task: asyncio.Task = None 93 | # event to force termination 94 | self._disconnect_signal:asyncio.Event = None 95 | # event to indicate the connection is ready for use 96 | self._ready_event:asyncio.Event = None 97 | # The RpcChannel initialized - used to access the client from other asyncio tasks 98 | self._rpc_channel = None 99 | # register given topics (if we got any) 100 | if isinstance(topics,list): 101 | for topic in topics: 102 | self.subscribe(topic, callback) 103 | 104 | def is_ready(self) -> bool: 105 | if self._ready_event is not None: 106 | return self._ready_event.is_set() 107 | else: 108 | return False 109 | 110 | async def wait_until_done(self): 111 | if self._run_task is not None: 112 | return await self._run_task 113 | 114 | def wait_until_ready(self, create_event=True) -> Coroutine: 115 | """Return a wait coroutine for the internal readiness event 116 | Args: 117 | create_event (bool, optional): Should an event be created if missing 118 | (@Note- avoid creating an event from another loop). Defaults to True. 119 | 120 | Raises: 121 | PubSubClientInvalidStateException: if has no event to wait on 122 | 123 | Returns: 124 | Coroutine: waiting function on internal event 125 | """ 126 | if self._ready_event is None: 127 | if create_event: 128 | self._ready_event = asyncio.Event() 129 | return self._ready_event.wait() 130 | else: 131 | raise PubSubClientInvalidStateException("Cannot wait on readiness prior to client starting or without creating an event") 132 | else: 133 | return self._ready_event.wait() 134 | 135 | def _init_event_objects(self): 136 | """ 137 | Init asyncio events. This should be done in the correct event loop (And so better avoided at __init__) 138 | """ 139 | if self._ready_event is None: 140 | self._ready_event = asyncio.Event() 141 | if self._disconnect_signal is None: 142 | self._disconnect_signal = asyncio.Event() 143 | 144 | async def __aenter__(self): 145 | if self._server_uri is not None: 146 | self.start_client(self._server_uri) 147 | return self 148 | 149 | async def __aexit__(self, exc_type, exc, tb): 150 | await self.disconnect() 151 | 152 | async def disconnect(self): 153 | """ 154 | Force the internal client to disconnect, and wait for it to do so 155 | """ 156 | if self._disconnect_signal is None: 157 | self._disconnect_signal = asyncio.Event() 158 | self._disconnect_signal.set() 159 | if self._run_task is not None: 160 | await self._run_task 161 | self._run_task = None 162 | 163 | async def run(self, uri, wait_on_reader=True): 164 | """ 165 | runs the rpc client (async api). 166 | if you want to call from a synchronous program, use start_client(). 167 | """ 168 | # init internal events 169 | self._init_event_objects() 170 | while not self._disconnect_signal.is_set(): 171 | try: 172 | logger.info(f"Trying to connect to Pub/Sub server - {uri}") 173 | client = WebSocketRpcClient(uri, self._methods, 174 | retry_config=self._retry_config, 175 | keep_alive=self._keep_alive, 176 | # Register core event callbacks 177 | on_connect=[self._primary_on_connect], 178 | on_disconnect=self._on_disconnect, 179 | **self._connect_kwargs) 180 | async with client: 181 | try: 182 | logger.info(f"Connected to PubSub server {uri}") 183 | if wait_on_reader: 184 | # Wait on the internal RPC task or until we ar asked to terminate - keeping the client alive meanwhile 185 | # Waiting on the reader tasks allows us to receive exceptions raised on the websocket layer- indicating a need to reset the connection 186 | wait_on_reader_task = client.wait_on_reader() 187 | for task in asyncio.as_completed([wait_on_reader_task, self._disconnect_signal.wait()]): 188 | await task 189 | break 190 | except websockets.exceptions.WebSocketException as err: 191 | logger.info(f"Connection failed with - {err}. -- Trying to reconnect.") 192 | except asyncio.CancelledError: 193 | logger.info(f"Connection was actively canceled -- Won't try to reconnect.") 194 | # clean exit (no retrying) 195 | # better support for keyboard interrupt 196 | return 197 | except: 198 | # log unhandled exceptions (which will be swallowed by the with statement otherwise ) 199 | logger.exception(f"Unknown PubSub error -- Trying to reconnect.") 200 | except websockets.exceptions.InvalidStatusCode as err: 201 | logger.exception(f"Connection failed with an invalid status code - {err.status_code} -- Won't try to reconnect.") 202 | raise 203 | except asyncio.CancelledError: 204 | logger.info(f"Connection attempt was canceled -- Won't try to reconnect.") 205 | except: 206 | logger.exception(f"Unknown PubSub init error -- Trying to reconnect.") 207 | 208 | 209 | async def _primary_on_connect(self, channel: RpcChannel): 210 | # Store current channel for additional use by other methods 211 | self._rpc_channel = channel 212 | # subscribe to all the topics we have registered 213 | await self._subscribe_stored_topics(channel) 214 | self._ready_event.set() 215 | # Now that PubSub is alive trigger sub subscribers 216 | if isinstance(self._on_connect, list): 217 | await asyncio.gather(*(callback(self, channel) for callback in self._on_connect)) 218 | 219 | def subscribe(self, topic: Topic, callback: Coroutine): 220 | """ 221 | Subscribe for events (prior to starting the client) 222 | @see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.subscribe 223 | 224 | Args: 225 | topic (Topic): the identifier of the event topic with wish to be called 226 | upon events being published - can be a simple string e.g. 227 | 'hello' or a complex path 'a/b/c/d' . 228 | Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to subscribe to all topics 229 | callback (Coroutine): the function to call upon relevant event publishing 230 | """ 231 | # TODO: add support for post connection subscriptions 232 | if not self.is_ready(): 233 | self._topics.add(topic) 234 | # init to empty list if no entry 235 | callbacks = self._callbacks[topic] = self._callbacks.get(topic,[]) 236 | # add callback to callbacks list of the topic 237 | callbacks.append(callback) 238 | else: 239 | raise PubSubClientInvalidStateException("Client already connected and subscribed") 240 | 241 | async def publish(self, topics: TopicList, data=None, sync=True, notifier_id=None) -> bool: 242 | """ 243 | Publish an event through the server to subscribers. 244 | @see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.publish 245 | 246 | Args: 247 | topics (TopicList): topics to publish 248 | data (Any, optional): data to pass with the event to the subscribers. Defaults to None. 249 | sync (bool, optional): Should the server finish publishing before returning to us 250 | notifier_id(str,optional): A unique identifier of the source of the event 251 | use a different id from the channel.id or the subscription id to receive own publications 252 | 253 | Raises: 254 | PubSubClientInvalidStateException 255 | 256 | Returns: 257 | bool: was the publish successful 258 | """ 259 | if self.is_ready() and self._rpc_channel is not None: 260 | return await self._rpc_channel.other.publish(topics=topics, data=data, sync=sync, notifier_id=notifier_id) 261 | else: 262 | raise PubSubClientInvalidStateException("Client not connected") 263 | 264 | async def _subscribe_stored_topics(self, channel): 265 | """ 266 | Communicate topics stored at self._topics to the PubSub Server 267 | """ 268 | if self._topics: 269 | await channel.other.subscribe(topics=list(self._topics)) 270 | 271 | async def trigger_topic(self, topic: Topic, data=None): 272 | """ 273 | Called by RpcEventClientMethods.notify (from RPC) to handle the published event 274 | 275 | Args: 276 | topic (Topic) 277 | data ([Any], optional) 278 | """ 279 | try: 280 | # get callbacks for topic 281 | callbacks = self._callbacks.get(topic,[]) 282 | # get callbacks for ALL_TOPICS 283 | callbacks.extend(self._callbacks.get(ALL_TOPICS,[])) 284 | # Gather coroutine futures to wait together 285 | futures = [callback(data=data, topic=topic) for callback in callbacks] 286 | await asyncio.gather(*futures) 287 | except: 288 | logger.exception("Failed to trigger a pub/sub callback", {'data':data, 'topic': topic}) 289 | 290 | def start_client(self, server_uri, loop: asyncio.AbstractEventLoop = None, wait_on_reader=True): 291 | """ 292 | Start the client (spinning out self.run as an asyncio task) 293 | 294 | Args: 295 | server_uri (str): uri to server pubsub-endpoint (e.g. 'http://localhost/pubsub') 296 | loop (asyncio.AbstractEventLoop, optional): event loop to run on. Defaults to asyncio.get_event_loop(). 297 | wait_on_reader (bool, optional): Wait on task reading from server. Defaults to True. 298 | """ 299 | loop = loop or asyncio.get_event_loop() 300 | # If the loop hasn't started yet - take over 301 | if not loop.is_running(): 302 | loop.run_until_complete(self.run(server_uri, wait_on_reader)) 303 | # Otherwise 304 | else: 305 | self._run_task = asyncio.create_task(self.run(server_uri, wait_on_reader)) 306 | 307 | def start_client_async(self, server_uri, loop: asyncio.AbstractEventLoop = None): 308 | """ 309 | Start the client and return once finished subscribing to events 310 | RPC notifications will still be handeled in the background 311 | Useful only in cases the async-loop is created by the client (i.e. start_client doesn't create a new task on an exiting loop) 312 | """ 313 | self.start_client(server_uri, loop, False) 314 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/pub_sub_server.py: -------------------------------------------------------------------------------- 1 | from typing import Coroutine, List, Union 2 | 3 | from fastapi import WebSocket 4 | from fastapi_websocket_rpc import WebsocketRPCEndpoint 5 | from fastapi_websocket_rpc.rpc_channel import RpcChannel 6 | 7 | from .logger import get_logger 8 | from .event_broadcaster import EventBroadcaster 9 | from .event_notifier import ALL_TOPICS, EventCallback, EventNotifier, Subscription, Topic, TopicList 10 | from .rpc_event_methods import RpcEventServerMethods 11 | from .websocket_rpc_event_notifier import WebSocketRpcEventNotifier 12 | 13 | logger = get_logger('PubSubEndpoint') 14 | 15 | class PubSubEndpoint: 16 | """ 17 | RPC pub/sub server endpoint 18 | """ 19 | 20 | def __init__(self, methods_class=None, 21 | notifier:EventNotifier=None, 22 | broadcaster:Union[EventBroadcaster, str]=None, 23 | on_connect:List[Coroutine]=None, 24 | on_disconnect:List[Coroutine]=None, 25 | rpc_channel_get_remote_id: bool=False): 26 | """ 27 | The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications. 28 | Publications (aka event notifications) can come from: 29 | - Code in the same server calling this instance's '.publish()' 30 | - Connected PubSubClients calling their own publish method (and piping into the servers via RPC) 31 | - Other servers linked through a broadcaster channel such as Redis Pub/Sub, Kafka, or postgres listen/notify 32 | (@see EventBroadcaster and of course https://pypi.org/project/broadcaster/) 33 | 34 | Args: 35 | methods_class (optional): a class deriving from RpcEventServerMethods providing a 'subscribe' rpc method 36 | or None if RpcEventServerMethods should be used as is 37 | 38 | notifier (optional): Instance of WebSocketRpcEventNotifier or None to use WebSocketRpcEventNotifier() as is 39 | Handles to internal event pub/sub logic 40 | 41 | broadcaster (optional): Instance of EventBroadcaster, a URL string to init EventBroadcaster, or None to not use 42 | The broadcaster allows several EventRpcEndpoints across multiple processes / services to share incoming notifications 43 | 44 | on_connect (List[Coroutine]): callbacks on connection being established (each callback is called with the channel) 45 | on_disconnect (List[Coroutine]): callbacks on connection termination (each callback is called with the channel) 46 | """ 47 | self.notifier = notifier if notifier is not None else WebSocketRpcEventNotifier() 48 | self.broadcaster = broadcaster if isinstance(broadcaster, EventBroadcaster) or broadcaster is None else EventBroadcaster(broadcaster, self.notifier) 49 | self.methods = methods_class(self.notifier) if methods_class is not None else RpcEventServerMethods(self.notifier) 50 | if on_disconnect is None: 51 | on_disconnect = [] 52 | self.endpoint = WebsocketRPCEndpoint(self.methods, on_disconnect=[self.on_disconnect, *on_disconnect], on_connect=on_connect, rpc_channel_get_remote_id=rpc_channel_get_remote_id) 53 | self._rpc_channel_get_remote_id = rpc_channel_get_remote_id 54 | # server id used to publish events for clients 55 | self._id = self.notifier.gen_subscriber_id() 56 | # Separate if for the server to subscribe to its own events 57 | self._subscriber_id:str = self.notifier.gen_subscriber_id() 58 | 59 | async def subscribe(self, topics: Union[TopicList, ALL_TOPICS], callback: EventCallback) -> List[Subscription]: 60 | return await self.notifier.subscribe(self._subscriber_id, topics, callback) 61 | 62 | async def publish(self, topics: Union[TopicList, Topic], data=None): 63 | """ 64 | Publish events to subscribres of given topics currently connected to the endpoint 65 | 66 | Args: 67 | topics (Union[TopicList, Topic]): topics to publish to relevant subscribers 68 | data (Any, optional): Event data to be passed to each subscriber. Defaults to None. 69 | """ 70 | # if we have a broadcaster make sure we share with it (no matter where this call comes from) 71 | # sharing here means - the broadcaster listens in to the notifier as well 72 | logger.debug(f"Publishing message to topics: {topics}") 73 | if self.broadcaster is not None: 74 | logger.debug(f"Acquiring broadcaster sharing context") 75 | async with self.broadcaster.get_context(listen=False, share=True): 76 | await self.notifier.notify(topics, data, notifier_id=self._id) 77 | # otherwise just notify 78 | else: 79 | await self.notifier.notify(topics, data, notifier_id=self._id) 80 | 81 | # canonical name (backward compatability) 82 | notify = publish 83 | 84 | async def on_disconnect(self, channel: RpcChannel): 85 | if self._rpc_channel_get_remote_id: 86 | channel_other_channel_id = await channel.get_other_channel_id() 87 | if channel_other_channel_id is None: 88 | logger.warning("could not fetch remote channel id, using local channel id to unsubscribe") 89 | subscriber_id = channel.id 90 | else: 91 | subscriber_id = channel_other_channel_id 92 | else: 93 | subscriber_id = channel.id 94 | await self.notifier.unsubscribe(subscriber_id) 95 | 96 | async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs): 97 | await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs) 98 | 99 | def register_route(self, router, path="/pubsub"): 100 | """ 101 | Register websocket routes on the given router 102 | Args: 103 | router: FastAPI router to load route onto 104 | """ 105 | self.endpoint.register_route(router, path) 106 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/rpc_event_methods.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from fastapi_websocket_rpc import RpcMethodsBase 3 | from .event_notifier import EventNotifier, Subscription, TopicList 4 | from .logger import get_logger 5 | 6 | 7 | class RpcEventServerMethods(RpcMethodsBase): 8 | 9 | def __init__(self, event_notifier: EventNotifier, rpc_channel_get_remote_id: bool=False): 10 | super().__init__() 11 | self.event_notifier = event_notifier 12 | self._rpc_channel_get_remote_id = rpc_channel_get_remote_id 13 | self.logger = get_logger('PubSubServer') 14 | 15 | async def subscribe(self, topics: TopicList = []) -> bool: 16 | """ 17 | provided by the server so that the client can subscribe to topics. 18 | when new events are available on a topic, the server will call the 19 | client's `notify` method. 20 | """ 21 | try: 22 | async def callback(subscription: Subscription, data): 23 | # remove the actual function 24 | sub = subscription.copy(exclude={"callback"}) 25 | self.logger.info("Notifying other side: {}".format({"subscriber_id": subscription.subscriber_id, "subscription_id": subscription.id, "topic": subscription.topic}), 26 | {"subscription":subscription, 27 | "data":data, "channel_id": self.channel.id}) 28 | await self.channel.other.notify(subscription=sub, data=data) 29 | 30 | if self._rpc_channel_get_remote_id: 31 | # We'll use the remote channel id as our subscriber id 32 | channel_other_channel_id = await self.channel.get_other_channel_id() 33 | if channel_other_channel_id is None: 34 | self.logger.warning("could not fetch remote channel id, using local channel id to subscribe") 35 | sub_id = self.channel.id 36 | else: 37 | sub_id = channel_other_channel_id 38 | else: 39 | # We'll use our channel id as our subscriber id 40 | sub_id = self.channel.id 41 | await self.event_notifier.subscribe(sub_id, topics, callback) 42 | return True 43 | except Exception as err: 44 | self.logger.exception("Failed to subscribe to RPC events notifier", topics) 45 | return False 46 | 47 | async def publish(self, topics: TopicList = [], data=None, sync=True, notifier_id=None) -> bool: 48 | """ 49 | Publish an event through the server to subscribers 50 | 51 | Args: 52 | topics (TopicList): topics to publish 53 | data (Any, optional): data to pass with the event to the subscribers. Defaults to None. 54 | sync (bool, optional): Should the server finish publishing before returning to us 55 | notifier_id(str,optional): A unique identifier of the source of the event 56 | use a different id from the channel.id or the subscription id to receive own publications 57 | 58 | Returns: 59 | bool: was the publish successful 60 | """ 61 | try: 62 | # use the given id or use our channel id 63 | notifier_id = notifier_id if notifier_id is not None else self.channel.id 64 | promise = self.event_notifier.notify(topics, data, notifier_id=notifier_id) 65 | if sync: 66 | await promise 67 | else: 68 | asyncio.create_task(promise) 69 | return True 70 | except Exception as err: 71 | self.logger.error("Failed to publish to events notifier",topics) 72 | return False 73 | 74 | async def ping(self) -> str: 75 | return "pong" 76 | 77 | 78 | class RpcEventClientMethods(RpcMethodsBase): 79 | 80 | def __init__(self, client): 81 | super().__init__() 82 | self.client = client 83 | self.logger = get_logger('PubSubClient') 84 | 85 | async def notify(self, subscription=None, data=None): 86 | self.logger.info("Received notification of event", 87 | {'subscription':subscription, 'data':data}) 88 | await self.client.trigger_topic(topic=subscription["topic"], data=data) 89 | -------------------------------------------------------------------------------- /fastapi_websocket_pubsub/websocket_rpc_event_notifier.py: -------------------------------------------------------------------------------- 1 | from .event_notifier import EventNotifier 2 | 3 | 4 | class WebSocketRpcEventNotifier(EventNotifier): 5 | pass -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi-websocket-rpc>=0.1.21 2 | broadcaster==0.2.0 3 | -------------------------------------------------------------------------------- /scripts/publish.sh: -------------------------------------------------------------------------------- 1 | python setup.py sdist bdist_wheel 2 | python -m twine upload dist/* -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | def get_requirements(env=""): 4 | if env: 5 | env = "-{}".format(env) 6 | with open("requirements{}.txt".format(env)) as fp: 7 | return [x.strip() for x in fp.read().split("\n") if not x.startswith("#")] 8 | 9 | with open("README.md", "r", encoding="utf-8") as fh: 10 | long_description = fh.read() 11 | 12 | setup( 13 | name='fastapi_websocket_pubsub', 14 | version='0.1.21', 15 | author='Or Weis', 16 | author_email="or@authorizon.com", 17 | description="A fast and durable PubSub channel over Websockets (using fastapi-websockets-rpc).", 18 | long_description_content_type="text/markdown", 19 | long_description=long_description, 20 | url="https://github.com/authorizon/fastapi_websocket_pubsub", 21 | packages=find_packages(), 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers", 27 | "Topic :: Internet :: WWW/HTTP :: WSGI" 28 | ], 29 | python_requires='>=3.7', 30 | install_requires=get_requirements(), 31 | ) 32 | -------------------------------------------------------------------------------- /tests/basic_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from fastapi_websocket_rpc import logger 5 | 6 | # Add parent path to use local src as package for tests 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 8 | 9 | import asyncio 10 | from multiprocessing import Process 11 | import requests 12 | 13 | import pytest 14 | import uvicorn 15 | from fastapi import APIRouter, FastAPI 16 | 17 | 18 | from fastapi_websocket_rpc.logger import get_logger 19 | from fastapi_websocket_rpc.utils import gen_uid 20 | from fastapi_websocket_pubsub import PubSubEndpoint, PubSubClient 21 | from fastapi_websocket_pubsub.event_notifier import ALL_TOPICS 22 | 23 | logger = get_logger("Test") 24 | 25 | # Configurable 26 | PORT = int(os.environ.get("PORT") or "7990") 27 | uri = f"ws://localhost:{PORT}/pubsub" 28 | trigger_url = f"http://localhost:{PORT}/trigger" 29 | 30 | DATA = "MAGIC" 31 | EVENT_TOPIC = "event/has-happened" 32 | 33 | def setup_server_rest_route(app, endpoint: PubSubEndpoint): 34 | 35 | @app.get("/trigger") 36 | async def trigger_events(): 37 | logger.info("Triggered via HTTP route - publishing event") 38 | # Publish an event named 'steel' 39 | # Since we are calling back (RPC) to the client- this would deadlock if we wait on it 40 | asyncio.create_task(endpoint.publish([EVENT_TOPIC], data=DATA)) 41 | return "triggered" 42 | 43 | 44 | def setup_server(): 45 | app = FastAPI() 46 | # PubSub websocket endpoint 47 | endpoint = PubSubEndpoint() 48 | endpoint.register_route(app, "/pubsub") 49 | # Regular REST endpoint - that publishes to PubSub 50 | setup_server_rest_route(app, endpoint) 51 | uvicorn.run(app, port=PORT) 52 | 53 | 54 | @pytest.fixture() 55 | def server(): 56 | # Run the server as a separate process 57 | proc = Process(target=setup_server, args=(), daemon=True) 58 | proc.start() 59 | yield proc 60 | proc.kill() # Cleanup after test 61 | 62 | @pytest.mark.asyncio 63 | async def test_subscribe_http_trigger(server): 64 | # finish trigger 65 | finish = asyncio.Event() 66 | # Create a client and subscribe to topics 67 | async with PubSubClient() as client: 68 | async def on_event(data, topic): 69 | assert data == DATA 70 | finish.set() 71 | # subscribe for the event 72 | client.subscribe(EVENT_TOPIC, on_event) 73 | # start listentining 74 | client.start_client(uri) 75 | # wait for the client to be ready to receive events 76 | await client.wait_until_ready() 77 | # trigger the server via an HTTP route 78 | requests.get(trigger_url) 79 | # wait for finish trigger 80 | await asyncio.wait_for(finish.wait(),5) 81 | 82 | @pytest.mark.asyncio 83 | async def test_pub_sub(server): 84 | # finish trigger 85 | finish = asyncio.Event() 86 | # Create a client and subscribe to topics 87 | async with PubSubClient() as client: 88 | async def on_event(data, topic): 89 | assert data == DATA 90 | finish.set() 91 | # subscribe for the event 92 | client.subscribe(EVENT_TOPIC, on_event) 93 | # start listentining 94 | client.start_client(uri) 95 | # wait for the client to be ready to receive events 96 | await client.wait_until_ready() 97 | # publish events (with sync=False toa void deadlocks waiting on the publish to ourselves) 98 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 99 | assert published.result == True 100 | # wait for finish trigger 101 | await asyncio.wait_for(finish.wait(),5) 102 | 103 | @pytest.mark.asyncio 104 | async def test_pub_sub_with_all_topics(server): 105 | """ 106 | Check client gets event when subscribing via ALL_TOPICS 107 | """ 108 | # finish trigger 109 | finish = asyncio.Event() 110 | # Create a client and subscribe to topics 111 | async with PubSubClient() as client: 112 | async def on_event(data, topic): 113 | assert data == DATA 114 | finish.set() 115 | # subscribe for the event 116 | client.subscribe(ALL_TOPICS, on_event) 117 | # start listentining 118 | client.start_client(uri) 119 | # wait for the client to be ready to receive events 120 | await client.wait_until_ready() 121 | # publish events (with sync=False toa void deadlocks waiting on the publish to ourselves) 122 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 123 | assert published.result == True 124 | # wait for finish trigger 125 | await asyncio.wait_for(finish.wait(),5) 126 | -------------------------------------------------------------------------------- /tests/multiprocess_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pattern: 3 | Publishing-Client -> PubSubServer -> Subscribing->Client 4 | 5 | """ 6 | 7 | import os 8 | import sys 9 | 10 | from fastapi_websocket_rpc import logger 11 | 12 | # Add parent path to use local src as package for tests 13 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 14 | 15 | import asyncio 16 | from multiprocessing import Process, Event as ProcEvent 17 | 18 | import pytest 19 | import uvicorn 20 | from fastapi import APIRouter, FastAPI 21 | 22 | 23 | from fastapi_websocket_rpc.logger import get_logger 24 | from fastapi_websocket_rpc.utils import gen_uid 25 | from fastapi_websocket_pubsub import PubSubEndpoint, PubSubClient 26 | 27 | logger = get_logger("Test") 28 | 29 | # Configurable 30 | PORT = int(os.environ.get("PORT") or "7990") 31 | uri = f"ws://localhost:{PORT}/pubsub" 32 | 33 | 34 | DATA = "MAGIC" 35 | EVENT_TOPIC = "event/has-happened" 36 | 37 | CLIENT_START_SYNC = ProcEvent() 38 | 39 | def setup_server(): 40 | app = FastAPI() 41 | router = APIRouter() 42 | # PubSub websocket endpoint 43 | endpoint = PubSubEndpoint() 44 | endpoint.register_route(router, "/pubsub") 45 | app.include_router(router) 46 | uvicorn.run(app, port=PORT) 47 | 48 | def setup_publishing_client(): 49 | """ 50 | this client will publish an event to the main-test client 51 | """ 52 | async def actual(): 53 | # Wait for other client to wake up before publishing to it 54 | CLIENT_START_SYNC.wait(5) 55 | # Create a client and subscribe to topics 56 | client = PubSubClient() 57 | client.start_client(uri) 58 | # wait for the client to be ready 59 | await client.wait_until_ready() 60 | # publish event 61 | logger.info("Publishing event") 62 | published = await client.publish([EVENT_TOPIC], data=DATA) 63 | assert published.result == True 64 | logger.info("Starting async publishing client") 65 | asyncio.get_event_loop().run_until_complete(actual()) 66 | 67 | 68 | @pytest.fixture(scope="module") 69 | def server(): 70 | # Run the server as a separate process 71 | proc = Process(target=setup_server, args=(), daemon=True) 72 | proc.start() 73 | yield proc 74 | proc.kill() # Cleanup after test 75 | 76 | @pytest.fixture(scope="module") 77 | def pub_client(): 78 | # Run the server as a separate process 79 | proc = Process(target=setup_publishing_client, args=(), daemon=True) 80 | proc.start() 81 | yield proc 82 | proc.kill() # Cleanup after test 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_pub_sub_multi_client(server, pub_client): 87 | # finish trigger 88 | finish = asyncio.Event() 89 | # Create a client and subscribe to topics 90 | async with PubSubClient() as client: 91 | async def on_event(data, topic): 92 | assert data == DATA 93 | assert topic == EVENT_TOPIC 94 | finish.set() 95 | # subscribe for the event 96 | logger.info("Subscribing for events") 97 | client.subscribe(EVENT_TOPIC, on_event) 98 | # start listentining 99 | client.start_client(uri) 100 | await client.wait_until_ready() 101 | # Let the other client know we're ready 102 | logger.info("First client is ready") 103 | CLIENT_START_SYNC.set() 104 | # wait for finish trigger 105 | await asyncio.wait_for(finish.wait(),10) 106 | -------------------------------------------------------------------------------- /tests/reconnect_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from fastapi_websocket_rpc import logger, RpcChannel 6 | from fastapi_websocket_rpc.rpc_channel import RpcChannelClosedException 7 | logger.logging_config.set_mode(logger.LoggingModes.UVICORN) 8 | 9 | # Add parent path to use local src as package for tests 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 11 | 12 | import asyncio 13 | from multiprocessing import Process, Value 14 | 15 | import pytest 16 | import uvicorn 17 | from fastapi import APIRouter, FastAPI 18 | 19 | 20 | 21 | from fastapi_websocket_rpc.logger import get_logger 22 | from fastapi_websocket_rpc.utils import gen_uid 23 | from fastapi_websocket_pubsub import PubSubEndpoint, PubSubClient 24 | 25 | logger = get_logger("Test") 26 | 27 | # Configurable 28 | PORT = int(os.environ.get("PORT") or "7990") 29 | uri = f"ws://localhost:{PORT}/pubsub" 30 | trigger_url = f"http://localhost:{PORT}/trigger" 31 | 32 | DATA = "MAGIC" 33 | EVENT_TOPIC = "event/has-happened" 34 | 35 | 36 | def setup_server(disconnect_delay=0): 37 | app = FastAPI() 38 | # Multiprocess shared value 39 | counter = Value("i", 0) 40 | 41 | async def on_connect(channel:RpcChannel): 42 | if counter.value == 0: 43 | # Immediate death 44 | if disconnect_delay == 0: 45 | logger.info("Disconnect once") 46 | await channel.socket.close() 47 | # Delayed death 48 | else: 49 | async def disconn(): 50 | await asyncio.sleep(disconnect_delay) 51 | logger.info("Disconnect once") 52 | await channel.socket.close() 53 | asyncio.create_task(disconn()) 54 | counter.value = 1 55 | 56 | # PubSub websocket endpoint 57 | endpoint = PubSubEndpoint(on_connect=[on_connect]) 58 | endpoint.register_route(app, "/pubsub") 59 | uvicorn.run(app, port=PORT) 60 | 61 | 62 | @pytest.fixture() 63 | def server(): 64 | # Run the server as a separate process 65 | proc = Process(target=setup_server, args=(), daemon=True) 66 | proc.start() 67 | yield proc 68 | proc.kill() # Cleanup after test 69 | 70 | @pytest.fixture(params=[0.001, 0.01, 0.1, 0.2]) 71 | def delayed_death_server(request): 72 | disconnect_delay = request.param 73 | # Run the server as a separate process 74 | proc = Process(target=setup_server, args=(disconnect_delay,), daemon=True) 75 | proc.start() 76 | yield proc 77 | proc.kill() # Cleanup after test 78 | 79 | 80 | @pytest.mark.asyncio 81 | async def test_immediate_server_disconnect(server): 82 | """ 83 | Test reconnecting when a server hangups on connect 84 | """ 85 | # finish trigger 86 | finish = asyncio.Event() 87 | # Create a client and subscribe to topics 88 | async with PubSubClient() as client: 89 | async def on_event(data, topic): 90 | assert data == DATA 91 | finish.set() 92 | # subscribe for the event 93 | client.subscribe(EVENT_TOPIC, on_event) 94 | # start listentining 95 | client.start_client(uri) 96 | # wait for the client to be ready to receive events 97 | await client.wait_until_ready() 98 | # publish events (with sync=False toa void deadlocks waiting on the publish to ourselves) 99 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 100 | assert published.result == True 101 | # wait for finish trigger 102 | await asyncio.wait_for(finish.wait(),5) 103 | 104 | @pytest.mark.asyncio 105 | async def test_delayed_server_disconnect(delayed_death_server): 106 | """ 107 | Test reconnecting when a server hangups AFTER connect 108 | """ 109 | # finish trigger 110 | finish = asyncio.Event() 111 | 112 | async def on_connect(client, channel): 113 | try: 114 | print ("Connected") 115 | # publish events (with sync=False to avoid deadlocks waiting on the publish to ourselves) 116 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 117 | assert published.result == True 118 | except RpcChannelClosedException: 119 | # expected 120 | pass 121 | 122 | # Create a client and subscribe to topics 123 | async with PubSubClient(on_connect=[on_connect]) as client: 124 | async def on_event(data, topic): 125 | assert data == DATA 126 | finish.set() 127 | # subscribe for the event 128 | client.subscribe(EVENT_TOPIC, on_event) 129 | # start listentining 130 | client.start_client(uri) 131 | # wait for the client to be ready to receive events 132 | await client.wait_until_ready() 133 | # wait for finish trigger 134 | await asyncio.wait_for(finish.wait(),5) 135 | 136 | 137 | @pytest.mark.asyncio 138 | async def test_disconnect_callback(delayed_death_server): 139 | """ 140 | Test reconnecting when a server hangups AFTER connect and that the disconnect callback work 141 | """ 142 | # finish trigger 143 | finish = asyncio.Event() 144 | disconnected = asyncio.Event() 145 | 146 | 147 | async def on_disconnect(channel): 148 | print ("-------- Disconnected") 149 | disconnected.set() 150 | 151 | async def on_connect(client, channel): 152 | try: 153 | print ("Connected") 154 | # publish events (with sync=False to avoid deadlocks waiting on the publish to ourselves) 155 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 156 | assert published.result == True 157 | except RpcChannelClosedException: 158 | # expected 159 | pass 160 | 161 | # Create a client and subscribe to topics 162 | async with PubSubClient(on_disconnect=[on_disconnect], on_connect=[on_connect]) as client: 163 | async def on_event(data, topic): 164 | assert data == DATA 165 | finish.set() 166 | # subscribe for the event 167 | client.subscribe(EVENT_TOPIC, on_event) 168 | # start listentining 169 | client.start_client(uri) 170 | # wait for the client to be ready to receive events 171 | await client.wait_until_ready() 172 | # wait for finish trigger 173 | await asyncio.wait_for(finish.wait(),5) 174 | 175 | await asyncio.wait_for(disconnected.wait(),1) 176 | assert disconnected.is_set() 177 | 178 | @pytest.mark.asyncio 179 | async def test_disconnect_callback_without_context(delayed_death_server): 180 | """ 181 | Test reconnecting when a server hangups AFTER connect and that the disconnect callback work 182 | """ 183 | # finish trigger 184 | finish = asyncio.Event() 185 | disconnected = asyncio.Event() 186 | 187 | async def on_disconnect(channel): 188 | disconnected.set() 189 | 190 | async def on_connect(client, channel): 191 | try: 192 | print ("Connected") 193 | # publish events (with sync=False to avoid deadlocks waiting on the publish to ourselves) 194 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 195 | assert published.result == True 196 | except RpcChannelClosedException: 197 | # expected 198 | pass 199 | 200 | # Create a client and subscribe to topics 201 | client = PubSubClient(on_disconnect=[on_disconnect], on_connect=[on_connect]) 202 | async def on_event(data, topic): 203 | assert data == DATA 204 | finish.set() 205 | # subscribe for the event 206 | client.subscribe(EVENT_TOPIC, on_event) 207 | # start listentining 208 | client.start_client(uri) 209 | # wait for the client to be ready to receive events 210 | await client.wait_until_ready() 211 | # publish events (with sync=False toa void deadlocks waiting on the publish to ourselves) 212 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 213 | assert published.result == True 214 | # wait for finish trigger 215 | await asyncio.wait_for(finish.wait(),5) 216 | await client.disconnect() 217 | await asyncio.wait_for(disconnected.wait(),1) 218 | assert disconnected.is_set() 219 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-asyncio -------------------------------------------------------------------------------- /tests/server_subscribe_events_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from fastapi_websocket_rpc import logger 5 | 6 | from fastapi_websocket_pubsub.event_notifier import EventNotifier 7 | 8 | # Add parent path to use local src as package for tests 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 10 | 11 | import asyncio 12 | from multiprocessing import Process 13 | 14 | import pytest 15 | 16 | from fastapi_websocket_pubsub import Subscription 17 | 18 | TOPIC = "event/has-been-processed" 19 | SUB_ID = "test" 20 | 21 | 22 | @pytest.mark.asyncio 23 | async def test_subscribe_callbacks_unit(): 24 | 25 | notifier = EventNotifier() 26 | 27 | async def event_callback(subscription:Subscription, data): 28 | logger.info(f"Got topic {subscription.topic}") 29 | 30 | subscribed = asyncio.Event() 31 | unsubscribed = asyncio.Event() 32 | async def server_subscribe(*args): 33 | subscribed.set() 34 | 35 | async def server_unsubscribe(*args): 36 | unsubscribed.set() 37 | 38 | notifier.register_subscribe_event(server_subscribe) 39 | notifier.register_unsubscribe_event(server_unsubscribe) 40 | 41 | await notifier.subscribe(SUB_ID, [TOPIC], event_callback) 42 | await asyncio.wait_for(subscribed.wait(),5) 43 | await notifier.unsubscribe(SUB_ID) 44 | await asyncio.wait_for(unsubscribed.wait(),5) 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /tests/server_subscriber_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from fastapi_websocket_rpc import logger 5 | 6 | # Add parent path to use local src as package for tests 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 8 | 9 | import asyncio 10 | from multiprocessing import Process 11 | import requests 12 | 13 | import pytest 14 | import uvicorn 15 | from fastapi import APIRouter, FastAPI 16 | 17 | 18 | from fastapi_websocket_rpc.logger import get_logger 19 | from fastapi_websocket_rpc.utils import gen_uid 20 | from fastapi_websocket_pubsub import PubSubEndpoint, PubSubClient, Subscription, Topic 21 | 22 | 23 | 24 | logger = get_logger("Test") 25 | 26 | # Configurable 27 | PORT = int(os.environ.get("PORT") or "7990") 28 | uri = f"ws://localhost:{PORT}/pubsub" 29 | trigger_url = f"http://localhost:{PORT}/trigger" 30 | 31 | DATA = "MAGIC" 32 | SERVER_TOPIC = "event/has-happened" 33 | CLIENT_TOPIC = "event/has-been-processed" 34 | 35 | def setup_server_rest_route(app, endpoint: PubSubEndpoint): 36 | 37 | @app.get("/trigger") 38 | async def trigger_events(): 39 | logger.info("Triggered via HTTP route - publishing event") 40 | # Publish an event - to the our own server callback / which will trigger another event for the client 41 | # Since we are calling back (RPC) to the client- this would deadlock if we wait on it 42 | asyncio.create_task(endpoint.publish([SERVER_TOPIC], data=DATA)) 43 | return "triggered" 44 | 45 | 46 | def setup_server(): 47 | app = FastAPI() 48 | # PubSub websocket endpoint 49 | endpoint = PubSubEndpoint() 50 | endpoint.register_route(app, "/pubsub") 51 | 52 | # receive an event and publish another (this time for the client) 53 | async def event_callback(subscription:Subscription, data): 54 | logger.info(f"Got topic {subscription.topic} - re-publishing as {CLIENT_TOPIC}") 55 | asyncio.create_task(endpoint.publish([CLIENT_TOPIC], data)) 56 | 57 | @app.on_event("startup") 58 | async def startup(): 59 | # subscribe to our own events 60 | await endpoint.subscribe([SERVER_TOPIC], event_callback) 61 | 62 | # Regular REST endpoint - that publishes to PubSub 63 | setup_server_rest_route(app, endpoint) 64 | uvicorn.run(app, port=PORT) 65 | 66 | 67 | @pytest.fixture() 68 | def server(): 69 | # Run the server as a separate process 70 | proc = Process(target=setup_server, args=(), daemon=True) 71 | proc.start() 72 | yield proc 73 | proc.kill() # Cleanup after test 74 | 75 | @pytest.mark.asyncio 76 | async def test_server_subscribe_http_trigger(server): 77 | # finish trigger 78 | finish = asyncio.Event() 79 | # Create a client and subscribe to topics 80 | async with PubSubClient() as client: 81 | async def on_event(data, topic): 82 | assert data == DATA 83 | finish.set() 84 | # subscribe for the event 85 | client.subscribe(CLIENT_TOPIC, on_event) 86 | # start listentining 87 | client.start_client(uri) 88 | # wait for the client to be ready to receive events 89 | await client.wait_until_ready() 90 | # trigger the server via an HTTP route 91 | requests.get(trigger_url) 92 | # wait for finish trigger 93 | await asyncio.wait_for(finish.wait(),5) 94 | -------------------------------------------------------------------------------- /tests/server_with_remote_id_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from fastapi_websocket_rpc import logger 5 | from fastapi_websocket_rpc.rpc_channel import RpcChannel 6 | 7 | # Add parent path to use local src as package for tests 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 9 | 10 | import asyncio 11 | from multiprocessing import Process 12 | import requests 13 | 14 | import pytest 15 | import uvicorn 16 | from fastapi import APIRouter, FastAPI 17 | 18 | 19 | from fastapi_websocket_rpc.logger import get_logger 20 | from fastapi_websocket_rpc.utils import gen_uid 21 | from fastapi_websocket_pubsub import PubSubEndpoint, PubSubClient, Subscription 22 | from fastapi_websocket_pubsub.event_notifier import ALL_TOPICS 23 | 24 | logger = get_logger("Test") 25 | 26 | # Configurable 27 | PORT = int(os.environ.get("PORT") or "7990") 28 | uri = f"ws://localhost:{PORT}/pubsub" 29 | trigger_url = f"http://localhost:{PORT}/trigger" 30 | ask_remote_id_url = f"http://localhost:{PORT}/ask-remote-id" 31 | 32 | DATA = "MAGIC" 33 | EVENT_TOPIC = "event/has-happened" 34 | 35 | REMOTE_ID_ANSWER_TOPIC = "client/my-remote-id" 36 | 37 | def setup_server_rest_routes(app, endpoint: PubSubEndpoint, remote_id_event: asyncio.Event): 38 | 39 | @app.get("/trigger") 40 | async def trigger_events(): 41 | logger.info("Triggered via HTTP route - publishing event") 42 | # Publish an event named 'steel' 43 | # Since we are calling back (RPC) to the client- this would deadlock if we wait on it 44 | asyncio.create_task(endpoint.publish([EVENT_TOPIC], data=DATA)) 45 | return "triggered" 46 | 47 | @app.get("/ask-remote-id") 48 | async def trigger_events(): 49 | logger.info("Got asked if i have the remote id") 50 | answer = "yes" if remote_id_event.is_set() else "no" 51 | asyncio.create_task(endpoint.publish([REMOTE_ID_ANSWER_TOPIC], {"answer": answer})) 52 | return {"answer": answer} 53 | 54 | 55 | def setup_server(): 56 | app = FastAPI() 57 | remote_id_ok = asyncio.Event() 58 | 59 | async def try_to_get_remote_id(channel: RpcChannel): 60 | logger.info(f"trying to get remote channel id") 61 | channel_other_channel_id = await channel.get_other_channel_id() 62 | logger.info(f"finished getting remote channel id") 63 | if channel_other_channel_id is not None: 64 | remote_id_ok.set() 65 | logger.info(f"remote channel id: {channel_other_channel_id}") 66 | logger.info(f"local channel id: {channel_other_channel_id}") 67 | 68 | async def on_connect(channel: RpcChannel): 69 | logger.info(f"Connected to remote channel") 70 | asyncio.create_task(try_to_get_remote_id(channel)) 71 | 72 | # PubSub websocket endpoint - setting up the server with remote id 73 | endpoint = PubSubEndpoint(rpc_channel_get_remote_id=True, on_connect=[on_connect]) 74 | endpoint.register_route(app, "/pubsub") 75 | 76 | # Regular REST endpoint - that publishes to PubSub 77 | setup_server_rest_routes(app, endpoint, remote_id_ok) 78 | uvicorn.run(app, port=PORT) 79 | 80 | 81 | @pytest.fixture() 82 | def server(): 83 | # Run the server as a separate process 84 | proc = Process(target=setup_server, args=(), daemon=True) 85 | proc.start() 86 | yield proc 87 | proc.kill() # Cleanup after test 88 | 89 | @pytest.mark.asyncio 90 | async def test_subscribe_http_trigger_with_remote_id_on(server): 91 | """ 92 | same as the basic_test::test_subscribe_http_trigger, but this time makes sure that 93 | the rpc_channel_get_remote_id doesn't break anything. 94 | """ 95 | # finish trigger 96 | finish = asyncio.Event() 97 | # Create a client and subscribe to topics 98 | async with PubSubClient() as client: 99 | async def on_event(data, topic): 100 | assert data == DATA 101 | finish.set() 102 | # subscribe for the event 103 | client.subscribe(EVENT_TOPIC, on_event) 104 | # start listentining 105 | client.start_client(uri) 106 | # wait for the client to be ready to receive events 107 | await client.wait_until_ready() 108 | # trigger the server via an HTTP route 109 | requests.get(trigger_url) 110 | # wait for finish trigger 111 | await asyncio.wait_for(finish.wait(),5) 112 | 113 | @pytest.mark.asyncio 114 | async def test_pub_sub_with_remote_id_on(server): 115 | """ 116 | same as the basic_test::test_pubsub, but this time makes sure that 117 | the rpc_channel_get_remote_id doesn't break anything. 118 | """ 119 | # finish trigger 120 | finish = asyncio.Event() 121 | # Create a client and subscribe to topics 122 | async with PubSubClient() as client: 123 | async def on_event(data, topic): 124 | assert data == DATA 125 | finish.set() 126 | # subscribe for the event 127 | client.subscribe(EVENT_TOPIC, on_event) 128 | # start listentining 129 | client.start_client(uri) 130 | # wait for the client to be ready to receive events 131 | await client.wait_until_ready() 132 | # publish events (with sync=False toa void deadlocks waiting on the publish to ourselves) 133 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 134 | assert published.result == True 135 | # wait for finish trigger 136 | await asyncio.wait_for(finish.wait(),5) 137 | 138 | @pytest.mark.asyncio 139 | async def test_pub_sub_with_all_topics_with_remote_id_on(server): 140 | """ 141 | same as the basic_test::test_pub_sub_with_all_topics, but this time makes sure that 142 | the rpc_channel_get_remote_id doesn't break anything. 143 | """ 144 | # finish trigger 145 | finish = asyncio.Event() 146 | # Create a client and subscribe to topics 147 | async with PubSubClient() as client: 148 | async def on_event(data, topic): 149 | assert data == DATA 150 | finish.set() 151 | # subscribe for the event 152 | client.subscribe(ALL_TOPICS, on_event) 153 | # start listentining 154 | client.start_client(uri) 155 | # wait for the client to be ready to receive events 156 | await client.wait_until_ready() 157 | # publish events (with sync=False toa void deadlocks waiting on the publish to ourselves) 158 | published = await client.publish([EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid()) 159 | assert published.result == True 160 | # wait for finish trigger 161 | await asyncio.wait_for(finish.wait(),5) 162 | 163 | 164 | @pytest.mark.asyncio 165 | async def test_getting_remote_id(server): 166 | """ 167 | tests that the server managed to get the client's channel id successfully. 168 | """ 169 | # finish trigger 170 | finish = asyncio.Event() 171 | remote_id_yes = asyncio.Event() 172 | 173 | # Create a client and subscribe to topics 174 | async with PubSubClient() as client: 175 | async def on_event(data, topic): 176 | assert data == DATA 177 | finish.set() 178 | 179 | async def on_answer(data, topic): 180 | assert data.get("answer", None) == "yes" 181 | remote_id_yes.set() 182 | 183 | # subscribe for the event 184 | client.subscribe(EVENT_TOPIC, on_event) 185 | client.subscribe(REMOTE_ID_ANSWER_TOPIC, on_answer) 186 | # start listentining 187 | client.start_client(uri) 188 | # wait for the client to be ready to receive events 189 | await client.wait_until_ready() 190 | # trigger the server via an HTTP route 191 | requests.get(trigger_url) 192 | # wait for finish trigger 193 | await asyncio.wait_for(finish.wait(),5) 194 | # sleep so that the server can finish getting the remote id 195 | await asyncio.sleep(1) 196 | # ask the server if he got the remote id 197 | # will trigger the REMOTE_ID_ANSWER_TOPIC topic and the on_answer() callback 198 | requests.get(ask_remote_id_url) 199 | await asyncio.wait_for(remote_id_yes.wait(),5) 200 | # the client can also try to get it's remote id 201 | # super ugly but it's working: 202 | my_remote_id = await client._rpc_channel._get_other_channel_id() 203 | assert my_remote_id is not None --------------------------------------------------------------------------------