├── MANIFEST.in ├── _config.yml ├── .gitlab-ci-alba.yml ├── sockio ├── common.py ├── __init__.py ├── sio.py ├── py2.py └── aio.py ├── examples ├── stream │ ├── client.py │ └── server.py └── req-rep │ ├── client.py │ └── server.py ├── setup.cfg ├── LICENSE ├── conda └── pypi │ └── meta.yaml ├── setup.py ├── .gitignore ├── tests ├── test_url.py ├── conftest.py ├── test_py2.py ├── test_sio.py └── test_aio.py ├── README.md └── demo.svg /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /.gitlab-ci-alba.yml: -------------------------------------------------------------------------------- 1 | include: 2 | - https://git.cells.es/ctpkg/ci/ctpipeline/raw/master/ctjobdefs-ci.yml 3 | - https://git.cells.es/ctpkg/ci/ctpipeline/raw/master/ctpipeline.yml 4 | -------------------------------------------------------------------------------- /sockio/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | IPTOS_NORMAL = 0x0 5 | IPTOS_LOWDELAY = 0x10 6 | IPTOS_THROUGHPUT = 0x08 7 | IPTOS_RELIABILITY = 0x04 8 | IPTOS_MINCOST = 0x02 9 | DEFAULT_LIMIT = 2 ** 20 # 1MB 10 | 11 | 12 | log = logging.getLogger("sockio") 13 | 14 | 15 | class ConnectionEOFError(ConnectionError): 16 | pass 17 | 18 | 19 | class ConnectionTimeoutError(ConnectionError): 20 | pass 21 | -------------------------------------------------------------------------------- /sockio/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.15.0" 2 | 3 | 4 | CONCURRENCY_MAP = { 5 | "sync": "sync", 6 | "syncio": "sync", 7 | "async": "async", 8 | "asyncio": "async", 9 | } 10 | 11 | 12 | def socket_for_url(url, *args, **kwargs): 13 | conc = kwargs.pop("concurrency", "async") 14 | concurrency = CONCURRENCY_MAP.get(conc) 15 | if concurrency == "async": 16 | from . import aio 17 | 18 | return aio.socket_for_url(url, *args, **kwargs) 19 | elif concurrency == "sync": 20 | from . import sio 21 | 22 | return sio.socket_for_url(url, *args, **kwargs) 23 | raise ValueError("unsupported concurrency {!r}".format(conc)) 24 | -------------------------------------------------------------------------------- /examples/stream/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import sockio.aio 5 | 6 | 7 | async def main(): 8 | event = asyncio.Event() 9 | s = sockio.aio.TCP("localhost", 12345, on_eof_received=event.set) 10 | async for line in s: 11 | print(line) 12 | await s.close() 13 | 14 | 15 | fmt = "%(asctime)-15s %(levelname)-5s %(name)s: %(message)s" 16 | logging.basicConfig(format=fmt, level=logging.DEBUG) 17 | 18 | 19 | try: 20 | if hasattr(asyncio, "run"): 21 | asyncio.run(main()) 22 | else: 23 | loop = asyncio.get_event_loop() 24 | loop.run_until_complete(main()) 25 | except KeyboardInterrupt: 26 | print("Ctrl-C pressed. Bailing out!") 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.15.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version="{current_version}" 8 | replace = version="{new_version}" 9 | 10 | [bumpversion:file:sockio/__init__.py] 11 | search = __version__ = "{current_version}" 12 | replace = __version__ = "{new_version}" 13 | 14 | [bumpversion:file:conda/pypi/meta.yaml] 15 | search = set version = "{current_version}" 16 | replace = set version = "{new_version}" 17 | 18 | [bdist_wheel] 19 | universal = 1 20 | 21 | [flake8] 22 | max-line-length = 88 23 | extend-ignore = E203 24 | exclude = docs 25 | 26 | [aliases] 27 | test = pytest 28 | 29 | [tool:pytest] 30 | addopts = --cov-config=.coveragerc --cov sockio 31 | --cov-report html --cov-report term 32 | --durations=2 --verbose 33 | 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019-2023 Jose Tiago Macara Coutinho 2 | 3 | This library is free software; you can redistribute it and/or 4 | modify it under the terms of the GNU Lesser General Public 5 | License as published by the Free Software Foundation; either 6 | version 2.1 of the License, or (at your option) any later version. 7 | 8 | This library is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 11 | Lesser General Public License for more details. 12 | 13 | You should have received a copy of the GNU Lesser General Public 14 | License along with this library; if not, write to the Free Software 15 | Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 16 | -------------------------------------------------------------------------------- /conda/pypi/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "sockio" %} 2 | {% set version = "0.15.0" %} 3 | 4 | package: 5 | name: "{{ name|lower }}" 6 | version: "{{ version }}" 7 | 8 | source: 9 | url: "https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz" 10 | 11 | build: 12 | number: 0 13 | noarch: python 14 | script: "{{ PYTHON }} -m pip install . -vv" 15 | 16 | requirements: 17 | host: 18 | - pip 19 | - python 20 | run: 21 | - python 22 | 23 | test: 24 | imports: 25 | - sockio 26 | requires: 27 | - pytest 28 | - pytest-asyncio 29 | - pytest-cov 30 | 31 | about: 32 | home: "https://tiagocoutinho.github.io/sockio/" 33 | license: GPLv3+ 34 | license_family: GPL3 35 | license_file: 36 | summary: "Concurrency agnostic socket API" 37 | doc_url: 38 | dev_url: 39 | 40 | extra: 41 | recipe-maintainers: 42 | - your-github-id-here 43 | -------------------------------------------------------------------------------- /examples/req-rep/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import sockio.aio 5 | 6 | 7 | async def main(): 8 | event = asyncio.Event() 9 | s = sockio.aio.TCP("localhost", 12345, on_eof_received=event.set) 10 | reply = await s.write_readline(b"*idn?\n") 11 | print("Server replies with: {!r}".format(reply)) 12 | print("Looks like the server is running. Great!") 13 | print("Now, please restart the server...") 14 | await event.wait() 15 | print("Thanks for turning it off!") 16 | print("You now have 5s to turn it back on again.") 17 | await asyncio.sleep(5) 18 | print("I will now try another request without explicitly reopening the socket") 19 | reply = await s.write_readline(b"*idn?\n") 20 | print("It works! Server replies with: {!r}".format(reply)) 21 | await s.close() 22 | 23 | 24 | fmt = "%(asctime)-15s %(levelname)-5s %(name)s: %(message)s" 25 | logging.basicConfig(format=fmt) 26 | 27 | try: 28 | if hasattr(asyncio, "run"): 29 | asyncio.run(main()) 30 | else: 31 | loop = asyncio.get_event_loop() 32 | loop.run_until_complete(main()) 33 | except KeyboardInterrupt: 34 | print("Ctrl-C pressed. Bailing out!") 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """The setup script.""" 4 | 5 | import sys 6 | from setuptools import setup, find_packages 7 | 8 | with open("README.md") as f: 9 | description = f.read() 10 | 11 | setup( 12 | name="sockio", 13 | author="Jose Tiago Macara Coutinho", 14 | author_email="coutinhotiago@gmail.com", 15 | classifiers=[ 16 | "Development Status :: 2 - Pre-Alpha", 17 | "Intended Audience :: Developers", 18 | "Natural Language :: English", 19 | "Programming Language :: Python :: 2.7", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.5", 22 | "Programming Language :: Python :: 3.6", 23 | "Programming Language :: Python :: 3.7", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "License :: OSI Approved :: GNU Lesser General Public License v2 or later (LGPLv2+)" 27 | ], 28 | description="Concurrency agnostic socket API", 29 | license="LGPL-2.1", 30 | long_description=description, 31 | long_description_content_type="text/markdown", 32 | keywords="socket, asyncio", 33 | packages=find_packages(include=["sockio"]), 34 | url="https://tiagocoutinho.github.io/sockio/", 35 | project_urls={ 36 | "Documentation": "https://tiagocoutinho.github.io/sockio/", 37 | "Source": "https://github.com/tiagocoutinho/sockio/", 38 | }, 39 | version="0.15.0", 40 | python_requires=">=2.7", 41 | zip_safe=True, 42 | ) 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | -------------------------------------------------------------------------------- /tests/test_url.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from sockio import socket_for_url 6 | 7 | from conftest import IDN_REQ, IDN_REP 8 | 9 | 10 | @pytest.mark.asyncio 11 | async def test_root_socket_for_url(aio_server): 12 | host, port = aio_server.sockets[0].getsockname() 13 | 14 | with pytest.raises(ValueError): 15 | socket_for_url("udp://{}:{}".format(host, port)) 16 | 17 | aio_tcp = socket_for_url("tcp://{}:{}".format(host, port)) 18 | 19 | assert not aio_tcp.connected() 20 | assert aio_tcp.connection_counter == 0 21 | 22 | await aio_tcp.open() 23 | assert aio_tcp.connected() 24 | assert aio_tcp.connection_counter == 1 25 | 26 | coro = aio_tcp.write_readline(IDN_REQ) 27 | assert asyncio.iscoroutine(coro) 28 | reply = await coro 29 | assert aio_tcp.connected() 30 | assert aio_tcp.connection_counter == 1 31 | assert reply == IDN_REP 32 | 33 | 34 | def test_root_socket_for_url_sync(sio_server): 35 | host, port = sio_server.sockets[0].getsockname() 36 | 37 | with pytest.raises(ValueError): 38 | socket_for_url("udp://{}:{}".format(host, port), concurrency="sync") 39 | 40 | aio_tcp = socket_for_url("tcp://{}:{}".format(host, port), concurrency="sync") 41 | 42 | assert not aio_tcp.connected() 43 | assert aio_tcp.connection_counter == 0 44 | 45 | aio_tcp.open() 46 | assert aio_tcp.connected() 47 | assert aio_tcp.connection_counter == 1 48 | 49 | reply = aio_tcp.write_readline(IDN_REQ) 50 | assert aio_tcp.connected() 51 | assert aio_tcp.connection_counter == 1 52 | assert reply == IDN_REP 53 | 54 | 55 | def test_root_socket_for_url_error(sio_server): 56 | host, port = sio_server.sockets[0].getsockname() 57 | 58 | with pytest.raises(ValueError): 59 | socket_for_url("udp://{}:{}".format(host, port)) 60 | 61 | with pytest.raises(ValueError): 62 | socket_for_url("udp://{}:{}".format(host, port), concurrency="async") 63 | 64 | with pytest.raises(ValueError): 65 | socket_for_url("tcp://{}:{}".format(host, port), concurrency="parallel") 66 | -------------------------------------------------------------------------------- /examples/stream/server.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import asyncio 3 | import logging 4 | 5 | 6 | PY_37 = sys.version_info >= (3, 7) 7 | 8 | 9 | async def run(options): 10 | async def cb(reader, writer): 11 | addr = writer.transport.get_extra_info("peername") 12 | logging.info("client connected from %s", addr) 13 | try: 14 | for i in range(10): 15 | msg = f"message {i}\n" 16 | writer.write(msg.encode()) 17 | await writer.drain() 18 | logging.debug("send %r", msg) 19 | await asyncio.sleep(1) 20 | writer.close() 21 | if PY_37: 22 | await writer.wait_closed() 23 | except Exception: 24 | pass 25 | 26 | server = await asyncio.start_server(cb, host=options.host, port=options.port) 27 | host, port = server.sockets[0].getsockname() 28 | logging.info("started accepting requests on %s:%d", host, port) 29 | async with server: 30 | await server.serve_forever() 31 | 32 | 33 | def main(args=None): 34 | import argparse 35 | 36 | parser = argparse.ArgumentParser() 37 | log_level_choices = ["critical", "error", "warning", "info", "debug"] 38 | log_level_choices += [i.upper() for i in log_level_choices] 39 | parser.add_argument("--host", default="0", help="SCPI bind address") 40 | parser.add_argument("-p", "--port", type=int, help="SCPI server port") 41 | parser.add_argument("--log-level", choices=log_level_choices, default="warning") 42 | parser.add_argument("-d", "--debug", action="store_true") 43 | options = parser.parse_args(args) 44 | fmt = "%(asctime)-15s %(levelname)-5s: %(message)s" 45 | logging.basicConfig(level=options.log_level.upper(), format=fmt) 46 | try: 47 | coro = run(options) 48 | if hasattr(asyncio, "run"): 49 | asyncio.run(coro) 50 | else: 51 | loop = asyncio.get_event_loop() 52 | loop.run_until_complete(coro) 53 | except KeyboardInterrupt: 54 | logging.info("Ctrl-C pressed. Bailing out!") 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /examples/req-rep/server.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import asyncio 3 | import logging 4 | 5 | IDN_REQ, IDN_REP = b"*idn?\n", b"ACME, bla ble ble, 1234, 5678\n" 6 | WRONG_REQ, WRONG_REP = b"wrong question\n", b"ERROR: unknown command\n" 7 | 8 | 9 | PY_37 = sys.version_info >= (3, 7) 10 | 11 | 12 | async def run(options): 13 | async def cb(reader, writer): 14 | addr = writer.transport.get_extra_info("peername") 15 | logging.info("client connected from %s", addr) 16 | try: 17 | while True: 18 | data = await reader.readline() 19 | if data.lower() == IDN_REQ: 20 | msg = IDN_REP 21 | elif not data: 22 | logging.info("client %s disconnected", addr) 23 | return 24 | else: 25 | msg = WRONG_REP 26 | logging.debug("recv %r", data) 27 | writer.write(msg) 28 | await writer.drain() 29 | logging.debug("send %r", msg) 30 | except Exception: 31 | writer.close() 32 | if PY_37: 33 | await writer.wait_closed() 34 | 35 | server = await asyncio.start_server(cb, host=options.host, port=options.port) 36 | host, port = server.sockets[0].getsockname() 37 | logging.info("started accepting requests on %s:%d", host, port) 38 | async with server: 39 | await server.serve_forever() 40 | 41 | 42 | def main(args=None): 43 | import argparse 44 | 45 | parser = argparse.ArgumentParser() 46 | log_level_choices = ["critical", "error", "warning", "info", "debug"] 47 | log_level_choices += [i.upper() for i in log_level_choices] 48 | parser.add_argument("--host", default="0", help="SCPI bind address") 49 | parser.add_argument("-p", "--port", type=int, help="SCPI server port") 50 | parser.add_argument("--log-level", choices=log_level_choices, default="warning") 51 | parser.add_argument("-d", "--debug", action="store_true") 52 | options = parser.parse_args(args) 53 | fmt = "%(asctime)-15s %(levelname)-5s: %(message)s" 54 | logging.basicConfig(level=options.log_level.upper(), format=fmt) 55 | try: 56 | coro = run(options) 57 | if hasattr(asyncio, "run"): 58 | asyncio.run(coro) 59 | else: 60 | loop = asyncio.get_event_loop() 61 | loop.run_until_complete(coro) 62 | except KeyboardInterrupt: 63 | logging.info("Ctrl-C pressed. Bailing out!") 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /sockio/sio.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import threading 4 | import urllib.parse 5 | 6 | from . import aio 7 | 8 | 9 | class BaseProxy: 10 | def __init__(self, ref): 11 | self._ref = ref 12 | 13 | def __getattr__(self, name): 14 | return getattr(self._ref, name) 15 | 16 | 17 | def ensure_running(f): 18 | @functools.wraps(f) 19 | def wrapper(self, *args, **kwargs): 20 | if self.master and self.thread is None: 21 | self.start() 22 | return f(self, *args, **kwargs) 23 | 24 | return wrapper 25 | 26 | 27 | class EventLoop: 28 | def __init__(self, loop=None): 29 | self.master = loop is None 30 | self.thread = None 31 | self.loop = loop 32 | self.proxies = {} 33 | 34 | def start(self): 35 | if self.thread: 36 | raise RuntimeError("event loop already started") 37 | if self.loop: 38 | raise RuntimeError("cannot run non master event loop") 39 | 40 | def run(): 41 | self.loop = asyncio.new_event_loop() 42 | asyncio.set_event_loop(self.loop) 43 | started.set() 44 | self.loop.run_forever() 45 | 46 | started = threading.Event() 47 | self.thread = threading.Thread(name="AIOTH", target=run) 48 | self.thread.daemon = True 49 | self.thread.start() 50 | started.wait() 51 | 52 | def stop(self): 53 | if self.loop is None: 54 | if self.thread is None: 55 | raise RuntimeError("event loop not started") 56 | else: 57 | if self.thread is None: 58 | raise RuntimeError("cannot stop non master event loop") 59 | self.loop.call_soon_threadsafe(self.loop.stop) 60 | 61 | @ensure_running 62 | def run_coroutine(self, coro): 63 | return asyncio.run_coroutine_threadsafe(coro, self.loop) 64 | 65 | def _create_coroutine_threadsafe(self, corof, resolve_future): 66 | @functools.wraps(corof) 67 | def wrapper(obj, *args, **kwargs): 68 | coro = corof(obj._ref, *args, **kwargs) 69 | future = self.run_coroutine(coro) 70 | return future.result() if resolve_future else future 71 | 72 | return wrapper 73 | 74 | def _create_proxy_for(self, klass, resolve_futures=True): 75 | class Proxy(BaseProxy): 76 | pass 77 | 78 | for name in dir(klass): 79 | if name.startswith("_"): 80 | continue 81 | member = getattr(klass, name) 82 | if asyncio.iscoroutinefunction(member): 83 | member = self._create_coroutine_threadsafe(member, resolve_futures) 84 | setattr(Proxy, name, member) 85 | return Proxy 86 | 87 | @ensure_running 88 | def proxy(self, obj, resolve_futures=True): 89 | klass = type(obj) 90 | key = klass, resolve_futures 91 | Proxy = self.proxies.get(key) 92 | if not Proxy: 93 | Proxy = self._create_proxy_for(klass, resolve_futures) 94 | self.proxies[key] = Proxy 95 | return Proxy(obj) 96 | 97 | @ensure_running 98 | def tcp(self, host, port, resolve_futures=True, **kwargs): 99 | async def create(): 100 | return aio.TCP(host, port, **kwargs) 101 | 102 | sock = self.run_coroutine(create()).result() 103 | return self.proxy(sock, resolve_futures) 104 | 105 | 106 | DefaultEventLoop = EventLoop() 107 | TCP = DefaultEventLoop.tcp 108 | 109 | 110 | def socket_for_url(url, *args, **kwargs): 111 | addr = urllib.parse.urlparse(url) 112 | scheme = addr.scheme 113 | if scheme == "tcp": 114 | return TCP(addr.hostname, addr.port, *args, **kwargs) 115 | raise ValueError("unsupported sync scheme {!r} for {}".format(scheme, url)) 116 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import asyncio 3 | 4 | import pytest 5 | 6 | import sockio.aio 7 | import sockio.sio 8 | import sockio.py2 9 | 10 | 11 | IDN_REQ, IDN_REP = b"*idn?\n", b"ACME, bla ble ble, 1234, 5678\n" 12 | WRONG_REQ, WRONG_REP = b"wrong question\n", b"ERROR: unknown command\n" 13 | 14 | 15 | async def server_coro(start_serving=True): 16 | writers = set() 17 | 18 | async def cb(reader, writer): 19 | writers.add(writer) 20 | try: 21 | while True: 22 | data = await reader.readline() 23 | data_l = data.lower() 24 | if data_l == IDN_REQ: 25 | msg = IDN_REP 26 | elif data_l.startswith(b"sleep"): 27 | t = float(data_l.rsplit(b" ", 1)[-1]) 28 | await asyncio.sleep(t) 29 | msg = b"OK\n" 30 | elif data_l.startswith(b"data?"): 31 | n = int(data.strip().split(b" ", 1)[-1]) 32 | if n > 0: 33 | for i in range(n): 34 | await asyncio.sleep(0.05) 35 | writer.write(b"1.2345 5.4321 12345.54321\n") 36 | await writer.drain() 37 | else: 38 | for i in range(abs(n)): 39 | await asyncio.sleep(0.05) 40 | msg = "message {:04d}".format(i).encode() 41 | assert len(msg) == 12 42 | writer.write(msg) 43 | await writer.drain() 44 | writer.close() 45 | await writer.wait_closed() 46 | return 47 | elif data_l.startswith(b"kill"): 48 | writer.close() 49 | await writer.wait_closed() 50 | return 51 | elif not data: 52 | writer.close() 53 | await writer.wait_closed() 54 | return 55 | else: 56 | msg = WRONG_REP 57 | # add 2ms delay 58 | await asyncio.sleep(0.002) 59 | writer.write(msg) 60 | await writer.drain() 61 | except ConnectionResetError: 62 | pass 63 | finally: 64 | writers.remove(writer) 65 | 66 | async def stop(): 67 | server.close() 68 | await server.wait_closed() 69 | assert not server.is_serving() 70 | for writer in set(writers): 71 | writer.close() 72 | await writer.wait_closed() 73 | 74 | server = await asyncio.start_server(cb, host="0", start_serving=start_serving) 75 | server.stop = stop 76 | return server 77 | 78 | 79 | @pytest.fixture 80 | async def aio_server(): 81 | server = await server_coro() 82 | yield server 83 | await server.stop() 84 | 85 | 86 | @pytest.fixture 87 | async def aio_tcp(aio_server): 88 | addr = aio_server.sockets[0].getsockname() 89 | sock = sockio.aio.TCP(*addr) 90 | yield sock 91 | await sock.close() 92 | 93 | 94 | @pytest.fixture 95 | def sio_server(): 96 | event_loop = sockio.sio.DefaultEventLoop 97 | channel = queue.Queue() 98 | 99 | async def serve_forever(): 100 | server = await server_coro(start_serving=False) 101 | channel.put(server) 102 | await server.serve_forever() 103 | await server.stop() 104 | 105 | event_loop.run_coroutine(serve_forever()) 106 | server = event_loop.proxy(channel.get()) 107 | yield server 108 | server.close() 109 | 110 | 111 | @pytest.fixture 112 | def sio_tcp(sio_server): 113 | addr = sio_server.sockets[0].getsockname() 114 | sock = sockio.sio.TCP(*addr) 115 | yield sock 116 | sock.close() 117 | 118 | 119 | @pytest.fixture 120 | def py2_tcp(sio_server): 121 | addr = sio_server.sockets[0].getsockname() 122 | sock = sockio.py2.TCP(*addr) 123 | yield sock 124 | sock.close() 125 | -------------------------------------------------------------------------------- /sockio/py2.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import logging 3 | import functools 4 | import threading 5 | 6 | try: 7 | ConnectionError 8 | except NameError: 9 | 10 | class ConnectionError(socket.error): 11 | pass 12 | 13 | class ConnectionResetError(socket.error): 14 | pass 15 | 16 | 17 | log = logging.getLogger("sockio") 18 | 19 | 20 | def ensure_closed_on_error(f): 21 | @functools.wraps(f) 22 | def wrapper(self, *args, **kwargs): 23 | try: 24 | return f(self, *args, **kwargs) 25 | except socket.error: 26 | self.close() 27 | raise 28 | 29 | return wrapper 30 | 31 | 32 | class Connection(object): 33 | def __init__(self, host, port, timeout=1.0): 34 | self.sock = socket.create_connection((host, port)) 35 | self.sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 36 | self.sock.settimeout(timeout) 37 | self.fobj = self.sock.makefile("rwb", 0) 38 | 39 | def close(self): 40 | if self.sock is not None: 41 | self.sock.close() 42 | self.sock = None 43 | self.fobj = None 44 | 45 | def connected(self): 46 | return self.sock is not None 47 | 48 | is_open = property(connected) 49 | 50 | @ensure_closed_on_error 51 | def readline(self): 52 | data = self.fobj.readline() 53 | if not data: 54 | raise ConnectionResetError("remote end disconnected") 55 | return data 56 | 57 | @ensure_closed_on_error 58 | def read(self, n=-1): 59 | data = self.fobj.read(n) 60 | if not data: 61 | raise ConnectionResetError("remote end disconnected") 62 | return data 63 | 64 | @ensure_closed_on_error 65 | def write(self, data): 66 | return self.fobj.write(data) 67 | 68 | @ensure_closed_on_error 69 | def writelines(self, lines): 70 | return self.fobj.writelines(lines) 71 | 72 | 73 | def ensure_connected(f): 74 | @functools.wraps(f) 75 | def wrapper(self, *args, **kwargs): 76 | with self._lock: 77 | if not self.connected(): 78 | self._open() 79 | return f(self, *args, **kwargs) 80 | else: 81 | try: 82 | return f(self, *args, **kwargs) 83 | except socket.error: 84 | self._open() 85 | return f(self, *args, **kwargs) 86 | 87 | return wrapper 88 | 89 | 90 | class TCP(object): 91 | def __init__(self, host, port, timeout=1.0): 92 | self.host = host 93 | self.port = port 94 | self.conn = None 95 | self.timeout = timeout 96 | self._log = log.getChild("TCP({0}:{1})".format(host, port)) 97 | self._lock = threading.Lock() 98 | self.connection_counter = 0 99 | 100 | def _open(self): 101 | if self.connected(): 102 | raise ConnectionError("socket already open") 103 | self._log.debug("openning connection (#%d)...", self.connection_counter + 1) 104 | self.conn = Connection(self.host, self.port, timeout=self.timeout) 105 | self.connection_counter += 1 106 | 107 | def open(self): 108 | with self._lock: 109 | self._open() 110 | 111 | def close(self): 112 | with self._lock: 113 | if self.conn is not None: 114 | self.conn.close() 115 | self.conn = None 116 | 117 | def connected(self): 118 | return self.conn is not None and self.conn.connected() 119 | 120 | is_open = property(connected) 121 | 122 | @ensure_connected 123 | def write(self, data): 124 | return self.conn.write(data) 125 | 126 | @ensure_connected 127 | def read(self, n=-1): 128 | return self.conn.read(n) 129 | 130 | @ensure_connected 131 | def readline(self): 132 | return self.conn.readline() 133 | 134 | @ensure_connected 135 | def readlines(self, n): 136 | return [self.conn.readline() for _ in range(n)] 137 | 138 | @ensure_connected 139 | def writelines(self, lines): 140 | return self.conn.writelines(lines) 141 | 142 | @ensure_connected 143 | def write_read(self, data, n=-1): 144 | self.conn.write(data) 145 | return self.conn.read(n) 146 | 147 | @ensure_connected 148 | def write_readline(self, data): 149 | self.conn.write(data) 150 | return self.conn.readline() 151 | 152 | @ensure_connected 153 | def write_readlines(self, data, n): 154 | self.conn.write(data) 155 | return [self.conn.readline() for _ in range(n)] 156 | 157 | @ensure_connected 158 | def writelines_readlines(self, lines, n=None): 159 | if n is None: 160 | n = len(lines) 161 | self.conn.writelines(lines) 162 | return [self.conn.readline() for _ in range(n)] 163 | 164 | 165 | def main(args=None): 166 | import argparse 167 | 168 | parser = argparse.ArgumentParser() 169 | log_level_choices = ["critical", "error", "warning", "info", "debug"] 170 | log_level_choices += [i.upper() for i in log_level_choices] 171 | parser.add_argument("--host", default="0", help="host / IP") 172 | parser.add_argument("-p", "--port", type=int, help="port") 173 | parser.add_argument("--log-level", choices=log_level_choices, default="warning") 174 | options = parser.parse_args(args) 175 | fmt = "%(asctime)-15s %(levelname)-5s %(threadName)s %(name)s: %(message)s" 176 | logging.basicConfig(level=options.log_level.upper(), format=fmt) 177 | return TCP(options.host, options.port) 178 | 179 | 180 | if __name__ == "__main__": 181 | conn = main() 182 | -------------------------------------------------------------------------------- /tests/test_py2.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sockio.py2 import TCP 4 | 5 | from conftest import IDN_REQ, IDN_REP, WRONG_REQ, WRONG_REP 6 | 7 | 8 | def test_socket_creation(): 9 | sock = TCP("example.com", 34567) 10 | assert sock.host == "example.com" 11 | assert sock.port == 34567 12 | assert not sock.connected() 13 | assert sock.connection_counter == 0 14 | 15 | 16 | def test_open_fail(unused_tcp_port): 17 | sock = TCP("0", unused_tcp_port) 18 | assert not sock.connected() 19 | assert sock.connection_counter == 0 20 | 21 | with pytest.raises(ConnectionRefusedError): 22 | sock.open() 23 | assert not sock.connected() 24 | assert sock.connection_counter == 0 25 | 26 | 27 | def test_write_fail(unused_tcp_port): 28 | sock = TCP("0", unused_tcp_port) 29 | assert not sock.connected() 30 | assert sock.connection_counter == 0 31 | 32 | with pytest.raises(ConnectionRefusedError): 33 | sock.write(IDN_REQ) 34 | assert not sock.connected() 35 | assert sock.connection_counter == 0 36 | 37 | 38 | def test_write_read_fail(unused_tcp_port): 39 | sock = TCP("0", unused_tcp_port) 40 | assert not sock.connected() 41 | assert sock.connection_counter == 0 42 | 43 | with pytest.raises(ConnectionRefusedError): 44 | sock.write_read(IDN_REQ) 45 | assert not sock.connected() 46 | assert sock.connection_counter == 0 47 | 48 | 49 | def test_write_readline_fail(unused_tcp_port): 50 | sock = TCP("0", unused_tcp_port) 51 | assert not sock.connected() 52 | assert sock.connection_counter == 0 53 | 54 | with pytest.raises(ConnectionRefusedError): 55 | sock.write_readline(IDN_REQ) 56 | assert not sock.connected() 57 | assert sock.connection_counter == 0 58 | 59 | 60 | def test_open_close(sio_server, py2_tcp): 61 | assert not py2_tcp.connected() 62 | assert py2_tcp.connection_counter == 0 63 | assert sio_server.sockets[0].getsockname() == (py2_tcp.host, py2_tcp.port) 64 | 65 | py2_tcp.open() 66 | assert py2_tcp.connected() 67 | assert py2_tcp.connection_counter == 1 68 | 69 | with pytest.raises(ConnectionError): 70 | py2_tcp.open() 71 | assert py2_tcp.connected() 72 | assert py2_tcp.connection_counter == 1 73 | 74 | py2_tcp.close() 75 | assert not py2_tcp.connected() 76 | assert py2_tcp.connection_counter == 1 77 | py2_tcp.open() 78 | assert py2_tcp.connected() 79 | assert py2_tcp.connection_counter == 2 80 | py2_tcp.close() 81 | py2_tcp.close() 82 | assert not py2_tcp.connected() 83 | assert py2_tcp.connection_counter == 2 84 | 85 | 86 | def test_write_read(py2_tcp): 87 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 88 | reply = py2_tcp.write_read(request, 1024) 89 | assert py2_tcp.connected() 90 | assert py2_tcp.connection_counter == 1 91 | assert expected == reply 92 | 93 | 94 | def test_write_readline(py2_tcp): 95 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 96 | reply = py2_tcp.write_readline(request) 97 | assert py2_tcp.connected() 98 | assert py2_tcp.connection_counter == 1 99 | assert expected == reply 100 | 101 | 102 | def test_write_readlines(py2_tcp): 103 | for request, expected in [ 104 | (IDN_REQ, [IDN_REP]), 105 | (2 * IDN_REQ, 2 * [IDN_REP]), 106 | (IDN_REQ + WRONG_REQ, [IDN_REP, WRONG_REP]), 107 | ]: 108 | gen = py2_tcp.write_readlines(request, len(expected)) 109 | reply = [line for line in gen] 110 | assert py2_tcp.connected() 111 | assert py2_tcp.connection_counter == 1 112 | assert expected == reply 113 | 114 | 115 | def test_writelines_readlines(py2_tcp): 116 | for request, expected in [ 117 | ([IDN_REQ], [IDN_REP]), 118 | (2 * [IDN_REQ], 2 * [IDN_REP]), 119 | ([IDN_REQ, WRONG_REQ], [IDN_REP, WRONG_REP]), 120 | ]: 121 | gen = py2_tcp.writelines_readlines(request) 122 | reply = [line for line in gen] 123 | assert py2_tcp.connected() 124 | assert py2_tcp.connection_counter == 1 125 | assert expected == reply 126 | 127 | 128 | def test_writelines(py2_tcp): 129 | for request, expected in [ 130 | ([IDN_REQ], [IDN_REP]), 131 | (2 * [IDN_REQ], 2 * [IDN_REP]), 132 | ([IDN_REQ, WRONG_REQ], [IDN_REP, WRONG_REP]), 133 | ]: 134 | answer = py2_tcp.writelines(request) 135 | assert py2_tcp.connected() 136 | assert py2_tcp.connection_counter == 1 137 | assert answer is None 138 | 139 | gen = py2_tcp.readlines(len(expected)) 140 | reply = [line for line in gen] 141 | assert py2_tcp.connected() 142 | assert py2_tcp.connection_counter == 1 143 | assert expected == reply 144 | 145 | 146 | def test_readline(py2_tcp): 147 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 148 | py2_tcp.write(request) 149 | assert py2_tcp.connected() 150 | assert py2_tcp.connection_counter == 1 151 | reply = py2_tcp.readline() 152 | assert expected == reply 153 | 154 | 155 | def test_readlines(py2_tcp): 156 | for request, expected in [ 157 | (IDN_REQ, [IDN_REP]), 158 | (2 * IDN_REQ, 2 * [IDN_REP]), 159 | (IDN_REQ + WRONG_REQ, [IDN_REP, WRONG_REP]), 160 | ]: 161 | py2_tcp.write(request) 162 | assert py2_tcp.connected() 163 | assert py2_tcp.connection_counter == 1 164 | gen = py2_tcp.readlines(len(expected)) 165 | reply = [line for line in gen] 166 | assert expected == reply 167 | 168 | 169 | def test_read(py2_tcp): 170 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 171 | py2_tcp.write(request) 172 | assert py2_tcp.connected() 173 | assert py2_tcp.connection_counter == 1 174 | reply, n = b"", 0 175 | while len(reply) < len(expected) and n < 2: 176 | reply += py2_tcp.read(1024) 177 | n += 1 178 | assert expected == reply 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sockio 2 | 3 | ![Pypi version][pypi] 4 | 5 | A python concurrency agnostic socket library. 6 | 7 | ![spec in action](./demo.svg) 8 | 9 | Helpful when handling with instrumentation which work over TCP and implement 10 | simple REQ-REP communication protocols (example: 11 | [SCPI](https://en.m.wikipedia.org/wiki/Standard_Commands_for_Programmable_Instruments)). 12 | 13 | So far implemented REQ-REP and streaming semantics with auto-reconnection facilites. 14 | 15 | Base implementation written in asyncio with support for different concurrency models: 16 | 17 | * asyncio 18 | * classic blocking API 19 | * future based API 20 | * python 2 compatible blocking API (for those pour souls stuck with python 2) 21 | 22 | 23 | ## Installation 24 | 25 | From within your favorite python environment: 26 | 27 | ```console 28 | pip install sockio 29 | ``` 30 | 31 | ## Usage 32 | 33 | *asyncio* 34 | 35 | ```python 36 | import asyncio 37 | from sockio.aio import TCP 38 | 39 | async def main(): 40 | sock = TCP('acme.example.com', 5000) 41 | # Assuming a SCPI complient on the other end we can ask for: 42 | reply = await sock.write_readline(b'*IDN?\n') 43 | print(reply) 44 | 45 | asyncio.run(main()) 46 | ``` 47 | 48 | *classic* 49 | 50 | ```python 51 | from sockio.sio import TCP 52 | 53 | sock = TCP('acme.example.com', 5000) 54 | reply = sock.write_readline(b'*IDN?\n') 55 | print(reply) 56 | ``` 57 | 58 | *concurrent.futures* 59 | 60 | ```python 61 | from sockio.sio import TCP 62 | 63 | sock = TCP('acme.example.com', 5000, resolve_futures=False) 64 | reply = sock.write_readline(b'*IDN?\n').result() 65 | print(reply) 66 | ``` 67 | 68 | *python 2 compatibility* 69 | 70 | ```python 71 | from sockio.py2 import TCP 72 | 73 | sock = TCP('acme.example.com', 5000) 74 | reply = sock.write_readline(b'*IDN?\n').result() 75 | print(reply) 76 | ``` 77 | 78 | ## Features 79 | 80 | The main goal of a sockio TCP object is to facilitate communication 81 | with instruments which listen on a TCP socket. 82 | 83 | The most frequent cases include instruments which expect a REQ/REP 84 | semantics with ASCII protocols like SCPI. In these cases most commands 85 | translate in small packets being exchanged between the host and the 86 | instrument. Care has been taken in this library to make sure we reduce 87 | latency as much as possible. This translates into the following defaults 88 | when creating a TCP object: 89 | 90 | * TCP no delay is active. Can be disabled with `TCP(..., no_delay=False)`. 91 | This prevents the kernel from applying 92 | [Nagle's algorithm](https://en.wikipedia.org/wiki/Nagle%27s_algorithm) 93 | * TCP ToS is set to LOWDELAY. This effectively prioritizes our packets 94 | if favor of other concurrent communications. Can be disabled with 95 | `TCP(tos=IPTOS_NORMAL)` 96 | 97 | ### Price to pay 98 | 99 | Before going in detail about the features, note that this abstraction comes 100 | with a price. Intentionally, when comparing with low level socket API, the 101 | following features are no longer available: 102 | 103 | 1. The capability of controlling the two ends of the socket independently 104 | (ex: close the write end) 105 | 2. While the low level `socket.recv()` returns empty string when EOF is reached, 106 | the TCP class raises `ConnectionEOFError` instead and closes both ends of 107 | the connection. 108 | 3. Clever low level operations like `os.dup()`, make socket non-blocking 109 | 110 | ### REQ-REP semantics 111 | 112 | Many instruments out there have a Request-Reply protocol. A sockio TCP 113 | provides `write_read` family of methods which simplify communication with 114 | these instruments. These methods are atomic which means different tasks or 115 | threads can safely work with the same socket object (although I would 116 | question myself why would I be doing that in my library/application). 117 | 118 | ### Auto-reconnection 119 | 120 | ```python 121 | sock = TCP('acme.example.com', 5000) 122 | reply = await sock.write_readline(b'*IDN?\n') 123 | print(reply) 124 | 125 | # ... kill the server connection somehow and bring it back to life again 126 | 127 | # You can use the same socket object. It will reconnect automatically 128 | # and work "transparently" 129 | reply = await sock.write_readline(b'*IDN?\n') 130 | print(reply) 131 | ``` 132 | 133 | The auto-reconnection facility is specially useful when, for example, you 134 | move equipement from one place to another, or you need to turn off the 135 | equipment during the night (planet Earth thanks you for saving energy!). 136 | 137 | ### Timeout 138 | 139 | The TCP constructor provides a `connection_timeout` that is used when the 140 | connection is open and `timeout` parameter that is taken into account 141 | when performing any data I/O operation (read, write, read_writeline, 142 | etc). 143 | By default, they are both None, meaning infinite timeout. 144 | 145 | ```python 146 | sock = TCP('acme.example.com', 5000, connection_timeout=0.1, timeout=1) 147 | ``` 148 | 149 | Additionally, you can override the object timeout on each data I/O method 150 | call by providing an alternative timeout parameter: 151 | 152 | ```python 153 | sock = TCP('acme.example.com', 5000, timeout=1) 154 | # the next call will raise asyncio.TimeoutError if it takes more than 0.1s 155 | reply = await sock.write_readline(b'*IDN?\n', timeout=0.1) 156 | print(reply) 157 | ``` 158 | 159 | ### Custom EOL 160 | 161 | In line based protocols, sometimes people decide `\n` is not a good EOL character. 162 | A sockio TCP can be customized with a different EOL character. Example: 163 | 164 | ```python 165 | sock = TCP('acme.example.com', 5000, eol=b'\r') 166 | ``` 167 | 168 | The EOL character can be overwritten in any of the `readline` methods. Example: 169 | ```python 170 | await sock.write_readline(b'*IDN?\n', eol=b'\r') 171 | ``` 172 | 173 | ### Connection event callbacks 174 | 175 | You can be notified on `connection_made`, `connection_lost` and `eof_received` events 176 | by registering callbacks on the sockio TCP constructor 177 | 178 | This is particularly useful if, for example, you want a specific procedure to be 179 | executed every time the socket is reconnected to make sure your configuration is 180 | right. Example: 181 | 182 | ```python 183 | async def connected(): 184 | await sock.write(b'ACQU:TRIGGER HARDWARE\n') 185 | await sock.write(b'DISPLAY OFF\n') 186 | 187 | sock = TCP('acme.example.com', 5000, on_connection_made=connected) 188 | ``` 189 | 190 | (see examples/req-rep/client.py) 191 | 192 | Connection event callbacks are **not** available in *python 2 compatibility module*. 193 | 194 | ### Streams 195 | 196 | sockio TCPs are asynchronous iterable objects. This means that line streaming 197 | is as easy as: 198 | 199 | ```python 200 | sock = TCP('acme.example.com', 5000, eol=b'\r') 201 | 202 | async for line in sock: 203 | print(line) 204 | ``` 205 | 206 | Streams are **not** available in *python 2 compatibility module*. Let me know 207 | if you need them by writing an issue. Also feel free to make a PR! 208 | 209 | ## Missing features 210 | 211 | * Connection retries 212 | * trio event loop 213 | * curio event loop 214 | 215 | Join the party by bringing your own concurrency library with a PR! 216 | 217 | I am looking in particular for implementations over trio and curio. 218 | 219 | 220 | [pypi]: https://img.shields.io/pypi/pyversions/sockio.svg 221 | -------------------------------------------------------------------------------- /tests/test_sio.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sockio.sio import TCP 4 | 5 | from conftest import IDN_REQ, IDN_REP, WRONG_REQ, WRONG_REP 6 | 7 | 8 | def test_socket_creation(): 9 | sock = TCP("example.com", 34567) 10 | assert sock.host == "example.com" 11 | assert sock.port == 34567 12 | assert sock.auto_reconnect is True 13 | assert not sock.connected() 14 | assert sock.connection_counter == 0 15 | 16 | 17 | def test_open_fail(unused_tcp_port): 18 | sock = TCP("0", unused_tcp_port) 19 | assert not sock.connected() 20 | assert sock.connection_counter == 0 21 | 22 | with pytest.raises(ConnectionRefusedError): 23 | sock.open() 24 | assert not sock.connected() 25 | assert sock.connection_counter == 0 26 | 27 | 28 | def test_write_fail(unused_tcp_port): 29 | sock = TCP("0", unused_tcp_port) 30 | assert not sock.connected() 31 | assert sock.connection_counter == 0 32 | 33 | with pytest.raises(ConnectionRefusedError): 34 | sock.write(IDN_REQ) 35 | assert not sock.connected() 36 | assert sock.connection_counter == 0 37 | 38 | 39 | def test_write_read_fail(unused_tcp_port): 40 | sock = TCP("0", unused_tcp_port) 41 | assert not sock.connected() 42 | assert sock.connection_counter == 0 43 | 44 | with pytest.raises(ConnectionRefusedError): 45 | sock.write_read(IDN_REQ) 46 | assert not sock.connected() 47 | assert sock.connection_counter == 0 48 | 49 | 50 | def test_write_readline_fail(unused_tcp_port): 51 | sock = TCP("0", unused_tcp_port) 52 | assert not sock.connected() 53 | assert sock.connection_counter == 0 54 | 55 | with pytest.raises(ConnectionRefusedError): 56 | sock.write_readline(IDN_REQ) 57 | assert not sock.connected() 58 | assert sock.connection_counter == 0 59 | 60 | 61 | def test_open_close(sio_server, sio_tcp): 62 | assert not sio_tcp.connected() 63 | assert sio_tcp.connection_counter == 0 64 | assert sio_server.sockets[0].getsockname() == (sio_tcp.host, sio_tcp.port) 65 | 66 | sio_tcp.open() 67 | assert sio_tcp.connected() 68 | assert sio_tcp.connection_counter == 1 69 | 70 | with pytest.raises(ConnectionError): 71 | sio_tcp.open() 72 | assert sio_tcp.connected() 73 | assert sio_tcp.connection_counter == 1 74 | 75 | sio_tcp.close() 76 | assert not sio_tcp.connected() 77 | assert sio_tcp.connection_counter == 1 78 | sio_tcp.open() 79 | assert sio_tcp.connected() 80 | assert sio_tcp.connection_counter == 2 81 | sio_tcp.close() 82 | sio_tcp.close() 83 | assert not sio_tcp.connected() 84 | assert sio_tcp.connection_counter == 2 85 | 86 | 87 | def test_callbacks(sio_server): 88 | host, port = sio_server.sockets[0].getsockname() 89 | state = dict(made=0, lost=0, eof=0) 90 | 91 | def made(): 92 | state["made"] += 1 93 | 94 | def lost(exc): 95 | state["lost"] += 1 96 | 97 | def eof(): 98 | state["eof"] += 1 99 | 100 | sio_tcp = TCP( 101 | host, 102 | port, 103 | on_connection_made=made, 104 | on_connection_lost=lost, 105 | on_eof_received=eof, 106 | ) 107 | assert not sio_tcp.connected() 108 | assert sio_tcp.connection_counter == 0 109 | assert state["made"] == 0 110 | assert state["lost"] == 0 111 | assert state["eof"] == 0 112 | 113 | sio_tcp.open() 114 | assert sio_tcp.connected() 115 | assert sio_tcp.connection_counter == 1 116 | assert state["made"] == 1 117 | assert state["lost"] == 0 118 | assert state["eof"] == 0 119 | 120 | with pytest.raises(ConnectionError): 121 | sio_tcp.open() 122 | assert sio_tcp.connected() 123 | assert sio_tcp.connection_counter == 1 124 | assert state["made"] == 1 125 | assert state["lost"] == 0 126 | assert state["eof"] == 0 127 | 128 | sio_tcp.close() 129 | assert not sio_tcp.connected() 130 | assert sio_tcp.connection_counter == 1 131 | assert state["made"] == 1 132 | assert state["lost"] == 1 133 | assert state["eof"] == 0 134 | 135 | sio_tcp.open() 136 | assert sio_tcp.connected() 137 | assert sio_tcp.connection_counter == 2 138 | assert state["made"] == 2 139 | assert state["lost"] == 1 140 | assert state["eof"] == 0 141 | 142 | sio_tcp.close() 143 | assert not sio_tcp.connected() 144 | assert sio_tcp.connection_counter == 2 145 | assert state["made"] == 2 146 | assert state["lost"] == 2 147 | assert state["eof"] == 0 148 | 149 | sio_tcp.close() 150 | assert not sio_tcp.connected() 151 | assert sio_tcp.connection_counter == 2 152 | assert state["made"] == 2 153 | assert state["lost"] == 2 154 | assert state["eof"] == 0 155 | 156 | 157 | def test_write_read(sio_tcp): 158 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 159 | reply = sio_tcp.write_read(request, 1024) 160 | assert sio_tcp.connected() 161 | assert sio_tcp.connection_counter == 1 162 | assert expected == reply 163 | 164 | 165 | def test_write_readline(sio_tcp): 166 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 167 | reply = sio_tcp.write_readline(request) 168 | assert sio_tcp.connected() 169 | assert sio_tcp.connection_counter == 1 170 | assert expected == reply 171 | 172 | 173 | def test_write_readlines(sio_tcp): 174 | for request, expected in [ 175 | (IDN_REQ, [IDN_REP]), 176 | (2 * IDN_REQ, 2 * [IDN_REP]), 177 | (IDN_REQ + WRONG_REQ, [IDN_REP, WRONG_REP]), 178 | ]: 179 | gen = sio_tcp.write_readlines(request, len(expected)) 180 | reply = [line for line in gen] 181 | assert sio_tcp.connected() 182 | assert sio_tcp.connection_counter == 1 183 | assert expected == reply 184 | 185 | 186 | def test_writelines_readlines(sio_tcp): 187 | for request, expected in [ 188 | ([IDN_REQ], [IDN_REP]), 189 | (2 * [IDN_REQ], 2 * [IDN_REP]), 190 | ([IDN_REQ, WRONG_REQ], [IDN_REP, WRONG_REP]), 191 | ]: 192 | gen = sio_tcp.writelines_readlines(request) 193 | reply = [line for line in gen] 194 | assert sio_tcp.connected() 195 | assert sio_tcp.connection_counter == 1 196 | assert expected == reply 197 | 198 | 199 | def test_writelines(sio_tcp): 200 | for request, expected in [ 201 | ([IDN_REQ], [IDN_REP]), 202 | (2 * [IDN_REQ], 2 * [IDN_REP]), 203 | ([IDN_REQ, WRONG_REQ], [IDN_REP, WRONG_REP]), 204 | ]: 205 | answer = sio_tcp.writelines(request) 206 | assert sio_tcp.connected() 207 | assert sio_tcp.connection_counter == 1 208 | assert answer is None 209 | 210 | gen = sio_tcp.readlines(len(expected)) 211 | reply = [line for line in gen] 212 | assert sio_tcp.connected() 213 | assert sio_tcp.connection_counter == 1 214 | assert expected == reply 215 | 216 | 217 | def test_readline(sio_tcp): 218 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 219 | answer = sio_tcp.write(request) 220 | assert sio_tcp.connected() 221 | assert sio_tcp.connection_counter == 1 222 | assert answer is None 223 | reply = sio_tcp.readline() 224 | assert expected == reply 225 | 226 | 227 | def test_readuntil(sio_tcp): 228 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 229 | answer = sio_tcp.write(request) 230 | assert sio_tcp.connected() 231 | assert sio_tcp.connection_counter == 1 232 | assert answer is None 233 | reply = sio_tcp.readuntil(b"\n") 234 | assert expected == reply 235 | 236 | 237 | def test_readexactly(sio_tcp): 238 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 239 | answer = sio_tcp.write(request) 240 | assert sio_tcp.connected() 241 | assert sio_tcp.connection_counter == 1 242 | assert answer is None 243 | reply = sio_tcp.readexactly(len(expected) - 5) 244 | assert expected[:-5] == reply 245 | reply = sio_tcp.readexactly(5) 246 | assert expected[-5:] == reply 247 | 248 | 249 | def test_readlines(sio_tcp): 250 | for request, expected in [ 251 | (IDN_REQ, [IDN_REP]), 252 | (2 * IDN_REQ, 2 * [IDN_REP]), 253 | (IDN_REQ + WRONG_REQ, [IDN_REP, WRONG_REP]), 254 | ]: 255 | answer = sio_tcp.write(request) 256 | assert sio_tcp.connected() 257 | assert sio_tcp.connection_counter == 1 258 | assert answer is None 259 | gen = sio_tcp.readlines(len(expected)) 260 | reply = [line for line in gen] 261 | assert expected == reply 262 | 263 | 264 | def test_read(sio_tcp): 265 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 266 | answer = sio_tcp.write(request) 267 | assert sio_tcp.connected() 268 | assert sio_tcp.connection_counter == 1 269 | assert answer is None 270 | reply, n = b"", 0 271 | while len(reply) < len(expected) and n < 2: 272 | reply += sio_tcp.read(1024) 273 | n += 1 274 | assert expected == reply 275 | -------------------------------------------------------------------------------- /sockio/aio.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import socket 3 | import asyncio 4 | import functools 5 | import threading 6 | import urllib.parse 7 | 8 | from .common import IPTOS_LOWDELAY, DEFAULT_LIMIT, ConnectionEOFError, ConnectionTimeoutError, log 9 | 10 | 11 | _PY_37 = sys.version_info >= (3, 7) 12 | 13 | _LOCK = threading.Lock() 14 | 15 | DFT_KEEP_ALIVE = dict(active=1, idle=60, retry=3, interval=10) 16 | 17 | 18 | def ensure_connection(f): 19 | assert asyncio.iscoroutinefunction(f) 20 | name = f.__name__ 21 | 22 | @functools.wraps(f) 23 | async def wrapper(self, *args, **kwargs): 24 | if self._lock is None: 25 | with _LOCK: 26 | if self._lock is None: 27 | self._lock = asyncio.Lock() 28 | timeout = kwargs.pop("timeout", self.timeout) 29 | async with self._lock: 30 | if self.auto_reconnect and not self.connected(): 31 | await self.open() 32 | coro = f(self, *args, **kwargs) 33 | if timeout is not None: 34 | coro = asyncio.wait_for(coro, timeout) 35 | try: 36 | return await coro 37 | except asyncio.TimeoutError as error: 38 | msg = "{} call timeout on '{}:{}'".format(name, self.host, self.port) 39 | raise ConnectionTimeoutError(msg) from error 40 | 41 | return wrapper 42 | 43 | 44 | def raw_handle_read(f): 45 | assert asyncio.iscoroutinefunction(f) 46 | 47 | @functools.wraps(f) 48 | async def wrapper(self, *args, **kwargs): 49 | try: 50 | reply = await f(self, *args, **kwargs) 51 | except BaseException: 52 | await self.close() 53 | raise 54 | if not reply: 55 | await self.close() 56 | raise ConnectionEOFError("Connection closed by peer") 57 | return reply 58 | 59 | return wrapper 60 | 61 | 62 | class StreamReaderProtocol(asyncio.StreamReaderProtocol): 63 | 64 | def connection_lost(self, exc): 65 | result = super().connection_lost(exc) 66 | self._exec_callback("connection_lost_cb", exc) 67 | return result 68 | 69 | def eof_received(self): 70 | result = super().eof_received() 71 | self._exec_callback("eof_received_cb") 72 | return result 73 | 74 | def _exec_callback(self, name, *args, **kwargs): 75 | callback = getattr(self, name, None) 76 | if callback is None: 77 | return 78 | try: 79 | res = callback(*args, **kwargs) 80 | if asyncio.iscoroutine(res): 81 | self._loop.create_task(res) 82 | except Exception: 83 | log.exception("Error in %s callback %r", name, callback.__name__) 84 | 85 | 86 | class StreamReader(asyncio.StreamReader): 87 | 88 | async def readline(self, eol=b"\n"): 89 | # This implementation is a copy of the asyncio.StreamReader.readline() 90 | # with the purpose of supporting different EOL characters. 91 | # we walk on thin ice here: we rely on the internal _buffer and 92 | # _maybe_resume_transport members 93 | try: 94 | line = await self.readuntil(eol) 95 | except asyncio.IncompleteReadError as e: 96 | return e.partial 97 | except asyncio.LimitOverrunError as e: 98 | if self._buffer.startswith(eol, e.consumed): 99 | del self._buffer[: e.consumed + len(eol)] 100 | else: 101 | self._buffer.clear() 102 | self._maybe_resume_transport() 103 | raise ValueError(e.args[0]) 104 | return line 105 | 106 | def __len__(self): 107 | return len(self._buffer) 108 | 109 | def reset(self): 110 | self._buffer.clear() 111 | 112 | 113 | def configure_socket(sock, no_delay=True, tos=IPTOS_LOWDELAY, keep_alive=DFT_KEEP_ALIVE): 114 | if hasattr(socket, "TCP_NODELAY") and no_delay: 115 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 116 | if hasattr(socket, "IP_TOS"): 117 | sock.setsockopt(socket.SOL_IP, socket.IP_TOS, tos) 118 | if keep_alive is not None and hasattr(socket, "SO_KEEPALIVE"): 119 | if isinstance(keep_alive, (int, bool)): 120 | keep_alive = dict(active=1 if keep_alive in {1, True} else False) 121 | active = keep_alive.get('active') 122 | idle = keep_alive.get('idle') # aka keepalive_time 123 | interval = keep_alive.get('interval') # aka keepalive_intvl 124 | retry = keep_alive.get('retry') # aka keepalive_probes 125 | if active is not None: 126 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, active) 127 | if idle is not None: 128 | sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, idle) 129 | if interval is not None: 130 | sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, interval) 131 | if retry is not None: 132 | sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, retry) 133 | 134 | 135 | async def open_connection( 136 | host=None, 137 | port=None, 138 | loop=None, 139 | limit=DEFAULT_LIMIT, 140 | flags=0, 141 | on_connection_lost=None, 142 | on_eof_received=None, 143 | no_delay=True, 144 | tos=IPTOS_LOWDELAY, 145 | keep_alive=DFT_KEEP_ALIVE 146 | ): 147 | if loop is None: 148 | loop = asyncio.get_event_loop() 149 | reader = StreamReader(limit=limit, loop=loop) 150 | protocol = StreamReaderProtocol(reader, loop=loop) 151 | protocol.connection_lost_cb = on_connection_lost 152 | protocol.eof_received_cb = on_eof_received 153 | transport, _ = await loop.create_connection( 154 | lambda: protocol, host, port, flags=flags 155 | ) 156 | writer = asyncio.StreamWriter(transport, protocol, reader, loop) 157 | sock = writer.transport.get_extra_info("socket") 158 | configure_socket(sock, no_delay=no_delay, tos=tos, keep_alive=keep_alive) 159 | return reader, writer 160 | 161 | 162 | class BaseStream: 163 | """Base asynchronous iterator stream helper for TCP connections""" 164 | 165 | def __init__(self, tcp): 166 | self.tcp = tcp 167 | 168 | async def _read(self): 169 | raise NotImplementedError 170 | 171 | def __aiter__(self): 172 | return self 173 | 174 | async def __anext__(self): 175 | try: 176 | return await self._read() 177 | except ConnectionEOFError: 178 | raise StopAsyncIteration 179 | 180 | 181 | class LineStream(BaseStream): 182 | """Line based asynchronous iterator stream helper for TCP connections""" 183 | 184 | def __init__(self, tcp, eol=None): 185 | super().__init__(tcp) 186 | self.eol = eol 187 | 188 | async def _read(self): 189 | return await self.tcp.readline(eol=self.eol) 190 | 191 | 192 | class BlockStream(BaseStream): 193 | """ 194 | Fixed based asynchronous iterator stream helper for TCP connections. 195 | 196 | - If limit is an int the block is of fixed size 197 | (TCP.readexactly semantics) 198 | - If limit is a string, block ends when the limit string is 199 | found (TCP.readuntil semantics) 200 | """ 201 | 202 | def __init__(self, tcp, limit): 203 | super().__init__(tcp) 204 | self.limit = limit 205 | 206 | async def _read(self): 207 | try: 208 | if isinstance(self.limit, int): 209 | return await self.tcp.readexactly(self.limit) 210 | else: 211 | return await self.tcp.readuntil(self.limit) 212 | except asyncio.IncompleteReadError as error: 213 | if error.partial: 214 | raise 215 | else: 216 | raise ConnectionEOFError() 217 | 218 | 219 | class TCP: 220 | def __init__( 221 | self, 222 | host, 223 | port, 224 | eol=b"\n", 225 | auto_reconnect=True, 226 | on_connection_made=None, 227 | on_connection_lost=None, 228 | on_eof_received=None, 229 | buffer_size=DEFAULT_LIMIT, 230 | no_delay=True, 231 | tos=IPTOS_LOWDELAY, 232 | connection_timeout=None, 233 | timeout=None, 234 | keep_alive=DFT_KEEP_ALIVE, 235 | ): 236 | self.host = host 237 | self.port = port 238 | self.eol = eol 239 | self.buffer_size = buffer_size 240 | self.auto_reconnect = auto_reconnect 241 | self.connection_counter = 0 242 | self.on_connection_made = on_connection_made 243 | self.on_connection_lost = on_connection_lost 244 | self.on_eof_received = on_eof_received 245 | self.no_delay = no_delay 246 | self.tos = tos 247 | self.connection_timeout = connection_timeout 248 | self.timeout = timeout 249 | self.keep_alive = keep_alive 250 | self.reader = None 251 | self.writer = None 252 | self._lock = None 253 | self._log = log.getChild("TCP({}:{})".format(host, port)) 254 | 255 | def __del__(self): 256 | if self.writer is not None: 257 | loop = self.writer._loop # !watch out: access internal stream loop 258 | if loop is not None and not loop.is_closed(): 259 | self.writer.close() 260 | else: 261 | self._log.info("could not close stream: loop closed") 262 | 263 | def __aiter__(self): 264 | return LineStream(self) 265 | 266 | async def open(self, **kwargs): 267 | connection_timeout = kwargs.get("timeout", self.connection_timeout) 268 | if self.connected(): 269 | raise ConnectionError("socket already open") 270 | self._log.debug("open connection (#%d)", self.connection_counter + 1) 271 | # make sure everything is clean before creating a new connection 272 | await self.close() 273 | coro = open_connection( 274 | self.host, 275 | self.port, 276 | limit=self.buffer_size, 277 | on_connection_lost=self.on_connection_lost, 278 | on_eof_received=self.on_eof_received, 279 | no_delay=self.no_delay, 280 | tos=self.tos, 281 | keep_alive=self.keep_alive 282 | ) 283 | if connection_timeout is not None: 284 | coro = asyncio.wait_for(coro, connection_timeout) 285 | 286 | try: 287 | self.reader, self.writer = await coro 288 | except asyncio.TimeoutError: 289 | addr = self.host, self.port 290 | raise ConnectionTimeoutError("Connect call timeout on {}".format(addr)) 291 | 292 | if self.on_connection_made is not None: 293 | try: 294 | res = self.on_connection_made() 295 | if asyncio.iscoroutine(res): 296 | await res 297 | except Exception: 298 | log.exception( 299 | "Error in connection_made callback %r", 300 | self.on_connection_made.__name__, 301 | ) 302 | self.connection_counter += 1 303 | 304 | async def close(self): 305 | try: 306 | if self.writer is not None: 307 | self.writer.close() 308 | if _PY_37: 309 | await self.writer.wait_closed() 310 | finally: 311 | self.reader = None 312 | self.writer = None 313 | 314 | def in_waiting(self): 315 | return len(self.reader) if self.connected() else 0 316 | 317 | def connected(self): 318 | return self.reader is not None and not self.at_eof() 319 | 320 | is_open = property(connected) 321 | 322 | def at_eof(self): 323 | return self.reader is not None and self.reader.at_eof() 324 | 325 | @raw_handle_read 326 | async def _read(self, n=-1): 327 | return await self.reader.read(n) 328 | 329 | @raw_handle_read 330 | async def _readexactly(self, n): 331 | return await self.reader.readexactly(n) 332 | 333 | @raw_handle_read 334 | async def _readuntil(self, separator=b"\n"): 335 | return await self.reader.readuntil(separator) 336 | 337 | @raw_handle_read 338 | async def _readline(self, eol=None): 339 | if eol is None: 340 | eol = self.eol 341 | return await self.reader.readline(eol=eol) 342 | 343 | @raw_handle_read 344 | async def _readlines(self, n, eol=None): 345 | if eol is None: 346 | eol = self.eol 347 | replies = [] 348 | for i in range(n): 349 | reply = await self.reader.readline(eol=eol) 350 | replies.append(reply) 351 | return replies 352 | 353 | async def _write(self, data): 354 | try: 355 | self.writer.write(data) 356 | await self.writer.drain() 357 | except ConnectionError: 358 | await self.close() 359 | raise 360 | 361 | async def _writelines(self, lines): 362 | try: 363 | self.writer.writelines(lines) 364 | await self.writer.drain() 365 | except ConnectionError: 366 | await self.close() 367 | raise 368 | 369 | @ensure_connection 370 | async def read(self, n=-1): 371 | return await self._read(n) 372 | 373 | @ensure_connection 374 | async def readline(self, eol=None): 375 | return await self._readline(eol=eol) 376 | 377 | @ensure_connection 378 | async def readlines(self, n, eol=None): 379 | return await self._readlines(n, eol=eol) 380 | 381 | @ensure_connection 382 | async def readexactly(self, n): 383 | return await self._readexactly(n) 384 | 385 | @ensure_connection 386 | async def readuntil(self, separator=b"\n"): 387 | return await self._readuntil(separator) 388 | 389 | @ensure_connection 390 | async def readbuffer(self): 391 | """Read all bytes currently available in the underlying buffer""" 392 | size = self.in_waiting() 393 | return (await self._read(size)) if size else b"" 394 | 395 | @ensure_connection 396 | async def write(self, data): 397 | return await self._write(data) 398 | 399 | @ensure_connection 400 | async def writelines(self, lines): 401 | return await self._writelines(lines) 402 | 403 | @ensure_connection 404 | async def write_read(self, data, n=-1): 405 | await self._write(data) 406 | return await self._read(n=n) 407 | 408 | @ensure_connection 409 | async def write_readline(self, data, eol=None): 410 | await self._write(data) 411 | return await self._readline(eol=eol) 412 | 413 | @ensure_connection 414 | async def write_readlines(self, data, n, eol=None): 415 | await self._write(data) 416 | return await self._readlines(n, eol=eol) 417 | 418 | @ensure_connection 419 | async def writelines_readlines(self, lines, n=None, eol=None): 420 | if n is None: 421 | n = len(lines) 422 | await self._writelines(lines) 423 | return await self._readlines(n, eol=eol) 424 | 425 | def reset_input_buffer(self): 426 | if self.connected(): 427 | self.reader.reset() 428 | 429 | 430 | def socket_for_url(url, *args, **kwargs): 431 | addr = urllib.parse.urlparse(url) 432 | scheme = addr.scheme 433 | if scheme == "tcp": 434 | return TCP(addr.hostname, addr.port, *args, **kwargs) 435 | raise ValueError("unsupported async scheme {!r} for {}".format(scheme, url)) 436 | -------------------------------------------------------------------------------- /tests/test_aio.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio.subprocess 3 | 4 | import pytest 5 | 6 | from sockio.aio import ( 7 | TCP, 8 | ConnectionTimeoutError, 9 | ConnectionEOFError, 10 | LineStream, 11 | BlockStream, 12 | socket_for_url 13 | ) 14 | 15 | from conftest import IDN_REQ, IDN_REP, WRONG_REQ, WRONG_REP 16 | 17 | 18 | def test_socket_creation(): 19 | sock = TCP("example.com", 34567) 20 | assert sock.host == "example.com" 21 | assert sock.port == 34567 22 | assert sock.connection_timeout is None 23 | assert sock.timeout is None 24 | assert sock.auto_reconnect 25 | assert not sock.connected() 26 | assert sock.in_waiting() == 0 27 | assert sock.connection_counter == 0 28 | 29 | 30 | @pytest.mark.asyncio 31 | async def test_open_fail(unused_tcp_port): 32 | sock = TCP("0", unused_tcp_port) 33 | assert not sock.connected() 34 | assert sock.connection_counter == 0 35 | 36 | with pytest.raises(ConnectionRefusedError): 37 | await sock.open() 38 | assert not sock.connected() 39 | assert sock.connection_counter == 0 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_open_timeout(): 44 | timeout = 0.1 45 | # TODO: Not cool to use an external connection 46 | aio_tcp = TCP("www.google.com", 81, connection_timeout=timeout) 47 | with pytest.raises(ConnectionTimeoutError): 48 | start = time.time() 49 | try: 50 | await aio_tcp.open() 51 | finally: 52 | dt = time.time() - start 53 | assert dt > timeout and dt < (timeout + 0.05) 54 | 55 | # TODO: Not cool to use an external connection 56 | aio_tcp = TCP("www.google.com", 82) 57 | with pytest.raises(ConnectionTimeoutError): 58 | start = time.time() 59 | try: 60 | await aio_tcp.open(timeout=timeout) 61 | finally: 62 | dt = time.time() - start 63 | assert dt > timeout and dt < (timeout + 0.05) 64 | 65 | 66 | @pytest.mark.asyncio 67 | async def test_write_fail(unused_tcp_port): 68 | sock = TCP("0", unused_tcp_port) 69 | assert not sock.connected() 70 | assert sock.connection_counter == 0 71 | 72 | with pytest.raises(ConnectionRefusedError): 73 | await sock.write(IDN_REQ) 74 | assert not sock.connected() 75 | assert sock.in_waiting() == 0 76 | assert sock.connection_counter == 0 77 | 78 | 79 | @pytest.mark.asyncio 80 | async def test_write_read_fail(unused_tcp_port): 81 | sock = TCP("0", unused_tcp_port) 82 | assert not sock.connected() 83 | assert sock.connection_counter == 0 84 | 85 | with pytest.raises(ConnectionRefusedError): 86 | await sock.write_read(IDN_REQ) 87 | assert not sock.connected() 88 | assert sock.in_waiting() == 0 89 | assert sock.connection_counter == 0 90 | 91 | 92 | @pytest.mark.asyncio 93 | async def test_write_readline_fail(unused_tcp_port): 94 | sock = TCP("0", unused_tcp_port) 95 | assert not sock.connected() 96 | assert sock.connection_counter == 0 97 | 98 | with pytest.raises(ConnectionRefusedError): 99 | await sock.write_readline(IDN_REQ) 100 | assert not sock.connected() 101 | assert sock.in_waiting() == 0 102 | assert sock.connection_counter == 0 103 | 104 | 105 | @pytest.mark.asyncio 106 | async def test_write_readline_error(aio_server, aio_tcp): 107 | with pytest.raises(ConnectionEOFError): 108 | await aio_tcp.write_readline(b"kill\n") 109 | 110 | 111 | @pytest.mark.asyncio 112 | async def test_open_close(aio_server, aio_tcp): 113 | assert not aio_tcp.connected() 114 | assert aio_tcp.connection_counter == 0 115 | assert aio_server.sockets[0].getsockname() == (aio_tcp.host, aio_tcp.port) 116 | 117 | await aio_tcp.open() 118 | assert aio_tcp.connected() 119 | assert aio_tcp.connection_counter == 1 120 | return 121 | with pytest.raises(ConnectionError): 122 | await aio_tcp.open() 123 | assert aio_tcp.connected() 124 | assert aio_tcp.connection_counter == 1 125 | 126 | await aio_tcp.close() 127 | assert not aio_tcp.connected() 128 | assert aio_tcp.connection_counter == 1 129 | await aio_tcp.open() 130 | assert aio_tcp.connected() 131 | assert aio_tcp.connection_counter == 2 132 | await aio_tcp.close() 133 | await aio_tcp.close() 134 | assert not aio_tcp.connected() 135 | assert aio_tcp.connection_counter == 2 136 | 137 | 138 | @pytest.mark.asyncio 139 | async def test_callbacks(aio_server): 140 | host, port = aio_server.sockets[0].getsockname() 141 | state = dict(made=0, lost=0, eof=0) 142 | 143 | def made(): 144 | state["made"] += 1 145 | 146 | def lost(exc): 147 | state["lost"] += 1 148 | 149 | def eof(): 150 | state["eof"] += 1 151 | 152 | aio_tcp = TCP( 153 | host, 154 | port, 155 | on_connection_made=made, 156 | on_connection_lost=lost, 157 | on_eof_received=eof, 158 | ) 159 | assert not aio_tcp.connected() 160 | assert aio_tcp.connection_counter == 0 161 | assert state["made"] == 0 162 | assert state["lost"] == 0 163 | assert state["eof"] == 0 164 | 165 | await aio_tcp.open() 166 | assert aio_tcp.connected() 167 | assert aio_tcp.connection_counter == 1 168 | assert state["made"] == 1 169 | assert state["lost"] == 0 170 | assert state["eof"] == 0 171 | 172 | with pytest.raises(ConnectionError): 173 | await aio_tcp.open() 174 | assert aio_tcp.connected() 175 | assert aio_tcp.connection_counter == 1 176 | assert state["made"] == 1 177 | assert state["lost"] == 0 178 | assert state["eof"] == 0 179 | 180 | await aio_tcp.close() 181 | assert not aio_tcp.connected() 182 | assert aio_tcp.connection_counter == 1 183 | assert state["made"] == 1 184 | assert state["lost"] == 1 185 | assert state["eof"] == 0 186 | 187 | await aio_tcp.open() 188 | assert aio_tcp.connected() 189 | assert aio_tcp.connection_counter == 2 190 | assert state["made"] == 2 191 | assert state["lost"] == 1 192 | assert state["eof"] == 0 193 | 194 | await aio_tcp.close() 195 | assert not aio_tcp.connected() 196 | assert aio_tcp.connection_counter == 2 197 | assert state["made"] == 2 198 | assert state["lost"] == 2 199 | assert state["eof"] == 0 200 | 201 | await aio_tcp.close() 202 | assert not aio_tcp.connected() 203 | assert aio_tcp.connection_counter == 2 204 | assert state["made"] == 2 205 | assert state["lost"] == 2 206 | assert state["eof"] == 0 207 | 208 | 209 | @pytest.mark.asyncio 210 | async def test_coroutine_callbacks(aio_server): 211 | host, port = aio_server.sockets[0].getsockname() 212 | RESP_TIME = 0.02 213 | state = dict(made=0, lost=0, eof=0) 214 | 215 | async def made(): 216 | await asyncio.sleep(RESP_TIME) 217 | state["made"] += 1 218 | 219 | async def lost(exc): 220 | await asyncio.sleep(RESP_TIME) 221 | state["lost"] += 1 222 | 223 | async def eof(): 224 | await asyncio.sleep(RESP_TIME) 225 | state["eof"] += 1 226 | 227 | aio_tcp = TCP( 228 | host, 229 | port, 230 | on_connection_made=made, 231 | on_connection_lost=lost, 232 | on_eof_received=eof, 233 | ) 234 | 235 | assert not aio_tcp.connected() 236 | assert aio_tcp.connection_counter == 0 237 | assert state["made"] == 0 238 | assert state["lost"] == 0 239 | assert state["eof"] == 0 240 | 241 | await aio_tcp.open() 242 | assert aio_tcp.connected() 243 | assert aio_tcp.connection_counter == 1 244 | assert state["made"] == 1 245 | assert state["lost"] == 0 246 | assert state["eof"] == 0 247 | 248 | with pytest.raises(ConnectionError): 249 | await aio_tcp.open() 250 | assert aio_tcp.connected() 251 | assert aio_tcp.connection_counter == 1 252 | assert state["made"] == 1 253 | assert state["lost"] == 0 254 | assert state["eof"] == 0 255 | 256 | await aio_tcp.close() 257 | assert not aio_tcp.connected() 258 | assert aio_tcp.connection_counter == 1 259 | assert state["made"] == 1 260 | assert state["lost"] == 0 261 | assert state["eof"] == 0 262 | await asyncio.sleep(RESP_TIME + 0.01) 263 | assert state["made"] == 1 264 | assert state["lost"] == 1 265 | assert state["eof"] == 0 266 | 267 | await aio_tcp.open() 268 | assert aio_tcp.connected() 269 | assert aio_tcp.connection_counter == 2 270 | assert state["made"] == 2 271 | assert state["lost"] == 1 272 | assert state["eof"] == 0 273 | 274 | await aio_tcp.close() 275 | assert not aio_tcp.connected() 276 | assert aio_tcp.connection_counter == 2 277 | assert state["made"] == 2 278 | assert state["lost"] == 1 279 | assert state["eof"] == 0 280 | await asyncio.sleep(RESP_TIME + 0.01) 281 | assert state["made"] == 2 282 | assert state["lost"] == 2 283 | assert state["eof"] == 0 284 | 285 | await aio_tcp.close() 286 | assert not aio_tcp.connected() 287 | assert aio_tcp.connection_counter == 2 288 | assert state["made"] == 2 289 | assert state["lost"] == 2 290 | assert state["eof"] == 0 291 | await asyncio.sleep(RESP_TIME + 0.01) 292 | assert state["made"] == 2 293 | assert state["lost"] == 2 294 | assert state["eof"] == 0 295 | 296 | 297 | @pytest.mark.asyncio 298 | async def test_error_callback(aio_server): 299 | host, port = aio_server.sockets[0].getsockname() 300 | 301 | state = dict(made=0) 302 | 303 | def error_callback(): 304 | state["made"] += 1 305 | raise RuntimeError("cannot handle this") 306 | 307 | aio_tcp = TCP(host, port, on_connection_made=error_callback) 308 | 309 | assert not aio_tcp.connected() 310 | assert aio_tcp.connection_counter == 0 311 | assert state["made"] == 0 312 | 313 | await aio_tcp.open() 314 | assert aio_tcp.connected() 315 | assert aio_tcp.connection_counter == 1 316 | assert state["made"] == 1 317 | 318 | 319 | @pytest.mark.asyncio 320 | async def test_eof_callback(aio_server): 321 | host, port = aio_server.sockets[0].getsockname() 322 | state = dict(made=0, lost=0, eof=0) 323 | 324 | def made(): 325 | state["made"] += 1 326 | 327 | def lost(exc): 328 | state["lost"] += 1 329 | 330 | def eof(): 331 | state["eof"] += 1 332 | 333 | aio_tcp = TCP( 334 | host, 335 | port, 336 | on_connection_made=made, 337 | on_connection_lost=lost, 338 | on_eof_received=eof, 339 | ) 340 | assert not aio_tcp.connected() 341 | assert aio_tcp.connection_counter == 0 342 | assert state["made"] == 0 343 | assert state["lost"] == 0 344 | assert state["eof"] == 0 345 | 346 | await aio_tcp.open() 347 | assert aio_tcp.connected() 348 | assert aio_tcp.connection_counter == 1 349 | assert state["made"] == 1 350 | assert state["lost"] == 0 351 | assert state["eof"] == 0 352 | 353 | await aio_server.stop() 354 | await asyncio.sleep(0.01) # give time for connection to be closed 355 | 356 | assert state["made"] == 1 357 | assert state["lost"] == 0 358 | assert state["eof"] == 1 359 | 360 | 361 | @pytest.mark.asyncio 362 | async def test_write_read(aio_tcp): 363 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 364 | coro = aio_tcp.write_read(request, 1024) 365 | assert asyncio.iscoroutine(coro) 366 | reply = await coro 367 | assert aio_tcp.connected() 368 | assert aio_tcp.connection_counter == 1 369 | assert expected == reply 370 | 371 | 372 | @pytest.mark.asyncio 373 | async def test_write_readline(aio_tcp): 374 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 375 | coro = aio_tcp.write_readline(request) 376 | assert asyncio.iscoroutine(coro) 377 | reply = await coro 378 | assert aio_tcp.connected() 379 | assert aio_tcp.connection_counter == 1 380 | assert expected == reply 381 | 382 | 383 | @pytest.mark.asyncio 384 | async def test_write_readlines(aio_tcp): 385 | for request, expected in [ 386 | (IDN_REQ, [IDN_REP]), 387 | (2 * IDN_REQ, 2 * [IDN_REP]), 388 | (IDN_REQ + WRONG_REQ, [IDN_REP, WRONG_REP]), 389 | ]: 390 | coro = aio_tcp.write_readlines(request, len(expected)) 391 | assert asyncio.iscoroutine(coro) 392 | reply = await coro 393 | assert aio_tcp.connected() 394 | assert aio_tcp.connection_counter == 1 395 | assert expected == reply 396 | 397 | 398 | @pytest.mark.asyncio 399 | async def test_writelines_readlines(aio_tcp): 400 | for request, expected in [ 401 | ([IDN_REQ], [IDN_REP]), 402 | (2 * [IDN_REQ], 2 * [IDN_REP]), 403 | ([IDN_REQ, WRONG_REQ], [IDN_REP, WRONG_REP]), 404 | ]: 405 | coro = aio_tcp.writelines_readlines(request) 406 | assert asyncio.iscoroutine(coro) 407 | reply = await coro 408 | assert aio_tcp.connected() 409 | assert aio_tcp.connection_counter == 1 410 | assert expected == reply 411 | 412 | 413 | @pytest.mark.asyncio 414 | async def test_writelines(aio_tcp): 415 | for request, expected in [ 416 | ([IDN_REQ], [IDN_REP]), 417 | (2 * [IDN_REQ], 2 * [IDN_REP]), 418 | ([IDN_REQ, WRONG_REQ], [IDN_REP, WRONG_REP]), 419 | ]: 420 | coro = aio_tcp.writelines(request) 421 | assert asyncio.iscoroutine(coro) 422 | answer = await coro 423 | assert aio_tcp.connected() 424 | assert aio_tcp.connection_counter == 1 425 | assert answer is None 426 | 427 | coro = aio_tcp.readlines(len(expected)) 428 | assert asyncio.iscoroutine(coro) 429 | reply = await coro 430 | assert aio_tcp.connected() 431 | assert aio_tcp.connection_counter == 1 432 | assert expected == reply 433 | 434 | 435 | @pytest.mark.asyncio 436 | async def test_readline(aio_tcp): 437 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 438 | coro = aio_tcp.write(request) 439 | assert asyncio.iscoroutine(coro) 440 | answer = await coro 441 | assert aio_tcp.connected() 442 | assert aio_tcp.connection_counter == 1 443 | assert answer is None 444 | await asyncio.sleep(0.05) 445 | assert aio_tcp.in_waiting() > 0 446 | coro = aio_tcp.readline() 447 | assert asyncio.iscoroutine(coro) 448 | reply = await coro 449 | assert expected == reply 450 | 451 | 452 | @pytest.mark.asyncio 453 | async def test_readuntil(aio_tcp): 454 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 455 | coro = aio_tcp.write(request) 456 | assert asyncio.iscoroutine(coro) 457 | answer = await coro 458 | assert aio_tcp.connected() 459 | assert aio_tcp.connection_counter == 1 460 | assert answer is None 461 | coro = aio_tcp.readuntil(b"\n") 462 | assert asyncio.iscoroutine(coro) 463 | reply = await coro 464 | assert expected == reply 465 | 466 | 467 | @pytest.mark.asyncio 468 | async def test_readexactly(aio_tcp): 469 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 470 | coro = aio_tcp.write(request) 471 | assert asyncio.iscoroutine(coro) 472 | answer = await coro 473 | assert aio_tcp.connected() 474 | assert aio_tcp.connection_counter == 1 475 | assert answer is None 476 | coro = aio_tcp.readexactly(len(expected) - 5) 477 | assert asyncio.iscoroutine(coro) 478 | reply = await coro 479 | assert expected[:-5] == reply 480 | coro = aio_tcp.readexactly(5) 481 | assert asyncio.iscoroutine(coro) 482 | reply = await coro 483 | assert expected[-5:] == reply 484 | 485 | 486 | @pytest.mark.asyncio 487 | async def test_readlines(aio_tcp): 488 | for request, expected in [ 489 | (IDN_REQ, [IDN_REP]), 490 | (2 * IDN_REQ, 2 * [IDN_REP]), 491 | (IDN_REQ + WRONG_REQ, [IDN_REP, WRONG_REP]), 492 | ]: 493 | coro = aio_tcp.write(request) 494 | assert asyncio.iscoroutine(coro) 495 | answer = await coro 496 | assert aio_tcp.connected() 497 | assert aio_tcp.connection_counter == 1 498 | assert answer is None 499 | coro = aio_tcp.readlines(len(expected)) 500 | assert asyncio.iscoroutine(coro) 501 | reply = await coro 502 | assert expected == reply 503 | 504 | 505 | @pytest.mark.asyncio 506 | async def test_read(aio_tcp): 507 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 508 | coro = aio_tcp.write(request) 509 | assert asyncio.iscoroutine(coro) 510 | answer = await coro 511 | assert aio_tcp.connected() 512 | assert aio_tcp.connection_counter == 1 513 | assert answer is None 514 | reply, n = b"", 0 515 | while len(reply) < len(expected) and n < 2: 516 | coro = aio_tcp.read(1024) 517 | assert asyncio.iscoroutine(coro) 518 | reply += await coro 519 | n += 1 520 | assert expected == reply 521 | 522 | 523 | @pytest.mark.asyncio 524 | async def test_readbuffer(aio_tcp): 525 | for request, expected in [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)]: 526 | await aio_tcp.write(request) 527 | assert aio_tcp.connected() 528 | for i in range(10): 529 | if aio_tcp.in_waiting() >= len(expected): 530 | break 531 | await asyncio.sleep(0.001) 532 | reply = await aio_tcp.readbuffer() 533 | assert expected == reply 534 | 535 | 536 | @pytest.mark.asyncio 537 | async def test_parallel_rw(aio_tcp): 538 | async def wr(request, expected_reply): 539 | reply = await aio_tcp.write_readline(request) 540 | return request, reply, expected_reply 541 | 542 | args = 10 * [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)] 543 | coros = [wr(*arg) for arg in args] 544 | result = await asyncio.gather(*coros) 545 | for req, reply, expected in result: 546 | assert reply == expected, "Failed request {}".format(req) 547 | 548 | 549 | @pytest.mark.asyncio 550 | async def test_parallel(aio_tcp): 551 | async def wr(request, expected_reply): 552 | await aio_tcp.write(request) 553 | reply = await aio_tcp.readline() 554 | return request, reply, expected_reply 555 | 556 | args = 10 * [(IDN_REQ, IDN_REP), (WRONG_REQ, WRONG_REP)] 557 | coros = [wr(*arg) for arg in args] 558 | result = await asyncio.gather(*coros) 559 | for req, reply, expected in result: 560 | assert reply == expected, "Failed request {}".format(req) 561 | 562 | 563 | @pytest.mark.asyncio 564 | async def test_stream(aio_tcp): 565 | request = b"data? 2\n" 566 | await aio_tcp.write(request) 567 | i = 0 568 | async for line in aio_tcp: 569 | assert line == b"1.2345 5.4321 12345.54321\n" 570 | i += 1 571 | assert i == 2 572 | assert aio_tcp.connection_counter == 1 573 | assert not aio_tcp.connected() 574 | 575 | 576 | @pytest.mark.asyncio 577 | async def test_timeout(aio_tcp): 578 | timeout = 0.1 579 | reply = await aio_tcp.write_readline(IDN_REQ, timeout=timeout) 580 | assert reply == IDN_REP 581 | 582 | start = time.time() 583 | reply = await aio_tcp.write_readline(b"sleep 0.05\n") 584 | dt = time.time() - start 585 | assert dt > 0.05 586 | assert reply == b"OK\n" 587 | 588 | timeout = 0.1 589 | start = time.time() 590 | await aio_tcp.write_readline(b"sleep 0.05\n", timeout=timeout) 591 | dt = time.time() - start 592 | assert dt < timeout 593 | 594 | timeout = 0.09 595 | with pytest.raises(ConnectionTimeoutError): 596 | start = time.time() 597 | try: 598 | await aio_tcp.write_readline(b"sleep 1\n", timeout=timeout) 599 | finally: 600 | dt = time.time() - start 601 | assert dt > timeout and dt < (timeout + 0.05) 602 | 603 | await aio_tcp.close() 604 | 605 | 606 | @pytest.mark.asyncio 607 | async def test_line_stream(aio_tcp): 608 | request = b"data? 2\n" 609 | await aio_tcp.write(request) 610 | i = 0 611 | async for line in LineStream(aio_tcp): 612 | assert line == b"1.2345 5.4321 12345.54321\n" 613 | i += 1 614 | assert i == 2 615 | assert aio_tcp.connection_counter == 1 616 | assert not aio_tcp.connected() 617 | 618 | 619 | @pytest.mark.asyncio 620 | async def test_block_stream(aio_tcp): 621 | request = b"data? -5\n" 622 | await aio_tcp.write(request) 623 | i = 0 624 | async for line in BlockStream(aio_tcp, 12): 625 | assert line == "message {:04d}".format(i).encode() 626 | i += 1 627 | assert i == 5 628 | assert aio_tcp.connection_counter == 1 629 | assert not aio_tcp.connected() 630 | 631 | 632 | @pytest.mark.asyncio 633 | async def test_socket_for_url(aio_server): 634 | host, port = aio_server.sockets[0].getsockname() 635 | 636 | with pytest.raises(ValueError): 637 | socket_for_url("udp://{}:{}".format(host, port)) 638 | 639 | aio_tcp = socket_for_url("tcp://{}:{}".format(host, port)) 640 | 641 | assert not aio_tcp.connected() 642 | assert aio_tcp.connection_counter == 0 643 | 644 | await aio_tcp.open() 645 | assert aio_tcp.connected() 646 | assert aio_tcp.connection_counter == 1 647 | 648 | coro = aio_tcp.write_readline(IDN_REQ) 649 | assert asyncio.iscoroutine(coro) 650 | reply = await coro 651 | assert aio_tcp.connected() 652 | assert aio_tcp.connection_counter == 1 653 | assert reply == IDN_REP 654 | -------------------------------------------------------------------------------- /demo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 66 | 87 | 88 | 89 | client $ ──────────────────────────────────────────────────────────────────────────────────────────────────server $ [0] 0:bash* "coutinhocells" 15:13 18-Oct-19client $ ──────────────────────────────────────────────────────────────────────────────────────────────────server $ server $ python examples/req-rep/server.py --log-level=debug --port=12345 server $ python examples/req-rep/server.py --log-level=debug --port=12345 2019-10-18 15:13:43,394 DEBUG: Using selector: EpollSelector 2019-10-18 15:13:43,397 INFO : started accepting requests on 0.0.0.0:12345 [0] 0:python* "coutinhocells" 15:13 18-Oct-19client $ python examples/req-rep/client.py client $ python examples/req-rep/client.py 2019-10-18 15:13:46,433 INFO : client connected from ('127.0.0.1', 49938) Server replies with: b'ACME, bla ble ble, 1234, 5678\n' Looks like the server is running. Great! Now, please restart the server... 2019-10-18 15:13:46,433 DEBUG: recv b'*idn?\n' 2019-10-18 15:13:46,433 DEBUG: send b'ACME, bla ble ble, 1234, 5678\n' ^C Thanks for turning it off! You now have 5s to turn it back on again. ^C2019-10-18 15:13:51,895 INFO : Ctrl-C pressed. Bailing out! 2019-10-18 15:13:55,436 DEBUG: Using selector: EpollSelector2019-10-18 15:13:55,439 INFO : started accepting requests on 0.0.0.0:12345I will now try another request without explicitly reopening the socket It works! Server replies with: b'ACME, bla ble ble, 1234, 5678\n' 2019-10-18 15:13:56,908 INFO : client connected from ('127.0.0.1', 49948)2019-10-18 15:13:56,908 DEBUG: recv b'*idn?\n'2019-10-18 15:13:56,908 DEBUG: send b'ACME, bla ble ble, 1234, 5678\n'2019-10-18 15:13:56,909 DEBUG: recv b''2019-10-18 15:13:56,909 INFO : client ('127.0.0.1', 49948) disconnected2019-10-18 15:13:56,909 DEBUG: send b'ACME, bla ble ble, 1234, 5678\n'[0] 0:python* "coutinhocells" 15:14 18-Oct-19 [detached (from session 0)] 90 | --------------------------------------------------------------------------------