├── chunnel ├── main.py ├── __init__.py ├── transports │ ├── __init__.py │ ├── base.py │ └── websocket.py ├── utils.py ├── messages.py ├── channel.py └── socket.py ├── test ├── __init__.py ├── test_utils.py ├── conftest.py ├── test_against_phoenix.py ├── shared.py ├── test_socket.py └── test_channel.py ├── requirements.txt ├── CHANGELOG.md ├── dev-requirements.txt ├── .travis.yml ├── setup.py ├── LICENSE ├── README.md └── .gitignore /chunnel/main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chunnel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | websockets==3.1 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ### v0.1.0 2 | 3 | - Initial alpha release. Supports joining channels, sending & receiving 4 | messages. 5 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | pytest==2.9.2 3 | pytest-mock==1.1 4 | pytest-asyncio==0.4.1 5 | pytest-timeout==1.0.0 6 | flake8 7 | requests==2.10.0 8 | -------------------------------------------------------------------------------- /chunnel/transports/__init__.py: -------------------------------------------------------------------------------- 1 | from .websocket import WebsocketTransport 2 | from .base import TransportMessage, OutgoingTransportMessage 3 | 4 | 5 | __all__ = [ 6 | 'WebsocketTransport', 'TransportMessage', 'OutgoingTransportMessage' 7 | ] 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 3.5.0 4 | - 3.5.1 5 | - 3.5.2 6 | env: 7 | - TEST_CARD_VER=0.1.1 8 | install: 9 | - pip install -r dev-requirements.txt 10 | - mkdir test_card 11 | - cd test_card 12 | - wget https://github.com/obmarg/test_card/releases/download/v${TEST_CARD_VER}/test_card-${TEST_CARD_VER}-ubuntu.tar.gz -O test_card.tar.gz 13 | - tar xzf test_card.tar.gz 14 | - PORT=4000 bin/test_card start 15 | - sleep 1 16 | - cat log/* 17 | - bin/test_card ping 18 | - cd .. 19 | script: TEST_CARD_URL=localhost:4000 py.test test --timeout=5 20 | after_script: cd test_card && bin/test_card stop 21 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from chunnel.utils import get_unless_done, DONE 6 | 7 | 8 | @pytest.mark.asyncio 9 | async def test_get_unless_done_getting(): 10 | queue = asyncio.Queue() 11 | await queue.put(1) 12 | await queue.put(2) 13 | future = asyncio.Future() 14 | assert await get_unless_done(queue.get(), future) == 1 15 | assert await get_unless_done(queue.get(), future) == 2 16 | 17 | 18 | @pytest.mark.asyncio 19 | async def test_get_unless_done_when_done(): 20 | queue = asyncio.Queue() 21 | future = asyncio.Future() 22 | future.set_result(True) 23 | assert await get_unless_done(queue.get(), future) is DONE 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_get_unless_done_when_done_and_queue(): 28 | queue = asyncio.Queue() 29 | future = asyncio.Future() 30 | future.set_result(True) 31 | await queue.put(1) 32 | assert await get_unless_done(queue.get(), future) is DONE 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | REQUIREMENTS = ['websockets'] 6 | LONG_DESCRIPTION = ''' 7 | Chunnel 8 | ----- 9 | 10 | A python client for phoenix channels. 11 | 12 | See the README at https://github.com/obmarg/chunnel for more details. 13 | 14 | ''' 15 | 16 | setup( 17 | name='chunnel', 18 | version='0.1.0', 19 | url='https://github.com/obmarg/chunnel', 20 | description='Phoenix channels client library', 21 | long_description=LONG_DESCRIPTION, 22 | author='Graeme Coupar', 23 | author_email='grambo@grambo.me.uk', 24 | packages=find_packages(exclude=['tests']), 25 | install_requires=REQUIREMENTS, 26 | zip_safe=False, 27 | include_package_data=True, 28 | classifiers=[ 29 | 'Development Status :: 3 - Alpha', 30 | 'Intended Audience :: Developers', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3.5' 33 | ] 34 | ) 35 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import sentinel 2 | import asyncio 3 | import logging 4 | 5 | import pytest 6 | 7 | from chunnel.socket import Socket 8 | 9 | from .shared import TestTransport 10 | 11 | 12 | @pytest.fixture(autouse=True) 13 | def setup_logging(): 14 | logging.basicConfig(level=logging.DEBUG) 15 | 16 | 17 | @pytest.yield_fixture 18 | def event_loop(): 19 | """ 20 | Create an instance of the default event loop for each test case. 21 | 22 | This implementation is mostly a workaround for pytest-asyncio issues #29 & 23 | #30 24 | """ 25 | policy = asyncio.get_event_loop_policy() 26 | res = policy.new_event_loop() 27 | asyncio.set_event_loop(res) 28 | res._close = res.close 29 | res.close = lambda: None 30 | 31 | yield res 32 | 33 | res._close() 34 | 35 | 36 | @pytest.fixture 37 | def socket(mocker, event_loop): 38 | mocker.patch.dict(Socket.TRANSPORTS, {'ws': TestTransport}) 39 | return Socket('ws://localhost', sentinel.connect_params) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Graeme Coupar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /chunnel/utils.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import FIRST_COMPLETED 2 | import asyncio 3 | 4 | __all__ = ['DONE', 'get_unless_done'] 5 | 6 | 7 | class DONE(): 8 | ''' 9 | Singleton that indicates a DONE when returned from get_unless_done. 10 | ''' 11 | pass 12 | 13 | 14 | # TODO: Tests 15 | async def get_unless_done(getter_future_or_coro, done_future): 16 | ''' 17 | Wraps a get operation & a future that indicates when we should stop 18 | getting. 19 | 20 | If the done_future is resolved while waiting for the get, we cancel the get 21 | and return DONE. 22 | 23 | :params getter_task_or_coro: A getter. 24 | :params done_future: A future that indicates we are done. 25 | ''' 26 | getter_future = asyncio.ensure_future(getter_future_or_coro) 27 | done, pending = await asyncio.wait( 28 | (getter_future, done_future), 29 | return_when=FIRST_COMPLETED 30 | ) 31 | if done_future in done: 32 | if not getter_future.done(): 33 | getter, = pending 34 | getter.cancel() 35 | 36 | return DONE 37 | 38 | return await getter_future 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Chunnel 2 | --- 3 | 4 | A python client for [phoenix](http://www.phoenixframework.org/) 5 | [channels](http://www.phoenixframework.org/docs/channels). 6 | 7 | Usage 8 | --- 9 | 10 | ```python 11 | from chunnel import Socket 12 | 13 | socket = Socket('ws://example.com/socket', params={'token': 'blah'}) 14 | async with socket: 15 | channel = socket.channel('room:lobby, {}) 16 | await channel.join() 17 | incoming = await channel.receive() 18 | await incoming.reply({'blah': 'whatever'}) 19 | msg = await channel.push('something', {}) 20 | response = await msg.response() 21 | ``` 22 | 23 | Status 24 | --- 25 | 26 | Chunnel is very much in alpha status right now. It's API can (and probably will) 27 | change, there's many edge cases that are not currently handled, and many TODOs 28 | littered about the code. 29 | 30 | Currently implemented: 31 | 32 | - Joining channels 33 | - Receving messages 34 | - Sending messages 35 | 36 | Not implemented: 37 | 38 | - Documentation 39 | - Incoming channel leave messages 40 | - Connection errors/reconnecting. 41 | - Much other error handling 42 | - Presence 43 | - Probably other things 44 | 45 | Pull requests welcome. 46 | 47 | Thanks 48 | --- 49 | 50 | Some thanks to: 51 | 52 | - `@ChaseGilliam` on Twitter for the name. 53 | - Various others for suggestions. 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | .dir-locals.el 92 | -------------------------------------------------------------------------------- /chunnel/messages.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from enum import Enum 3 | 4 | 5 | class MessageStatus(Enum): 6 | ok = 'ok' 7 | error = 'error' 8 | 9 | 10 | class IncomingMessage: 11 | def __init__(self, transport_message, socket): 12 | self._transport_message = transport_message 13 | self._socket = socket 14 | 15 | @property 16 | def event(self): 17 | return self._transport_message.event 18 | 19 | @property 20 | def payload(self): 21 | return self._transport_message.payload 22 | 23 | async def reply(self, status, response): 24 | await self._socket._send_message( 25 | self._transport_message.topic, 26 | ChannelEvents.reply.value, 27 | {'status': status, 'response': response}, 28 | self._transport_message.ref 29 | ) 30 | 31 | 32 | # TODO: Think about where this belongs... 33 | class ChannelEvents(Enum): 34 | close = "phx_close", 35 | error = "phx_error" 36 | join = "phx_join" 37 | reply = "phx_reply" 38 | leave = "phx_leave" 39 | 40 | 41 | # TODO: PushedMessage? 42 | class SentMessage: 43 | def __init__(self, response_future): 44 | self._response_future = response_future 45 | 46 | async def response(self): 47 | # TODO: Definitely need to think more about timeouts... 48 | # Currently the self._future is wrapped in a timeout, but is that what 49 | # I want? 50 | resp = await self._response_future 51 | return resp 52 | -------------------------------------------------------------------------------- /chunnel/transports/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import namedtuple 3 | 4 | 5 | TransportMessage = namedtuple( 6 | 'TransportMessage', ['event', 'topic', 'payload', 'ref'] 7 | ) 8 | 9 | # TODO: Could maybe call this Push to mirror it's name in phoenix js imp. 10 | OutgoingTransportMessage = namedtuple( 11 | 'OutgoingTransportMessage', ['message', 'sent'] 12 | ) 13 | 14 | 15 | class BaseTransport: 16 | ''' 17 | The base class for a transport. 18 | 19 | Transports are used to implement the sending & receiving of messages in 20 | chunnel. Each transport is constructed with 2 queues - a queue for incoming 21 | messages and a queue for outgoing messages. 22 | 23 | The incoming message queue should contain TransportMessage namedtuples. 24 | 25 | The outgoing message queue should contain OutgoingTransportMessage 26 | namedtuples. 27 | 28 | The transport should read messages from the outgoing queue and send them 29 | onto a phoenix server, and put any incoming messages onto the incoming 30 | queue. 31 | 32 | Transports are not responsible for interpreting the messages in any way, 33 | they just handle the communication. 34 | ''' 35 | def __init__(self, incoming_queue, outgoing_queue): 36 | self.incoming = incoming_queue 37 | self.outgoing = outgoing_queue 38 | self.ready = asyncio.Future() 39 | 40 | async def run(self): 41 | ''' 42 | Connects the transport and runs it's main loop. 43 | 44 | Will resolve `self.ready` when a connection has been made. 45 | ''' 46 | raise NotImplementedError 47 | 48 | async def stop(self): 49 | ''' 50 | Signals to the transport that it should stop. 51 | ''' 52 | raise NotImplementedError 53 | -------------------------------------------------------------------------------- /test/test_against_phoenix.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | import asyncio 3 | import os 4 | 5 | import pytest 6 | import requests 7 | 8 | from chunnel.socket import Socket 9 | 10 | TEST_CARD_URL = os.getenv('TEST_CARD_URL') 11 | SKIP_TESTS = TEST_CARD_URL is None 12 | 13 | 14 | @pytest.fixture 15 | def user_id(): 16 | id_ = str(uuid4()) 17 | response = requests.post( 18 | 'http://{}/api/users'.format(TEST_CARD_URL), 19 | json={"user": {"id": id_, "rooms": ["lobby"]}} 20 | ) 21 | assert response.status_code == 201 22 | return id_ 23 | 24 | 25 | def make_socket(user_id): 26 | return Socket( 27 | 'ws://{}/socket/websocket'.format(TEST_CARD_URL), 28 | {'user_id': user_id} 29 | ) 30 | 31 | 32 | @pytest.fixture 33 | def socket(event_loop, user_id): 34 | return make_socket(user_id) 35 | 36 | 37 | @pytest.fixture 38 | def socket2(event_loop, user_id): 39 | return make_socket(user_id) 40 | 41 | 42 | @pytest.mark.skipif(SKIP_TESTS, reason="TEST_CARD_URL env var not set") 43 | @pytest.mark.asyncio 44 | async def test_join_and_ping(socket): 45 | async with socket: 46 | channel = socket.channel("room:lobby", {'join': 'params'}) 47 | assert await channel.join() == {'join': 'params'} 48 | ping = await channel.push("ping", {"some": "data"}) 49 | assert await ping.response() == {"some": "data"} 50 | 51 | 52 | @pytest.mark.skipif(SKIP_TESTS, reason="TEST_CARD_URL env var not set") 53 | @pytest.mark.asyncio 54 | async def test_join_and_shout(socket, socket2): 55 | async with socket: 56 | async with socket2: 57 | socket1_channel = socket.channel("room:lobby", {}) 58 | socket2_channel = socket2.channel("room:lobby", {}) 59 | await asyncio.gather( 60 | socket1_channel.join(), socket2_channel.join() 61 | ) 62 | shout_payload = {"hello": "is anybody there?"} 63 | await socket1_channel.push("shout", shout_payload) 64 | 65 | incoming = await socket1_channel.receive() 66 | assert incoming.payload == shout_payload 67 | 68 | incoming = await socket2_channel.receive() 69 | assert incoming.payload == shout_payload 70 | -------------------------------------------------------------------------------- /test/shared.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from chunnel.messages import IncomingMessage 4 | from chunnel.transports.base import BaseTransport, TransportMessage 5 | 6 | 7 | class TestTransport(BaseTransport): 8 | ''' 9 | A transport implementation for testing. 10 | ''' 11 | RESOLVE_READY = True 12 | 13 | def __init__(self, url, params, incoming, outgoing): 14 | self.url = url 15 | self.params = params 16 | super().__init__(incoming, outgoing) 17 | self._future = asyncio.Future() 18 | if self.RESOLVE_READY: 19 | self.ready.set_result(True) 20 | 21 | async def run(self): 22 | self.running = True 23 | await self._future 24 | self.running = False 25 | 26 | async def stop(self): 27 | self._future.set_result(True) 28 | 29 | 30 | # TODO: Decide if this class is worth it. 31 | # Currently just used in one place, and quite easy to do a gather w/ 32 | # set_reply instead... 33 | class TestSender(): 34 | ''' 35 | Helper class for sending test messages and setting their responses. 36 | 37 | Example: 38 | 39 | sender = TestSender('test:topic', 'an_event', {}) 40 | sender.set_reply({'whatever': 'you_want'}) 41 | msg = await sender.send(socket) 42 | response = await msg.response() 43 | assert response == {'whatever': 'you_want'} 44 | ''' 45 | def __init__(self, topic, event, payload): 46 | self._outgoing = { 47 | 'topic': topic, 48 | 'event': event, 49 | 'payload': payload 50 | } 51 | 52 | def set_reply(self, payload): 53 | self._reply_payload = payload 54 | 55 | async def send(self, socket): 56 | if self._reply_payload: 57 | other_future = set_reply( 58 | socket, self._outgoing['topic'], self._reply_payload 59 | ) 60 | else: 61 | other_future = _mark_sent(socket) 62 | 63 | _, sent_message = await asyncio.gather( 64 | other_future, socket._send_message(**self._outgoing) 65 | ) 66 | return sent_message 67 | 68 | 69 | async def _mark_sent(socket): 70 | ''' 71 | Marks the next outgoing message as sent. 72 | ''' 73 | msg = await socket.transport.outgoing.get() 74 | msg.sent.set_result(True) 75 | 76 | 77 | async def set_reply(socket, topic, payload): 78 | ''' 79 | Sets the reply to the next sent message. 80 | 81 | Also marks the message as sent if it is not already. 82 | ''' 83 | msg = await socket.transport.outgoing.get() 84 | if not msg.sent.done(): 85 | msg.sent.set_result(True) 86 | 87 | await socket.transport.incoming.put( 88 | TransportMessage('phx_reply', topic, payload, msg.message.ref) 89 | ) 90 | -------------------------------------------------------------------------------- /chunnel/channel.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from .messages import ChannelEvents 4 | 5 | 6 | class ChannelJoinFailure(Exception): 7 | pass 8 | 9 | 10 | class ChannelLeaveFailure(Exception): 11 | pass 12 | 13 | # TODO: Random thought, but _might_ be nice to ditch the 14 | # mutable-ness of these classes. 15 | # Like a Channel can be joined or not. 16 | # Maybe we should have a JoinedChannel class vs a Channel class. 17 | # `.channel` always returns the Channel. 18 | # `.join` returns the JoinedChannel (and if already joined does nothing extra). 19 | # Not sure if it'd make a good API, but worth thinking about... 20 | 21 | 22 | class Channel: 23 | ''' 24 | A channel on a phoenix server. 25 | 26 | Should not be instantiated directly, but through a socket. 27 | ''' 28 | def __init__(self, socket, topic, params): 29 | self.socket = socket 30 | self.topic = topic 31 | self.params = params 32 | self._incoming_messages = asyncio.Queue() 33 | # TODO: Consider something like channel_states in js lib? 34 | 35 | async def join(self): 36 | ''' 37 | Joins the channel. 38 | ''' 39 | join = await self.socket._send_message( 40 | self.topic, ChannelEvents.join.value, self.params 41 | ) 42 | try: 43 | response = await join.response() 44 | except Exception as e: 45 | # TODO: this needs some work. 46 | raise ChannelJoinFailure() from e 47 | 48 | return response 49 | 50 | async def leave(self): 51 | ''' 52 | Leaves the channel. 53 | ''' 54 | leave = await self.socket._send_message( 55 | self.topic, ChannelEvents.leave.value, self.params 56 | ) 57 | try: 58 | response = await leave.response() 59 | except Exception as e: 60 | # TODO: this needs some work. 61 | raise ChannelLeaveFailure() from e 62 | 63 | async def push(self, event, payload): 64 | ''' 65 | Pushes a message to a channel. 66 | 67 | :param event: The event to push. 68 | :param payload: The payload for the event. 69 | ''' 70 | msg = await self.socket._send_message(self.topic, event, payload) 71 | return msg 72 | 73 | # TODO: could be nice to just expose a "read only queue" under .incoming 74 | # With get, get_nowait & an async iterator interface? 75 | # TODO: Otherwise should maybe be called pull (to go with push) 76 | async def receive(self): 77 | msg = await self._incoming_messages.get() 78 | return msg 79 | 80 | async def __aenter__(self): 81 | resp = await self.join() 82 | return self, resp 83 | 84 | async def __aexit__(self, exc_type, exc_value, tb): 85 | await self.leave() 86 | -------------------------------------------------------------------------------- /test/test_socket.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import sentinel 2 | import asyncio 3 | 4 | import pytest 5 | 6 | from chunnel.transports import TransportMessage 7 | 8 | from .shared import TestTransport, TestSender 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_connecting(socket, mocker): 13 | mocker.patch.object(TestTransport, 'RESOLVE_READY', new=False) 14 | 15 | connect_future = asyncio.ensure_future(socket.connect()) 16 | await asyncio.sleep(0) 17 | assert socket.transport.params == sentinel.connect_params 18 | 19 | assert socket.transport 20 | assert socket.transport.url == 'ws://localhost' 21 | assert not connect_future.done() 22 | 23 | socket.transport.ready.set_result(True) 24 | await asyncio.sleep(0) 25 | assert connect_future.done() 26 | assert socket.connected 27 | await connect_future 28 | 29 | await socket.disconnect() 30 | 31 | 32 | @pytest.mark.asyncio 33 | async def test_connect_as_context_manager(socket): 34 | async with socket as socket: 35 | assert socket.connected 36 | 37 | assert not socket.connected 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_creating_channel(socket): 42 | async with socket: 43 | channel = socket.channel("test:topic", sentinel.channel_params) 44 | assert channel.topic == 'test:topic' 45 | assert channel.params == sentinel.channel_params 46 | assert 'test:topic' in socket.channels 47 | assert socket.channels['test:topic'] == channel 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_sending_message(socket): 52 | async with socket: 53 | send_future = asyncio.ensure_future( 54 | socket._send_message( 55 | sentinel.topic, sentinel.event, sentinel.payload 56 | ) 57 | ) 58 | await asyncio.sleep(0) 59 | assert socket.transport.outgoing.qsize() == 1 60 | message, sent_future = socket.transport.outgoing.get_nowait() 61 | assert message.event == sentinel.event 62 | assert message.topic == sentinel.topic 63 | assert message.payload == sentinel.payload 64 | assert message.ref 65 | 66 | sent_future.set_result(True) 67 | 68 | await asyncio.sleep(0) 69 | assert send_future.done() 70 | assert await send_future 71 | 72 | 73 | @pytest.mark.asyncio 74 | async def test_replies_routed_correctly(socket): 75 | async with socket: 76 | sender = TestSender(sentinel.topic, sentinel.event, sentinel.payload) 77 | sender.set_reply({'status': 'ok', 'response': sentinel.response}) 78 | 79 | sent_message = await sender.send(socket) 80 | 81 | assert await sent_message.response() == sentinel.response 82 | 83 | 84 | @pytest.mark.asyncio 85 | async def test_channel_messages_routed_correctly(socket): 86 | async with socket: 87 | channel = socket.channel("test:topic", sentinel.channel_params) 88 | message = TransportMessage( 89 | sentinel.event, "test:topic", sentinel.payload, sentinel.ref 90 | ) 91 | await socket.transport.incoming.put(message) 92 | message = await channel.receive() 93 | assert message.event == sentinel.event 94 | assert message.payload == sentinel.payload 95 | -------------------------------------------------------------------------------- /chunnel/transports/websocket.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlencode 2 | import asyncio 3 | import json 4 | import logging 5 | 6 | import websockets 7 | 8 | from .base import BaseTransport, TransportMessage 9 | from ..utils import get_unless_done, DONE 10 | 11 | __all__ = ['WebsocketTransport'] 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class WebsocketTransport(BaseTransport): 17 | ''' 18 | Implements the websocket transport for talking to phoenix servers. 19 | ''' 20 | def __init__(self, url, params, incoming_queue, outgoing_queue): 21 | super().__init__( 22 | incoming_queue=incoming_queue, outgoing_queue=outgoing_queue 23 | ) 24 | qs_params = {'vsn': '1.0.0', **params} 25 | self.url = url + '?' + urlencode(qs_params) 26 | print(self.url) 27 | self.ready = asyncio.Future() 28 | self._done = asyncio.Future() 29 | 30 | async def run(self): 31 | try: 32 | async with websockets.connect(self.url) as websocket: 33 | self.ready.set_result(True) 34 | # TODO: Think about error propagation at some point? 35 | # If one of these crashes, does it cancel the other? 36 | # It should, not sure if it does. 37 | await asyncio.gather( 38 | self._recv_loop(websocket), self._send_loop(websocket) 39 | ) 40 | except Exception as e: 41 | if not self.ready.done(): 42 | self.ready.set_exception(e) 43 | 44 | raise 45 | 46 | async def stop(self): 47 | self._done.set_result(True) 48 | 49 | async def _recv_loop(self, websocket): 50 | while True: 51 | message_data = await get_unless_done(websocket.recv(), self._done) 52 | if message_data is DONE: 53 | return 54 | 55 | logger.debug("received: %s", message_data) 56 | # TODO: This needs updates. 57 | message = _load_incoming_message(json.loads(message_data)) 58 | await self.incoming.put(message) 59 | logger.debug("sent") 60 | 61 | async def _send_loop(self, websocket): 62 | while True: 63 | message = await get_unless_done(self.outgoing.get(), self._done) 64 | if message is DONE: 65 | return 66 | 67 | logger.debug("sending: %s", message) 68 | try: 69 | # TODO: This needs updates. 70 | message_data = json.dumps( 71 | dump_outgoing_message(message.message) 72 | ) 73 | await websocket.send(message_data) 74 | message.sent.set_result(True) 75 | except Exception as e: 76 | message.sent.set_exception(e) 77 | logger.debug("sent") 78 | 79 | 80 | def _load_incoming_message(message_data): 81 | return TransportMessage( 82 | message_data['event'], 83 | message_data['topic'], 84 | message_data['payload'], 85 | message_data.get('ref') 86 | ) 87 | 88 | 89 | def dump_outgoing_message(message): 90 | return { 91 | 'event': message.event, 92 | 'topic': message.topic, 93 | 'ref': message.ref, 94 | 'payload': message.payload 95 | } 96 | -------------------------------------------------------------------------------- /test/test_channel.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import sentinel 2 | import asyncio 3 | 4 | import pytest 5 | 6 | from chunnel.channel import ChannelJoinFailure, ChannelLeaveFailure 7 | from chunnel.messages import ChannelEvents 8 | from chunnel.transports import TransportMessage 9 | 10 | from .shared import set_reply 11 | 12 | 13 | @pytest.yield_fixture 14 | def socket(socket, event_loop): 15 | event_loop.run_until_complete(socket.connect()) 16 | yield socket 17 | event_loop.run_until_complete(socket.disconnect()) 18 | 19 | 20 | @pytest.fixture 21 | def channel(socket): 22 | return socket.channel("test:lobby", {}) 23 | 24 | 25 | @pytest.mark.asyncio 26 | async def test_join(socket, channel): 27 | response, _ = await asyncio.gather( 28 | channel.join(), 29 | set_reply( 30 | socket, 31 | None, 32 | {'status': 'ok', 'response': sentinel.response} 33 | ) 34 | ) 35 | assert response == sentinel.response 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_join_failure(socket, channel): 40 | with pytest.raises(ChannelJoinFailure): 41 | await asyncio.gather( 42 | channel.join(), 43 | set_reply(socket, None, {'status': 'error'}) 44 | ) 45 | 46 | 47 | @pytest.mark.asyncio 48 | async def test_leave(socket, channel): 49 | response, _ = await asyncio.gather( 50 | channel.join(), 51 | set_reply(socket, None, {'status': 'ok', 'response': {}}) 52 | ) 53 | await asyncio.gather( 54 | channel.leave(), 55 | set_reply(socket, None, {'status': 'ok'}) 56 | ) 57 | 58 | 59 | @pytest.mark.asyncio 60 | async def test_leave_failure(socket, channel): 61 | response, _ = await asyncio.gather( 62 | channel.join(), 63 | set_reply(socket, None, {'status': 'ok', 'response': {}}) 64 | ) 65 | with pytest.raises(ChannelLeaveFailure): 66 | await asyncio.gather( 67 | channel.leave(), 68 | set_reply(socket, None, {'status': 'error'}) 69 | ) 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_context_manager_join(socket): 74 | asyncio.ensure_future( 75 | set_reply( 76 | socket, None, {'status': 'ok', 'response': sentinel.response} 77 | ) 78 | ) 79 | # TODO: Not sure about this api... 80 | async with socket.channel("test:lobby", {}) as (channel, response): 81 | assert response == sentinel.response 82 | assert channel == socket.channels['test:lobby'] 83 | asyncio.ensure_future( 84 | set_reply(socket, None, {'status': 'ok'}) 85 | ) 86 | 87 | 88 | @pytest.mark.asyncio 89 | async def test_send_message(socket, channel): 90 | send_future = asyncio.ensure_future( 91 | channel.push(sentinel.event, sentinel.payload) 92 | ) 93 | await asyncio.sleep(0) 94 | assert socket.transport.outgoing.qsize() == 1 95 | msg, sent_future = socket.transport.outgoing.get_nowait() 96 | assert msg.topic == channel.topic 97 | assert msg.event == sentinel.event 98 | assert msg.payload == sentinel.payload 99 | sent_future.set_result(True) 100 | 101 | assert await send_future 102 | 103 | 104 | @pytest.mark.asyncio 105 | async def test_message_replies(socket, channel): 106 | sent_message, _ = await asyncio.gather( 107 | channel.push(sentinel.event, sentinel.payload), 108 | set_reply(socket, None, {'status': 'ok', 'response': 'abcd'}) 109 | ) 110 | response = await sent_message.response() 111 | assert response == 'abcd' 112 | 113 | 114 | @pytest.mark.asyncio 115 | async def test_message_receive(socket, channel): 116 | message = TransportMessage( 117 | sentinel.event, channel.topic, sentinel.payload, sentinel.ref 118 | ) 119 | await socket.transport.incoming.put(message) 120 | await socket.transport.incoming.put(message) 121 | message = await channel.receive() 122 | assert message.event == sentinel.event 123 | assert message.payload == sentinel.payload 124 | 125 | 126 | @pytest.mark.asyncio 127 | async def test_reply_to_received_message(socket, channel): 128 | await socket.transport.incoming.put( 129 | TransportMessage( 130 | sentinel.event, channel.topic, sentinel.payload, sentinel.ref 131 | ) 132 | ) 133 | incoming_message = await channel.receive() 134 | 135 | reply_future = asyncio.ensure_future( 136 | incoming_message.reply(sentinel.status, sentinel.response) 137 | ) 138 | await asyncio.sleep(0) 139 | reply, sent_future = await socket.transport.outgoing.get() 140 | assert reply.topic == channel.topic 141 | assert reply.event == ChannelEvents.reply.value 142 | assert reply.payload == { 143 | 'status': sentinel.status, 'response': sentinel.response 144 | } 145 | assert reply.ref == sentinel.ref 146 | 147 | sent_future.set_result(True) 148 | assert reply_future 149 | -------------------------------------------------------------------------------- /chunnel/socket.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import FIRST_COMPLETED 2 | from urllib.parse import urlsplit 3 | import asyncio 4 | import logging 5 | 6 | from .transports import ( 7 | WebsocketTransport, TransportMessage, OutgoingTransportMessage 8 | ) 9 | from .channel import Channel 10 | from .messages import SentMessage, ChannelEvents, IncomingMessage 11 | from .utils import get_unless_done, DONE 12 | 13 | __all__ = ['Socket'] 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # TODO: Should this be called Socket? Dunno if it matches up with phoenix too 19 | # well.. 20 | class Socket: 21 | ''' 22 | A connection to a phoenix server. 23 | 24 | A transport will automatically be selected based on the URL provided. 25 | See the TRANSPORTS dict for more details. 26 | 27 | :param url: The URL of the phoenix server to connect to. 28 | :param params: Optional parameters to use when connecting. 29 | ''' 30 | 31 | # A mapping of url scheme -> transport. 32 | TRANSPORTS = { 33 | 'ws': WebsocketTransport, 34 | 'wss': WebsocketTransport 35 | } 36 | 37 | # TODO: Should these parameters be passed to connect? Maybe not.. 38 | def __init__(self, url, params): 39 | self.url = url 40 | self.params = params 41 | self.connected = False 42 | self.channels = {} 43 | self._incoming = asyncio.Queue() 44 | self._outgoing = asyncio.Queue() 45 | self._ref = 1 46 | self._response_futures = {} 47 | 48 | async def connect(self): 49 | if self.connected: 50 | raise Exception("Already connected!") 51 | 52 | transport_class = self.TRANSPORTS[urlsplit(self.url).scheme] 53 | self.transport = transport_class( 54 | self.url, self.params, self._incoming, self._outgoing 55 | ) 56 | transport_task = asyncio.ensure_future(self.transport.run()) 57 | 58 | await self.transport.ready 59 | 60 | self._transport_task = transport_task 61 | self._done_recv = asyncio.Future() 62 | self._recv_task = asyncio.ensure_future(self._recv_loop()) 63 | # TODO: Ok, so this is cool - but how to tell if our transport_task has 64 | # failed. 65 | self.connected = True 66 | 67 | async def disconnect(self): 68 | if not self.connected: 69 | raise Exception("Not connected!") 70 | 71 | self._done_recv.set_result(True) 72 | await self.transport.stop() 73 | 74 | await asyncio.gather(self._recv_task, self._transport_task) 75 | 76 | self.connected = False 77 | 78 | def channel(self, topic, params): 79 | # TODO: What to do if we already have this channel? 80 | channel = Channel(self, topic, params) 81 | self.channels[topic] = channel 82 | return channel 83 | 84 | async def __aenter__(self): 85 | await self.connect() 86 | return self 87 | 88 | async def __aexit__(self, exc_type, exc, tb): 89 | await self.disconnect() 90 | 91 | async def _check_transport(self): 92 | # TODO: More thought around this function... 93 | if not self.connected: 94 | raise Exception("Not connected!") 95 | if self._transport_task.done(): 96 | # We've probably excepted. 97 | # TODO: Do something more thorough here... 98 | self._transport_task.result() 99 | 100 | # TODO: _push_message? 101 | async def _send_message(self, topic, event, payload, ref=None): 102 | ''' 103 | Sends a message to the remote. 104 | 105 | :param topic: The topic to send the message on. 106 | :param event: The name of the event to send. 107 | :param payload: The payload of the event. 108 | :param ref: Optional ref to use for sending. 109 | :returns: The ref of the event, which can be used to receive 110 | replies. 111 | ''' 112 | if not ref: 113 | ref = self._ref 114 | self._ref += 1 115 | 116 | message = OutgoingTransportMessage( 117 | TransportMessage(event, topic, payload, ref), 118 | asyncio.Future() 119 | ) 120 | # TODO: add a done callback to reply_future that deletes it from 121 | # self._response_futures after a certain time... 122 | resp_future = asyncio.Future() 123 | self._response_futures[ref] = resp_future 124 | await self._outgoing.put(message) 125 | await message.sent 126 | # TODO: Return something slightly different.... 127 | return SentMessage(resp_future) 128 | 129 | async def _recv_loop(self): 130 | ''' 131 | Runs the socket receive loop. 132 | 133 | This will read incoming messages off the queue and attempt to route 134 | them to an appropriate place. 135 | ''' 136 | while True: 137 | message = await get_unless_done( 138 | self._incoming.get(), self._done_recv 139 | ) 140 | if message is DONE: 141 | break 142 | # TODO: Definitely need to handle phx_close... 143 | if message.event == ChannelEvents.reply.value: 144 | if message.payload['status'] == 'ok': 145 | self._response_futures[message.ref].set_result( 146 | message.payload.get('response') 147 | ) 148 | else: 149 | # TODO: we can do better than this... 150 | self._response_futures[message.ref].set_exception( 151 | Exception("Response not ok!") 152 | ) 153 | else: 154 | channel = self.channels.get(message.topic) 155 | if channel: 156 | await channel._incoming_messages.put( 157 | IncomingMessage(message, self) 158 | ) 159 | --------------------------------------------------------------------------------