├── .coveragerc ├── .gitignore ├── .isort.cfg ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── async_asgi_testclient ├── __init__.py ├── compatibility.py ├── multipart.py ├── response.py ├── testing.py ├── tests │ └── test_testing.py ├── utils.py └── websocket.py ├── requirements.txt ├── setup.cfg ├── setup.py └── test-requirements.txt /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | async_asgi_testclient/tests/* 4 | async_asgi_testclient/compatibility.py 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | force_single_line=True 3 | sections=THIRDPARTY,FIRSTPARTY,LOCALFOLDER,STDLIB 4 | no_lines_before=LOCALFOLDER,THIRDPARTY,FIRSTPARTY,STDLIB 5 | force_alphabetical_sort=True 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: focal 2 | language: python 3 | python: 4 | - "3.6" 5 | - "3.7" 6 | - "3.8" 7 | - "3.9" 8 | - "3.10.0" 9 | # command to install dependencies 10 | install: 11 | - pip install codecov 12 | - pip install -e . 13 | - pip install -r test-requirements.txt 14 | # command to run tests 15 | script: 16 | - isort -c -rc async_asgi_testclient 17 | - black --check async_asgi_testclient 18 | - flake8 async_asgi_testclient 19 | - pytest --cov=async_asgi_testclient -v --cov-report term-missing async_asgi_testclient 20 | after_success: 21 | - codecov 22 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 1.4.11 2 | ----- 3 | - Added handling of absolute redirect paths 4 | [bentheiii] 5 | 6 | 1.4.10 7 | ----- 8 | - Bump an upper constraint for the multidict dependency 9 | [kleschenko] 10 | 11 | 1.4.9 12 | ----- 13 | - Fix websocket scope scheme 14 | [yanyongyu] 15 | 16 | 1.4.8 17 | ----- 18 | - fix Cookie header rendering in requests 19 | [druid8] [Paweł Pecio] 20 | 21 | 1.4.7 22 | ----- 23 | - Add support for Python 3.10 24 | 25 | 1.4.6 26 | ----- 27 | - Maintain hard references to tasks to prevent garbage collection 28 | [MatthewScholefield] 29 | 30 | 1.4.5 31 | ----- 32 | - Add support for Python 3.9 33 | [kleschenko] 34 | 35 | 1.4.4 36 | ----- 37 | - Fix WebSocketSession.receive_json() doesn't support bytes 38 | [masipcat] 39 | 40 | 1.4.3 41 | ----- 42 | - Send header Content-Length 43 | [masipcat] 44 | 45 | 1.4.2 46 | ----- 47 | - Fixed mypy annotation 48 | [masipcat] 49 | - Remove default dict for self.cookies attr 50 | [otsuka] 51 | 52 | 1.4.1 53 | ----- 54 | - Don't decode bytes to string to build multipart 55 | [masipcat] 56 | 57 | 1.4.0 58 | ----- 59 | - Added argument 'cookies' to `websocket_connect()` 60 | [masipcat] 61 | - Renamed `ws.send_str()` to `ws.send_text()` 62 | [masipcat] 63 | - Fix return type annotation of the methods invoking open() 64 | [otsuka] 65 | 66 | 1.3.0 67 | ----- 68 | - Add support for multipart/form-data 69 | [masipcat] 70 | 71 | 1.2.2 72 | ----- 73 | - Quote query_string by default 74 | [masipcat] 75 | 76 | 1.2.1 77 | ----- 78 | - Add client (remote peer) to scope 79 | [aviramha] 80 | 81 | 1.2.0 82 | ----- 83 | - Added support for Python 3.8 84 | [masipcat] 85 | - Updated test dependencies 86 | [masipcat] 87 | 88 | 1.1.3 89 | ----- 90 | - added default client headers 91 | [logileifs] 92 | 93 | 1.1.2 94 | ----- 95 | - Prevent PytestCollectionWarning 96 | [podhmo] 97 | 98 | 1.1.1 99 | ----- 100 | - fast work-around to make websocket query params works 101 | [grubberr] 102 | 103 | 1.1.0 104 | ----- 105 | - Relicensed library to MIT License 106 | [masipcat] 107 | 108 | 1.0.4 109 | ----- 110 | - ws: added safeguards to validate received ASGI message has expected type and fixed query_string default value 111 | [masipcat] 112 | 113 | 1.0.3 114 | ----- 115 | - Fix response with multime cookies 116 | [masipcat] 117 | 118 | 1.0.2 119 | ----- 120 | - Fix warning on Py37 and added 'timeout' in 'send_lifespan()' 121 | [masipcat] 122 | 123 | 1.0.1 124 | ----- 125 | - Unpinned dependencies 126 | [masipcat] 127 | 128 | 1.0.0 129 | ----- 130 | - Websocket client 131 | [dmanchon] 132 | 133 | 0.2.2 134 | ----- 135 | - Add 'allow_redirects' to TestClient.open(). Defaults to True 136 | [masipcat] 137 | 138 | 0.2.1 139 | ----- 140 | - Support Python 3.6 and small improvements 141 | [masipcat] 142 | 143 | 0.2.0 144 | ----- 145 | - Streams and redirects 146 | [masipcat] 147 | 148 | 0.1.3 149 | ----- 150 | - Improved cookies support 151 | [masipcat] 152 | 153 | 0.1.2 154 | ----- 155 | - flag on the testclient to catch unhandle server exceptions 156 | [jordic] 157 | 158 | 0.1 159 | --- 160 | - Initial version 161 | [masipcat] 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Jordi Masip 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | include test-requirements.txt 5 | 6 | recursive-include tests * 7 | recursive-exclude * __pycache__ 8 | recursive-exclude * *.py[co] 9 | 10 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # async-asgi-testclient 2 | 3 | [![Build Status](https://travis-ci.com/vinissimus/async-asgi-testclient.svg?branch=master)](https://travis-ci.com/vinissimus/async-asgi-testclient) [![PyPI version](https://badge.fury.io/py/async-asgi-testclient.svg)](https://badge.fury.io/py/async-asgi-testclient) ![](https://img.shields.io/pypi/pyversions/async-asgi-testclient.svg) [![Codcov](https://codecov.io/gh/vinissimus/async-asgi-testclient/branch/master/graph/badge.svg)](https://codecov.io/gh/vinissimus/async-asgi-testclient/branch/master) ![](https://img.shields.io/github/license/vinissimus/async-asgi-testclient) 4 | 5 | Async ASGI TestClient is a library for testing web applications that implements ASGI specification (version 2 and 3). 6 | 7 | The motivation behind this project is building a common testing library that doesn't depend on the web framework ([Quart](https://gitlab.com/pgjones/quart), [Startlette](https://github.com/encode/starlette), ...). 8 | 9 | It works by calling the ASGI app directly. This avoids the need to run the app with a http server in a different process/thread/asyncio-loop. Since the app and test run in the same asyncio loop, it's easier to write tests and debug code. 10 | 11 | This library is based on the testing module provided in [Quart](https://gitlab.com/pgjones/quart). 12 | 13 | ## Quickstart 14 | 15 | Requirements: Python 3.6+ 16 | 17 | Installation: 18 | 19 | ```bash 20 | pip install async-asgi-testclient 21 | ``` 22 | 23 | ## Usage 24 | 25 | `my_api.py`: 26 | ```python 27 | from quart import Quart, jsonify 28 | 29 | app = Quart(__name__) 30 | 31 | @app.route("/") 32 | async def root(): 33 | return "plain response" 34 | 35 | @app.route("/json") 36 | async def json(): 37 | return jsonify({"hello": "world"}) 38 | 39 | if __name__ == '__main__': 40 | app.run() 41 | ``` 42 | 43 | `test_app.py`: 44 | ```python 45 | from async_asgi_testclient import TestClient 46 | 47 | import pytest 48 | 49 | @pytest.mark.asyncio 50 | async def test_quart_app(): 51 | from .my_api import app 52 | 53 | async with TestClient(app) as client: 54 | resp = await client.get("/") 55 | assert resp.status_code == 200 56 | assert resp.text == "plain response" 57 | 58 | resp = await client.get("/json") 59 | assert resp.status_code == 200 60 | assert resp.json() == {"hello": "world"} 61 | ``` 62 | 63 | ## Supports 64 | 65 | - [X] cookies 66 | - [X] multipart/form-data 67 | - [X] follow redirects 68 | - [X] response streams 69 | - [X] request streams 70 | - [X] websocket support 71 | -------------------------------------------------------------------------------- /async_asgi_testclient/__init__.py: -------------------------------------------------------------------------------- 1 | from async_asgi_testclient.testing import TestClient # noqa 2 | -------------------------------------------------------------------------------- /async_asgi_testclient/compatibility.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Django Software Foundation and individual contributors. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of Django nor the names of its contributors may be used 16 | to endorse or promote products derived from this software without 17 | specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 26 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | import asyncio 31 | import inspect 32 | 33 | 34 | def is_double_callable(application): 35 | """ 36 | Tests to see if an application is a legacy-style (double-callable) application. 37 | """ 38 | # Look for a hint on the object first 39 | if getattr(application, "_asgi_single_callable", False): 40 | return False 41 | if getattr(application, "_asgi_double_callable", False): 42 | return True 43 | # Uninstanted classes are double-callable 44 | if inspect.isclass(application): 45 | return True 46 | # Instanted classes depend on their __call__ 47 | if hasattr(application, "__call__"): 48 | # We only check to see if its __call__ is a coroutine function - 49 | # if it's not, it still might be a coroutine function itself. 50 | if asyncio.iscoroutinefunction(application.__call__): 51 | return False 52 | # Non-classes we just check directly 53 | return not asyncio.iscoroutinefunction(application) 54 | 55 | 56 | def double_to_single_callable(application): 57 | """ 58 | Transforms a double-callable ASGI application into a single-callable one. 59 | """ 60 | 61 | async def new_application(scope, receive, send): 62 | instance = application(scope) 63 | return await instance(receive, send) 64 | 65 | return new_application 66 | 67 | 68 | def guarantee_single_callable(application): 69 | """ 70 | Takes either a single- or double-callable application and always returns it 71 | in single-callable style. Use this to add backwards compatibility for ASGI 72 | 2.0 applications to your server/test harness/etc. 73 | """ 74 | if is_double_callable(application): 75 | application = double_to_single_callable(application) 76 | return application 77 | -------------------------------------------------------------------------------- /async_asgi_testclient/multipart.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from typing import Tuple 3 | from typing import Union 4 | 5 | import binascii 6 | import os 7 | 8 | 9 | def encode_multipart_formdata( 10 | fields: Dict[str, Union[str, Tuple]] 11 | ) -> Tuple[bytes, str]: 12 | # Based on https://julien.danjou.info/handling-multipart-form-data-python/ 13 | boundary = binascii.hexlify(os.urandom(16)).decode("ascii") 14 | 15 | body = b"".join( 16 | build_part(boundary, field_name, file_tuple) 17 | for field_name, file_tuple in fields.items() 18 | ) + bytes(f"--{boundary}--\r\n", "ascii") 19 | 20 | content_type = f"multipart/form-data; boundary={boundary}" 21 | 22 | return body, content_type 23 | 24 | 25 | def build_part(boundary: str, field_name: str, file_tuple: Union[str, Tuple]) -> bytes: 26 | """ 27 | file_tuple: 28 | - 'string value' 29 | - (fileobj,) 30 | - ('filename', fileobj) 31 | - ('filename', fileobj, 'content_type') 32 | """ 33 | value = b"" 34 | filename = "" 35 | content_type = "" 36 | 37 | if isinstance(file_tuple, str): 38 | value = file_tuple.encode("ascii") 39 | else: 40 | if len(file_tuple) == 1: 41 | file_ = file_tuple[0] 42 | elif len(file_tuple) == 2: 43 | filename, file_ = file_tuple 44 | elif len(file_tuple) == 3: 45 | filename, file_, content_type = file_tuple 46 | value = file_.read() 47 | 48 | if isinstance(value, str): 49 | value = value.encode("ascii") 50 | 51 | part = f'--{boundary}\r\nContent-Disposition: form-data; name="{field_name}"' 52 | if filename: 53 | part += f'; filename="{filename}"' 54 | 55 | if content_type: 56 | part += f"\r\nContent-Type: {content_type}" 57 | 58 | return part.encode("ascii") + b"\r\n\r\n" + value + b"\r\n" 59 | -------------------------------------------------------------------------------- /async_asgi_testclient/response.py: -------------------------------------------------------------------------------- 1 | from requests.exceptions import StreamConsumedError 2 | from requests.models import Response as _Response 3 | from requests.utils import iter_slices 4 | from requests.utils import stream_decode_response_unicode 5 | 6 | import io 7 | 8 | 9 | class BytesRW(object): 10 | def __init__(self): 11 | self._stream = io.BytesIO() 12 | self._rpos = 0 13 | self._wpos = 0 14 | 15 | def read(self, size=-1): 16 | if self._stream is None: 17 | raise Exception("Stream is closed") 18 | self._stream.seek(self._rpos) 19 | bytes_ = self._stream.read(size) 20 | self._rpos += len(bytes_) 21 | return bytes_ 22 | 23 | def write(self, b): 24 | if self._stream is None: 25 | raise Exception("Stream is closed") 26 | self._stream.seek(self._wpos) 27 | n = self._stream.write(b) 28 | self._wpos += n 29 | return n 30 | 31 | def close(self): 32 | self._stream = None 33 | 34 | 35 | class Response(_Response): 36 | def __init__(self, stream: bool, receive, send): 37 | super().__init__() 38 | 39 | self.stream = stream 40 | self.receive_or_fail = receive 41 | self.send = send 42 | self._more_body = False 43 | self.raw = BytesRW() 44 | 45 | async def __aiter__(self): 46 | async for c in self.iter_content(128): 47 | yield c 48 | 49 | async def generate(self, n): 50 | while True: 51 | val = self.raw.read(n) 52 | if val == b"": # EOF 53 | break 54 | yield val 55 | 56 | while self._more_body: 57 | message = await self.receive_or_fail() 58 | if not isinstance(message, dict): 59 | raise Exception(f"Unexpected message {message}") 60 | if message["type"] != "http.response.body": 61 | raise Exception( 62 | f"Excpected message type 'http.response.body'. " f"Found {message}" 63 | ) 64 | 65 | yield message["body"] 66 | self._more_body = message.get("more_body", False) 67 | 68 | # Send disconnect 69 | self.send({"type": "http.disconnect"}) 70 | message = await self.receive_or_fail() 71 | assert message.event == "exit" 72 | 73 | async def iter_content(self, chunk_size=1, decode_unicode=False): 74 | if self._content_consumed and isinstance(self._content, bool): 75 | raise StreamConsumedError() 76 | elif chunk_size is not None and not isinstance(chunk_size, int): 77 | raise TypeError( 78 | "chunk_size must be an int, it is instead a %s." % type(chunk_size) 79 | ) 80 | 81 | # simulate reading small chunks of the content 82 | reused_chunks = iter_slices(self._content, chunk_size) 83 | stream_chunks = self.generate(chunk_size) 84 | chunks = reused_chunks if self._content_consumed else stream_chunks 85 | 86 | if decode_unicode: 87 | chunks = stream_decode_response_unicode(chunks, self) 88 | 89 | async for c in chunks: 90 | yield c 91 | -------------------------------------------------------------------------------- /async_asgi_testclient/testing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright P G Jones 2017. 3 | 4 | Permission is hereby granted, free of charge, to any person 5 | obtaining a copy of this software and associated documentation 6 | files (the "Software"), to deal in the Software without 7 | restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the 10 | Software is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 18 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 20 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 21 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 22 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 23 | OTHER DEALINGS IN THE SOFTWARE. 24 | """ 25 | from async_asgi_testclient.compatibility import guarantee_single_callable 26 | from async_asgi_testclient.multipart import encode_multipart_formdata 27 | from async_asgi_testclient.response import BytesRW 28 | from async_asgi_testclient.response import Response 29 | from async_asgi_testclient.utils import create_monitored_task 30 | from async_asgi_testclient.utils import flatten_headers 31 | from async_asgi_testclient.utils import is_last_one 32 | from async_asgi_testclient.utils import make_test_headers_path_and_query_string 33 | from async_asgi_testclient.utils import Message 34 | from async_asgi_testclient.utils import receive 35 | from async_asgi_testclient.utils import to_relative_path 36 | from async_asgi_testclient.websocket import WebSocketSession 37 | from functools import partial 38 | from http.cookies import SimpleCookie 39 | from json import dumps 40 | from multidict import CIMultiDict 41 | from typing import Any 42 | from typing import Optional 43 | from typing import Union 44 | from urllib.parse import urlencode 45 | 46 | import asyncio 47 | import inspect 48 | import requests 49 | 50 | sentinel = object() 51 | 52 | 53 | class TestClient: 54 | """A Client bound to an app for testing. 55 | 56 | This should be used to make requests and receive responses from 57 | the app for testing purposes. 58 | """ 59 | 60 | __test__ = False # prevent pytest.PytestCollectionWarning 61 | 62 | def __init__( 63 | self, 64 | application, 65 | use_cookies: bool = True, 66 | timeout: Optional[Union[int, float]] = None, 67 | headers: Optional[Union[dict, CIMultiDict]] = None, 68 | scope: Optional[dict] = None, 69 | ): 70 | self.application = guarantee_single_callable(application) 71 | self.cookie_jar: Optional[SimpleCookie] = ( 72 | SimpleCookie() if use_cookies else None 73 | ) 74 | self.timeout = timeout 75 | self.headers = headers or {} 76 | self._scope = scope or {} 77 | self._lifespan_input_queue: asyncio.Queue[dict] = asyncio.Queue() 78 | self._lifespan_output_queue: asyncio.Queue[dict] = asyncio.Queue() 79 | self._lifespan_task = None # Must keep hard reference to prevent gc 80 | 81 | async def __aenter__(self): 82 | self._lifespan_task = create_monitored_task( 83 | self.application( 84 | {"type": "lifespan", "asgi": {"version": "3.0"}}, 85 | self._lifespan_input_queue.get, 86 | self._lifespan_output_queue.put, 87 | ), 88 | self._lifespan_output_queue.put_nowait, 89 | ) 90 | 91 | await self.send_lifespan("startup") 92 | return self 93 | 94 | async def __aexit__(self, exc_type, exc, tb): 95 | await self.send_lifespan("shutdown") 96 | self._lifespan_task = None 97 | 98 | async def send_lifespan(self, action): 99 | await self._lifespan_input_queue.put({"type": f"lifespan.{action}"}) 100 | message = await receive(self._lifespan_output_queue, timeout=self.timeout) 101 | 102 | if isinstance(message, Message): 103 | raise Exception(f"{message.event} - {message.reason} - {message.task}") 104 | 105 | if message["type"] == f"lifespan.{action}.complete": 106 | pass 107 | elif message["type"] == f"lifespan.{action}.failed": 108 | raise Exception(message) 109 | 110 | def websocket_connect(self, *args: Any, **kwargs: Any) -> WebSocketSession: 111 | return WebSocketSession(self, *args, **kwargs) 112 | 113 | async def open( 114 | self, 115 | path: str, 116 | *, 117 | method: str = "GET", 118 | headers: Optional[Union[dict, CIMultiDict]] = None, 119 | data: Any = None, 120 | form: Optional[dict] = None, 121 | files: Optional[dict] = None, 122 | query_string: Optional[dict] = None, 123 | json: Any = sentinel, 124 | scheme: str = "http", 125 | cookies: Optional[dict] = None, 126 | stream: bool = False, 127 | allow_redirects: bool = True, 128 | ): 129 | """Open a request to the app associated with this client. 130 | 131 | Arguments: 132 | path 133 | The path to request. If the query_string argument is not 134 | defined this argument will be partitioned on a '?' with the 135 | following part being considered the query_string. 136 | 137 | method 138 | The method to make the request with, defaults to 'GET'. 139 | 140 | headers 141 | Headers to include in the request. 142 | 143 | data 144 | Raw data to send in the request body or async generator 145 | 146 | form 147 | Data to send form encoded in the request body. 148 | 149 | files 150 | Data to send as multipart in the request body. 151 | 152 | query_string 153 | To send as a dictionary, alternatively the query_string can be 154 | determined from the path. 155 | 156 | json 157 | Data to send json encoded in the request body. 158 | 159 | scheme 160 | The scheme to use in the request, default http. 161 | 162 | cookies 163 | Cookies to send in the request instead of cookies in 164 | TestClient.cookie_jar 165 | 166 | stream 167 | Return the response in streaming instead of buffering 168 | 169 | allow_redirects 170 | If set to True follows redirects 171 | 172 | Returns: 173 | The response from the app handling the request. 174 | """ 175 | input_queue: asyncio.Queue[dict] = asyncio.Queue() 176 | output_queue: asyncio.Queue[dict] = asyncio.Queue() 177 | 178 | if not headers: 179 | headers = {} 180 | merged_headers = self.headers.copy() 181 | merged_headers.update(headers) 182 | headers, path, query_string_bytes = make_test_headers_path_and_query_string( 183 | self.application, path, merged_headers, query_string 184 | ) 185 | 186 | if [ 187 | json is not sentinel, 188 | form is not None, 189 | data is not None, 190 | files is not None, 191 | ].count(True) > 1: 192 | raise ValueError( 193 | "Test args 'json', 'form', 'files' and 'data' are mutually exclusive" 194 | ) 195 | 196 | request_data = b"" 197 | 198 | if isinstance(data, str): 199 | request_data = data.encode("utf-8") 200 | elif isinstance(data, bytes): 201 | request_data = data 202 | 203 | if json is not sentinel: 204 | request_data = dumps(json).encode("utf-8") 205 | headers["Content-Type"] = "application/json" 206 | 207 | if form is not None: 208 | request_data = urlencode(form).encode("utf-8") 209 | headers["Content-Type"] = "application/x-www-form-urlencoded" 210 | 211 | if files is not None: 212 | request_data, content_type = encode_multipart_formdata(files) 213 | headers["Content-Type"] = content_type 214 | 215 | if request_data and headers.get("Content-Length") is None: 216 | headers["Content-Length"] = str(len(request_data)) 217 | 218 | if cookies is None: # use TestClient.cookie_jar 219 | cookie_jar = self.cookie_jar 220 | else: 221 | cookie_jar = SimpleCookie(cookies) 222 | 223 | if cookie_jar: 224 | cookie_data = [] 225 | for cookie_name, cookie in cookie_jar.items(): 226 | cookie_data.append(f"{cookie_name}={cookie.value}") 227 | if cookie_data: 228 | headers.add("Cookie", "; ".join(cookie_data)) 229 | 230 | scope = { 231 | "type": "http", 232 | "http_version": "1.1", 233 | "asgi": {"version": "3.0"}, 234 | "method": method, 235 | "scheme": scheme, 236 | "path": path, 237 | "query_string": query_string_bytes, 238 | "root_path": "", 239 | "headers": flatten_headers(headers), 240 | } 241 | scope.update(self._scope) 242 | 243 | running_task = create_monitored_task( 244 | self.application(scope, input_queue.get, output_queue.put), 245 | output_queue.put_nowait, 246 | ) 247 | 248 | send = input_queue.put_nowait 249 | receive_or_fail = partial(receive, output_queue, timeout=self.timeout) 250 | 251 | # Send request 252 | if inspect.isasyncgen(data): 253 | async for is_last, body in is_last_one(data): 254 | send({"type": "http.request", "body": body, "more_body": not is_last}) 255 | else: 256 | send({"type": "http.request", "body": request_data}) 257 | 258 | response = Response(stream, receive_or_fail, send) 259 | 260 | # Receive response start 261 | message = await self.wait_response(receive_or_fail, "http.response.start") 262 | response.status_code = message["status"] 263 | response.headers = CIMultiDict( 264 | [(k.decode("utf8"), v.decode("utf8")) for k, v in message["headers"]] 265 | ) 266 | 267 | # Receive initial response body 268 | message = await self.wait_response(receive_or_fail, "http.response.body") 269 | response.raw.write(message["body"]) 270 | response._more_body = message.get("more_body", False) 271 | 272 | # Consume the remaining response if not in stream 273 | if not stream: 274 | bytes_io = BytesRW() 275 | bytes_io.write(response.raw.read()) 276 | async for chunk in response: 277 | bytes_io.write(chunk) 278 | response.raw = bytes_io 279 | response._content = bytes_io.read() 280 | response._content_consumed = True 281 | 282 | if cookie_jar is not None: 283 | cookies = SimpleCookie() 284 | for c in response.headers.getall("Set-Cookie", ""): 285 | cookies.load(c) 286 | response.cookies = requests.cookies.RequestsCookieJar() 287 | response.cookies.update(cookies) 288 | cookie_jar.update(cookies) 289 | 290 | # We need to keep a hard reference to running task to prevent gc 291 | assert running_task # Useless assert to prevent unused variable warnings 292 | 293 | if allow_redirects and response.is_redirect: 294 | path = to_relative_path(response.headers["location"]) 295 | return await self.get(path) 296 | else: 297 | return response 298 | 299 | async def wait_response(self, receive_or_fail, type_): 300 | message = await receive_or_fail() 301 | if not isinstance(message, dict): 302 | raise Exception(f"Unexpected message {message}") 303 | if message["type"] != type_: 304 | raise Exception(f"Excpected message type '{type_}'. " f"Found {message}") 305 | return message 306 | 307 | async def delete(self, *args: Any, **kwargs: Any) -> Response: 308 | """Make a DELETE request.""" 309 | return await self.open(*args, method="DELETE", **kwargs) 310 | 311 | async def get(self, *args: Any, **kwargs: Any) -> Response: 312 | """Make a GET request.""" 313 | return await self.open(*args, method="GET", **kwargs) 314 | 315 | async def head(self, *args: Any, **kwargs: Any) -> Response: 316 | """Make a HEAD request.""" 317 | return await self.open(*args, method="HEAD", **kwargs) 318 | 319 | async def options(self, *args: Any, **kwargs: Any) -> Response: 320 | """Make a OPTIONS request.""" 321 | return await self.open(*args, method="OPTIONS", **kwargs) 322 | 323 | async def patch(self, *args: Any, **kwargs: Any) -> Response: 324 | """Make a PATCH request.""" 325 | return await self.open(*args, method="PATCH", **kwargs) 326 | 327 | async def post(self, *args: Any, **kwargs: Any) -> Response: 328 | """Make a POST request.""" 329 | return await self.open(*args, method="POST", **kwargs) 330 | 331 | async def put(self, *args: Any, **kwargs: Any) -> Response: 332 | """Make a PUT request.""" 333 | return await self.open(*args, method="PUT", **kwargs) 334 | 335 | async def trace(self, *args: Any, **kwargs: Any) -> Response: 336 | """Make a TRACE request.""" 337 | return await self.open(*args, method="TRACE", **kwargs) 338 | -------------------------------------------------------------------------------- /async_asgi_testclient/tests/test_testing.py: -------------------------------------------------------------------------------- 1 | from async_asgi_testclient import TestClient 2 | from http.cookies import SimpleCookie 3 | from json import dumps 4 | from sys import version_info as PY_VER # noqa 5 | 6 | import asyncio 7 | import io 8 | import pytest 9 | 10 | 11 | @pytest.fixture 12 | def quart_app(): 13 | from quart import Quart, jsonify, request, redirect, Response 14 | 15 | app = Quart(__name__) 16 | 17 | @app.before_serving 18 | async def startup(): 19 | app.custom_init_complete = True 20 | 21 | @app.route("/") 22 | async def root(): 23 | return "full response" 24 | 25 | @app.route("/json") 26 | async def json(): 27 | return jsonify({"hello": "world"}) 28 | 29 | @app.route("/header") 30 | async def headers(): 31 | return "", 204, {"X-Header": "Value"} 32 | 33 | @app.route("/form", methods=["POST"]) 34 | async def form(): 35 | form = await request.form 36 | return jsonify(dict(form)) 37 | 38 | @app.route("/check_startup_works") 39 | async def check_startup_works(): 40 | if app.custom_init_complete: 41 | return "yes" 42 | return "no" 43 | 44 | @app.route("/set_cookies", methods=["POST"]) 45 | async def set_cookie(): 46 | r = Response("") 47 | r.set_cookie(key="my-cookie", value="1234") 48 | r.set_cookie(key="my-cookie-2", value="5678") 49 | return r 50 | 51 | @app.route("/clear_cookie", methods=["POST"]) 52 | async def clear_cookie(): 53 | r = Response("") 54 | r.delete_cookie(key="my-cookie") 55 | r.delete_cookie(key="my-cookie-2") 56 | return r 57 | 58 | @app.route("/cookies") 59 | async def get_cookie(): 60 | cookies = request.cookies 61 | return jsonify(cookies) 62 | 63 | @app.route("/cookies-raw") 64 | async def get_cookie_raw(): 65 | return Response(request.headers["Cookie"]) 66 | 67 | @app.route("/stuck") 68 | async def stuck(): 69 | await asyncio.sleep(60) 70 | 71 | @app.route("/redir") 72 | async def redir(): 73 | return redirect(request.args["path"]) 74 | 75 | @app.route("/echoheaders") 76 | async def echoheaders(): 77 | return "", 200, request.headers 78 | 79 | @app.route("/test_query") 80 | async def test_query(): 81 | return Response(request.query_string) 82 | 83 | yield app 84 | 85 | 86 | @pytest.fixture 87 | def starlette_app(): 88 | from starlette.applications import Starlette 89 | from starlette.endpoints import WebSocketEndpoint 90 | from starlette.responses import JSONResponse, Response 91 | 92 | app = Starlette() 93 | 94 | @app.on_event("startup") 95 | async def startup(): 96 | app.custom_init_complete = True 97 | 98 | @app.websocket_route("/ws") 99 | class Echo(WebSocketEndpoint): 100 | 101 | encoding = "text" 102 | 103 | async def on_receive(self, websocket, data): 104 | if data == "cookies": 105 | await websocket.send_text(dumps(websocket.cookies)) 106 | elif data == "url": 107 | await websocket.send_text(str(websocket.url)) 108 | else: 109 | await websocket.send_text(f"Message text was: {data}") 110 | 111 | @app.route("/") 112 | async def homepage(request): 113 | return Response("full response") 114 | 115 | @app.route("/json") 116 | async def json(request): 117 | return JSONResponse({"hello": "world"}) 118 | 119 | @app.route("/json-redirect") 120 | async def json_redirect(request): 121 | return Response(status_code=302, headers={"Location": "http://localhost/json"}) 122 | 123 | @app.route("/header") 124 | async def headers(request): 125 | return Response(status_code=204, headers={"X-Header": "Value"}) 126 | 127 | @app.route("/form", methods=["POST"]) 128 | async def form(request): 129 | form = await request.form() 130 | return JSONResponse(form._dict) 131 | 132 | @app.route("/multipart", methods=["POST"]) 133 | async def multipart(request): 134 | form = await request.form() 135 | return JSONResponse(form._dict) 136 | 137 | @app.route("/multipart_bin", methods=["POST"]) 138 | async def multipart_bin(request): 139 | form = await request.form() 140 | assert form["a"] == "\x89\x01\x02\x03\x04" 141 | 142 | file_b = form["b"] 143 | assert file_b.filename == "b.bin" 144 | assert await file_b.read() == b"\x89\x01\x02\x03\x04" 145 | 146 | file_c = form["c"] 147 | assert file_c.filename == "c.txt" 148 | assert file_c.content_type == "text/plain" 149 | assert await file_c.read() == b"01234" 150 | 151 | return Response(status_code=200) 152 | 153 | @app.route("/check_startup_works") 154 | async def check_startup_works(request): 155 | if app.custom_init_complete: 156 | return Response("yes") 157 | return Response("no") 158 | 159 | @app.route("/set_cookies", methods=["POST"]) 160 | async def set_cookie(request): 161 | r = Response("") 162 | r.set_cookie("my-cookie", "1234") 163 | r.set_cookie("my-cookie-2", "5678") 164 | return r 165 | 166 | @app.route("/clear_cookie", methods=["POST"]) 167 | async def clear_cookie(request): 168 | r = Response("") 169 | r.delete_cookie("my-cookie") 170 | r.delete_cookie("my-cookie-2") 171 | return r 172 | 173 | @app.route("/cookies") 174 | async def get_cookie(request): 175 | cookies = request.cookies 176 | return JSONResponse(cookies) 177 | 178 | @app.route("/cookies-raw") 179 | async def get_cookie_raw(request): 180 | return Response(request.headers["Cookie"]) 181 | 182 | @app.route("/stuck") 183 | async def stuck(request): 184 | await asyncio.sleep(60) 185 | 186 | @app.route("/echoheaders") 187 | async def echoheaders(request): 188 | return Response(headers=request.headers) 189 | 190 | @app.route("/test_query") 191 | async def test_query(request): 192 | return Response(str(request.query_params)) 193 | 194 | yield app 195 | 196 | 197 | @pytest.mark.asyncio 198 | @pytest.mark.skipif("PY_VER < (3,7)") 199 | async def test_TestClient_Quart(quart_app): 200 | async with TestClient(quart_app) as client: 201 | resp = await client.get("/") 202 | assert resp.status_code == 200 203 | assert resp.text == "full response" 204 | 205 | resp = await client.get("/json") 206 | assert resp.status_code == 200 207 | assert resp.json() == {"hello": "world"} 208 | 209 | resp = await client.get("/header") 210 | assert resp.status_code == 204 211 | assert resp.headers["X-Header"] == "Value" 212 | assert resp.text == "" 213 | 214 | resp = await client.post("/form", form=[("user", "root"), ("pswd", 1234)]) 215 | assert resp.json() == {"pswd": "1234", "user": "root"} 216 | 217 | resp = await client.get("/check_startup_works") 218 | assert resp.status_code == 200 219 | assert resp.text == "yes" 220 | 221 | resp = await client.post("/set_cookies") 222 | assert resp.status_code == 200 223 | assert resp.cookies.get_dict() == {"my-cookie": "1234", "my-cookie-2": "5678"} 224 | 225 | resp = await client.get("/cookies") 226 | assert resp.status_code == 200 227 | assert resp.json() == {"my-cookie": "1234", "my-cookie-2": "5678"} 228 | 229 | resp = await client.get("/cookies-raw") 230 | assert resp.status_code == 200 231 | assert resp.text == "my-cookie=1234; my-cookie-2=5678" 232 | 233 | resp = await client.post("/clear_cookie") 234 | assert resp.cookies.get_dict() == {"my-cookie": "", "my-cookie-2": ""} 235 | assert resp.status_code == 200 236 | 237 | client.headers = {"Authorization": "mytoken"} 238 | resp = await client.get("/echoheaders", headers={"this should be": "merged"}) 239 | assert resp.status_code == 200 240 | assert resp.headers.get("authorization") == "mytoken" 241 | assert resp.headers.get("this should be") == "merged" 242 | # Reset client headers for next tests 243 | client.headers = {} 244 | 245 | resp = await client.get("/echoheaders") 246 | assert resp.status_code == 200 247 | assert "Authorization" not in resp.headers 248 | 249 | resp = await client.get("/test_query", query_string={"a": 1, "b": "ç"}) 250 | assert resp.status_code == 200 251 | assert resp.text == "a=1&b=%C3%A7" 252 | 253 | resp = await client.get("/test_query?a=1&b=ç") 254 | assert resp.status_code == 200 255 | assert resp.text == "a=1&b=%C3%A7" 256 | 257 | 258 | @pytest.mark.asyncio 259 | async def test_TestClient_Starlette(starlette_app): 260 | async with TestClient(starlette_app) as client: 261 | resp = await client.get("/") 262 | assert resp.status_code == 200 263 | assert resp.text == "full response" 264 | 265 | resp = await client.get("/json") 266 | assert resp.status_code == 200 267 | assert resp.json() == {"hello": "world"} 268 | 269 | resp = await client.get("/header") 270 | assert resp.status_code == 204 271 | assert resp.headers["X-Header"] == "Value" 272 | assert resp.text == "" 273 | 274 | resp = await client.post("/form", form=[("user", "root"), ("pswd", 1234)]) 275 | assert resp.json() == {"pswd": "1234", "user": "root"} 276 | 277 | file_like = io.StringIO("abcd") 278 | resp = await client.post("/multipart", files={"a": "abcd", "b": (file_like,)}) 279 | assert resp.json() == {"a": "abcd", "b": "abcd"} 280 | 281 | file_like_1 = io.BytesIO(bytes([0x89, 1, 2, 3, 4])) 282 | file_like_2 = io.BytesIO(bytes([0x89, 1, 2, 3, 4])) 283 | file_like_3 = io.BytesIO(bytes("01234", "ascii")) 284 | resp = await client.post( 285 | "/multipart_bin", 286 | files={ 287 | "a": (file_like_1,), 288 | "b": ("b.bin", file_like_2), 289 | "c": ("c.txt", file_like_3, "text/plain"), 290 | }, 291 | ) 292 | assert resp.status_code == 200 293 | 294 | resp = await client.get("/check_startup_works") 295 | assert resp.status_code == 200 296 | assert resp.text == "yes" 297 | 298 | resp = await client.post("/set_cookies") 299 | assert resp.status_code == 200 300 | assert resp.cookies.get_dict() == {"my-cookie": "1234", "my-cookie-2": "5678"} 301 | 302 | resp = await client.get("/cookies") 303 | assert resp.status_code == 200 304 | assert resp.json() == {"my-cookie": "1234", "my-cookie-2": "5678"} 305 | 306 | resp = await client.get("/cookies-raw") 307 | assert resp.status_code == 200 308 | assert resp.text == "my-cookie=1234; my-cookie-2=5678" 309 | 310 | resp = await client.post("/clear_cookie") 311 | assert resp.cookies.get_dict() == {"my-cookie": "", "my-cookie-2": ""} 312 | assert resp.status_code == 200 313 | 314 | client.headers = {"Authorization": "mytoken"} 315 | resp = await client.get("/echoheaders", headers={"this should be": "merged"}) 316 | assert resp.status_code == 200 317 | assert resp.headers.get("authorization") == "mytoken" 318 | assert resp.headers.get("this should be") == "merged" 319 | # Reset client headers for next tests 320 | client.headers = {} 321 | 322 | resp = await client.get("/echoheaders") 323 | assert resp.status_code == 200 324 | assert "authorization" not in resp.headers 325 | 326 | resp = await client.get("/test_query", query_string={"a": 1, "b": "ç"}) 327 | assert resp.status_code == 200 328 | assert resp.text == "a=1&b=%C3%A7" 329 | 330 | resp = await client.get("/test_query?a=1&b=ç") 331 | assert resp.status_code == 200 332 | assert resp.text == "a=1&b=%C3%A7" 333 | 334 | 335 | @pytest.mark.asyncio 336 | @pytest.mark.skipif("PY_VER < (3,7)") 337 | async def test_set_cookie_in_request(quart_app): 338 | async with TestClient(quart_app) as client: 339 | resp = await client.post("/set_cookies") 340 | assert resp.status_code == 200 341 | assert resp.cookies.get_dict() == {"my-cookie": "1234", "my-cookie-2": "5678"} 342 | 343 | # Uses 'custom_cookie_jar' instead of 'client.cookie_jar' 344 | custom_cookie_jar = {"my-cookie": "6666"} 345 | resp = await client.get("/cookies", cookies=custom_cookie_jar) 346 | assert resp.status_code == 200 347 | assert resp.json() == custom_cookie_jar 348 | 349 | # Uses 'client.cookie_jar' again 350 | resp = await client.get("/cookies") 351 | assert resp.status_code == 200 352 | assert resp.json() == {"my-cookie": "1234", "my-cookie-2": "5678"} 353 | 354 | resp = await client.get("/cookies-raw") 355 | assert resp.status_code == 200 356 | assert resp.text == "my-cookie=1234; my-cookie-2=5678" 357 | 358 | 359 | @pytest.mark.asyncio 360 | @pytest.mark.skipif("PY_VER < (3,7)") 361 | async def test_disable_cookies_in_client(quart_app): 362 | async with TestClient(quart_app, use_cookies=False) as client: 363 | resp = await client.post( 364 | "/set_cookies" 365 | ) # responds with 'set-cookie: my-cookie=1234' but cookies are disabled 366 | assert resp.status_code == 200 367 | assert resp.cookies.get_dict() == {} 368 | 369 | 370 | @pytest.mark.asyncio 371 | async def test_exception_starlette(starlette_app): 372 | async def view_raiser(request): 373 | assert 1 == 0 374 | 375 | starlette_app.add_route("/raiser", view_raiser) 376 | 377 | async with TestClient(starlette_app) as client: 378 | with pytest.raises(AssertionError): 379 | await client.get("/raiser") 380 | 381 | 382 | @pytest.mark.asyncio 383 | @pytest.mark.skipif("PY_VER < (3,7)") 384 | async def test_exception_quart(quart_app): 385 | @quart_app.route("/raiser") 386 | async def error(): 387 | assert 1 == 0 388 | 389 | async with TestClient(quart_app) as client: 390 | resp = await client.get("/raiser") 391 | # Quart suppresses all type of exceptions 392 | assert resp.status_code == 500 393 | 394 | 395 | @pytest.mark.asyncio 396 | @pytest.mark.skipif("PY_VER < (3,7)") 397 | async def test_quart_endpoint_not_responding(quart_app): 398 | async with TestClient(quart_app, timeout=0.1) as client: 399 | with pytest.raises(asyncio.TimeoutError): 400 | await client.get("/stuck") 401 | 402 | 403 | @pytest.mark.asyncio 404 | async def test_startlette_endpoint_not_responding(starlette_app): 405 | async with TestClient(starlette_app, timeout=0.1) as client: 406 | with pytest.raises(asyncio.TimeoutError): 407 | await client.get("/stuck") 408 | 409 | 410 | @pytest.mark.asyncio 411 | async def test_ws_endpoint(starlette_app): 412 | async with TestClient(starlette_app, timeout=0.1) as client: 413 | async with client.websocket_connect("/ws") as ws: 414 | await ws.send_text("hi!") 415 | msg = await ws.receive_text() 416 | assert msg == "Message text was: hi!" 417 | 418 | 419 | @pytest.mark.asyncio 420 | async def test_ws_endpoint_cookies(starlette_app): 421 | async with TestClient(starlette_app, timeout=0.1) as client: 422 | async with client.websocket_connect("/ws", cookies={"session": "abc"}) as ws: 423 | await ws.send_text("cookies") 424 | msg = await ws.receive_json() 425 | assert msg == {"session": "abc"} 426 | 427 | 428 | @pytest.mark.asyncio 429 | async def test_ws_connect_inherits_test_client_cookies(starlette_app): 430 | client = TestClient(starlette_app, use_cookies=True, timeout=0.1) 431 | client.cookie_jar = SimpleCookie({"session": "abc"}) 432 | async with client: 433 | async with client.websocket_connect("/ws") as ws: 434 | await ws.send_text("cookies") 435 | msg = await ws.receive_text() 436 | assert msg == '{"session": "abc"}' 437 | 438 | 439 | @pytest.mark.asyncio 440 | async def test_ws_connect_default_scheme(starlette_app): 441 | async with TestClient(starlette_app, timeout=0.1) as client: 442 | async with client.websocket_connect("/ws") as ws: 443 | await ws.send_text("url") 444 | msg = await ws.receive_text() 445 | assert msg.startswith("ws://") 446 | 447 | 448 | @pytest.mark.asyncio 449 | async def test_ws_connect_custom_scheme(starlette_app): 450 | async with TestClient(starlette_app, timeout=0.1) as client: 451 | async with client.websocket_connect("/ws", scheme="wss") as ws: 452 | await ws.send_text("url") 453 | msg = await ws.receive_text() 454 | assert msg.startswith("wss://") 455 | 456 | 457 | @pytest.mark.asyncio 458 | async def test_request_stream(starlette_app): 459 | from starlette.responses import StreamingResponse 460 | 461 | async def up_stream(request): 462 | async def gen(): 463 | async for chunk in request.stream(): 464 | yield chunk 465 | 466 | return StreamingResponse(gen()) 467 | 468 | starlette_app.add_route("/upload_stream", up_stream, methods=["POST"]) 469 | 470 | async with TestClient(starlette_app) as client: 471 | 472 | async def stream_gen(): 473 | chunk = b"X" * 1024 474 | for _ in range(3): 475 | yield chunk 476 | 477 | resp = await client.post("/upload_stream", data=stream_gen(), stream=True) 478 | assert resp.status_code == 200 479 | chunks = [c async for c in resp.iter_content(1024)] 480 | assert len(b"".join(chunks)) == 3 * 1024 481 | 482 | 483 | @pytest.mark.asyncio 484 | async def test_upload_stream_from_download_stream(starlette_app): 485 | from starlette.responses import StreamingResponse 486 | 487 | async def down_stream(request): 488 | def gen(): 489 | for _ in range(3): 490 | yield b"X" * 1024 491 | 492 | return StreamingResponse(gen()) 493 | 494 | async def up_stream(request): 495 | async def gen(): 496 | async for chunk in request.stream(): 497 | yield chunk 498 | 499 | return StreamingResponse(gen()) 500 | 501 | starlette_app.add_route("/download_stream", down_stream, methods=["GET"]) 502 | starlette_app.add_route("/upload_stream", up_stream, methods=["POST"]) 503 | 504 | async with TestClient(starlette_app) as client: 505 | resp = await client.get("/download_stream", stream=True) 506 | assert resp.status_code == 200 507 | resp2 = await client.post( 508 | "/upload_stream", data=resp.iter_content(1024), stream=True 509 | ) 510 | chunks = [c async for c in resp2.iter_content(1024)] 511 | assert len(b"".join(chunks)) == 3 * 1024 512 | 513 | 514 | @pytest.mark.asyncio 515 | @pytest.mark.skipif("PY_VER < (3,7)") 516 | async def test_response_stream(quart_app): 517 | @quart_app.route("/download_stream") 518 | async def down_stream(): 519 | async def async_generator(): 520 | chunk = b"X" * 1024 521 | for _ in range(3): 522 | yield chunk 523 | 524 | return async_generator() 525 | 526 | async with TestClient(quart_app) as client: 527 | resp = await client.get("/download_stream", stream=True) 528 | assert resp.status_code == 200 529 | chunks = [c async for c in resp.iter_content(1024)] 530 | assert len(b"".join(chunks)) == 3 * 1024 531 | 532 | 533 | @pytest.mark.asyncio 534 | async def test_response_stream_crashes(starlette_app): 535 | from starlette.responses import StreamingResponse 536 | 537 | @starlette_app.route("/download_stream_crashes") 538 | async def stream_crashes(request): 539 | def gen(): 540 | yield b"X" * 1024 541 | yield b"X" * 1024 542 | yield b"X" * 1024 543 | raise Exception("Stream crashed!") 544 | 545 | return StreamingResponse(gen()) 546 | 547 | async with TestClient(starlette_app) as client: 548 | resp = await client.get("/download_stream_crashes", stream=True) 549 | assert resp.status_code == 200 550 | 551 | with pytest.raises(Exception): 552 | async for _ in resp.iter_content(1024): 553 | pass 554 | 555 | 556 | @pytest.mark.asyncio 557 | async def test_absolute_redirect(starlette_app): 558 | async with TestClient(starlette_app) as client: 559 | resp = await client.get("/json-redirect") 560 | assert resp.status_code == 200 561 | assert resp.json() == {"hello": "world"} 562 | 563 | 564 | @pytest.mark.asyncio 565 | @pytest.mark.skipif("PY_VER < (3,7)") 566 | async def test_follow_redirects(quart_app): 567 | async with TestClient(quart_app) as client: 568 | resp = await client.get("/redir?path=/") 569 | assert resp.status_code == 200 570 | assert resp.text == "full response" 571 | 572 | 573 | @pytest.mark.asyncio 574 | @pytest.mark.skipif("PY_VER < (3,7)") 575 | async def test_no_follow_redirects(quart_app): 576 | async with TestClient(quart_app) as client: 577 | resp = await client.get("/redir?path=/", allow_redirects=False) 578 | assert resp.status_code == 302 579 | -------------------------------------------------------------------------------- /async_asgi_testclient/utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from multidict import CIMultiDict 3 | from typing import Any 4 | from typing import Dict 5 | from typing import List 6 | from typing import Optional 7 | from typing import Tuple 8 | from typing import Union 9 | from urllib.parse import quote 10 | from urllib.parse import urlencode 11 | 12 | import asyncio 13 | import re 14 | import sys 15 | 16 | 17 | def flatten_headers(headers: Union[Dict, CIMultiDict]) -> List[Tuple]: 18 | return [(bytes(k.lower(), "utf8"), bytes(v, "utf8")) for k, v in headers.items()] 19 | 20 | 21 | def make_test_headers_path_and_query_string( 22 | app: Any, 23 | path: str, 24 | headers: Optional[Union[dict, CIMultiDict]] = None, 25 | query_string: Optional[dict] = None, 26 | ) -> Tuple[CIMultiDict, str, bytes]: 27 | """Make the headers and path with defaults for testing. 28 | 29 | Arguments: 30 | app: The application to test against. 31 | path: The path to request. If the query_string argument is not 32 | defined this argument will be partitioned on a '?' with 33 | the following part being considered the query_string. 34 | headers: Initial headers to send. 35 | query_string: To send as a dictionary, alternatively the 36 | query_string can be determined from the path. 37 | """ 38 | if headers is None: 39 | headers = CIMultiDict() 40 | elif isinstance(headers, CIMultiDict): 41 | headers = headers 42 | elif headers is not None: 43 | headers = CIMultiDict(headers) 44 | headers.setdefault("Remote-Addr", "127.0.0.1") 45 | headers.setdefault("User-Agent", "ASGI-Test-Client") 46 | headers.setdefault("host", "localhost") 47 | 48 | if "?" in path and query_string is not None: 49 | raise ValueError("Query string is defined in the path and as an argument") 50 | if query_string is None: 51 | path, _, query_string_raw = path.partition("?") 52 | query_string_raw = quote(query_string_raw, safe="&=") 53 | else: 54 | query_string_raw = urlencode(query_string, doseq=True) 55 | query_string_bytes = query_string_raw.encode("ascii") 56 | return headers, path, query_string_bytes 57 | 58 | 59 | def to_relative_path(path: str): 60 | if path.startswith("/"): 61 | return path 62 | return re.sub(r"^[a-zA-Z]+://[^/]+/", "/", path) 63 | 64 | 65 | async def is_last_one(gen): 66 | prev_el = None 67 | async for el in gen: 68 | prev_el = el 69 | async for el in gen: 70 | yield (False, prev_el) 71 | prev_el = el 72 | yield (True, prev_el) 73 | 74 | 75 | class Message: 76 | def __init__(self, event, reason, task): 77 | self.event: str = event 78 | self.reason: str = reason 79 | self.task: asyncio.Task = task 80 | 81 | 82 | def create_monitored_task(coro, send): 83 | future = asyncio.ensure_future(coro) 84 | future.add_done_callback(partial(_callback, send)) 85 | return future 86 | 87 | 88 | async def receive(ch, timeout=None): 89 | fut = set_timeout(ch, timeout) 90 | msg = await ch.get() 91 | if not fut.cancelled(): 92 | fut.cancel() 93 | if isinstance(msg, Message): 94 | if msg.event == "err": 95 | raise msg.reason 96 | return msg 97 | 98 | 99 | def _callback(send, fut): 100 | try: 101 | fut.result() 102 | except asyncio.CancelledError: 103 | send(Message("exit", "killed", fut)) 104 | raise 105 | except Exception as e: 106 | send(Message("err", e, fut)) 107 | else: 108 | send(Message("exit", "normal", fut)) 109 | 110 | 111 | async def _send_after(timeout, queue, msg): 112 | if timeout is None: 113 | return 114 | await asyncio.sleep(timeout) 115 | await queue.put(msg) 116 | 117 | 118 | def set_timeout(queue, timeout): 119 | msg = Message("err", asyncio.TimeoutError, current_task()) 120 | return asyncio.ensure_future(_send_after(timeout, queue, msg)) 121 | 122 | 123 | def current_task(): 124 | PY37 = sys.version_info >= (3, 7) 125 | if PY37: 126 | return asyncio.current_task() 127 | else: 128 | return asyncio.Task.current_task() 129 | -------------------------------------------------------------------------------- /async_asgi_testclient/websocket.py: -------------------------------------------------------------------------------- 1 | from async_asgi_testclient.utils import create_monitored_task 2 | from async_asgi_testclient.utils import flatten_headers 3 | from async_asgi_testclient.utils import make_test_headers_path_and_query_string 4 | from async_asgi_testclient.utils import Message 5 | from async_asgi_testclient.utils import receive 6 | from http.cookies import SimpleCookie 7 | from typing import Dict 8 | from typing import Optional 9 | 10 | import asyncio 11 | import json 12 | 13 | 14 | class WebSocketSession: 15 | def __init__( 16 | self, 17 | testclient, 18 | path, 19 | headers: Optional[Dict] = None, 20 | cookies: Optional[Dict] = None, 21 | scheme: str = "ws", 22 | ): 23 | self.testclient = testclient 24 | self.path = path 25 | self.headers = headers or {} 26 | self.cookies = cookies 27 | self.scheme = scheme 28 | self.input_queue: asyncio.Queue[dict] = asyncio.Queue() 29 | self.output_queue: asyncio.Queue[dict] = asyncio.Queue() 30 | self._app_task = None # Necessary to keep a hard reference to running task 31 | 32 | async def __aenter__(self): 33 | await self.connect() 34 | return self 35 | 36 | async def __aexit__(self, exc_type, exc, tb): 37 | await self.close() 38 | self._app_task = None 39 | 40 | async def close(self, code: int = 1000): 41 | await self._send({"type": "websocket.disconnect", "code": code}) 42 | 43 | async def send_str(self, data: str) -> None: 44 | await self.send_text(data) 45 | 46 | async def send_text(self, data: str) -> None: 47 | await self._send({"type": "websocket.receive", "text": data}) 48 | 49 | async def send_bytes(self, data: bytes) -> None: 50 | await self._send({"type": "websocket.receive", "bytes": data}) 51 | 52 | async def send_json(self, data, mode: str = "text") -> None: 53 | assert mode in ["text", "binary"] 54 | text = json.dumps(data) 55 | if mode == "text": 56 | await self._send({"type": "websocket.receive", "text": text}) 57 | else: 58 | await self._send( 59 | {"type": "websocket.receive", "bytes": text.encode("utf-8")} 60 | ) 61 | 62 | async def _send(self, data): 63 | self.input_queue.put_nowait(data) 64 | 65 | async def receive_text(self) -> str: 66 | message = await self._receive() 67 | if message["type"] != "websocket.send": 68 | raise Exception(message) 69 | return message["text"] 70 | 71 | async def receive_bytes(self) -> bytes: 72 | message = await self._receive() 73 | if message["type"] != "websocket.send": 74 | raise Exception(message) 75 | return message["bytes"] 76 | 77 | async def receive_json(self): 78 | message = await self._receive() 79 | if message["type"] != "websocket.send": 80 | raise Exception(message) 81 | if "text" in message: 82 | data = message["text"] 83 | elif "bytes" in message: 84 | data = message["bytes"] 85 | else: 86 | raise Exception(message) 87 | return json.loads(data) 88 | 89 | async def _receive(self): 90 | return await receive(self.output_queue) 91 | 92 | def __aiter__(self): 93 | return self 94 | 95 | async def __anext__(self): 96 | msg = await self._receive() 97 | if isinstance(msg, Message): 98 | if msg.event == "exit": 99 | raise StopAsyncIteration(msg) 100 | return msg 101 | 102 | async def connect(self): 103 | tc = self.testclient 104 | app = tc.application 105 | headers, path, query_string_bytes = make_test_headers_path_and_query_string( 106 | app, self.path, self.headers 107 | ) 108 | 109 | if self.cookies is None: # use TestClient.cookie_jar 110 | cookie_jar = tc.cookie_jar 111 | else: 112 | cookie_jar = SimpleCookie(self.cookies) 113 | 114 | if cookie_jar and cookie_jar.output(header=""): 115 | headers.add("Cookie", cookie_jar.output(header="")) 116 | 117 | scope = { 118 | "type": "websocket", 119 | "headers": flatten_headers(headers), 120 | "path": path, 121 | "query_string": query_string_bytes, 122 | "root_path": "", 123 | "scheme": self.scheme, 124 | "subprotocols": [], 125 | } 126 | 127 | self._app_task = create_monitored_task( 128 | app(scope, self.input_queue.get, self.output_queue.put), 129 | self.output_queue.put_nowait, 130 | ) 131 | 132 | await self._send({"type": "websocket.connect"}) 133 | msg = await self._receive() 134 | assert msg["type"] == "websocket.accept" 135 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | multidict>=4.0,<7.0 2 | requests~=2.21 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | E302 4 | W292 5 | W391 6 | W503 7 | E722 8 | E501 9 | max-line-length = 110 10 | exclude = .git,__pycache__,docs/source/conf.py,old,build,dist 11 | 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import find_packages 7 | from setuptools import setup 8 | 9 | with open("README.md", "r") as f: 10 | readme = f.read() 11 | 12 | with open("requirements.txt", "r") as f: 13 | requirements = f.readlines() 14 | 15 | with open("test-requirements.txt", "r") as f: 16 | test_requirements = f.readlines() 17 | 18 | setup( 19 | author="Jordi Masip", 20 | author_email="jordi@masip.cat", 21 | classifiers=[ 22 | "Development Status :: 5 - Production/Stable", 23 | "Intended Audience :: Developers", 24 | "License :: OSI Approved :: MIT License", 25 | "Natural Language :: English", 26 | "Programming Language :: Python :: 3.6", 27 | "Programming Language :: Python :: 3.7", 28 | "Programming Language :: Python :: 3.8", 29 | "Programming Language :: Python :: 3.9", 30 | "Programming Language :: Python :: 3.10", 31 | ], 32 | description="Async client for testing ASGI web applications", 33 | install_requires=requirements, 34 | license="MIT license", 35 | long_description=readme, 36 | long_description_content_type="text/markdown", 37 | include_package_data=True, 38 | name="async-asgi-testclient", 39 | keywords="async asgi testclient", 40 | packages=find_packages(include=["async_asgi_testclient"]), 41 | test_suite="tests", 42 | tests_require=test_requirements, 43 | url="https://github.com/vinissimus/async-asgi-testclient", 44 | version="1.4.11", 45 | zip_safe=False, 46 | ) 47 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | quart==0.17.0; python_version >= '3.7' 2 | starlette==0.12.13 3 | python-multipart==0.0.5 4 | pytest==6.2.5 5 | pytest-asyncio==0.15.0 6 | pytest-cov==2.8.1 7 | black==22.3.0 8 | flake8~=3.8.0 9 | mypy==0.761 10 | isort==4.3.21 11 | --------------------------------------------------------------------------------