├── tests ├── __init__.py ├── helpers.py ├── test_client.py ├── test_aioclient.py ├── test_server.py └── test_aioserver.py ├── docs ├── _static │ └── css │ │ └── custom.css ├── index.rst ├── Makefile ├── api.rst ├── make.bat ├── conf.py └── intro.rst ├── .github ├── FUNDING.yml └── workflows │ └── tests.yml ├── MANIFEST.in ├── src └── simple_websocket │ ├── __init__.py │ ├── errors.py │ ├── asgi.py │ ├── aiows.py │ └── ws.py ├── .readthedocs.yaml ├── README.md ├── examples ├── asgiechoserver.py ├── echoclient.py ├── echoserver.py ├── aioechoclient.py └── aioechoserver.py ├── tox.ini ├── LICENSE ├── pyproject.toml ├── .gitignore └── CHANGES.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | .py .class, .py .method, .py .property { 2 | margin-top: 20px; 3 | } 4 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: miguelgrinberg 2 | patreon: miguelgrinberg 3 | custom: https://paypal.me/miguelgrinberg 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE tox.ini 2 | recursive-include docs * 3 | recursive-exclude docs/_build * 4 | recursive-include tests * 5 | exclude **/*.pyc 6 | -------------------------------------------------------------------------------- /src/simple_websocket/__init__.py: -------------------------------------------------------------------------------- 1 | from .ws import Server, Client # noqa: F401 2 | from .aiows import AioServer, AioClient # noqa: F401 3 | from .errors import ConnectionError, ConnectionClosed # noqa: F401 4 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - docs 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | simple-websocket 2 | ================ 3 | 4 | Simple WebSocket server and client for Python. 5 | 6 | ## Resources 7 | 8 | - [Documentation](http://simple-websocket.readthedocs.io/en/latest/) 9 | - [PyPI](https://pypi.python.org/pypi/simple-websocket) 10 | - [Change Log](https://github.com/miguelgrinberg/simple-websocket/blob/main/CHANGES.md) 11 | 12 | -------------------------------------------------------------------------------- /examples/asgiechoserver.py: -------------------------------------------------------------------------------- 1 | from simple_websocket import AioServer, ConnectionClosed 2 | 3 | 4 | async def echo(scope, receive, send): 5 | ws = await AioServer.accept(asgi=(scope, receive, send)) 6 | try: 7 | while True: 8 | data = await ws.receive() 9 | await ws.send(data) 10 | except ConnectionClosed: 11 | pass 12 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. simple-websocket documentation master file, created by 2 | sphinx-quickstart on Mon Jun 7 14:14:04 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | simple-websocket 7 | ================ 8 | 9 | Simple WebSocket server and client for Python. 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | intro 15 | api 16 | 17 | * :ref:`search` 18 | -------------------------------------------------------------------------------- /examples/echoclient.py: -------------------------------------------------------------------------------- 1 | from simple_websocket import Client, ConnectionClosed 2 | 3 | 4 | def main(): 5 | ws = Client.connect('ws://localhost:5000/echo') 6 | try: 7 | while True: 8 | data = input('> ') 9 | ws.send(data) 10 | data = ws.receive() 11 | print(f'< {data}') 12 | except (KeyboardInterrupt, EOFError, ConnectionClosed): 13 | ws.close() 14 | 15 | 16 | if __name__ == '__main__': 17 | main() 18 | -------------------------------------------------------------------------------- /examples/echoserver.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request 2 | from simple_websocket import Server, ConnectionClosed 3 | 4 | app = Flask(__name__) 5 | 6 | 7 | @app.route('/echo', websocket=True) 8 | def echo(): 9 | ws = Server.accept(request.environ) 10 | try: 11 | while True: 12 | data = ws.receive() 13 | ws.send(data) 14 | except ConnectionClosed: 15 | pass 16 | return '' 17 | 18 | 19 | if __name__ == '__main__': 20 | app.run() 21 | -------------------------------------------------------------------------------- /examples/aioechoclient.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from simple_websocket import AioClient, ConnectionClosed 3 | 4 | 5 | async def main(): 6 | ws = await AioClient.connect('ws://localhost:5000/echo') 7 | try: 8 | while True: 9 | data = input('> ') 10 | await ws.send(data) 11 | data = await ws.receive() 12 | print(f'< {data}') 13 | except (KeyboardInterrupt, EOFError, ConnectionClosed): 14 | await ws.close() 15 | 16 | 17 | if __name__ == '__main__': 18 | asyncio.run(main()) 19 | -------------------------------------------------------------------------------- /examples/aioechoserver.py: -------------------------------------------------------------------------------- 1 | from aiohttp import web 2 | from simple_websocket import AioServer, ConnectionClosed 3 | 4 | app = web.Application() 5 | 6 | 7 | async def echo(request): 8 | ws = await AioServer.accept(aiohttp=request) 9 | try: 10 | while True: 11 | data = await ws.receive() 12 | await ws.send(data) 13 | except ConnectionClosed: 14 | pass 15 | return web.Response(text='') 16 | 17 | 18 | app.add_routes([web.get('/echo', echo)]) 19 | 20 | if __name__ == '__main__': 21 | web.run_app(app, port=5000) 22 | -------------------------------------------------------------------------------- /src/simple_websocket/errors.py: -------------------------------------------------------------------------------- 1 | from wsproto.frame_protocol import CloseReason 2 | 3 | 4 | class ConnectionError(RuntimeError): # pragma: no cover 5 | """Connection error exception class.""" 6 | def __init__(self, status_code=None): 7 | self.status_code = status_code 8 | super().__init__(f'Connection error: {status_code}') 9 | 10 | 11 | class ConnectionClosed(RuntimeError): 12 | """Connection closed exception class.""" 13 | def __init__(self, reason=CloseReason.NO_STATUS_RCVD, message=None): 14 | self.reason = reason 15 | self.message = message 16 | super().__init__(f'Connection closed: {reason} {message or ""}') 17 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from unittest import mock 3 | 4 | 5 | def AsyncMock(*args, **kwargs): 6 | """Return a mock asynchronous function.""" 7 | m = mock.MagicMock(*args, **kwargs) 8 | 9 | async def mock_coro(*args, **kwargs): 10 | return m(*args, **kwargs) 11 | 12 | mock_coro.mock = m 13 | return mock_coro 14 | 15 | 16 | def _run(coro): 17 | """Run the given coroutine.""" 18 | return asyncio.get_event_loop().run_until_complete(coro) 19 | 20 | 21 | def make_sync(coro): 22 | """Wrap a coroutine so that it can be executed by pytest.""" 23 | def wrapper(*args, **kwargs): 24 | return _run(coro(*args, **kwargs)) 25 | return wrapper 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist=flake8,py37,py38,py39,py310,py311,pypy3,docs 3 | skip_missing_interpreters=True 4 | 5 | [gh-actions] 6 | python = 7 | 3.7: py37 8 | 3.8: py38 9 | 3.9: py39 10 | 3.10: py310 11 | 3.11: py311 12 | pypy-3: pypy3 13 | 14 | [testenv] 15 | commands= 16 | pip install -e . 17 | pytest -p no:logging --cov=simple_websocket --cov-branch --cov-report=term-missing --cov-report=xml 18 | deps= 19 | pytest 20 | pytest-cov 21 | 22 | [testenv:pypy3] 23 | 24 | [testenv:flake8] 25 | deps= 26 | flake8 27 | commands= 28 | flake8 --exclude=".*" src/simple_websocket tests 29 | 30 | [testenv:docs] 31 | changedir=docs 32 | deps= 33 | sphinx 34 | allowlist_externals= 35 | make 36 | commands= 37 | make html 38 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ------------- 3 | 4 | The ``Server`` class 5 | ~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: simple_websocket.Server 8 | :inherited-members: 9 | :members: 10 | 11 | The ``AioServer`` class 12 | ~~~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: simple_websocket.AioServer 15 | :inherited-members: 16 | :members: 17 | 18 | The ``Client`` class 19 | ~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: simple_websocket.Client 22 | :inherited-members: 23 | :members: 24 | 25 | The ``AioClient`` class 26 | ~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: simple_websocket.AioClient 29 | :inherited-members: 30 | :members: 31 | 32 | Exceptions 33 | ~~~~~~~~~~ 34 | 35 | .. autoclass:: simple_websocket.ConnectionError 36 | :members: 37 | 38 | .. autoclass:: simple_websocket.ConnectionClosed 39 | :members: 40 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Miguel Grinberg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "simple-websocket" 3 | version = "1.0.1.dev0" 4 | authors = [ 5 | { name = "Miguel Grinberg", email = "miguel.grinberg@gmail.com" }, 6 | ] 7 | description = "Simple WebSocket server and client for Python" 8 | classifiers = [ 9 | "Environment :: Web Environment", 10 | "Intended Audience :: Developers", 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | ] 15 | requires-python = ">=3.6" 16 | dependencies = [ 17 | "wsproto", 18 | ] 19 | 20 | [project.readme] 21 | file = "README.md" 22 | content-type = "text/markdown" 23 | 24 | [project.urls] 25 | Homepage = "https://github.com/miguelgrinberg/simple-websocket" 26 | "Bug Tracker" = "https://github.com/miguelgrinberg/simple-websocket/issues" 27 | 28 | [project.optional-dependencies] 29 | docs = [ 30 | "sphinx", 31 | ] 32 | 33 | [tool.setuptools] 34 | zip-safe = false 35 | include-package-data = true 36 | 37 | [tool.setuptools.package-dir] 38 | "" = "src" 39 | 40 | [tool.setuptools.packages.find] 41 | where = [ 42 | "src", 43 | ] 44 | namespaces = false 45 | 46 | [build-system] 47 | requires = [ 48 | "setuptools>=61.2", 49 | ] 50 | build-backend = "setuptools.build_meta" 51 | 52 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | lint: 11 | name: lint 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - uses: actions/setup-python@v3 16 | - run: python -m pip install --upgrade pip wheel 17 | - run: pip install tox tox-gh-actions 18 | - run: tox -eflake8 19 | - run: tox -edocs 20 | tests: 21 | name: tests 22 | strategy: 23 | matrix: 24 | os: [ubuntu-latest, macos-latest, windows-latest] 25 | python: ['3.8', '3.9', '3.10', '3.11', 'pypy-3.9'] 26 | fail-fast: false 27 | runs-on: ${{ matrix.os }} 28 | steps: 29 | - uses: actions/checkout@v3 30 | - uses: actions/setup-python@v3 31 | with: 32 | python-version: ${{ matrix.python }} 33 | - run: python -m pip install --upgrade pip wheel 34 | - run: pip install tox tox-gh-actions 35 | - run: tox 36 | coverage: 37 | name: coverage 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v3 41 | - uses: actions/setup-python@v3 42 | - run: python -m pip install --upgrade pip wheel 43 | - run: pip install tox tox-gh-actions 44 | - run: tox 45 | - uses: codecov/codecov-action@v3 46 | with: 47 | files: ./coverage.xml 48 | fail_ci_if_error: true 49 | -------------------------------------------------------------------------------- /src/simple_websocket/asgi.py: -------------------------------------------------------------------------------- 1 | from .errors import ConnectionClosed # pragma: no cover 2 | 3 | 4 | class WebSocketASGI: # pragma: no cover 5 | def __init__(self, scope, receive, send, subprotocols=None): 6 | self._scope = scope 7 | self._receive = receive 8 | self._send = send 9 | self.subprotocols = subprotocols or [] 10 | self.subprotocol = None 11 | self.connected = False 12 | 13 | @classmethod 14 | async def accept(cls, scope, receive, send, subprotocols=None): 15 | ws = WebSocketASGI(scope, receive, send, subprotocols=subprotocols) 16 | await ws._accept() 17 | return ws 18 | 19 | async def _accept(self): 20 | connect = await self._receive() 21 | if connect['type'] != 'websocket.connect': 22 | raise ValueError('Expected websocket.connect') 23 | for subprotocol in self._scope['subprotocols']: 24 | if subprotocol in self.subprotocols: 25 | self.subprotocol = subprotocol 26 | break 27 | await self._send({'type': 'websocket.accept', 28 | 'subprotocol': self.subprotocol}) 29 | 30 | async def receive(self): 31 | message = await self._receive() 32 | if message['type'] == 'websocket.disconnect': 33 | raise ConnectionClosed() 34 | elif message['type'] != 'websocket.receive': 35 | raise OSError(32, 'Websocket message type not supported') 36 | return message.get('text', message.get('bytes')) 37 | 38 | async def send(self, data): 39 | if isinstance(data, str): 40 | await self._send({'type': 'websocket.send', 'text': data}) 41 | else: 42 | await self._send({'type': 'websocket.send', 'bytes': data}) 43 | 44 | async def close(self): 45 | if not self.connected: 46 | self.conncted = False 47 | try: 48 | await self._send({'type': 'websocket.close'}) 49 | except Exception: 50 | pass 51 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../src')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'simple-websocket' 21 | copyright = '2021, Miguel Grinberg' 22 | author = 'Miguel Grinberg' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | ] 33 | autodoc_member_order = 'bysource' 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | html_theme = 'alabaster' 50 | 51 | # Add any paths that contain custom static files (such as style sheets) here, 52 | # relative to this directory. They are copied after the builtin static files, 53 | # so a file named "default.css" will overwrite the builtin "default.css". 54 | html_static_path = ['_static'] 55 | 56 | html_css_files = [ 57 | 'css/custom.css', 58 | ] 59 | 60 | html_theme_options = { 61 | 'github_user': 'miguelgrinberg', 62 | 'github_repo': 'simple-websocket', 63 | 'github_banner': True, 64 | 'github_button': True, 65 | 'github_type': 'star', 66 | 'fixed_sidebar': True, 67 | } 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /docs/intro.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | ``simple-websocket`` includes a collection of WebSocket servers and clients for 5 | Python, including support for both traditional and asynchronous (asyncio) 6 | workflows. The servers are designed to be integrated into larger web 7 | applications if desired. 8 | 9 | Installation 10 | ------------ 11 | 12 | This package is installed with ``pip``:: 13 | 14 | pip install simple-websocket 15 | 16 | Server Example #1: Flask 17 | ------------------------ 18 | 19 | The following example shows how to add a WebSocket route to a 20 | `Flask `_ application. 21 | 22 | :: 23 | 24 | from flask import Flask, request 25 | from simple_websocket import Server, ConnectionClosed 26 | 27 | app = Flask(__name__) 28 | 29 | @app.route('/echo', websocket=True) 30 | def echo(): 31 | ws = Server.accept(request.environ) 32 | try: 33 | while True: 34 | data = ws.receive() 35 | ws.send(data) 36 | except ConnectionClosed: 37 | pass 38 | return '' 39 | 40 | Integration with web applications using other 41 | `WSGI `_ frameworks works in a similar way. The 42 | only requirement is to pass the ``environ`` dictionary to the 43 | ``Server.accept()`` method to initiate the WebSocket handshake. 44 | 45 | Server Example #2: Aiohttp 46 | -------------------------- 47 | 48 | The following example shows how to add a WebSocket route to a web application 49 | built with the `aiohttp `_ framework. 50 | 51 | :: 52 | 53 | from aiohttp import web 54 | from simple_websocket import AioServer, ConnectionClosed 55 | 56 | app = web.Application() 57 | 58 | async def echo(request): 59 | ws = await AioServer.accept(aiohttp=request) 60 | try: 61 | while True: 62 | data = await ws.receive() 63 | await ws.send(data) 64 | except ConnectionClosed: 65 | pass 66 | return web.Response(text='') 67 | 68 | app.add_routes([web.get('/echo', echo)]) 69 | 70 | if __name__ == '__main__': 71 | web.run_app(app, port=5000) 72 | 73 | Server Example #3: ASGI 74 | ----------------------- 75 | 76 | The next server example shows an asynchronous application that supports the 77 | `ASGI `_ protocol. 78 | 79 | :: 80 | 81 | from simple_websocket import AioServer, ConnectionClosed 82 | 83 | async def echo(scope, receive, send): 84 | ws = await AioServer.accept(asgi=(scope, receive, send)) 85 | try: 86 | while True: 87 | data = await ws.receive() 88 | await ws.send(data) 89 | except ConnectionClosed: 90 | pass 91 | 92 | Client Example #1: Synchronous 93 | ------------------------------ 94 | 95 | The client example that follows can connect to any of the server examples above 96 | using a synchronous interface. 97 | 98 | :: 99 | 100 | from simple_websocket import Client, ConnectionClosed 101 | 102 | def main(): 103 | ws = Client.connect('ws://localhost:5000/echo') 104 | try: 105 | while True: 106 | data = input('> ') 107 | ws.send(data) 108 | data = ws.receive() 109 | print(f'< {data}') 110 | except (KeyboardInterrupt, EOFError, ConnectionClosed): 111 | ws.close() 112 | 113 | if __name__ == '__main__': 114 | main() 115 | 116 | Client Example #2: Asynchronous 117 | ------------------------------- 118 | 119 | The next client uses Python's ``asyncio`` framework. 120 | 121 | :: 122 | 123 | import asyncio 124 | from simple_websocket import AioClient, ConnectionClosed 125 | 126 | async def main(): 127 | ws = await AioClient.connect('ws://localhost:5000/echo') 128 | try: 129 | while True: 130 | data = input('> ') 131 | await ws.send(data) 132 | data = await ws.receive() 133 | print(f'< {data}') 134 | except (KeyboardInterrupt, EOFError, ConnectionClosed): 135 | await ws.close() 136 | 137 | if __name__ == '__main__': 138 | asyncio.run(main()) 139 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | # simple-websocket change log 2 | 3 | **Release 1.0.0** - 2023-10-05 4 | 5 | - New async client and server [#28](https://github.com/miguelgrinberg/simple-websocket/issues/28) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/57c5ffcb25c14d5c70f1ad4edd0261cdfcd27c94)) 6 | - On a closed connection, return buffered input before raising an exception [#30](https://github.com/miguelgrinberg/simple-websocket/issues/30) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/6c87abe22215c45b3dc0dadc168c3dd061eb2aa4)) 7 | - Do not duplicate SSLSocket instances [#26](https://github.com/miguelgrinberg/simple-websocket/issues/26) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/da42e98bf80f22747089946a6a08840e0bf646a9)) 8 | - Handle broken pipe errors in background thread [#29](https://github.com/miguelgrinberg/simple-websocket/issues/29) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/6f92764754550fc85b25e42182050c1e6636a41d)) 9 | - Remove unused argument ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/245eedcf1e82fd3d199a6f7bf44916047588763d)) 10 | - Eliminate race conditions during testing [#27](https://github.com/miguelgrinberg/simple-websocket/issues/27) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/a37c79dc9ec8a54968d8b849c7f0a2e3bca46db8)) 11 | - Remove python 3.7 from unit tests ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/761925a635901b4641ad63b6072c24ff5c4099d5)) 12 | 13 | **Release 0.10.1** - 2023-06-04 14 | 15 | - Duplicate the gevent socket to avoid using it in multiple greenlets [#24](https://github.com/miguelgrinberg/simple-websocket/issues/24) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/ebc12b1a390ab36d8dcd020b45410da282fa8d60)) 16 | - Add Python 3.11 to builds ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/df5c92a8d8b48e3482be5ad7af2628b17e6d6d07)) 17 | 18 | **Release 0.10.0** - 2023-04-08 19 | 20 | - Support custom headers in the client ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/4f5c653378e77026604b4b25b8a5373da48b5f74)) 21 | 22 | **Release 0.9.0** - 2022-11-17 23 | 24 | - Properly clean up closed connections [#19](https://github.com/miguelgrinberg/simple-websocket/issues/19) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/9bda31010405045125b304afd633b9a9a5171335)) (thanks **Carlos Carvalho**!) 25 | 26 | **Release 0.8.1** - 2022-09-11 27 | 28 | - Correct handling of an empty subprotocol list in server [#22](https://github.com/miguelgrinberg/simple-websocket/issues/22) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/cf336163fbc65281163fac0c253c4281b760c169)) 29 | - Handshake robustness with slow clients such as microcontrollers ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/271f8fc3ee466a0d0bd5a71543b2e50a632891dd)) 30 | - Prevent race condition on client close [#18](https://github.com/miguelgrinberg/simple-websocket/issues/18) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/e17449153b472a801df4bf2246f06a8486d91c9d)) 31 | - Improved documentation for subprotocol negotiation ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/c74785482ff266c552692a330c3c71d2b9d1f438)) 32 | 33 | **Release 0.8.0** - 2022-08-08 34 | 35 | - Support for subprotocol negotiation [#17](https://github.com/miguelgrinberg/simple-websocket/issues/17) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/04baf871e05e99d80c8905e9e9b0ff4be322e71f)) 36 | 37 | **Release 0.7.0** - 2022-07-24 38 | 39 | - More robust handling of ping intervals [#16](https://github.com/miguelgrinberg/simple-websocket/issues/16) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/05185122a0d2548d5cbd7c3d650db9c9dd49fa76) [commit](https://github.com/miguelgrinberg/simple-websocket/commit/08bd663a918669fb12e805e08a73cae7d7aac3a1)) 40 | 41 | **Release 0.6.0** - 2022-07-15 42 | 43 | - Improved performance of multi-part messages [#15](https://github.com/miguelgrinberg/simple-websocket/issues/15) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/ca2ea38520229ef7c881690667f23b99506f54a3)) 44 | 45 | **Release 0.5.2** - 2022-04-12 46 | 47 | - Compression support [#11](https://github.com/miguelgrinberg/simple-websocket/issues/11) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/9277e67140a456bd34e09146732d4bdca0c6db12)) 48 | - Update builds for python 3.10 and pypy3.8 ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/f44674fd8ec42b05e6ebc0571cb53ba60d3ce144)) 49 | 50 | **Release 0.5.1** - 2022-02-17 51 | 52 | - Store the detected WebSocket mode in server ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/145e3be63ad1de75eedbcfc193eb304767607bc8)) 53 | 54 | **Release 0.5.0** - 2021-12-04 55 | 56 | - Added optional WebSocket Ping/Pong mechanism [#6](https://github.com/miguelgrinberg/simple-websocket/issues/6) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/6f13cdf74abf8627af53e03df2e84db204392a21)) 57 | - Option to set a maximum message size [#5](https://github.com/miguelgrinberg/simple-websocket/issues/5) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/b285024fc3fd75910d166fa5ad258490b70d1326)) 58 | - Store close reason in `ConnectionClosed` exception [#9](https://github.com/miguelgrinberg/simple-websocket/issues/9) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/91eaa52c659e69307e1b3a64329aafc81e3b4625)) 59 | - Option configure a custom selector class ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/1b3dcf77c2aba7ccc6b0f108744f46575ef190b8)) 60 | 61 | **Release 0.4.0** - 2021-09-23 62 | 63 | - Close the connection if `socket.recv()` returns 0 bytes [#4](https://github.com/miguelgrinberg/simple-websocket/issues/4) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/6a75a742fe28ef6fe30ca901144478c466640967)) 64 | 65 | **Release 0.3.0** - 2021-08-05 66 | 67 | - Handle older versions of gevent [#3](https://github.com/miguelgrinberg/simple-websocket/issues/3) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/5ce50802d053bf04d1f6f8c43569105bc5c0b389)) 68 | - Handle large messages split during transmission [#2](https://github.com/miguelgrinberg/simple-websocket/issues/2) ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/e16058daf6d0329028b7f9b81f65f13b64e8e45b)) 69 | - Documentation ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/02cbe78c723b298af9114989c41b8660b8aad3fb)) 70 | - GitHub builds ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/e846f0f86f8bdfed6fb2e7f5fff62abad6de518c)) 71 | - Unit tests 72 | 73 | **Release 0.2.0** - 2021-05-15 74 | 75 | - Make the closing of the connection more resilient to errors ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/6cdf24a8fc1fb782db968e6d4526cced6984d5a4)) 76 | - Unit testing framework ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/35de1658593a153b6926f05b3e3b2eadda814a47)) 77 | 78 | **Release 0.1.0** - 2021-04-18 79 | 80 | - initial commit ([commit](https://github.com/miguelgrinberg/simple-websocket/commit/1ddd63d230950f40683a7771eb3ce6ae7d199c23)) 81 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | from unittest import mock 4 | import pytest # noqa: F401 5 | 6 | from wsproto.events import AcceptConnection, CloseConnection, TextMessage, \ 7 | BytesMessage, Ping 8 | import simple_websocket 9 | 10 | 11 | class SimpleWebSocketClientTestCase(unittest.TestCase): 12 | def get_client(self, mock_wsconn, url, events=[], subprotocols=None, 13 | headers=None): 14 | mock_wsconn().events.side_effect = \ 15 | [iter(ev) for ev in 16 | [[AcceptConnection()]] + events + [[CloseConnection(1000)]]] 17 | mock_wsconn().send = lambda x: str(x).encode('utf-8') 18 | return simple_websocket.Client.connect(url, subprotocols=subprotocols, 19 | headers=headers) 20 | 21 | @mock.patch('simple_websocket.ws.socket.socket') 22 | @mock.patch('simple_websocket.ws.WSConnection') 23 | def test_make_client(self, mock_wsconn, mock_socket): 24 | mock_socket.return_value.recv.return_value = b'x' 25 | client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1') 26 | assert client.sock == mock_socket() 27 | assert client.receive_bytes == 4096 28 | assert client.input_buffer == [] 29 | assert client.event.__class__.__name__ == 'Event' 30 | client.sock.send.assert_called_with( 31 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 32 | b"extra_headers=[], subprotocols=[])") 33 | assert not client.is_server 34 | assert client.host == 'example.com' 35 | assert client.port == 80 36 | assert client.path == '/ws?a=1' 37 | 38 | @mock.patch('simple_websocket.ws.socket.socket') 39 | @mock.patch('simple_websocket.ws.WSConnection') 40 | def test_make_client_subprotocol(self, mock_wsconn, mock_socket): 41 | mock_socket.return_value.recv.return_value = b'x' 42 | client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 43 | subprotocols='foo') 44 | assert client.subprotocols == ['foo'] 45 | client.sock.send.assert_called_with( 46 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 47 | b"extra_headers=[], subprotocols=['foo'])") 48 | 49 | @mock.patch('simple_websocket.ws.socket.socket') 50 | @mock.patch('simple_websocket.ws.WSConnection') 51 | def test_make_client_subprotocols(self, mock_wsconn, mock_socket): 52 | mock_socket.return_value.recv.return_value = b'x' 53 | client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 54 | subprotocols=['foo', 'bar']) 55 | assert client.subprotocols == ['foo', 'bar'] 56 | client.sock.send.assert_called_with( 57 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 58 | b"extra_headers=[], subprotocols=['foo', 'bar'])") 59 | 60 | @mock.patch('simple_websocket.ws.socket.socket') 61 | @mock.patch('simple_websocket.ws.WSConnection') 62 | def test_make_client_headers(self, mock_wsconn, mock_socket): 63 | mock_socket.return_value.recv.return_value = b'x' 64 | client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 65 | headers={'Foo': 'Bar'}) 66 | client.sock.send.assert_called_with( 67 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 68 | b"extra_headers=[('Foo', 'Bar')], subprotocols=[])") 69 | 70 | @mock.patch('simple_websocket.ws.socket.socket') 71 | @mock.patch('simple_websocket.ws.WSConnection') 72 | def test_make_client_headers2(self, mock_wsconn, mock_socket): 73 | mock_socket.return_value.recv.return_value = b'x' 74 | client = self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 75 | headers=[('Foo', 'Bar'), ('Foo', 'Baz')]) 76 | client.sock.send.assert_called_with( 77 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 78 | b"extra_headers=[('Foo', 'Bar'), ('Foo', 'Baz')], " 79 | b"subprotocols=[])") 80 | 81 | @mock.patch('simple_websocket.ws.socket.socket') 82 | @mock.patch('simple_websocket.ws.WSConnection') 83 | def test_send(self, mock_wsconn, mock_socket): 84 | mock_socket.return_value.recv.return_value = b'x' 85 | client = self.get_client(mock_wsconn, 'ws://example.com/ws') 86 | while client.connected: 87 | time.sleep(0.01) 88 | with pytest.raises(simple_websocket.ConnectionClosed): 89 | client.send('hello') 90 | client.connected = True 91 | client.send('hello') 92 | mock_socket().send.assert_called_with( 93 | b"TextMessage(data='hello', frame_finished=True, " 94 | b"message_finished=True)") 95 | client.connected = True 96 | client.send(b'hello') 97 | mock_socket().send.assert_called_with( 98 | b"Message(data=b'hello', frame_finished=True, " 99 | b"message_finished=True)") 100 | 101 | @mock.patch('simple_websocket.ws.socket.socket') 102 | @mock.patch('simple_websocket.ws.WSConnection') 103 | def test_receive(self, mock_wsconn, mock_socket): 104 | mock_socket.return_value.recv.return_value = b'x' 105 | client = self.get_client(mock_wsconn, 'ws://example.com/ws', events=[ 106 | [TextMessage('hello')], 107 | [BytesMessage(b'hello')], 108 | ]) 109 | while client.connected: 110 | time.sleep(0.01) 111 | client.connected = True 112 | assert client.receive() == 'hello' 113 | assert client.receive() == b'hello' 114 | assert client.receive(timeout=0) is None 115 | 116 | @mock.patch('simple_websocket.ws.socket.socket') 117 | @mock.patch('simple_websocket.ws.WSConnection') 118 | def test_receive_after_close(self, mock_wsconn, mock_socket): 119 | mock_socket.return_value.recv.return_value = b'x' 120 | client = self.get_client(mock_wsconn, 'ws://example.com/ws', events=[ 121 | [TextMessage('hello')], 122 | ]) 123 | while client.connected: 124 | time.sleep(0.01) 125 | assert client.receive() == 'hello' 126 | with pytest.raises(simple_websocket.ConnectionClosed): 127 | client.receive() 128 | 129 | @mock.patch('simple_websocket.ws.socket.socket') 130 | @mock.patch('simple_websocket.ws.WSConnection') 131 | def test_receive_ping(self, mock_wsconn, mock_socket): 132 | mock_socket.return_value.recv.return_value = b'x' 133 | client = self.get_client(mock_wsconn, 'ws://example.com/ws', events=[ 134 | [Ping(b'hello')], 135 | ]) 136 | while client.connected: 137 | time.sleep(0.01) 138 | mock_socket().send.assert_any_call(b"Pong(payload=b'hello')") 139 | 140 | @mock.patch('simple_websocket.ws.socket.socket') 141 | @mock.patch('simple_websocket.ws.WSConnection') 142 | def test_receive_empty(self, mock_wsconn, mock_socket): 143 | mock_socket.return_value.recv.side_effect = [b'x', b'x', b''] 144 | client = self.get_client(mock_wsconn, 'ws://example.com/ws', events=[ 145 | [TextMessage('hello')], 146 | ]) 147 | while client.connected: 148 | time.sleep(0.01) 149 | client.connected = True 150 | assert client.receive() == 'hello' 151 | assert client.receive(timeout=0) is None 152 | 153 | @mock.patch('simple_websocket.ws.socket.socket') 154 | @mock.patch('simple_websocket.ws.WSConnection') 155 | def test_close(self, mock_wsconn, mock_socket): 156 | mock_socket.return_value.recv.return_value = b'x' 157 | client = self.get_client(mock_wsconn, 'ws://example.com/ws') 158 | while client.connected: 159 | time.sleep(0.01) 160 | with pytest.raises(simple_websocket.ConnectionClosed): 161 | client.close() 162 | client.connected = True 163 | client.close() 164 | assert not client.connected 165 | mock_socket().send.assert_called_with( 166 | b'CloseConnection(code=, ' 167 | b'reason=None)') 168 | -------------------------------------------------------------------------------- /tests/test_aioclient.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import unittest 3 | from unittest import mock 4 | import pytest # noqa: F401 5 | 6 | from wsproto.events import AcceptConnection, CloseConnection, TextMessage, \ 7 | BytesMessage, Ping 8 | import simple_websocket 9 | from .helpers import make_sync, AsyncMock 10 | 11 | 12 | class AioSimpleWebSocketClientTestCase(unittest.TestCase): 13 | async def get_client(self, mock_wsconn, url, events=[], subprotocols=None, 14 | headers=None): 15 | mock_wsconn().events.side_effect = \ 16 | [iter(ev) for ev in 17 | [[AcceptConnection()]] + events + [[CloseConnection(1000)]]] 18 | mock_wsconn().send = lambda x: str(x).encode('utf-8') 19 | return await simple_websocket.AioClient.connect( 20 | url, subprotocols=subprotocols, headers=headers) 21 | 22 | @make_sync 23 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 24 | @mock.patch('simple_websocket.aiows.WSConnection') 25 | async def test_make_client(self, mock_wsconn, mock_open_connection): 26 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 27 | wsock = mock.MagicMock() 28 | mock_open_connection.return_value = (rsock, wsock) 29 | client = await self.get_client(mock_wsconn, 'ws://example.com/ws?a=1') 30 | assert client.rsock == rsock 31 | assert client.wsock == wsock 32 | assert client.receive_bytes == 4096 33 | assert client.input_buffer == [] 34 | assert client.event.__class__.__name__ == 'Event' 35 | client.wsock.write.assert_called_with( 36 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 37 | b"extra_headers=[], subprotocols=[])") 38 | assert not client.is_server 39 | assert client.host == 'example.com' 40 | assert client.port == 80 41 | assert client.path == '/ws?a=1' 42 | 43 | @make_sync 44 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 45 | @mock.patch('simple_websocket.aiows.WSConnection') 46 | async def test_make_client_subprotocol(self, mock_wsconn, 47 | mock_open_connection): 48 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 49 | wsock = mock.MagicMock() 50 | mock_open_connection.return_value = (rsock, wsock) 51 | client = await self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 52 | subprotocols='foo') 53 | assert client.subprotocols == ['foo'] 54 | client.wsock.write.assert_called_with( 55 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 56 | b"extra_headers=[], subprotocols=['foo'])") 57 | 58 | @make_sync 59 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 60 | @mock.patch('simple_websocket.aiows.WSConnection') 61 | async def test_make_client_subprotocols(self, mock_wsconn, 62 | mock_open_connection): 63 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 64 | wsock = mock.MagicMock() 65 | mock_open_connection.return_value = (rsock, wsock) 66 | client = await self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 67 | subprotocols=['foo', 'bar']) 68 | assert client.subprotocols == ['foo', 'bar'] 69 | client.wsock.write.assert_called_with( 70 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 71 | b"extra_headers=[], subprotocols=['foo', 'bar'])") 72 | 73 | @make_sync 74 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 75 | @mock.patch('simple_websocket.aiows.WSConnection') 76 | async def test_make_client_headers(self, mock_wsconn, 77 | mock_open_connection): 78 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 79 | wsock = mock.MagicMock() 80 | mock_open_connection.return_value = (rsock, wsock) 81 | client = await self.get_client(mock_wsconn, 'ws://example.com/ws?a=1', 82 | headers={'Foo': 'Bar'}) 83 | client.wsock.write.assert_called_with( 84 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 85 | b"extra_headers=[('Foo', 'Bar')], subprotocols=[])") 86 | 87 | @make_sync 88 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 89 | @mock.patch('simple_websocket.aiows.WSConnection') 90 | async def test_make_client_headers2(self, mock_wsconn, 91 | mock_open_connection): 92 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 93 | wsock = mock.MagicMock() 94 | mock_open_connection.return_value = (rsock, wsock) 95 | client = await self.get_client( 96 | mock_wsconn, 'ws://example.com/ws?a=1', 97 | headers=[('Foo', 'Bar'), ('Foo', 'Baz')]) 98 | client.wsock.write.assert_called_with( 99 | b"Request(host='example.com', target='/ws?a=1', extensions=[], " 100 | b"extra_headers=[('Foo', 'Bar'), ('Foo', 'Baz')], " 101 | b"subprotocols=[])") 102 | 103 | @make_sync 104 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 105 | @mock.patch('simple_websocket.aiows.WSConnection') 106 | async def test_send(self, mock_wsconn, mock_open_connection): 107 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 108 | wsock = mock.MagicMock() 109 | mock_open_connection.return_value = (rsock, wsock) 110 | client = await self.get_client(mock_wsconn, 'ws://example.com/ws') 111 | while client.connected: 112 | await asyncio.sleep(0.01) 113 | with pytest.raises(simple_websocket.ConnectionClosed): 114 | await client.send('hello') 115 | client.connected = True 116 | await client.send('hello') 117 | wsock.write.assert_called_with( 118 | b"TextMessage(data='hello', frame_finished=True, " 119 | b"message_finished=True)") 120 | client.connected = True 121 | await client.send(b'hello') 122 | wsock.write.assert_called_with( 123 | b"Message(data=b'hello', frame_finished=True, " 124 | b"message_finished=True)") 125 | 126 | @make_sync 127 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 128 | @mock.patch('simple_websocket.aiows.WSConnection') 129 | async def test_receive(self, mock_wsconn, mock_open_connection): 130 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 131 | wsock = mock.MagicMock() 132 | mock_open_connection.return_value = (rsock, wsock) 133 | client = await self.get_client( 134 | mock_wsconn, 'ws://example.com/ws', events=[ 135 | [TextMessage('hello')], 136 | [BytesMessage(b'hello')], 137 | ]) 138 | while client.connected: 139 | await asyncio.sleep(0.01) 140 | client.connected = True 141 | assert await client.receive() == 'hello' 142 | assert await client.receive() == b'hello' 143 | assert await client.receive(timeout=0) is None 144 | 145 | @make_sync 146 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 147 | @mock.patch('simple_websocket.aiows.WSConnection') 148 | async def test_receive_after_close(self, mock_wsconn, 149 | mock_open_connection): 150 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 151 | wsock = mock.MagicMock() 152 | mock_open_connection.return_value = (rsock, wsock) 153 | client = await self.get_client( 154 | mock_wsconn, 'ws://example.com/ws', events=[ 155 | [TextMessage('hello')], 156 | ]) 157 | while client.connected: 158 | await asyncio.sleep(0.01) 159 | assert await client.receive() == 'hello' 160 | with pytest.raises(simple_websocket.ConnectionClosed): 161 | await client.receive() 162 | 163 | @make_sync 164 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 165 | @mock.patch('simple_websocket.aiows.WSConnection') 166 | async def test_receive_ping(self, mock_wsconn, mock_open_connection): 167 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 168 | 169 | wsock = mock.MagicMock() 170 | mock_open_connection.return_value = (rsock, wsock) 171 | client = await self.get_client( 172 | mock_wsconn, 'ws://example.com/ws', events=[ 173 | [Ping(b'hello')], 174 | ]) 175 | while client.connected: 176 | await asyncio.sleep(0.01) 177 | wsock.write.assert_any_call(b"Pong(payload=b'hello')") 178 | 179 | @make_sync 180 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 181 | @mock.patch('simple_websocket.aiows.WSConnection') 182 | async def test_receive_empty(self, mock_wsconn, mock_open_connection): 183 | rsock = mock.MagicMock(read=AsyncMock(side_effect=[b'x', b'x', b''])) 184 | wsock = mock.MagicMock() 185 | mock_open_connection.return_value = (rsock, wsock) 186 | client = await self.get_client( 187 | mock_wsconn, 'ws://example.com/ws', events=[ 188 | [TextMessage('hello')], 189 | ]) 190 | while client.connected: 191 | await asyncio.sleep(0.01) 192 | client.connected = True 193 | assert await client.receive() == 'hello' 194 | assert await client.receive(timeout=0) is None 195 | 196 | @make_sync 197 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 198 | @mock.patch('simple_websocket.aiows.WSConnection') 199 | async def test_close(self, mock_wsconn, mock_open_connection): 200 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 201 | wsock = mock.MagicMock() 202 | mock_open_connection.return_value = (rsock, wsock) 203 | client = await self.get_client( 204 | mock_wsconn, 'ws://example.com/ws') 205 | while client.connected: 206 | await asyncio.sleep(0.01) 207 | with pytest.raises(simple_websocket.ConnectionClosed): 208 | await client.close() 209 | client.connected = True 210 | await client.close() 211 | assert not client.connected 212 | wsock.write.assert_called_with( 213 | b'CloseConnection(code=, ' 214 | b'reason=None)') 215 | -------------------------------------------------------------------------------- /tests/test_server.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | from unittest import mock 4 | import pytest # noqa: F401 5 | 6 | from wsproto.events import Request, CloseConnection, TextMessage, \ 7 | BytesMessage, Ping, Pong 8 | import simple_websocket 9 | 10 | 11 | class SimpleWebSocketServerTestCase(unittest.TestCase): 12 | def get_server(self, mock_wsconn, environ, events=[], 13 | client_subprotocols=None, server_subprotocols=None, 14 | **kwargs): 15 | mock_wsconn().events.side_effect = \ 16 | [iter(ev) for ev in [[ 17 | Request(host='example.com', target='/ws', 18 | subprotocols=client_subprotocols or [])]] + 19 | events + [[CloseConnection(1000, 'bye')]]] 20 | mock_wsconn().send = lambda x: str(x).encode('utf-8') 21 | environ.update({ 22 | 'HTTP_HOST': 'example.com', 23 | 'HTTP_CONNECTION': 'Upgrade', 24 | 'HTTP_UPGRADE': 'websocket', 25 | 'HTTP_SEC_WEBSOCKET_KEY': 'Iv8io/9s+lYFgZWcXczP8Q==', 26 | 'HTTP_SEC_WEBSOCKET_VERSION': '13', 27 | }) 28 | return simple_websocket.Server.accept( 29 | environ, subprotocols=server_subprotocols, **kwargs) 30 | 31 | @mock.patch('simple_websocket.ws.WSConnection') 32 | def test_werkzeug(self, mock_wsconn): 33 | mock_socket = mock.MagicMock() 34 | mock_socket.recv.return_value = b'x' 35 | server = self.get_server(mock_wsconn, { 36 | 'werkzeug.socket': mock_socket, 37 | }) 38 | assert server.sock == mock_socket 39 | assert server.mode == 'werkzeug' 40 | assert server.receive_bytes == 4096 41 | assert server.input_buffer == [] 42 | assert server.event.__class__.__name__ == 'Event' 43 | mock_wsconn().receive_data.assert_any_call( 44 | b'GET / HTTP/1.1\r\n' 45 | b'Host: example.com\r\n' 46 | b'Connection: Upgrade\r\n' 47 | b'Upgrade: websocket\r\n' 48 | b'Sec-Websocket-Key: Iv8io/9s+lYFgZWcXczP8Q==\r\n' 49 | b'Sec-Websocket-Version: 13\r\n\r\n') 50 | assert server.is_server 51 | 52 | @mock.patch('simple_websocket.ws.WSConnection') 53 | def test_gunicorn(self, mock_wsconn): 54 | mock_socket = mock.MagicMock() 55 | mock_socket.recv.return_value = b'x' 56 | server = self.get_server(mock_wsconn, { 57 | 'gunicorn.socket': mock_socket, 58 | }) 59 | assert server.sock == mock_socket 60 | assert server.mode == 'gunicorn' 61 | assert server.receive_bytes == 4096 62 | assert server.input_buffer == [] 63 | assert server.event.__class__.__name__ == 'Event' 64 | mock_wsconn().receive_data.assert_any_call( 65 | b'GET / HTTP/1.1\r\n' 66 | b'Host: example.com\r\n' 67 | b'Connection: Upgrade\r\n' 68 | b'Upgrade: websocket\r\n' 69 | b'Sec-Websocket-Key: Iv8io/9s+lYFgZWcXczP8Q==\r\n' 70 | b'Sec-Websocket-Version: 13\r\n\r\n') 71 | assert server.is_server 72 | 73 | def test_no_socket(self): 74 | with pytest.raises(RuntimeError): 75 | self.get_server(mock.MagicMock(), {}) 76 | 77 | @mock.patch('simple_websocket.ws.WSConnection') 78 | def test_send(self, mock_wsconn): 79 | mock_socket = mock.MagicMock() 80 | mock_socket.recv.return_value = b'x' 81 | server = self.get_server(mock_wsconn, { 82 | 'werkzeug.socket': mock_socket, 83 | }) 84 | while server.connected: 85 | time.sleep(0.01) 86 | with pytest.raises(simple_websocket.ConnectionClosed): 87 | server.send('hello') 88 | server.connected = True 89 | server.send('hello') 90 | mock_socket.send.assert_called_with( 91 | b"TextMessage(data='hello', frame_finished=True, " 92 | b"message_finished=True)") 93 | server.connected = True 94 | server.send(b'hello') 95 | mock_socket.send.assert_called_with( 96 | b"Message(data=b'hello', frame_finished=True, " 97 | b"message_finished=True)") 98 | 99 | @mock.patch('simple_websocket.ws.WSConnection') 100 | def test_receive(self, mock_wsconn): 101 | mock_socket = mock.MagicMock() 102 | mock_socket.recv.return_value = b'x' 103 | server = self.get_server(mock_wsconn, { 104 | 'werkzeug.socket': mock_socket, 105 | }, events=[ 106 | [TextMessage('hello')], 107 | [BytesMessage(b'hello')], 108 | ]) 109 | while server.connected: 110 | time.sleep(0.01) 111 | server.connected = True 112 | assert server.receive() == 'hello' 113 | assert server.receive() == b'hello' 114 | assert server.receive(timeout=0) is None 115 | 116 | @mock.patch('simple_websocket.ws.WSConnection') 117 | def test_receive_after_close(self, mock_wsconn): 118 | mock_socket = mock.MagicMock() 119 | mock_socket.recv.return_value = b'x' 120 | server = self.get_server(mock_wsconn, { 121 | 'werkzeug.socket': mock_socket, 122 | }, events=[ 123 | [TextMessage('hello')], 124 | ]) 125 | while server.connected: 126 | time.sleep(0.01) 127 | assert server.receive() == 'hello' 128 | with pytest.raises(simple_websocket.ConnectionClosed): 129 | server.receive() 130 | 131 | @mock.patch('simple_websocket.ws.WSConnection') 132 | def test_receive_split_messages(self, mock_wsconn): 133 | mock_socket = mock.MagicMock() 134 | mock_socket.recv.return_value = b'x' 135 | server = self.get_server(mock_wsconn, { 136 | 'werkzeug.socket': mock_socket, 137 | }, events=[ 138 | [TextMessage('hel', message_finished=False)], 139 | [TextMessage('lo')], 140 | [TextMessage('he', message_finished=False)], 141 | [TextMessage('l', message_finished=False)], 142 | [TextMessage('lo')], 143 | [BytesMessage(b'hel', message_finished=False)], 144 | [BytesMessage(b'lo')], 145 | [BytesMessage(b'he', message_finished=False)], 146 | [BytesMessage(b'l', message_finished=False)], 147 | [BytesMessage(b'lo')], 148 | ]) 149 | while server.connected: 150 | time.sleep(0.01) 151 | server.connected = True 152 | assert server.receive() == 'hello' 153 | assert server.receive() == 'hello' 154 | assert server.receive() == b'hello' 155 | assert server.receive() == b'hello' 156 | assert server.receive(timeout=0) is None 157 | 158 | @mock.patch('simple_websocket.ws.WSConnection') 159 | def test_receive_ping(self, mock_wsconn): 160 | mock_socket = mock.MagicMock() 161 | mock_socket.recv.return_value = b'x' 162 | server = self.get_server(mock_wsconn, { 163 | 'werkzeug.socket': mock_socket, 164 | }, events=[ 165 | [Ping(b'hello')], 166 | ]) 167 | while server.connected: 168 | time.sleep(0.01) 169 | mock_socket.send.assert_any_call(b"Pong(payload=b'hello')") 170 | 171 | @mock.patch('simple_websocket.ws.WSConnection') 172 | def test_receive_empty(self, mock_wsconn): 173 | mock_socket = mock.MagicMock() 174 | mock_socket.recv.side_effect = [b'x', b'x', b''] 175 | server = self.get_server(mock_wsconn, { 176 | 'werkzeug.socket': mock_socket, 177 | }, events=[ 178 | [TextMessage('hello')], 179 | ]) 180 | while server.connected: 181 | time.sleep(0.01) 182 | server.connected = True 183 | assert server.receive() == 'hello' 184 | assert server.receive(timeout=0) is None 185 | 186 | @mock.patch('simple_websocket.ws.WSConnection') 187 | def test_receive_large(self, mock_wsconn): 188 | mock_socket = mock.MagicMock() 189 | mock_socket.recv.return_value = b'x' 190 | server = self.get_server(mock_wsconn, { 191 | 'werkzeug.socket': mock_socket, 192 | }, events=[ 193 | [TextMessage('hello')], 194 | [TextMessage('hello1')], 195 | ], max_message_size=5) 196 | while server.connected: 197 | time.sleep(0.01) 198 | server.connected = True 199 | assert server.receive() == 'hello' 200 | assert server.receive(timeout=0) is None 201 | 202 | @mock.patch('simple_websocket.ws.WSConnection') 203 | def test_close(self, mock_wsconn): 204 | mock_socket = mock.MagicMock() 205 | mock_socket.recv.return_value = b'x' 206 | server = self.get_server(mock_wsconn, { 207 | 'werkzeug.socket': mock_socket, 208 | }) 209 | while server.connected: 210 | time.sleep(0.01) 211 | with pytest.raises(simple_websocket.ConnectionClosed) as exc: 212 | server.close() 213 | assert str(exc.value) == 'Connection closed: 1000 bye' 214 | server.connected = True 215 | server.close() 216 | assert not server.connected 217 | mock_socket.send.assert_called_with( 218 | b'CloseConnection(code=, ' 219 | b'reason=None)') 220 | 221 | @mock.patch('simple_websocket.ws.WSConnection') 222 | @mock.patch('simple_websocket.ws.time') 223 | def test_ping_pong(self, mock_time, mock_wsconn): 224 | mock_sel = mock.MagicMock() 225 | mock_sel().select.side_effect = [True, True, False, False] 226 | mock_time.side_effect = [0, 1, 25.01, 25.02, 28, 52, 76] 227 | mock_socket = mock.MagicMock() 228 | mock_socket.recv.side_effect = [b'x', b'x'] 229 | server = self.get_server( 230 | mock_wsconn, {'werkzeug.socket': mock_socket}, 231 | events=[ 232 | [TextMessage('hello')], 233 | [Pong()], 234 | ], ping_interval=25, thread_class=mock.MagicMock(), 235 | selector_class=mock_sel) 236 | server._thread() 237 | assert mock_socket.send.call_count == 4 238 | assert mock_socket.send.call_args_list[1][0][0].startswith(b'Ping') 239 | assert mock_socket.send.call_args_list[2][0][0].startswith(b'Ping') 240 | assert mock_socket.send.call_args_list[3][0][0].startswith(b'Close') 241 | 242 | @mock.patch('simple_websocket.ws.WSConnection') 243 | def test_subprotocols(self, mock_wsconn): 244 | mock_socket = mock.MagicMock() 245 | mock_socket.recv.return_value = b'x' 246 | 247 | server = self.get_server(mock_wsconn, { 248 | 'werkzeug.socket': mock_socket, 249 | }, client_subprotocols=['foo', 'bar'], server_subprotocols='bar') 250 | while server.connected: 251 | time.sleep(0.01) 252 | assert server.subprotocol == 'bar' 253 | 254 | server = self.get_server(mock_wsconn, { 255 | 'werkzeug.socket': mock_socket, 256 | }, client_subprotocols=['foo', 'bar'], server_subprotocols=['bar']) 257 | while server.connected: 258 | time.sleep(0.01) 259 | assert server.subprotocol == 'bar' 260 | 261 | server = self.get_server(mock_wsconn, { 262 | 'werkzeug.socket': mock_socket, 263 | }, client_subprotocols=['foo'], server_subprotocols=['foo', 'bar']) 264 | while server.connected: 265 | time.sleep(0.01) 266 | assert server.subprotocol == 'foo' 267 | 268 | server = self.get_server(mock_wsconn, { 269 | 'werkzeug.socket': mock_socket, 270 | }, client_subprotocols=['foo'], server_subprotocols=['bar', 'baz']) 271 | while server.connected: 272 | time.sleep(0.01) 273 | assert server.subprotocol is None 274 | 275 | server = self.get_server(mock_wsconn, { 276 | 'werkzeug.socket': mock_socket, 277 | }, client_subprotocols=['foo'], server_subprotocols=None) 278 | while server.connected: 279 | time.sleep(0.01) 280 | assert server.subprotocol is None 281 | -------------------------------------------------------------------------------- /tests/test_aioserver.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import unittest 3 | from unittest import mock 4 | import pytest # noqa: F401 5 | 6 | from wsproto.events import Request, CloseConnection, TextMessage, \ 7 | BytesMessage, Ping, Pong 8 | import simple_websocket 9 | from .helpers import make_sync, AsyncMock 10 | 11 | 12 | class AioSimpleWebSocketServerTestCase(unittest.TestCase): 13 | async def get_server(self, mock_wsconn, request, events=[], 14 | client_subprotocols=None, server_subprotocols=None, 15 | **kwargs): 16 | mock_wsconn().events.side_effect = \ 17 | [iter(ev) for ev in [[ 18 | Request(host='example.com', target='/ws', 19 | subprotocols=client_subprotocols or [])]] + 20 | events + [[CloseConnection(1000, 'bye')]]] 21 | mock_wsconn().send = lambda x: str(x).encode('utf-8') 22 | request.headers.update({ 23 | 'Host': 'example.com', 24 | 'Connection': 'Upgrade', 25 | 'Upgrade': 'websocket', 26 | 'Sec-Websocket-Key': 'Iv8io/9s+lYFgZWcXczP8Q==', 27 | 'Sec-Websocket-Version': '13', 28 | }) 29 | return await simple_websocket.AioServer.accept( 30 | aiohttp=request, subprotocols=server_subprotocols, **kwargs) 31 | 32 | @make_sync 33 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 34 | @mock.patch('simple_websocket.aiows.WSConnection') 35 | async def test_aiohttp(self, mock_wsconn, mock_open_connection): 36 | mock_request = mock.MagicMock(headers={}) 37 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 38 | wsock = mock.MagicMock() 39 | mock_open_connection.return_value = (rsock, wsock) 40 | server = await self.get_server(mock_wsconn, mock_request) 41 | assert server.rsock == rsock 42 | assert server.wsock == wsock 43 | assert server.mode == 'aiohttp' 44 | assert server.receive_bytes == 4096 45 | assert server.input_buffer == [] 46 | assert server.event.__class__.__name__ == 'Event' 47 | mock_wsconn().receive_data.assert_any_call( 48 | b'GET / HTTP/1.1\r\n' 49 | b'Host: example.com\r\n' 50 | b'Connection: Upgrade\r\n' 51 | b'Upgrade: websocket\r\n' 52 | b'Sec-Websocket-Key: Iv8io/9s+lYFgZWcXczP8Q==\r\n' 53 | b'Sec-Websocket-Version: 13\r\n\r\n') 54 | assert server.is_server 55 | 56 | @make_sync 57 | async def test_invalid_request(self): 58 | with pytest.raises(ValueError): 59 | await simple_websocket.AioServer.accept(aiohttp='foo', asgi='bar') 60 | with pytest.raises(ValueError): 61 | await simple_websocket.AioServer.accept(asgi='bar', sock='baz') 62 | 63 | @make_sync 64 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 65 | @mock.patch('simple_websocket.aiows.WSConnection') 66 | async def test_send(self, mock_wsconn, mock_open_connection): 67 | mock_request = mock.MagicMock(headers={}) 68 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 69 | wsock = mock.MagicMock() 70 | mock_open_connection.return_value = (rsock, wsock) 71 | server = await self.get_server(mock_wsconn, mock_request) 72 | while server.connected: 73 | await asyncio.sleep(0.01) 74 | with pytest.raises(simple_websocket.ConnectionClosed): 75 | await server.send('hello') 76 | server.connected = True 77 | await server.send('hello') 78 | wsock.write.assert_called_with( 79 | b"TextMessage(data='hello', frame_finished=True, " 80 | b"message_finished=True)") 81 | server.connected = True 82 | await server.send(b'hello') 83 | wsock.write.assert_called_with( 84 | b"Message(data=b'hello', frame_finished=True, " 85 | b"message_finished=True)") 86 | 87 | @make_sync 88 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 89 | @mock.patch('simple_websocket.aiows.WSConnection') 90 | async def test_receive(self, mock_wsconn, mock_open_connection): 91 | mock_request = mock.MagicMock(headers={}) 92 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 93 | wsock = mock.MagicMock() 94 | mock_open_connection.return_value = (rsock, wsock) 95 | server = await self.get_server(mock_wsconn, mock_request, events=[ 96 | [TextMessage('hello')], 97 | [BytesMessage(b'hello')], 98 | ]) 99 | while server.connected: 100 | await asyncio.sleep(0.01) 101 | server.connected = True 102 | assert await server.receive() == 'hello' 103 | assert await server.receive() == b'hello' 104 | assert await server.receive(timeout=0) is None 105 | 106 | @make_sync 107 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 108 | @mock.patch('simple_websocket.aiows.WSConnection') 109 | async def test_receive_after_close(self, mock_wsconn, 110 | mock_open_connection): 111 | mock_request = mock.MagicMock(headers={}) 112 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 113 | wsock = mock.MagicMock() 114 | mock_open_connection.return_value = (rsock, wsock) 115 | server = await self.get_server(mock_wsconn, mock_request, events=[ 116 | [TextMessage('hello')], 117 | ]) 118 | while server.connected: 119 | await asyncio.sleep(0.01) 120 | assert await server.receive() == 'hello' 121 | with pytest.raises(simple_websocket.ConnectionClosed): 122 | await server.receive() 123 | 124 | @make_sync 125 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 126 | @mock.patch('simple_websocket.aiows.WSConnection') 127 | async def test_receive_split_messages(self, mock_wsconn, 128 | mock_open_connection): 129 | mock_request = mock.MagicMock(headers={}) 130 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 131 | wsock = mock.MagicMock() 132 | mock_open_connection.return_value = (rsock, wsock) 133 | server = await self.get_server(mock_wsconn, mock_request, events=[ 134 | [TextMessage('hel', message_finished=False)], 135 | [TextMessage('lo')], 136 | [TextMessage('he', message_finished=False)], 137 | [TextMessage('l', message_finished=False)], 138 | [TextMessage('lo')], 139 | [BytesMessage(b'hel', message_finished=False)], 140 | [BytesMessage(b'lo')], 141 | [BytesMessage(b'he', message_finished=False)], 142 | [BytesMessage(b'l', message_finished=False)], 143 | [BytesMessage(b'lo')], 144 | ]) 145 | while server.connected: 146 | await asyncio.sleep(0.01) 147 | server.connected = True 148 | assert await server.receive() == 'hello' 149 | assert await server.receive() == 'hello' 150 | assert await server.receive() == b'hello' 151 | assert await server.receive() == b'hello' 152 | assert await server.receive(timeout=0) is None 153 | 154 | @make_sync 155 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 156 | @mock.patch('simple_websocket.aiows.WSConnection') 157 | async def test_receive_ping(self, mock_wsconn, mock_open_connection): 158 | mock_request = mock.MagicMock(headers={}) 159 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 160 | wsock = mock.MagicMock() 161 | mock_open_connection.return_value = (rsock, wsock) 162 | server = await self.get_server(mock_wsconn, mock_request, events=[ 163 | [Ping(b'hello')], 164 | ]) 165 | while server.connected: 166 | await asyncio.sleep(0.01) 167 | wsock.write.assert_any_call(b"Pong(payload=b'hello')") 168 | 169 | @make_sync 170 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 171 | @mock.patch('simple_websocket.aiows.WSConnection') 172 | async def test_receive_empty(self, mock_wsconn, mock_open_connection): 173 | mock_request = mock.MagicMock(headers={}) 174 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 175 | wsock = mock.MagicMock() 176 | mock_open_connection.return_value = (rsock, wsock) 177 | server = await self.get_server(mock_wsconn, mock_request, events=[ 178 | [TextMessage('hello')], 179 | ]) 180 | while server.connected: 181 | await asyncio.sleep(0.01) 182 | server.connected = True 183 | assert await server.receive() == 'hello' 184 | assert await server.receive(timeout=0) is None 185 | 186 | @make_sync 187 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 188 | @mock.patch('simple_websocket.aiows.WSConnection') 189 | async def test_receive_large(self, mock_wsconn, mock_open_connection): 190 | mock_request = mock.MagicMock(headers={}) 191 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 192 | wsock = mock.MagicMock() 193 | mock_open_connection.return_value = (rsock, wsock) 194 | server = await self.get_server(mock_wsconn, mock_request, events=[ 195 | [TextMessage('hello')], 196 | [TextMessage('hello1')], 197 | ], max_message_size=5) 198 | while server.connected: 199 | await asyncio.sleep(0.01) 200 | server.connected = True 201 | assert await server.receive() == 'hello' 202 | assert await server.receive(timeout=0) is None 203 | 204 | @make_sync 205 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 206 | @mock.patch('simple_websocket.aiows.WSConnection') 207 | async def test_close(self, mock_wsconn, mock_open_connection): 208 | mock_request = mock.MagicMock(headers={}) 209 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 210 | wsock = mock.MagicMock() 211 | mock_open_connection.return_value = (rsock, wsock) 212 | server = await self.get_server(mock_wsconn, mock_request) 213 | while server.connected: 214 | await asyncio.sleep(0.01) 215 | with pytest.raises(simple_websocket.ConnectionClosed) as exc: 216 | await server.close() 217 | assert str(exc.value) == 'Connection closed: 1000 bye' 218 | server.connected = True 219 | await server.close() 220 | assert not server.connected 221 | wsock.write.assert_called_with( 222 | b'CloseConnection(code=, ' 223 | b'reason=None)') 224 | 225 | @make_sync 226 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 227 | @mock.patch('simple_websocket.aiows.WSConnection') 228 | @mock.patch('simple_websocket.aiows.time') 229 | @mock.patch('simple_websocket.aiows.asyncio.wait_for') 230 | async def test_ping_pong(self, mock_wait_for, mock_time, mock_wsconn, 231 | mock_open_connection): 232 | mock_request = mock.MagicMock(headers={}) 233 | rsock = mock.MagicMock(read=AsyncMock()) 234 | wsock = mock.MagicMock() 235 | mock_open_connection.return_value = (rsock, wsock) 236 | server = await self.get_server(mock_wsconn, mock_request, events=[ 237 | [TextMessage('hello')], 238 | [Pong()], 239 | ], ping_interval=25) 240 | mock_wait_for.side_effect = [b'x', b'x', asyncio.TimeoutError, 241 | asyncio.TimeoutError] 242 | mock_time.side_effect = [0, 1, 25.01, 25.02, 28, 52, 76] 243 | await server._task() 244 | assert wsock.write.call_count == 4 245 | assert wsock.write.call_args_list[1][0][0].startswith(b'Ping') 246 | assert wsock.write.call_args_list[2][0][0].startswith(b'Ping') 247 | assert wsock.write.call_args_list[3][0][0].startswith(b'Close') 248 | 249 | @make_sync 250 | @mock.patch('simple_websocket.aiows.asyncio.open_connection') 251 | @mock.patch('simple_websocket.aiows.WSConnection') 252 | async def test_subprotocols(self, mock_wsconn, mock_open_connection): 253 | mock_request = mock.MagicMock(headers={}) 254 | rsock = mock.MagicMock(read=AsyncMock(return_value=b'x')) 255 | wsock = mock.MagicMock() 256 | mock_open_connection.return_value = (rsock, wsock) 257 | 258 | server = await self.get_server(mock_wsconn, mock_request, 259 | client_subprotocols=['foo', 'bar'], 260 | server_subprotocols='bar') 261 | while server.connected: 262 | await asyncio.sleep(0.01) 263 | assert server.subprotocol == 'bar' 264 | 265 | server = await self.get_server(mock_wsconn, mock_request, 266 | client_subprotocols=['foo', 'bar'], 267 | server_subprotocols=['bar']) 268 | while server.connected: 269 | await asyncio.sleep(0.01) 270 | assert server.subprotocol == 'bar' 271 | 272 | server = await self.get_server(mock_wsconn, mock_request, 273 | client_subprotocols=['foo'], 274 | server_subprotocols=['foo', 'bar']) 275 | while server.connected: 276 | await asyncio.sleep(0.01) 277 | assert server.subprotocol == 'foo' 278 | 279 | server = await self.get_server(mock_wsconn, mock_request, 280 | client_subprotocols=['foo'], 281 | server_subprotocols=['bar', 'baz']) 282 | while server.connected: 283 | await asyncio.sleep(0.01) 284 | assert server.subprotocol is None 285 | 286 | server = await self.get_server(mock_wsconn, mock_request, 287 | client_subprotocols=['foo'], 288 | server_subprotocols=None) 289 | while server.connected: 290 | await asyncio.sleep(0.01) 291 | assert server.subprotocol is None 292 | -------------------------------------------------------------------------------- /src/simple_websocket/aiows.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import ssl 3 | from time import time 4 | from urllib.parse import urlsplit 5 | 6 | from wsproto import ConnectionType, WSConnection 7 | from wsproto.events import ( 8 | AcceptConnection, 9 | RejectConnection, 10 | CloseConnection, 11 | Message, 12 | Request, 13 | Ping, 14 | Pong, 15 | TextMessage, 16 | BytesMessage, 17 | ) 18 | from wsproto.extensions import PerMessageDeflate 19 | from wsproto.frame_protocol import CloseReason 20 | from wsproto.utilities import LocalProtocolError 21 | from .errors import ConnectionError, ConnectionClosed 22 | 23 | 24 | class AioBase: 25 | def __init__(self, connection_type=None, receive_bytes=4096, 26 | ping_interval=None, max_message_size=None): 27 | #: The name of the subprotocol chosen for the WebSocket connection. 28 | self.subprotocol = None 29 | 30 | self.connection_type = connection_type 31 | self.receive_bytes = receive_bytes 32 | self.ping_interval = ping_interval 33 | self.max_message_size = max_message_size 34 | self.pong_received = True 35 | self.input_buffer = [] 36 | self.incoming_message = None 37 | self.incoming_message_len = 0 38 | self.connected = False 39 | self.is_server = (connection_type == ConnectionType.SERVER) 40 | self.close_reason = CloseReason.NO_STATUS_RCVD 41 | self.close_message = None 42 | 43 | self.rsock = None 44 | self.wsock = None 45 | self.event = asyncio.Event() 46 | self.ws = None 47 | self.task = None 48 | 49 | async def connect(self): 50 | self.ws = WSConnection(self.connection_type) 51 | await self.handshake() 52 | 53 | if not self.connected: # pragma: no cover 54 | raise ConnectionError() 55 | self.task = asyncio.create_task(self._task()) 56 | 57 | async def handshake(self): # pragma: no cover 58 | # to be implemented by subclasses 59 | pass 60 | 61 | async def send(self, data): 62 | """Send data over the WebSocket connection. 63 | 64 | :param data: The data to send. If ``data`` is of type ``bytes``, then 65 | a binary message is sent. Else, the message is sent in 66 | text format. 67 | """ 68 | if not self.connected: 69 | raise ConnectionClosed(self.close_reason, self.close_message) 70 | if isinstance(data, bytes): 71 | out_data = self.ws.send(Message(data=data)) 72 | else: 73 | out_data = self.ws.send(TextMessage(data=str(data))) 74 | self.wsock.write(out_data) 75 | 76 | async def receive(self, timeout=None): 77 | """Receive data over the WebSocket connection. 78 | 79 | :param timeout: Amount of time to wait for the data, in seconds. Set 80 | to ``None`` (the default) to wait indefinitely. Set 81 | to 0 to read without blocking. 82 | 83 | The data received is returned, as ``bytes`` or ``str``, depending on 84 | the type of the incoming message. 85 | """ 86 | while self.connected and not self.input_buffer: 87 | try: 88 | await asyncio.wait_for(self.event.wait(), timeout=timeout) 89 | except asyncio.TimeoutError: 90 | return None 91 | self.event.clear() # pragma: no cover 92 | try: 93 | return self.input_buffer.pop(0) 94 | except IndexError: 95 | pass 96 | if not self.connected: # pragma: no cover 97 | raise ConnectionClosed(self.close_reason, self.close_message) 98 | 99 | async def close(self, reason=None, message=None): 100 | """Close the WebSocket connection. 101 | 102 | :param reason: A numeric status code indicating the reason of the 103 | closure, as defined by the WebSocket specification. The 104 | default is 1000 (normal closure). 105 | :param message: A text message to be sent to the other side. 106 | """ 107 | if not self.connected: 108 | raise ConnectionClosed(self.close_reason, self.close_message) 109 | out_data = self.ws.send(CloseConnection( 110 | reason or CloseReason.NORMAL_CLOSURE, message)) 111 | try: 112 | self.wsock.write(out_data) 113 | except BrokenPipeError: # pragma: no cover 114 | pass 115 | self.connected = False 116 | 117 | def choose_subprotocol(self, request): # pragma: no cover 118 | # The method should return the subprotocol to use, or ``None`` if no 119 | # subprotocol is chosen. Can be overridden by subclasses that implement 120 | # the server-side of the WebSocket protocol. 121 | return None 122 | 123 | async def _task(self): 124 | next_ping = None 125 | if self.ping_interval: 126 | next_ping = time() + self.ping_interval 127 | 128 | while self.connected: 129 | try: 130 | in_data = b'' 131 | if next_ping: 132 | now = time() 133 | timed_out = True 134 | if next_ping > now: 135 | timed_out = False 136 | try: 137 | in_data = await asyncio.wait_for( 138 | self.rsock.read(self.receive_bytes), 139 | timeout=next_ping - now) 140 | except asyncio.TimeoutError: 141 | timed_out = True 142 | if timed_out: 143 | # we reached the timeout, we have to send a ping 144 | if not self.pong_received: 145 | await self.close( 146 | reason=CloseReason.POLICY_VIOLATION, 147 | message='Ping/Pong timeout') 148 | break 149 | self.pong_received = False 150 | self.wsock.write(self.ws.send(Ping())) 151 | next_ping = max(now, next_ping) + self.ping_interval 152 | continue 153 | else: 154 | in_data = await self.rsock.read(self.receive_bytes) 155 | if len(in_data) == 0: 156 | raise OSError() 157 | except (OSError, ConnectionResetError): # pragma: no cover 158 | self.connected = False 159 | self.event.set() 160 | break 161 | 162 | self.ws.receive_data(in_data) 163 | self.connected = await self._handle_events() 164 | self.wsock.close() 165 | 166 | async def _handle_events(self): 167 | keep_going = True 168 | out_data = b'' 169 | for event in self.ws.events(): 170 | try: 171 | if isinstance(event, Request): 172 | self.subprotocol = self.choose_subprotocol(event) 173 | out_data += self.ws.send(AcceptConnection( 174 | subprotocol=self.subprotocol, 175 | extensions=[PerMessageDeflate()])) 176 | elif isinstance(event, CloseConnection): 177 | if self.is_server: 178 | out_data += self.ws.send(event.response()) 179 | self.close_reason = event.code 180 | self.close_message = event.reason 181 | self.connected = False 182 | self.event.set() 183 | keep_going = False 184 | elif isinstance(event, Ping): 185 | out_data += self.ws.send(event.response()) 186 | elif isinstance(event, Pong): 187 | self.pong_received = True 188 | elif isinstance(event, (TextMessage, BytesMessage)): 189 | self.incoming_message_len += len(event.data) 190 | if self.max_message_size and \ 191 | self.incoming_message_len > self.max_message_size: 192 | out_data += self.ws.send(CloseConnection( 193 | CloseReason.MESSAGE_TOO_BIG, 'Message is too big')) 194 | self.event.set() 195 | keep_going = False 196 | break 197 | if self.incoming_message is None: 198 | # store message as is first 199 | # if it is the first of a group, the message will be 200 | # converted to bytearray on arrival of the second 201 | # part, since bytearrays are mutable and can be 202 | # concatenated more efficiently 203 | self.incoming_message = event.data 204 | elif isinstance(event, TextMessage): 205 | if not isinstance(self.incoming_message, bytearray): 206 | # convert to bytearray and append 207 | self.incoming_message = bytearray( 208 | (self.incoming_message + event.data).encode()) 209 | else: 210 | # append to bytearray 211 | self.incoming_message += event.data.encode() 212 | else: 213 | if not isinstance(self.incoming_message, bytearray): 214 | # convert to mutable bytearray and append 215 | self.incoming_message = bytearray( 216 | self.incoming_message + event.data) 217 | else: 218 | # append to bytearray 219 | self.incoming_message += event.data 220 | if not event.message_finished: 221 | continue 222 | if isinstance(self.incoming_message, (str, bytes)): 223 | # single part message 224 | self.input_buffer.append(self.incoming_message) 225 | elif isinstance(event, TextMessage): 226 | # convert multi-part message back to text 227 | self.input_buffer.append( 228 | self.incoming_message.decode()) 229 | else: 230 | # convert multi-part message back to bytes 231 | self.input_buffer.append(bytes(self.incoming_message)) 232 | self.incoming_message = None 233 | self.incoming_message_len = 0 234 | self.event.set() 235 | else: # pragma: no cover 236 | pass 237 | except LocalProtocolError: # pragma: no cover 238 | out_data = b'' 239 | self.event.set() 240 | keep_going = False 241 | if out_data: 242 | self.wsock.write(out_data) 243 | return keep_going 244 | 245 | 246 | class AioServer(AioBase): 247 | """This class implements a WebSocket server. 248 | 249 | Instead of creating an instance of this class directly, use the 250 | ``accept()`` class method to create individual instances of the server, 251 | each bound to a client request. 252 | """ 253 | def __init__(self, request, subprotocols=None, receive_bytes=4096, 254 | ping_interval=None, max_message_size=None): 255 | super().__init__(connection_type=ConnectionType.SERVER, 256 | receive_bytes=receive_bytes, 257 | ping_interval=ping_interval, 258 | max_message_size=max_message_size) 259 | self.request = request 260 | self.headers = {} 261 | self.subprotocols = subprotocols or [] 262 | if isinstance(self.subprotocols, str): 263 | self.subprotocols = [self.subprotocols] 264 | self.mode = 'unknown' 265 | 266 | @classmethod 267 | async def accept(cls, aiohttp=None, asgi=None, sock=None, headers=None, 268 | subprotocols=None, receive_bytes=4096, ping_interval=None, 269 | max_message_size=None): 270 | """Accept a WebSocket connection from a client. 271 | 272 | :param aiohttp: The request object from aiohttp. If this argument is 273 | provided, ``asgi``, ``sock`` and ``headers`` must not 274 | be set. 275 | :param asgi: A (scope, receive, send) tuple from an ASGI request. If 276 | this argument is provided, ``aiohttp``, ``sock`` and 277 | ``headers`` must not be set. 278 | :param sock: A connected socket to use. If this argument is provided, 279 | ``aiohttp`` and ``asgi`` must not be set. The ``headers`` 280 | argument must be set with the incoming request headers. 281 | :param headers: A dictionary with the incoming request headers, when 282 | ``sock`` is used. 283 | :param subprotocols: A list of supported subprotocols, or ``None`` (the 284 | default) to disable subprotocol negotiation. 285 | :param receive_bytes: The size of the receive buffer, in bytes. The 286 | default is 4096. 287 | :param ping_interval: Send ping packets to clients at the requested 288 | interval in seconds. Set to ``None`` (the 289 | default) to disable ping/pong logic. Enable to 290 | prevent disconnections when the line is idle for 291 | a certain amount of time, or to detect 292 | unresponsive clients and disconnect them. A 293 | recommended interval is 25 seconds. 294 | :param max_message_size: The maximum size allowed for a message, in 295 | bytes, or ``None`` for no limit. The default 296 | is ``None``. 297 | """ 298 | if aiohttp and (asgi or sock): 299 | raise ValueError('aiohttp argument cannot be used with asgi or ' 300 | 'sock') 301 | if asgi and (aiohttp or sock): 302 | raise ValueError('asgi argument cannot be used with aiohttp or ' 303 | 'sock') 304 | if asgi: # pragma: no cover 305 | from .asgi import WebSocketASGI 306 | return await WebSocketASGI.accept(asgi[0], asgi[1], asgi[2], 307 | subprotocols=subprotocols) 308 | 309 | ws = cls({'aiohttp': aiohttp, 'sock': sock, 'headers': headers}, 310 | subprotocols=subprotocols, receive_bytes=receive_bytes, 311 | ping_interval=ping_interval, 312 | max_message_size=max_message_size) 313 | await ws._accept() 314 | return ws 315 | 316 | async def _accept(self): 317 | if self.request['sock']: # pragma: no cover 318 | # custom integration, request is a tuple with (socket, headers) 319 | sock = self.request['sock'] 320 | self.headers = self.request['headers'] 321 | self.mode = 'custom' 322 | elif self.request['aiohttp']: 323 | # default implementation, request is an aiohttp request object 324 | sock = self.request['aiohttp'].transport.get_extra_info( 325 | 'socket').dup() 326 | self.headers = self.request['aiohttp'].headers 327 | self.mode = 'aiohttp' 328 | else: # pragma: no cover 329 | raise ValueError('Invalid request') 330 | self.rsock, self.wsock = await asyncio.open_connection(sock=sock) 331 | await super().connect() 332 | 333 | async def handshake(self): 334 | in_data = b'GET / HTTP/1.1\r\n' 335 | for header, value in self.headers.items(): 336 | in_data += f'{header}: {value}\r\n'.encode() 337 | in_data += b'\r\n' 338 | self.ws.receive_data(in_data) 339 | self.connected = await self._handle_events() 340 | 341 | def choose_subprotocol(self, request): 342 | """Choose a subprotocol to use for the WebSocket connection. 343 | 344 | The default implementation selects the first protocol requested by the 345 | client that is accepted by the server. Subclasses can override this 346 | method to implement a different subprotocol negotiation algorithm. 347 | 348 | :param request: A ``Request`` object. 349 | 350 | The method should return the subprotocol to use, or ``None`` if no 351 | subprotocol is chosen. 352 | """ 353 | for subprotocol in request.subprotocols: 354 | if subprotocol in self.subprotocols: 355 | return subprotocol 356 | return None 357 | 358 | 359 | class AioClient(AioBase): 360 | """This class implements a WebSocket client. 361 | 362 | Instead of creating an instance of this class directly, use the 363 | ``connect()`` class method to create an instance that is connected to a 364 | server. 365 | """ 366 | def __init__(self, url, subprotocols=None, headers=None, 367 | receive_bytes=4096, ping_interval=None, max_message_size=None, 368 | ssl_context=None): 369 | super().__init__(connection_type=ConnectionType.CLIENT, 370 | receive_bytes=receive_bytes, 371 | ping_interval=ping_interval, 372 | max_message_size=max_message_size) 373 | self.url = url 374 | self.ssl_context = ssl_context 375 | parsed_url = urlsplit(url) 376 | self.is_secure = parsed_url.scheme in ['https', 'wss'] 377 | self.host = parsed_url.hostname 378 | self.port = parsed_url.port or (443 if self.is_secure else 80) 379 | self.path = parsed_url.path 380 | if parsed_url.query: 381 | self.path += '?' + parsed_url.query 382 | self.subprotocols = subprotocols or [] 383 | if isinstance(self.subprotocols, str): 384 | self.subprotocols = [self.subprotocols] 385 | 386 | self.extra_headeers = [] 387 | if isinstance(headers, dict): 388 | for key, value in headers.items(): 389 | self.extra_headeers.append((key, value)) 390 | elif isinstance(headers, list): 391 | self.extra_headeers = headers 392 | 393 | @classmethod 394 | async def connect(cls, url, subprotocols=None, headers=None, 395 | receive_bytes=4096, ping_interval=None, 396 | max_message_size=None, ssl_context=None, 397 | thread_class=None, event_class=None): 398 | """Returns a WebSocket client connection. 399 | 400 | :param url: The connection URL. Both ``ws://`` and ``wss://`` URLs are 401 | accepted. 402 | :param subprotocols: The name of the subprotocol to use, or a list of 403 | subprotocol names in order of preference. Set to 404 | ``None`` (the default) to not use a subprotocol. 405 | :param headers: A dictionary or list of tuples with additional HTTP 406 | headers to send with the connection request. Note that 407 | custom headers are not supported by the WebSocket 408 | protocol, so the use of this parameter is not 409 | recommended. 410 | :param receive_bytes: The size of the receive buffer, in bytes. The 411 | default is 4096. 412 | :param ping_interval: Send ping packets to the server at the requested 413 | interval in seconds. Set to ``None`` (the 414 | default) to disable ping/pong logic. Enable to 415 | prevent disconnections when the line is idle for 416 | a certain amount of time, or to detect an 417 | unresponsive server and disconnect. A recommended 418 | interval is 25 seconds. In general it is 419 | preferred to enable ping/pong on the server, and 420 | let the client respond with pong (which it does 421 | regardless of this setting). 422 | :param max_message_size: The maximum size allowed for a message, in 423 | bytes, or ``None`` for no limit. The default 424 | is ``None``. 425 | :param ssl_context: An ``SSLContext`` instance, if a default SSL 426 | context isn't sufficient. 427 | """ 428 | ws = cls(url, subprotocols=subprotocols, headers=headers, 429 | receive_bytes=receive_bytes, ping_interval=ping_interval, 430 | max_message_size=max_message_size, ssl_context=ssl_context) 431 | await ws._connect() 432 | return ws 433 | 434 | async def _connect(self): 435 | if self.is_secure: # pragma: no cover 436 | if self.ssl_context is None: 437 | self.ssl_context = ssl.create_default_context( 438 | purpose=ssl.Purpose.SERVER_AUTH) 439 | self.rsock, self.wsock = await asyncio.open_connection( 440 | self.host, self.port, ssl=self.ssl_context) 441 | await super().connect() 442 | 443 | async def handshake(self): 444 | out_data = self.ws.send(Request(host=self.host, target=self.path, 445 | subprotocols=self.subprotocols, 446 | extra_headers=self.extra_headeers)) 447 | self.wsock.write(out_data) 448 | 449 | while True: 450 | in_data = await self.rsock.read(self.receive_bytes) 451 | self.ws.receive_data(in_data) 452 | try: 453 | event = next(self.ws.events()) 454 | except StopIteration: # pragma: no cover 455 | pass 456 | else: # pragma: no cover 457 | break 458 | if isinstance(event, RejectConnection): # pragma: no cover 459 | raise ConnectionError(event.status_code) 460 | elif not isinstance(event, AcceptConnection): # pragma: no cover 461 | raise ConnectionError(400) 462 | self.subprotocol = event.subprotocol 463 | self.connected = True 464 | 465 | async def close(self, reason=None, message=None): 466 | await super().close(reason=reason, message=message) 467 | self.wsock.close() 468 | -------------------------------------------------------------------------------- /src/simple_websocket/ws.py: -------------------------------------------------------------------------------- 1 | import selectors 2 | import socket 3 | import ssl 4 | from time import time 5 | from urllib.parse import urlsplit 6 | 7 | from wsproto import ConnectionType, WSConnection 8 | from wsproto.events import ( 9 | AcceptConnection, 10 | RejectConnection, 11 | CloseConnection, 12 | Message, 13 | Request, 14 | Ping, 15 | Pong, 16 | TextMessage, 17 | BytesMessage, 18 | ) 19 | from wsproto.extensions import PerMessageDeflate 20 | from wsproto.frame_protocol import CloseReason 21 | from wsproto.utilities import LocalProtocolError 22 | from .errors import ConnectionError, ConnectionClosed 23 | 24 | 25 | class Base: 26 | def __init__(self, sock=None, connection_type=None, receive_bytes=4096, 27 | ping_interval=None, max_message_size=None, 28 | thread_class=None, event_class=None, selector_class=None): 29 | #: The name of the subprotocol chosen for the WebSocket connection. 30 | self.subprotocol = None 31 | 32 | self.sock = sock 33 | self.receive_bytes = receive_bytes 34 | self.ping_interval = ping_interval 35 | self.max_message_size = max_message_size 36 | self.pong_received = True 37 | self.input_buffer = [] 38 | self.incoming_message = None 39 | self.incoming_message_len = 0 40 | self.connected = False 41 | self.is_server = (connection_type == ConnectionType.SERVER) 42 | self.close_reason = CloseReason.NO_STATUS_RCVD 43 | self.close_message = None 44 | 45 | if thread_class is None: 46 | import threading 47 | thread_class = threading.Thread 48 | if event_class is None: # pragma: no branch 49 | import threading 50 | event_class = threading.Event 51 | if selector_class is None: 52 | selector_class = selectors.DefaultSelector 53 | self.selector_class = selector_class 54 | self.event = event_class() 55 | 56 | self.ws = WSConnection(connection_type) 57 | self.handshake() 58 | 59 | if not self.connected: # pragma: no cover 60 | raise ConnectionError() 61 | self.thread = thread_class(target=self._thread) 62 | self.thread.name = self.thread.name.replace( 63 | '(_thread)', '(simple_websocket.Base._thread)') 64 | self.thread.start() 65 | 66 | def handshake(self): # pragma: no cover 67 | # to be implemented by subclasses 68 | pass 69 | 70 | def send(self, data): 71 | """Send data over the WebSocket connection. 72 | 73 | :param data: The data to send. If ``data`` is of type ``bytes``, then 74 | a binary message is sent. Else, the message is sent in 75 | text format. 76 | """ 77 | if not self.connected: 78 | raise ConnectionClosed(self.close_reason, self.close_message) 79 | if isinstance(data, bytes): 80 | out_data = self.ws.send(Message(data=data)) 81 | else: 82 | out_data = self.ws.send(TextMessage(data=str(data))) 83 | self.sock.send(out_data) 84 | 85 | def receive(self, timeout=None): 86 | """Receive data over the WebSocket connection. 87 | 88 | :param timeout: Amount of time to wait for the data, in seconds. Set 89 | to ``None`` (the default) to wait indefinitely. Set 90 | to 0 to read without blocking. 91 | 92 | The data received is returned, as ``bytes`` or ``str``, depending on 93 | the type of the incoming message. 94 | """ 95 | while self.connected and not self.input_buffer: 96 | if not self.event.wait(timeout=timeout): 97 | return None 98 | self.event.clear() 99 | try: 100 | return self.input_buffer.pop(0) 101 | except IndexError: 102 | pass 103 | if not self.connected: # pragma: no cover 104 | raise ConnectionClosed(self.close_reason, self.close_message) 105 | 106 | def close(self, reason=None, message=None): 107 | """Close the WebSocket connection. 108 | 109 | :param reason: A numeric status code indicating the reason of the 110 | closure, as defined by the WebSocket specification. The 111 | default is 1000 (normal closure). 112 | :param message: A text message to be sent to the other side. 113 | """ 114 | if not self.connected: 115 | raise ConnectionClosed(self.close_reason, self.close_message) 116 | out_data = self.ws.send(CloseConnection( 117 | reason or CloseReason.NORMAL_CLOSURE, message)) 118 | try: 119 | self.sock.send(out_data) 120 | except BrokenPipeError: # pragma: no cover 121 | pass 122 | self.connected = False 123 | 124 | def choose_subprotocol(self, request): # pragma: no cover 125 | # The method should return the subprotocol to use, or ``None`` if no 126 | # subprotocol is chosen. Can be overridden by subclasses that implement 127 | # the server-side of the WebSocket protocol. 128 | return None 129 | 130 | def _thread(self): 131 | sel = None 132 | if self.ping_interval: 133 | next_ping = time() + self.ping_interval 134 | sel = self.selector_class() 135 | try: 136 | sel.register(self.sock, selectors.EVENT_READ, True) 137 | except ValueError: # pragma: no cover 138 | self.connected = False 139 | 140 | while self.connected: 141 | try: 142 | if sel: 143 | now = time() 144 | if next_ping <= now or not sel.select(next_ping - now): 145 | # we reached the timeout, we have to send a ping 146 | if not self.pong_received: 147 | self.close(reason=CloseReason.POLICY_VIOLATION, 148 | message='Ping/Pong timeout') 149 | self.event.set() 150 | break 151 | self.pong_received = False 152 | self.sock.send(self.ws.send(Ping())) 153 | next_ping = max(now, next_ping) + self.ping_interval 154 | continue 155 | in_data = self.sock.recv(self.receive_bytes) 156 | if len(in_data) == 0: 157 | raise OSError() 158 | self.ws.receive_data(in_data) 159 | self.connected = self._handle_events() 160 | except (OSError, ConnectionResetError, 161 | LocalProtocolError): # pragma: no cover 162 | self.connected = False 163 | self.event.set() 164 | break 165 | sel.close() if sel else None 166 | self.sock.close() 167 | 168 | def _handle_events(self): 169 | keep_going = True 170 | out_data = b'' 171 | for event in self.ws.events(): 172 | try: 173 | if isinstance(event, Request): 174 | self.subprotocol = self.choose_subprotocol(event) 175 | out_data += self.ws.send(AcceptConnection( 176 | subprotocol=self.subprotocol, 177 | extensions=[PerMessageDeflate()])) 178 | elif isinstance(event, CloseConnection): 179 | if self.is_server: 180 | out_data += self.ws.send(event.response()) 181 | self.close_reason = event.code 182 | self.close_message = event.reason 183 | self.connected = False 184 | self.event.set() 185 | keep_going = False 186 | elif isinstance(event, Ping): 187 | out_data += self.ws.send(event.response()) 188 | elif isinstance(event, Pong): 189 | self.pong_received = True 190 | elif isinstance(event, (TextMessage, BytesMessage)): 191 | self.incoming_message_len += len(event.data) 192 | if self.max_message_size and \ 193 | self.incoming_message_len > self.max_message_size: 194 | out_data += self.ws.send(CloseConnection( 195 | CloseReason.MESSAGE_TOO_BIG, 'Message is too big')) 196 | self.event.set() 197 | keep_going = False 198 | break 199 | if self.incoming_message is None: 200 | # store message as is first 201 | # if it is the first of a group, the message will be 202 | # converted to bytearray on arrival of the second 203 | # part, since bytearrays are mutable and can be 204 | # concatenated more efficiently 205 | self.incoming_message = event.data 206 | elif isinstance(event, TextMessage): 207 | if not isinstance(self.incoming_message, bytearray): 208 | # convert to bytearray and append 209 | self.incoming_message = bytearray( 210 | (self.incoming_message + event.data).encode()) 211 | else: 212 | # append to bytearray 213 | self.incoming_message += event.data.encode() 214 | else: 215 | if not isinstance(self.incoming_message, bytearray): 216 | # convert to mutable bytearray and append 217 | self.incoming_message = bytearray( 218 | self.incoming_message + event.data) 219 | else: 220 | # append to bytearray 221 | self.incoming_message += event.data 222 | if not event.message_finished: 223 | continue 224 | if isinstance(self.incoming_message, (str, bytes)): 225 | # single part message 226 | self.input_buffer.append(self.incoming_message) 227 | elif isinstance(event, TextMessage): 228 | # convert multi-part message back to text 229 | self.input_buffer.append( 230 | self.incoming_message.decode()) 231 | else: 232 | # convert multi-part message back to bytes 233 | self.input_buffer.append(bytes(self.incoming_message)) 234 | self.incoming_message = None 235 | self.incoming_message_len = 0 236 | self.event.set() 237 | else: # pragma: no cover 238 | pass 239 | except LocalProtocolError: # pragma: no cover 240 | out_data = b'' 241 | self.event.set() 242 | keep_going = False 243 | if out_data: 244 | self.sock.send(out_data) 245 | return keep_going 246 | 247 | 248 | class Server(Base): 249 | """This class implements a WebSocket server. 250 | 251 | Instead of creating an instance of this class directly, use the 252 | ``accept()`` class method to create individual instances of the server, 253 | each bound to a client request. 254 | """ 255 | def __init__(self, environ, subprotocols=None, receive_bytes=4096, 256 | ping_interval=None, max_message_size=None, thread_class=None, 257 | event_class=None, selector_class=None): 258 | self.environ = environ 259 | self.subprotocols = subprotocols or [] 260 | if isinstance(self.subprotocols, str): 261 | self.subprotocols = [self.subprotocols] 262 | self.mode = 'unknown' 263 | sock = None 264 | if 'werkzeug.socket' in environ: 265 | # extract socket from Werkzeug's WSGI environment 266 | sock = environ.get('werkzeug.socket') 267 | self.mode = 'werkzeug' 268 | elif 'gunicorn.socket' in environ: 269 | # extract socket from Gunicorn WSGI environment 270 | sock = environ.get('gunicorn.socket') 271 | self.mode = 'gunicorn' 272 | elif 'eventlet.input' in environ: # pragma: no cover 273 | # extract socket from Eventlet's WSGI environment 274 | sock = environ.get('eventlet.input').get_socket() 275 | self.mode = 'eventlet' 276 | elif environ.get('SERVER_SOFTWARE', '').startswith( 277 | 'gevent'): # pragma: no cover 278 | # extract socket from Gevent's WSGI environment 279 | wsgi_input = environ['wsgi.input'] 280 | if not hasattr(wsgi_input, 'raw') and hasattr(wsgi_input, 'rfile'): 281 | wsgi_input = wsgi_input.rfile 282 | if hasattr(wsgi_input, 'raw'): 283 | sock = wsgi_input.raw._sock 284 | try: 285 | sock = sock.dup() 286 | except NotImplementedError: 287 | pass 288 | self.mode = 'gevent' 289 | if sock is None: 290 | raise RuntimeError('Cannot obtain socket from WSGI environment.') 291 | super().__init__(sock, connection_type=ConnectionType.SERVER, 292 | receive_bytes=receive_bytes, 293 | ping_interval=ping_interval, 294 | max_message_size=max_message_size, 295 | thread_class=thread_class, event_class=event_class, 296 | selector_class=selector_class) 297 | 298 | @classmethod 299 | def accept(cls, environ, subprotocols=None, receive_bytes=4096, 300 | ping_interval=None, max_message_size=None, thread_class=None, 301 | event_class=None, selector_class=None): 302 | """Accept a WebSocket connection from a client. 303 | 304 | :param environ: A WSGI ``environ`` dictionary with the request details. 305 | Among other things, this class expects to find the 306 | low-level network socket for the connection somewhere 307 | in this dictionary. Since the WSGI specification does 308 | not cover where or how to store this socket, each web 309 | server does this in its own different way. Werkzeug, 310 | Gunicorn, Eventlet and Gevent are the only web servers 311 | that are currently supported. 312 | :param subprotocols: A list of supported subprotocols, or ``None`` (the 313 | default) to disable subprotocol negotiation. 314 | :param receive_bytes: The size of the receive buffer, in bytes. The 315 | default is 4096. 316 | :param ping_interval: Send ping packets to clients at the requested 317 | interval in seconds. Set to ``None`` (the 318 | default) to disable ping/pong logic. Enable to 319 | prevent disconnections when the line is idle for 320 | a certain amount of time, or to detect 321 | unresponsive clients and disconnect them. A 322 | recommended interval is 25 seconds. 323 | :param max_message_size: The maximum size allowed for a message, in 324 | bytes, or ``None`` for no limit. The default 325 | is ``None``. 326 | :param thread_class: The ``Thread`` class to use when creating 327 | background threads. The default is the 328 | ``threading.Thread`` class from the Python 329 | standard library. 330 | :param event_class: The ``Event`` class to use when creating event 331 | objects. The default is the `threading.Event`` 332 | class from the Python standard library. 333 | :param selector_class: The ``Selector`` class to use when creating 334 | selectors. The default is the 335 | ``selectors.DefaultSelector`` class from the 336 | Python standard library. 337 | """ 338 | return cls(environ, subprotocols=subprotocols, 339 | receive_bytes=receive_bytes, ping_interval=ping_interval, 340 | max_message_size=max_message_size, 341 | thread_class=thread_class, event_class=event_class, 342 | selector_class=selector_class) 343 | 344 | def handshake(self): 345 | in_data = b'GET / HTTP/1.1\r\n' 346 | for key, value in self.environ.items(): 347 | if key.startswith('HTTP_'): 348 | header = '-'.join([p.capitalize() for p in key[5:].split('_')]) 349 | in_data += f'{header}: {value}\r\n'.encode() 350 | in_data += b'\r\n' 351 | self.ws.receive_data(in_data) 352 | self.connected = self._handle_events() 353 | 354 | def choose_subprotocol(self, request): 355 | """Choose a subprotocol to use for the WebSocket connection. 356 | 357 | The default implementation selects the first protocol requested by the 358 | client that is accepted by the server. Subclasses can override this 359 | method to implement a different subprotocol negotiation algorithm. 360 | 361 | :param request: A ``Request`` object. 362 | 363 | The method should return the subprotocol to use, or ``None`` if no 364 | subprotocol is chosen. 365 | """ 366 | for subprotocol in request.subprotocols: 367 | if subprotocol in self.subprotocols: 368 | return subprotocol 369 | return None 370 | 371 | 372 | class Client(Base): 373 | """This class implements a WebSocket client. 374 | 375 | Instead of creating an instance of this class directly, use the 376 | ``connect()`` class method to create an instance that is connected to a 377 | server. 378 | """ 379 | def __init__(self, url, subprotocols=None, headers=None, 380 | receive_bytes=4096, ping_interval=None, max_message_size=None, 381 | ssl_context=None, thread_class=None, event_class=None): 382 | parsed_url = urlsplit(url) 383 | is_secure = parsed_url.scheme in ['https', 'wss'] 384 | self.host = parsed_url.hostname 385 | self.port = parsed_url.port or (443 if is_secure else 80) 386 | self.path = parsed_url.path 387 | if parsed_url.query: 388 | self.path += '?' + parsed_url.query 389 | self.subprotocols = subprotocols or [] 390 | if isinstance(self.subprotocols, str): 391 | self.subprotocols = [self.subprotocols] 392 | 393 | self.extra_headeers = [] 394 | if isinstance(headers, dict): 395 | for key, value in headers.items(): 396 | self.extra_headeers.append((key, value)) 397 | elif isinstance(headers, list): 398 | self.extra_headeers = headers 399 | 400 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 401 | if is_secure: # pragma: no cover 402 | if ssl_context is None: 403 | ssl_context = ssl.create_default_context( 404 | purpose=ssl.Purpose.SERVER_AUTH) 405 | sock = ssl_context.wrap_socket(sock, server_hostname=self.host) 406 | sock.connect((self.host, self.port)) 407 | super().__init__(sock, connection_type=ConnectionType.CLIENT, 408 | receive_bytes=receive_bytes, 409 | ping_interval=ping_interval, 410 | max_message_size=max_message_size, 411 | thread_class=thread_class, event_class=event_class) 412 | 413 | @classmethod 414 | def connect(cls, url, subprotocols=None, headers=None, 415 | receive_bytes=4096, ping_interval=None, max_message_size=None, 416 | ssl_context=None, thread_class=None, event_class=None): 417 | """Returns a WebSocket client connection. 418 | 419 | :param url: The connection URL. Both ``ws://`` and ``wss://`` URLs are 420 | accepted. 421 | :param subprotocols: The name of the subprotocol to use, or a list of 422 | subprotocol names in order of preference. Set to 423 | ``None`` (the default) to not use a subprotocol. 424 | :param headers: A dictionary or list of tuples with additional HTTP 425 | headers to send with the connection request. Note that 426 | custom headers are not supported by the WebSocket 427 | protocol, so the use of this parameter is not 428 | recommended. 429 | :param receive_bytes: The size of the receive buffer, in bytes. The 430 | default is 4096. 431 | :param ping_interval: Send ping packets to the server at the requested 432 | interval in seconds. Set to ``None`` (the 433 | default) to disable ping/pong logic. Enable to 434 | prevent disconnections when the line is idle for 435 | a certain amount of time, or to detect an 436 | unresponsive server and disconnect. A recommended 437 | interval is 25 seconds. In general it is 438 | preferred to enable ping/pong on the server, and 439 | let the client respond with pong (which it does 440 | regardless of this setting). 441 | :param max_message_size: The maximum size allowed for a message, in 442 | bytes, or ``None`` for no limit. The default 443 | is ``None``. 444 | :param ssl_context: An ``SSLContext`` instance, if a default SSL 445 | context isn't sufficient. 446 | :param thread_class: The ``Thread`` class to use when creating 447 | background threads. The default is the 448 | ``threading.Thread`` class from the Python 449 | standard library. 450 | :param event_class: The ``Event`` class to use when creating event 451 | objects. The default is the `threading.Event`` 452 | class from the Python standard library. 453 | """ 454 | return cls(url, subprotocols=subprotocols, headers=headers, 455 | receive_bytes=receive_bytes, ping_interval=ping_interval, 456 | max_message_size=max_message_size, ssl_context=ssl_context, 457 | thread_class=thread_class, event_class=event_class) 458 | 459 | def handshake(self): 460 | out_data = self.ws.send(Request(host=self.host, target=self.path, 461 | subprotocols=self.subprotocols, 462 | extra_headers=self.extra_headeers)) 463 | self.sock.send(out_data) 464 | 465 | while True: 466 | in_data = self.sock.recv(self.receive_bytes) 467 | self.ws.receive_data(in_data) 468 | try: 469 | event = next(self.ws.events()) 470 | except StopIteration: # pragma: no cover 471 | pass 472 | else: # pragma: no cover 473 | break 474 | if isinstance(event, RejectConnection): # pragma: no cover 475 | raise ConnectionError(event.status_code) 476 | elif not isinstance(event, AcceptConnection): # pragma: no cover 477 | raise ConnectionError(400) 478 | self.subprotocol = event.subprotocol 479 | self.connected = True 480 | 481 | def close(self, reason=None, message=None): 482 | super().close(reason=reason, message=message) 483 | self.sock.close() 484 | --------------------------------------------------------------------------------