├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── RELEASE.md ├── commitlint.config.js ├── docs └── immediate_mode.md ├── misc ├── greeter │ ├── README │ ├── baseline_main.py │ ├── client.py │ ├── client_grpcio.py │ ├── failing_client.py │ ├── failing_server.py │ ├── generated │ │ ├── __init__.py │ │ ├── greeter.proto │ │ ├── greeter_grpc.py │ │ ├── greeter_pb2.py │ │ └── greeter_pb2_grpc.py │ ├── main.py │ ├── main_pingpong.py │ ├── main_pingpong_servicer.py │ ├── test_perf.py │ └── test_perf_grpcio.py ├── h2load │ ├── latency_h2load.sh │ ├── old_logs │ │ ├── pypy6.curio.log.txt │ │ ├── pypy6.trio.0.3.log.txt │ │ ├── pypy6.trio.0.4.log.txt │ │ └── python.log.txt │ ├── request.bin │ └── run_h2load.sh └── pypy_tests │ └── bytearray_perf_test.py ├── requirements_test.txt ├── requirements_test_pypy.txt ├── setup.cfg ├── setup.py ├── src └── purerpc │ ├── __init__.py │ ├── _version.py │ ├── client.py │ ├── grpc_proto.py │ ├── grpc_socket.py │ ├── grpclib │ ├── __init__.py │ ├── buffers.py │ ├── config.py │ ├── connection.py │ ├── events.py │ ├── exceptions.py │ ├── headers.py │ └── status.py │ ├── protoc_plugin │ ├── __init__.py │ └── plugin.py │ ├── rpc.py │ ├── server.py │ ├── test_utils.py │ ├── utils.py │ └── wrappers.py └── tests ├── __init__.py ├── conftest.py ├── data ├── echo.proto ├── greeter.proto └── test_package_names │ ├── A.proto │ ├── B.proto │ └── C.proto ├── exceptiongroups.py ├── test_buffers.py ├── test_echo.py ├── test_errors.py ├── test_greeter.py ├── test_metadata.py ├── test_protoc_plugin.py ├── test_server_http2.py ├── test_status_codes.py └── test_test_utils.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches-ignore: 6 | - "dependabot/**" 7 | pull_request: 8 | 9 | jobs: 10 | build_and_test_pinned: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: ['3.8', '3.9', '3.10', '3.11'] # 'pypy-3.7' 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Setup Python 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | cache: 'pip' 24 | cache-dependency-path: 'requirements_test.txt' 25 | - run: pip install . -r requirements_test.txt 26 | - run: pytest 27 | 28 | build_and_test_latest: 29 | runs-on: ${{ matrix.os }} 30 | strategy: 31 | matrix: 32 | # macos-latest disabled due to unexplained timeout 33 | # https://github.com/python-trio/purerpc/issues/39 34 | os: [ubuntu-latest] # TODO: windows-latest 35 | python-version: ['3.12'] 36 | steps: 37 | - uses: actions/checkout@v3 38 | - name: Setup Python 39 | uses: actions/setup-python@v3 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | - run: pip install .[dev] 43 | - run: pytest 44 | 45 | build_and_test_pypy: 46 | if: ${{ false }} # install for pypy is too slow due to grpc packages 47 | runs-on: ${{ matrix.os }} 48 | strategy: 49 | matrix: 50 | os: [ubuntu-latest] 51 | python-version: ['pypy-3.7'] 52 | steps: 53 | - uses: actions/checkout@v3 54 | - name: Setup Python 55 | uses: actions/setup-python@v3 56 | with: 57 | python-version: ${{ matrix.python-version }} 58 | cache: 'pip' 59 | cache-dependency-path: 'requirements_test_pypy.txt' 60 | - run: pip install . -r requirements_test_pypy.txt 61 | - run: pytest 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Logs 50 | *.log 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | 55 | # PyBuilder 56 | target/ 57 | 58 | # pyenv 59 | .python-version 60 | 61 | # dotenv 62 | .env 63 | 64 | # virtualenv 65 | .venv 66 | venv/ 67 | ENV/ 68 | 69 | # mkdocs documentation 70 | /site 71 | 72 | # mypy 73 | .mypy_cache/ 74 | 75 | # IDEs 76 | .idea 77 | .vscode 78 | 79 | # PyTests 80 | .pytest_cache 81 | 82 | # CLion 83 | /cmake-build-debug/ 84 | /cmake-build-release/ 85 | 86 | .DS_Store 87 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | For the Trio code of conduct, see: 2 | https://trio.readthedocs.io/en/latest/code-of-conduct.html 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # purerpc 2 | 3 | 4 | [![Build Status](https://img.shields.io/github/actions/workflow/status/python-trio/purerpc/ci.yml)](https://github.com/python-trio/purerpc/actions/workflows/ci.yml) 5 | [![PyPI version](https://img.shields.io/pypi/v/purerpc.svg?style=flat)](https://pypi.org/project/purerpc/) 6 | [![Supported Python versions](https://img.shields.io/pypi/pyversions/purerpc.svg)](https://pypi.org/project/purerpc) 7 | 8 | 9 | _purerpc_ is a native, async Python gRPC client and server implementation supporting 10 | [asyncio](https://docs.python.org/3/library/asyncio.html), 11 | [uvloop](https://github.com/MagicStack/uvloop), and 12 | [trio](https://github.com/python-trio/trio) (achieved with [anyio](https://github.com/agronholm/anyio) compatibility layer). 13 | 14 | This project is in maintenance mode. Updates will primarily be limited to fixing 15 | severe bugs, keeping the package usable for actively developed projects, and 16 | easing maintenance. 17 | 18 | For use cases limited to asyncio, consider the Python package published by the 19 | main [grpc](https://github.com/grpc/grpc) project instead. 20 | 21 | ## Requirements 22 | 23 | * CPython >= 3.8 24 | * ? PyPy >= 3.8 25 | 26 | ## Installation 27 | 28 | Latest PyPI version: 29 | 30 | ```bash 31 | pip install purerpc[grpc] 32 | ``` 33 | 34 | NOTE: for PyPy, replace "grpc" with "grpc-pypy". Support is tentative, as 35 | [grpc does not officially support PyPy](https://github.com/grpc/grpc/issues/4221). 36 | 37 | Latest development version: 38 | 39 | ```bash 40 | pip install git+https://github.com/python-trio/purerpc.git[grpc] 41 | ``` 42 | 43 | These invocations will include dependencies for the grpc runtime and 44 | generation of service stubs. 45 | 46 | To install extra dependencies for running tests or examples, using the 47 | `test_utils` module, etc., apply the `[dev]` suffix (e.g. 48 | `pip install purerpc[dev]`). 49 | 50 | ## protoc plugin 51 | 52 | purerpc adds `protoc-gen-purerpc` plugin for `protoc` to your `PATH` environment variable 53 | so you can use it to generate service definition and stubs: 54 | 55 | ```bash 56 | protoc --purerpc_out=. --python_out=. -I. greeter.proto 57 | ``` 58 | 59 | or, if you installed the `grpcio-tools` Python package: 60 | 61 | ```bash 62 | python -m grpc_tools.protoc --purerpc_out=. --python_out=. -I. greeter.proto 63 | ``` 64 | 65 | ## Usage 66 | 67 | NOTE: `greeter_grpc` module is generated by purerpc's `protoc-gen-purerpc` plugin. 68 | 69 | ### Server 70 | 71 | ```python 72 | from purerpc import Server 73 | from greeter_pb2 import HelloRequest, HelloReply 74 | from greeter_grpc import GreeterServicer 75 | 76 | 77 | class Greeter(GreeterServicer): 78 | async def SayHello(self, message): 79 | return HelloReply(message="Hello, " + message.name) 80 | 81 | async def SayHelloToMany(self, input_messages): 82 | async for message in input_messages: 83 | yield HelloReply(message=f"Hello, {message.name}") 84 | 85 | 86 | if __name__ == '__main__': 87 | server = Server(50055) 88 | server.add_service(Greeter().service) 89 | # NOTE: if you already have an async loop running, use "await server.serve_async()" 90 | import anyio 91 | anyio.run(server.serve_async) # or set explicit backend="asyncio" or "trio" 92 | ``` 93 | 94 | ### Client 95 | 96 | ```python 97 | import purerpc 98 | from greeter_pb2 import HelloRequest, HelloReply 99 | from greeter_grpc import GreeterStub 100 | 101 | 102 | async def gen(): 103 | for i in range(5): 104 | yield HelloRequest(name=str(i)) 105 | 106 | 107 | async def listen(): 108 | async with purerpc.insecure_channel("localhost", 50055) as channel: 109 | stub = GreeterStub(channel) 110 | reply = await stub.SayHello(HelloRequest(name="World")) 111 | print(reply.message) 112 | 113 | async with stub.SayHelloToMany(gen()) as stream: 114 | async for reply in stream: 115 | print(reply.message) 116 | 117 | 118 | if __name__ == '__main__': 119 | # NOTE: if you already have an async loop running, use "await listen()" 120 | import anyio 121 | anyio.run(listen) # or set explicit backend="asyncio" or "trio" 122 | ``` 123 | 124 | You can mix server and client code, for example make a server that requests something using purerpc from another gRPC server, etc. 125 | 126 | More examples in `misc/` folder 127 | 128 | # Project history 129 | 130 | purerpc was originally written by [Andrew Stepanov](https://github.com/standy66) 131 | and used the curio async event loop. Later it 132 | was migrated to the [anyio](https://github.com/agronholm/anyio) API, supporting 133 | asyncio, curio, uvloop, and trio (though curio support has since been dropped 134 | from the API). 135 | 136 | After going a few years unmaintained, the project was adopted by the [python-trio 137 | organization](https://github.com/python-trio) with the intent of ensuring a 138 | continued gRPC solution for Trio users. 139 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release history 2 | 3 | ## Release 0.8.0 (2022-04-20) 4 | 5 | ### Development 6 | 7 | * test runs now cover all backends automatically, so multiple invocations 8 | of pytest are no longer needed 9 | 10 | ### BREAKING CHANGES 11 | 12 | * revise dependency management. The default requirements are strictly for 13 | purerpc's direct dependencies. Use requirements extra "grpc" for grpc 14 | runtime and stub generation. Use "dev" for running tests or examples, 15 | using the test_utils module, etc. 16 | 17 | 18 | ## Release 0.7.1 (2022-04-19) 19 | 20 | ### Bug Fixes 21 | 22 | * fix server end-of-stream handling, where normal client disconnects were 23 | being logged as exceptions 24 | 25 | 26 | ## Release 0.7.0 (2022-04-17) 27 | 28 | ### Features 29 | 30 | * add Server.serve_async(), allowing the grpc server to run concurrently 31 | with other async tasks. (Server.serve() is deprecated.) 32 | * upgrade anyio dependency, which will resolve conflicts when trying to use 33 | purerpc together with other packages depending on anyio 34 | 35 | ### BREAKING CHANGES 36 | 37 | * drop curio backend support (since anyio has dropped it) 38 | * drop Python 3.5, 3.6 support 39 | 40 | 41 | ## [Release 0.6.1](https://github.com/python-trio/purerpc/compare/v0.6.0...v0.6.1) (2020-04-13) 42 | 43 | ### Bug Fixes 44 | 45 | * build in PyPy 3.6, remove 3.5 builds from CI ([d1bcc9d](https://github.com/python-trio/purerpc/commit/d1bcc9d)) 46 | * remove CPython 3.5 builds ([7488ba8](https://github.com/python-trio/purerpc/commit/7488ba8)) 47 | 48 | 49 | 50 | ## [Release 0.6.0](https://github.com/python-trio/purerpc/compare/v0.5.2...v0.6.0) (2020-04-13) 51 | 52 | ### Features 53 | 54 | * Add TLS Support 55 | 56 | 57 | ## [Release 0.5.2](https://github.com/python-trio/purerpc/compare/v0.5.1...v0.5.2) (2019-07-23) 58 | 59 | 60 | ### Features 61 | 62 | * additional exception shielding for asyncio ([3cbd35c](https://github.com/python-trio/purerpc/commit/3cbd35c)) 63 | 64 | 65 | 66 | ## [Release 0.5.1](https://github.com/python-trio/purerpc/compare/v0.5.0...v0.5.1) (2019-07-23) 67 | 68 | 69 | ### Bug Fixes 70 | 71 | * async generators on python 3.5 ([1c19229](https://github.com/python-trio/purerpc/commit/1c19229)) 72 | 73 | 74 | 75 | ## [Release 0.5.0](https://github.com/python-trio/purerpc/compare/v0.4.1...v0.5.0) (2019-07-23) 76 | 77 | 78 | ### Features 79 | 80 | * can now pass contextmngr or setup_fn/teardown_fn to add_service ([208dd95](https://github.com/python-trio/purerpc/commit/208dd95)) 81 | 82 | 83 | 84 | ## [Release 0.4.1](https://github.com/python-trio/purerpc/compare/v0.4.0...v0.4.1) (2019-07-22) 85 | 86 | 87 | ### Features 88 | 89 | * remove undocumented use of raw_socket in anyio ([6de2c9a](https://github.com/python-trio/purerpc/commit/6de2c9a)) 90 | 91 | 92 | 93 | ## [Release 0.4.0](https://github.com/python-trio/purerpc/compare/v0.3.2...v0.4.0) (2019-07-22) 94 | 95 | 96 | ### Bug Fixes 97 | 98 | * speed improvements ([1cb3d46](https://github.com/python-trio/purerpc/commit/1cb3d46)) 99 | 100 | 101 | ### Features 102 | 103 | * add state property to GRPCStream ([0019d8c](https://github.com/python-trio/purerpc/commit/0019d8c)) 104 | * answer PING frames ([c829901](https://github.com/python-trio/purerpc/commit/c829901)) 105 | * change MAX_CONCURRENT_STREAMS from 1000 to 65536 ([d2d461f](https://github.com/python-trio/purerpc/commit/d2d461f)) 106 | * decouple h2 and grpclib logic ([1f4e6b0](https://github.com/python-trio/purerpc/commit/1f4e6b0)) 107 | * support percent-encoded grpc-message header ([c6636f4](https://github.com/python-trio/purerpc/commit/c6636f4)) 108 | * change default max message length to 32 MB 109 | 110 | 111 | ## [Release 0.3.2](https://github.com/python-trio/purerpc/compare/v0.3.1...v0.3.2) (2019-02-15) 112 | 113 | 114 | ### Bug Fixes 115 | 116 | * fix dependencies, remove some of anyio monkey patches ([ac6c5c2](https://github.com/python-trio/purerpc/commit/ac6c5c2)) 117 | 118 | 119 | 120 | ## [Release 0.3.1](https://github.com/python-trio/purerpc/compare/v0.3.0...v0.3.1) (2019-02-15) 121 | 122 | 123 | ### Bug Fixes 124 | 125 | * fix pickling error in purerpc.test_utils._WrappedResult ([9f0a63d](https://github.com/python-trio/purerpc/commit/9f0a63d)) 126 | 127 | 128 | 129 | ## [Release 0.3.0](https://github.com/python-trio/purerpc/compare/v0.2.1...v0.3.0) (2019-02-14) 130 | 131 | 132 | ### Features 133 | 134 | * expose new functions in purerpc.test_utils ([07b10e1](https://github.com/python-trio/purerpc/commit/07b10e1)) 135 | * migrate to pytest ([95c0a8b](https://github.com/python-trio/purerpc/commit/95c0a8b)) 136 | 137 | 138 | ### BREAKING CHANGES 139 | 140 | * purerpc.test_utils.PureRPCTestCase is removed 141 | 142 | 143 | 144 | ## [Release 0.2.0](https://github.com/python-trio/purerpc/compare/v0.1.6...v0.2.0) (2019-02-10) 145 | 146 | 147 | ### Features 148 | 149 | * add backend option to Server.serve ([5f47f8e](https://github.com/python-trio/purerpc/commit/5f47f8e)) 150 | * add support for Python 3.5 ([a681192](https://github.com/python-trio/purerpc/commit/a681192)) 151 | * improved exception handling in test utils ([b1df796](https://github.com/python-trio/purerpc/commit/b1df796)) 152 | * migrate to anyio ([746b1c2](https://github.com/python-trio/purerpc/commit/746b1c2)) 153 | 154 | 155 | ### BREAKING CHANGES 156 | 157 | * Server and test now use asyncio event loop by default, 158 | this behaviour can be changed with PURERPC_BACKEND environment variable 159 | * purerpc.Channel is removed, migrate to 160 | purerpc.insecure_channel async context manager (now supports correct 161 | shutdown) 162 | 163 | ## Release 0.1.6 164 | 165 | * Allow passing request headers to method handlers in request argument 166 | * Allow passing custom metadata to method stub calls (in metadata optional keyword argument) 167 | 168 | ## Release 0.1.5 169 | 170 | * Enforce SO_KEEPALIVE with small timeouts 171 | * Expose PureRPCTestCase in purerpc API for unit testing purerpc services 172 | 173 | ## Release 0.1.4 174 | 175 | * Speed up protoc plugin 176 | 177 | ## Release 0.1.3 [PyPI only] 178 | 179 | * Fix long description on PyPI 180 | 181 | ## Release 0.1.2 182 | 183 | * Fix unit tests on Python 3.7 184 | 185 | ## Release 0.1.0 186 | 187 | * Implement immediate mode 188 | 189 | ## Release 0.0.1 190 | 191 | * Initial release 192 | -------------------------------------------------------------------------------- /commitlint.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | rules: { 3 | 'body-leading-blank': [1, 'always'], 4 | 'footer-leading-blank': [1, 'always'], 5 | 'header-max-length': [2, 'always', 72], 6 | 'scope-case': [2, 'always', 'lower-case'], 7 | 'subject-case': [ 8 | 2, 9 | 'never', 10 | ['sentence-case', 'start-case', 'pascal-case', 'upper-case'] 11 | ], 12 | 'subject-empty': [2, 'never'], 13 | 'subject-full-stop': [2, 'never', '.'], 14 | 'type-case': [2, 'always', 'lower-case'], 15 | 'type-empty': [2, 'never'], 16 | 'type-enum': [ 17 | 2, 18 | 'always', 19 | [ 20 | 'build', 21 | 'chore', 22 | 'ci', 23 | 'docs', 24 | 'feat', 25 | 'fix', 26 | 'perf', 27 | 'refactor', 28 | 'revert', 29 | 'style', 30 | 'test' 31 | ] 32 | ] 33 | } 34 | }; 35 | -------------------------------------------------------------------------------- /docs/immediate_mode.md: -------------------------------------------------------------------------------- 1 | # Immediate mode 2 | 3 | The goal: make `send` / `recv` on a Stream actually send or receive a message to / from socket. 4 | 5 | It is a little bit tricky for receive, because all the receives happen in single listening coroutine thread that is bound to the underlying connection: 6 | 7 | ```python 8 | async def listen(): 9 | ... 10 | while True: 11 | data = await sock.recv(BUFFER_SIZE) 12 | if not data: 13 | break 14 | 15 | events = h2_conn.receive_data(data) 16 | for event in events: 17 | # process events 18 | ... 19 | await sock.sendall(h2_conn.data_to_send()) 20 | ``` 21 | 22 | In contrast, multiple HTTP/2 streams can be opened on a same connection, this streams are typically processed in separate coroutine threads. That means there is asymmetry in how `send` and `recv` are processed. 23 | 24 | Because all the receives happen on a separate thread it seems resonable to buffer the incoming data, HTTP/2 gives us flow control for that after all. The buffer should be as simple as possible: ***we cannot block*** in `listen()` thread on any call other than `sock.recv()` / `sock.sendall()`, because that will increase latency. Backpressure should be applied by the means of flow control. The simplest solution is the following: set a `bytes` ring buffer with the maximum size equal-ish to the maximum GRPC message that is expected (the maximum size should really be set because otherwise someone can attack the unlimited buffer). Set the stream flow control window to this size. On receiving a message, just copy the data (or reject the message altogether because of size limits) to the internal buffer. Parse the message from buffer and acknowledge it via `WINDOW_UPDATE` only in `Stream.recv` (when it is actually needed by application). With this scheme, trying to send more data than the application consumes violates flow control, and all the buffers are bounded (and also there is overall bound as `num_streams * max_message_size`) 25 | 26 | The sending is implemented as follows: each access to `H2Connection.data_to_send` is protected by a FIFO lock: 27 | 28 | ```python 29 | with global_connection_lock: 30 | data = h2_conn.data_to_send() 31 | await sock.sendall(data) 32 | ``` 33 | 34 | This is a simple primitive that flushes all the data that is ready to be sent to the socket. The flow control on the sending end is also respected: 35 | 36 | ```python 37 | async def send_message(stream_id, msg): 38 | pos = 0 39 | while pos < len(msg): 40 | size = h2_conn.local_flow_control_window(stream_id) 41 | if size == 0: 42 | await flow_control_event[stream_id].wait() 43 | flow_control_event[stream_id].clear() 44 | continue 45 | else: 46 | h2_conn.send_data(stream_id, msg[pos:pos+size]) 47 | pos += size 48 | with global_connection_lock: 49 | data = h2_conn.data_to_send() 50 | await sock.sendall(data) 51 | ``` 52 | 53 | The `flow_control_event` is local to the current stream and the `await flow_control_event[stream_id].set()` is executed each time `WINDOW_UPDATE` frame for that stream or the connection itself arrives in `listen()` thread. 54 | 55 | With this approach, there is seemingly no `Queue`s anywhere and each action does exactly what it says it does, without any buffers. 56 | 57 | For both client and server, there is one listening thread for the connections, and also one sending thread per stream. In case of server, the sending threads are spawned by the listening threads itself on the arrival of requests. 58 | 59 | The benefits also include: 60 | 61 | * no more complicated logic in _writer_thread 62 | * no _writer_thread at all 63 | * no more complicated logic in GRPCConnection.data_to_send regarding flow control and buffering 64 | 65 | The cons: 66 | 67 | * need to rethink "sans IO" GRPCConnection to work with this case 68 | * need to think how to cancel tasks in progress when too large message is received 69 | * need to count `flow_controlled_length`s to acknowledge the correct size (also need to set flow controlled length larger than maximum message length in case of padding) 70 | * need to design more random tests that test identity of the messages before implementing this 71 | * make perf tests 72 | 73 | **UPDATE (2019-02-15)**: Turns out we need `_writer_thread` after all, because we cannot send anything on 74 | `_listener_thread` (or we would block when both ends try to send very large chunks of data). If we just exclude 75 | calls to send from `_listener_thread`, we won't be able to answer PING frames. So instead, we ping `_writer_thread` 76 | so it can do the sending for us. -------------------------------------------------------------------------------- /misc/greeter/README: -------------------------------------------------------------------------------- 1 | Usage notes 2 | 3 | $ python main_pingpong_servicer.py 4 | # (another terminal) 5 | $ python client.py 6 | RPS: 942.7659800962675 7 | RPS: 928.7526104604065 8 | RPS: 856.73468437314 9 | ... 10 | 11 | $ python failing_server.py 12 | # (another terminal) 13 | $ python failing_client.py 14 | Round 0 rps 894.274406271697 avg latency 110.80834655761724 15 | Round 1 rps 821.6505728334758 avg latency 120.84926767349245 16 | Round 2 rps 786.0224822867286 avg latency 126.15007748603823 17 | ... 18 | 19 | $ python baseline_main.py 20 | # (another terminal) 21 | $ python test_perf_grpcio.py --load_type unary 22 | Round 0, RPS: 2653.0537795880555, avg latency: 112.15578104654948 ms, max latency: 113.03731441497803 ms 23 | Round 1, RPS: 2785.6302402624187, avg latency: 107.29748307863872 ms, max latency: 107.88612365722656 ms 24 | Round 2, RPS: 2755.1956914726525, avg latency: 108.47193616231284 ms, max latency: 108.89975547790527 ms 25 | ... 26 | 27 | $ python main.py 28 | # (another terminal) 29 | $ python test_perf.py --load_type unary 30 | Round 0, RPS: 1213.7552524716182, avg latency: 246.56518713633216 ms, max latency: 249.38842296600342 ms 31 | Round 1, RPS: 1325.8140062285606, avg latency: 225.77957820892334 ms, max latency: 231.95827960968018 ms 32 | Round 2, RPS: 1387.5940927539905, avg latency: 215.69242871602378 ms, max latency: 218.73981952667236 ms 33 | ... 34 | -------------------------------------------------------------------------------- /misc/greeter/baseline_main.py: -------------------------------------------------------------------------------- 1 | """The Python implementation of the GRPC helloworld.Greeter server.""" 2 | 3 | from concurrent import futures 4 | import time 5 | 6 | import grpc 7 | 8 | from generated.greeter_pb2 import HelloReply 9 | from generated.greeter_pb2_grpc import GreeterServicer, add_GreeterServicer_to_server 10 | 11 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24 12 | 13 | 14 | class Greeter(GreeterServicer): 15 | 16 | def SayHelloToMany(self, request_iterator, context): 17 | requests = [] 18 | for request in request_iterator: 19 | requests.append(request) 20 | 21 | name = requests[0].name 22 | for i in range(8): 23 | name += name 24 | yield HelloReply(message='Hello, {}'.format(name)) 25 | 26 | 27 | def serve(): 28 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) 29 | add_GreeterServicer_to_server(Greeter(), server) 30 | server.add_insecure_port('[::]:50055') 31 | server.start() 32 | try: 33 | while True: 34 | time.sleep(_ONE_DAY_IN_SECONDS) 35 | except KeyboardInterrupt: 36 | server.stop(0) 37 | 38 | 39 | if __name__ == '__main__': 40 | serve() 41 | 42 | -------------------------------------------------------------------------------- /misc/greeter/client.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import anyio 4 | 5 | import purerpc 6 | from generated.greeter_pb2 import HelloRequest 7 | from generated.greeter_grpc import GreeterStub 8 | 9 | 10 | async def worker(channel): 11 | stub = GreeterStub(channel) 12 | for i in range(100): 13 | data = "World" * 1 14 | response = await stub.SayHello(HelloRequest(name=data)) 15 | assert(response.message == "Hello, " + data) 16 | 17 | 18 | async def main_coro(): 19 | async with purerpc.insecure_channel("localhost", 50055) as channel: 20 | for _ in range(100): 21 | start = time.time() 22 | async with anyio.create_task_group() as task_group: 23 | for _ in range(100): 24 | task_group.start_soon(worker, channel) 25 | print("RPS: {}".format(10000 / (time.time() - start))) 26 | 27 | 28 | def main(): 29 | purerpc.run(main_coro) 30 | 31 | 32 | if __name__ == "__main__": 33 | try: 34 | main() 35 | except KeyboardInterrupt: 36 | pass 37 | -------------------------------------------------------------------------------- /misc/greeter/client_grpcio.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | from generated.greeter_pb2 import HelloRequest 3 | from generated.greeter_pb2_grpc import GreeterStub 4 | 5 | 6 | def main(): 7 | channel = grpc.insecure_channel("localhost:50055") 8 | stub = GreeterStub(channel) 9 | data = "World" * 20000 10 | response = stub.SayHello(HelloRequest(name=data)) 11 | assert(response.message == "Hello, " + data) 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /misc/greeter/failing_client.py: -------------------------------------------------------------------------------- 1 | import anyio 2 | import sys 3 | import time 4 | from generated import greeter_grpc, greeter_pb2 5 | import purerpc 6 | 7 | 8 | GreeterStub = greeter_grpc.GreeterStub 9 | 10 | 11 | 12 | async def do_load_unary(result_queue, stub, num_requests, message_size): 13 | message = "0" * message_size 14 | start = time.time() 15 | for _ in range(num_requests): 16 | result = (await stub.SayHello(greeter_pb2.HelloRequest(name=message))).message 17 | assert (len(result) == message_size) 18 | avg_latency = (time.time() - start) / num_requests 19 | await result_queue.send(avg_latency) 20 | 21 | 22 | async def do_load_stream(result_queue, stub, num_requests, message_size): 23 | message = "0" * message_size 24 | stream = await stub.SayHelloToMany() 25 | start = time.time() 26 | for _ in range(num_requests): 27 | await stream.send_message(greeter_pb2.HelloRequest(name=message)) 28 | result = await stream.receive_message() 29 | assert (len(result.message) == message_size) 30 | avg_latency = (time.time() - start) / num_requests 31 | await stream.close() 32 | await stream.receive_message() 33 | await result_queue.send(avg_latency) 34 | 35 | 36 | async def worker(port, num_concurrent_streams, num_requests_per_stream, 37 | num_rounds, message_size, load_type): 38 | async with purerpc.insecure_channel("localhost", port) as channel: 39 | stub = GreeterStub(channel) 40 | if load_type == "unary": 41 | load_fn = do_load_unary 42 | elif load_type == "stream": 43 | load_fn = do_load_stream 44 | else: 45 | raise ValueError(f"Unknown load type: {load_type}") 46 | for idx in range(num_rounds): 47 | start = time.time() 48 | send_queue, receive_queue = anyio.create_memory_object_stream(max_buffer_size=sys.maxsize) 49 | async with anyio.create_task_group() as task_group: 50 | for _ in range(num_concurrent_streams): 51 | task_group.start_soon(load_fn, send_queue, stub, num_requests_per_stream, message_size) 52 | end = time.time() 53 | 54 | rps = num_concurrent_streams * num_requests_per_stream / (end - start) 55 | 56 | latencies = [] 57 | for _ in range(num_concurrent_streams): 58 | latencies.append(await receive_queue.receive()) 59 | 60 | print("Round", idx, "rps", rps, "avg latency", 1000 * sum(latencies) / len(latencies)) 61 | 62 | 63 | if __name__ == "__main__": 64 | try: 65 | purerpc.run(worker, 50055, 100, 50, 10, 1000, "unary") 66 | except KeyboardInterrupt: 67 | pass 68 | -------------------------------------------------------------------------------- /misc/greeter/failing_server.py: -------------------------------------------------------------------------------- 1 | from generated import greeter_grpc, greeter_pb2 2 | 3 | import purerpc 4 | 5 | 6 | GreeterServicer = greeter_grpc.GreeterServicer 7 | class Servicer(GreeterServicer): 8 | async def SayHello(self, message): 9 | return greeter_pb2.HelloReply(message=message.name) 10 | 11 | async def SayHelloGoodbye(self, message): 12 | yield greeter_pb2.HelloReply(message=message.name) 13 | yield greeter_pb2.HelloReply(message=message.name) 14 | 15 | async def SayHelloToManyAtOnce(self, messages): 16 | names = [] 17 | async for message in messages: 18 | names.append(message.name) 19 | return greeter_pb2.HelloReply(message="".join(names)) 20 | 21 | async def SayHelloToMany(self, messages): 22 | async for message in messages: 23 | yield greeter_pb2.HelloReply(message=message.name) 24 | 25 | 26 | def main(): 27 | server = purerpc.Server(50055) 28 | server.add_service(Servicer().service) 29 | server.serve() 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /misc/greeter/generated/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/python-trio/purerpc/a3c17dd885d8f36bcf3a78c7506e35f7bc33cccc/misc/greeter/generated/__init__.py -------------------------------------------------------------------------------- /misc/greeter/generated/greeter.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | service Greeter { 4 | rpc SayHello (HelloRequest) returns (HelloReply) {} 5 | rpc SayHelloGoodbye (HelloRequest) returns (stream HelloReply) {} 6 | rpc SayHelloToMany (stream HelloRequest) returns (stream HelloReply) {} 7 | rpc SayHelloToManyAtOnce (stream HelloRequest) returns (HelloReply) {} 8 | } 9 | 10 | 11 | message HelloRequest { 12 | string name = 1; 13 | } 14 | 15 | message HelloReply { 16 | string message = 1; 17 | } -------------------------------------------------------------------------------- /misc/greeter/generated/greeter_grpc.py: -------------------------------------------------------------------------------- 1 | import purerpc 2 | import generated.greeter_pb2 3 | 4 | 5 | class GreeterServicer(purerpc.Servicer): 6 | async def SayHello(self, input_message): 7 | raise NotImplementedError() 8 | 9 | async def SayHelloGoodbye(self, input_message): 10 | raise NotImplementedError() 11 | 12 | async def SayHelloToMany(self, input_messages): 13 | raise NotImplementedError() 14 | 15 | async def SayHelloToManyAtOnce(self, input_messages): 16 | raise NotImplementedError() 17 | 18 | @property 19 | def service(self) -> purerpc.Service: 20 | service_obj = purerpc.Service( 21 | "Greeter" 22 | ) 23 | service_obj.add_method( 24 | "SayHello", 25 | self.SayHello, 26 | purerpc.RPCSignature( 27 | purerpc.Cardinality.UNARY_UNARY, 28 | generated.greeter_pb2.HelloRequest, 29 | generated.greeter_pb2.HelloReply, 30 | ) 31 | ) 32 | service_obj.add_method( 33 | "SayHelloGoodbye", 34 | self.SayHelloGoodbye, 35 | purerpc.RPCSignature( 36 | purerpc.Cardinality.UNARY_STREAM, 37 | generated.greeter_pb2.HelloRequest, 38 | generated.greeter_pb2.HelloReply, 39 | ) 40 | ) 41 | service_obj.add_method( 42 | "SayHelloToMany", 43 | self.SayHelloToMany, 44 | purerpc.RPCSignature( 45 | purerpc.Cardinality.STREAM_STREAM, 46 | generated.greeter_pb2.HelloRequest, 47 | generated.greeter_pb2.HelloReply, 48 | ) 49 | ) 50 | service_obj.add_method( 51 | "SayHelloToManyAtOnce", 52 | self.SayHelloToManyAtOnce, 53 | purerpc.RPCSignature( 54 | purerpc.Cardinality.STREAM_UNARY, 55 | generated.greeter_pb2.HelloRequest, 56 | generated.greeter_pb2.HelloReply, 57 | ) 58 | ) 59 | return service_obj 60 | 61 | 62 | class GreeterStub: 63 | def __init__(self, channel): 64 | self._client = purerpc.Client( 65 | "Greeter", 66 | channel 67 | ) 68 | self.SayHello = self._client.get_method_stub( 69 | "SayHello", 70 | purerpc.RPCSignature( 71 | purerpc.Cardinality.UNARY_UNARY, 72 | generated.greeter_pb2.HelloRequest, 73 | generated.greeter_pb2.HelloReply, 74 | ) 75 | ) 76 | self.SayHelloGoodbye = self._client.get_method_stub( 77 | "SayHelloGoodbye", 78 | purerpc.RPCSignature( 79 | purerpc.Cardinality.UNARY_STREAM, 80 | generated.greeter_pb2.HelloRequest, 81 | generated.greeter_pb2.HelloReply, 82 | ) 83 | ) 84 | self.SayHelloToMany = self._client.get_method_stub( 85 | "SayHelloToMany", 86 | purerpc.RPCSignature( 87 | purerpc.Cardinality.STREAM_STREAM, 88 | generated.greeter_pb2.HelloRequest, 89 | generated.greeter_pb2.HelloReply, 90 | ) 91 | ) 92 | self.SayHelloToManyAtOnce = self._client.get_method_stub( 93 | "SayHelloToManyAtOnce", 94 | purerpc.RPCSignature( 95 | purerpc.Cardinality.STREAM_UNARY, 96 | generated.greeter_pb2.HelloRequest, 97 | generated.greeter_pb2.HelloReply, 98 | ) 99 | ) -------------------------------------------------------------------------------- /misc/greeter/generated/greeter_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: generated/greeter.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='generated/greeter.proto', 19 | package='', 20 | syntax='proto3', 21 | serialized_options=None, 22 | serialized_pb=_b('\n\x17generated/greeter.proto\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2\xd2\x01\n\x07Greeter\x12(\n\x08SayHello\x12\r.HelloRequest\x1a\x0b.HelloReply\"\x00\x12\x31\n\x0fSayHelloGoodbye\x12\r.HelloRequest\x1a\x0b.HelloReply\"\x00\x30\x01\x12\x32\n\x0eSayHelloToMany\x12\r.HelloRequest\x1a\x0b.HelloReply\"\x00(\x01\x30\x01\x12\x36\n\x14SayHelloToManyAtOnce\x12\r.HelloRequest\x1a\x0b.HelloReply\"\x00(\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _HELLOREQUEST = _descriptor.Descriptor( 29 | name='HelloRequest', 30 | full_name='HelloRequest', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='HelloRequest.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | serialized_options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | serialized_options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=27, 55 | serialized_end=55, 56 | ) 57 | 58 | 59 | _HELLOREPLY = _descriptor.Descriptor( 60 | name='HelloReply', 61 | full_name='HelloReply', 62 | filename=None, 63 | file=DESCRIPTOR, 64 | containing_type=None, 65 | fields=[ 66 | _descriptor.FieldDescriptor( 67 | name='message', full_name='HelloReply.message', index=0, 68 | number=1, type=9, cpp_type=9, label=1, 69 | has_default_value=False, default_value=_b("").decode('utf-8'), 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | serialized_options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | serialized_options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=57, 86 | serialized_end=86, 87 | ) 88 | 89 | DESCRIPTOR.message_types_by_name['HelloRequest'] = _HELLOREQUEST 90 | DESCRIPTOR.message_types_by_name['HelloReply'] = _HELLOREPLY 91 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 92 | 93 | HelloRequest = _reflection.GeneratedProtocolMessageType('HelloRequest', (_message.Message,), dict( 94 | DESCRIPTOR = _HELLOREQUEST, 95 | __module__ = 'generated.greeter_pb2' 96 | # @@protoc_insertion_point(class_scope:HelloRequest) 97 | )) 98 | _sym_db.RegisterMessage(HelloRequest) 99 | 100 | HelloReply = _reflection.GeneratedProtocolMessageType('HelloReply', (_message.Message,), dict( 101 | DESCRIPTOR = _HELLOREPLY, 102 | __module__ = 'generated.greeter_pb2' 103 | # @@protoc_insertion_point(class_scope:HelloReply) 104 | )) 105 | _sym_db.RegisterMessage(HelloReply) 106 | 107 | 108 | 109 | _GREETER = _descriptor.ServiceDescriptor( 110 | name='Greeter', 111 | full_name='Greeter', 112 | file=DESCRIPTOR, 113 | index=0, 114 | serialized_options=None, 115 | serialized_start=89, 116 | serialized_end=299, 117 | methods=[ 118 | _descriptor.MethodDescriptor( 119 | name='SayHello', 120 | full_name='Greeter.SayHello', 121 | index=0, 122 | containing_service=None, 123 | input_type=_HELLOREQUEST, 124 | output_type=_HELLOREPLY, 125 | serialized_options=None, 126 | ), 127 | _descriptor.MethodDescriptor( 128 | name='SayHelloGoodbye', 129 | full_name='Greeter.SayHelloGoodbye', 130 | index=1, 131 | containing_service=None, 132 | input_type=_HELLOREQUEST, 133 | output_type=_HELLOREPLY, 134 | serialized_options=None, 135 | ), 136 | _descriptor.MethodDescriptor( 137 | name='SayHelloToMany', 138 | full_name='Greeter.SayHelloToMany', 139 | index=2, 140 | containing_service=None, 141 | input_type=_HELLOREQUEST, 142 | output_type=_HELLOREPLY, 143 | serialized_options=None, 144 | ), 145 | _descriptor.MethodDescriptor( 146 | name='SayHelloToManyAtOnce', 147 | full_name='Greeter.SayHelloToManyAtOnce', 148 | index=3, 149 | containing_service=None, 150 | input_type=_HELLOREQUEST, 151 | output_type=_HELLOREPLY, 152 | serialized_options=None, 153 | ), 154 | ]) 155 | _sym_db.RegisterServiceDescriptor(_GREETER) 156 | 157 | DESCRIPTOR.services_by_name['Greeter'] = _GREETER 158 | 159 | # @@protoc_insertion_point(module_scope) 160 | -------------------------------------------------------------------------------- /misc/greeter/generated/greeter_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from generated import greeter_pb2 as generated_dot_greeter__pb2 5 | 6 | 7 | class GreeterStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.SayHello = channel.unary_unary( 18 | '/Greeter/SayHello', 19 | request_serializer=generated_dot_greeter__pb2.HelloRequest.SerializeToString, 20 | response_deserializer=generated_dot_greeter__pb2.HelloReply.FromString, 21 | ) 22 | self.SayHelloGoodbye = channel.unary_stream( 23 | '/Greeter/SayHelloGoodbye', 24 | request_serializer=generated_dot_greeter__pb2.HelloRequest.SerializeToString, 25 | response_deserializer=generated_dot_greeter__pb2.HelloReply.FromString, 26 | ) 27 | self.SayHelloToMany = channel.stream_stream( 28 | '/Greeter/SayHelloToMany', 29 | request_serializer=generated_dot_greeter__pb2.HelloRequest.SerializeToString, 30 | response_deserializer=generated_dot_greeter__pb2.HelloReply.FromString, 31 | ) 32 | self.SayHelloToManyAtOnce = channel.stream_unary( 33 | '/Greeter/SayHelloToManyAtOnce', 34 | request_serializer=generated_dot_greeter__pb2.HelloRequest.SerializeToString, 35 | response_deserializer=generated_dot_greeter__pb2.HelloReply.FromString, 36 | ) 37 | 38 | 39 | class GreeterServicer(object): 40 | # missing associated documentation comment in .proto file 41 | pass 42 | 43 | def SayHello(self, request, context): 44 | # missing associated documentation comment in .proto file 45 | pass 46 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 47 | context.set_details('Method not implemented!') 48 | raise NotImplementedError('Method not implemented!') 49 | 50 | def SayHelloGoodbye(self, request, context): 51 | # missing associated documentation comment in .proto file 52 | pass 53 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 54 | context.set_details('Method not implemented!') 55 | raise NotImplementedError('Method not implemented!') 56 | 57 | def SayHelloToMany(self, request_iterator, context): 58 | # missing associated documentation comment in .proto file 59 | pass 60 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 61 | context.set_details('Method not implemented!') 62 | raise NotImplementedError('Method not implemented!') 63 | 64 | def SayHelloToManyAtOnce(self, request_iterator, context): 65 | # missing associated documentation comment in .proto file 66 | pass 67 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 68 | context.set_details('Method not implemented!') 69 | raise NotImplementedError('Method not implemented!') 70 | 71 | 72 | def add_GreeterServicer_to_server(servicer, server): 73 | rpc_method_handlers = { 74 | 'SayHello': grpc.unary_unary_rpc_method_handler( 75 | servicer.SayHello, 76 | request_deserializer=generated_dot_greeter__pb2.HelloRequest.FromString, 77 | response_serializer=generated_dot_greeter__pb2.HelloReply.SerializeToString, 78 | ), 79 | 'SayHelloGoodbye': grpc.unary_stream_rpc_method_handler( 80 | servicer.SayHelloGoodbye, 81 | request_deserializer=generated_dot_greeter__pb2.HelloRequest.FromString, 82 | response_serializer=generated_dot_greeter__pb2.HelloReply.SerializeToString, 83 | ), 84 | 'SayHelloToMany': grpc.stream_stream_rpc_method_handler( 85 | servicer.SayHelloToMany, 86 | request_deserializer=generated_dot_greeter__pb2.HelloRequest.FromString, 87 | response_serializer=generated_dot_greeter__pb2.HelloReply.SerializeToString, 88 | ), 89 | 'SayHelloToManyAtOnce': grpc.stream_unary_rpc_method_handler( 90 | servicer.SayHelloToManyAtOnce, 91 | request_deserializer=generated_dot_greeter__pb2.HelloRequest.FromString, 92 | response_serializer=generated_dot_greeter__pb2.HelloReply.SerializeToString, 93 | ), 94 | } 95 | generic_handler = grpc.method_handlers_generic_handler( 96 | 'Greeter', rpc_method_handlers) 97 | server.add_generic_rpc_handlers((generic_handler,)) 98 | -------------------------------------------------------------------------------- /misc/greeter/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | 4 | from purerpc import Server, Service, Stream 5 | from generated.greeter_pb2 import HelloRequest, HelloReply 6 | 7 | 8 | def configure_logs(log_file=None): 9 | conf = { 10 | "version": 1, 11 | "formatters": { 12 | "simple": { 13 | "format": "[%(asctime)s - %(name)s - %(levelname)s]: %(message)s" 14 | }, 15 | }, 16 | "handlers": { 17 | "console": { 18 | "class": "logging.StreamHandler", 19 | "level": "WARNING", 20 | "formatter": "simple", 21 | "stream": "ext://sys.stdout", 22 | } 23 | }, 24 | "root": { 25 | "level": "WARNING", 26 | "handlers": ["console"], 27 | }, 28 | "disable_existing_loggers": False 29 | } 30 | if log_file is not None: 31 | conf["handlers"]["file"] = { 32 | "class": "logging.FileHandler", 33 | "level": "DEBUG", 34 | "formatter": "simple", 35 | "filename": log_file, 36 | } 37 | conf["root"]["handlers"].append("file") 38 | logging.config.dictConfig(conf) 39 | 40 | 41 | configure_logs() 42 | 43 | 44 | service = Service("Greeter") 45 | 46 | 47 | @service.rpc("SayHelloToMany") 48 | async def say_hello_to_many(message_iterator: Stream[HelloRequest]) -> Stream[HelloReply]: 49 | requests = [] 50 | async for message in message_iterator: 51 | requests.append(message) 52 | 53 | name = requests[0].name 54 | for i in range(8): 55 | name += name 56 | yield HelloReply(message="Hello, {}".format(name)) 57 | 58 | server = Server(port=50055) 59 | server.add_service(service) 60 | 61 | 62 | if __name__ == "__main__": 63 | server.serve() 64 | -------------------------------------------------------------------------------- /misc/greeter/main_pingpong.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | 4 | from purerpc.server import Service, Server 5 | from purerpc.rpc import Stream 6 | from generated.greeter_pb2 import HelloRequest, HelloReply 7 | 8 | 9 | def configure_logs(log_file=None): 10 | conf = { 11 | "version": 1, 12 | "formatters": { 13 | "simple": { 14 | "format": "[%(asctime)s - %(name)s - %(levelname)s]: %(message)s" 15 | }, 16 | }, 17 | "handlers": { 18 | "console": { 19 | "class": "logging.StreamHandler", 20 | "level": "WARNING", 21 | "formatter": "simple", 22 | "stream": "ext://sys.stdout", 23 | } 24 | }, 25 | "root": { 26 | "level": "WARNING", 27 | "handlers": ["console"], 28 | }, 29 | "disable_existing_loggers": False 30 | } 31 | if log_file is not None: 32 | conf["handlers"]["file"] = { 33 | "class": "logging.FileHandler", 34 | "level": "DEBUG", 35 | "formatter": "simple", 36 | "filename": log_file, 37 | } 38 | conf["root"]["handlers"].append("file") 39 | logging.config.dictConfig(conf) 40 | 41 | 42 | configure_logs() 43 | 44 | 45 | service = Service("Greeter") 46 | 47 | 48 | @service.rpc("SayHelloToMany") 49 | async def say_hello_to_many(message: Stream[HelloRequest]) -> Stream[HelloReply]: 50 | yield HelloReply(message="Hello, world!") 51 | 52 | server = Server(50055) 53 | server.add_service(service) 54 | 55 | if __name__ == "__main__": 56 | server.serve() 57 | -------------------------------------------------------------------------------- /misc/greeter/main_pingpong_servicer.py: -------------------------------------------------------------------------------- 1 | from purerpc.server import Server 2 | from generated.greeter_pb2 import HelloReply 3 | from generated.greeter_grpc import GreeterServicer 4 | 5 | """ 6 | def configure_logs(log_file=None): 7 | conf = { 8 | "version": 1, 9 | "formatters": { 10 | "simple": { 11 | "format": "[%(asctime)s - %(name)s - %(levelname)s]: %(message)s" 12 | }, 13 | }, 14 | "handlers": { 15 | "console": { 16 | "class": "logging.StreamHandler", 17 | "level": "WARNING", 18 | "formatter": "simple", 19 | "stream": "ext://sys.stdout", 20 | } 21 | }, 22 | "root": { 23 | "level": "WARNING", 24 | "handlers": ["console"], 25 | }, 26 | "disable_existing_loggers": False 27 | } 28 | if log_file is not None: 29 | conf["handlers"]["file"] = { 30 | "class": "logging.FileHandler", 31 | "level": "DEBUG", 32 | "formatter": "simple", 33 | "filename": log_file, 34 | } 35 | conf["root"]["handlers"].append("file") 36 | logging.config.dictConfig(conf) 37 | 38 | 39 | configure_logs() 40 | """ 41 | 42 | 43 | class Greeter(GreeterServicer): 44 | async def SayHello(self, message): 45 | return HelloReply(message="Hello, " + message.name) 46 | 47 | async def SayHelloToMany(self, input_messages): 48 | async for _ in input_messages: 49 | pass 50 | yield HelloReply(message="Hello, world!") 51 | 52 | 53 | if __name__ == "__main__": 54 | server = Server(50055) 55 | server.add_service(Greeter().service) 56 | server.serve() 57 | -------------------------------------------------------------------------------- /misc/greeter/test_perf.py: -------------------------------------------------------------------------------- 1 | import time 2 | import anyio 3 | import sys 4 | import argparse 5 | import multiprocessing 6 | 7 | import purerpc 8 | from generated.greeter_pb2 import HelloRequest, HelloReply 9 | from generated.greeter_grpc import GreeterServicer, GreeterStub 10 | 11 | from purerpc.test_utils import run_purerpc_service_in_process 12 | 13 | 14 | class Greeter(GreeterServicer): 15 | async def SayHello(self, message): 16 | return HelloReply(message=message.name) 17 | 18 | async def SayHelloToMany(self, input_messages): 19 | async for message in input_messages: 20 | yield HelloReply(message=message.name) 21 | 22 | 23 | async def do_load_unary(result_queue, stub, num_requests, message_size): 24 | message = "0" * message_size 25 | start = time.time() 26 | for _ in range(num_requests): 27 | result = (await stub.SayHello(HelloRequest(name=message))).message 28 | assert (len(result) == message_size) 29 | avg_latency = (time.time() - start) / num_requests 30 | await result_queue.send(avg_latency) 31 | 32 | 33 | async def do_load_stream(result_queue, stub, num_requests, message_size): 34 | message = "0" * message_size 35 | stream = await stub.SayHelloToMany() 36 | start = time.time() 37 | for _ in range(num_requests): 38 | await stream.send_message(HelloRequest(name=message)) 39 | result = await stream.receive_message() 40 | assert (len(result.message) == message_size) 41 | avg_latency = (time.time() - start) / num_requests 42 | await stream.close() 43 | await stream.receive_message() 44 | await result_queue.send(avg_latency) 45 | 46 | 47 | async def worker(port, queue, num_concurrent_streams, num_requests_per_stream, 48 | num_rounds, message_size, load_type): 49 | async with purerpc.insecure_channel("localhost", port) as channel: 50 | stub = GreeterStub(channel) 51 | if load_type == "unary": 52 | load_fn = do_load_unary 53 | elif load_type == "stream": 54 | load_fn = do_load_stream 55 | else: 56 | raise ValueError(f"Unknown load type: {load_type}") 57 | for _ in range(num_rounds): 58 | start = time.time() 59 | send_queue, receive_queue = anyio.create_memory_object_stream(max_buffer_size=sys.maxsize) 60 | async with anyio.create_task_group() as task_group: 61 | for _ in range(num_concurrent_streams): 62 | task_group.start_soon(load_fn, send_queue, stub, num_requests_per_stream, message_size) 63 | end = time.time() 64 | rps = num_concurrent_streams * num_requests_per_stream / (end - start) 65 | queue.put(rps) 66 | results = [] 67 | for _ in range(num_concurrent_streams): 68 | results.append(await receive_queue.receive()) 69 | queue.put(results) 70 | queue.close() 71 | queue.join_thread() 72 | 73 | 74 | def main(): 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--message_size", type=int, default=1000) 77 | parser.add_argument("--num_workers", type=int, default=3) 78 | parser.add_argument("--num_concurrent_streams", type=int, default=100) 79 | parser.add_argument("--num_requests_per_stream", type=int, default=50) 80 | parser.add_argument("--num_rounds", type=int, default=10) 81 | parser.add_argument("--load_type", choices=["unary", "stream"], required=True) 82 | 83 | args = parser.parse_args() 84 | 85 | queues = [multiprocessing.Queue() for _ in range(args.num_workers)] 86 | 87 | with run_purerpc_service_in_process(Greeter().service) as port: 88 | def target_fn(worker_id): 89 | queue = queues[worker_id] 90 | purerpc.run(worker, port, queue, args.num_concurrent_streams, 91 | args.num_requests_per_stream, args.num_rounds, args.message_size, 92 | args.load_type) 93 | 94 | processes = [] 95 | for worker_id in range(args.num_workers): 96 | process = multiprocessing.Process(target=target_fn, args=(worker_id,)) 97 | process.start() 98 | processes.append(process) 99 | 100 | for round_id in range(args.num_rounds): 101 | total_rps = 0 102 | latencies = [] 103 | for queue in queues: 104 | total_rps += queue.get() 105 | latencies.extend(queue.get()) 106 | avg_latency = 1000 * sum(latencies) / len(latencies) 107 | max_latency = 1000 * max(latencies) 108 | print(f"Round {round_id}, RPS: {total_rps}, avg latency: {avg_latency} ms, " 109 | f"max latency: {max_latency} ms") 110 | 111 | for queue in queues: 112 | queue.close() 113 | queue.join_thread() 114 | 115 | for process in processes: 116 | process.join() 117 | 118 | 119 | if __name__ == "__main__": 120 | try: 121 | main() 122 | except KeyboardInterrupt: 123 | pass 124 | -------------------------------------------------------------------------------- /misc/greeter/test_perf_grpcio.py: -------------------------------------------------------------------------------- 1 | # from gevent import monkey 2 | # monkey.patch_all() 3 | # 4 | # import grpc._cython.cygrpc 5 | # grpc._cython.cygrpc.init_grpc_gevent() 6 | 7 | 8 | import time 9 | import argparse 10 | import functools 11 | from queue import Queue 12 | import multiprocessing 13 | 14 | 15 | 16 | 17 | import grpc 18 | from generated.greeter_pb2 import HelloRequest, HelloReply 19 | from generated.greeter_pb2_grpc import GreeterServicer, GreeterStub, add_GreeterServicer_to_server 20 | 21 | from purerpc.test_utils import run_grpc_service_in_process 22 | 23 | 24 | class Greeter(GreeterServicer): 25 | def SayHello(self, message, context): 26 | return HelloReply(message=message.name[::-1]) 27 | 28 | def SayHelloToMany(self, messages, context): 29 | for message in messages: 30 | yield HelloReply(message=message.name[::-1]) 31 | 32 | 33 | def do_load_unary(result_queue, stub, num_requests, message_size): 34 | requests_left = num_requests 35 | avg_latency = 0 36 | message = "0" * message_size 37 | start = time.time() 38 | fut = stub.SayHello.future(HelloRequest(name=message)) 39 | 40 | def done_callback(fut): 41 | nonlocal requests_left 42 | nonlocal avg_latency 43 | requests_left -= 1 44 | assert len(fut.result().message) == message_size 45 | if requests_left > 0: 46 | fut = stub.SayHello.future(HelloRequest(name=message)) 47 | fut.add_done_callback(done_callback) 48 | else: 49 | avg_latency = (time.time() - start) / num_requests 50 | result_queue.put(avg_latency) 51 | 52 | fut.add_done_callback(done_callback) 53 | 54 | 55 | # def do_load_stream(result_queue, stub, num_requests, message_size): 56 | # message = "0" * message_size 57 | # stream = await stub.SayHelloToMany() 58 | # start = time.time() 59 | # for _ in range(num_requests): 60 | # await stream.send_message(HelloRequest(name=message)) 61 | # result = await stream.receive_message() 62 | # assert (len(result.message) == message_size) 63 | # avg_latency = (time.time() - start) / num_requests 64 | # await stream.close() 65 | # await stream.receive_message() 66 | # await result_queue.put(avg_latency) 67 | 68 | 69 | def worker(port, queue, num_concurrent_streams, num_requests_per_stream, 70 | num_rounds, message_size, load_type): 71 | with grpc.insecure_channel("localhost:{}".format(port)) as channel: 72 | stub = GreeterStub(channel) 73 | if load_type == "unary": 74 | load_fn = do_load_unary 75 | else: 76 | raise ValueError(f"Unknown load type: {load_type}") 77 | for _ in range(num_rounds): 78 | start = time.time() 79 | task_results = Queue() 80 | for _ in range(num_concurrent_streams): 81 | load_fn(task_results, stub, num_requests_per_stream, message_size) 82 | 83 | results = [] 84 | for _ in range(num_concurrent_streams): 85 | results.append(task_results.get()) 86 | 87 | end = time.time() 88 | rps = num_concurrent_streams * num_requests_per_stream / (end - start) 89 | queue.put(rps) 90 | queue.put(results) 91 | queue.close() 92 | queue.join_thread() 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--message_size", type=int, default=1000) 98 | parser.add_argument("--num_workers", type=int, default=3) 99 | parser.add_argument("--num_concurrent_streams", type=int, default=100) 100 | parser.add_argument("--num_requests_per_stream", type=int, default=50) 101 | parser.add_argument("--num_rounds", type=int, default=10) 102 | parser.add_argument("--load_type", choices=["unary", "stream"], required=True) 103 | 104 | args = parser.parse_args() 105 | 106 | queues = [multiprocessing.Queue() for _ in range(args.num_workers)] 107 | 108 | with run_grpc_service_in_process(functools.partial(add_GreeterServicer_to_server, Greeter())) as port: 109 | def target_fn(worker_id): 110 | queue = queues[worker_id] 111 | worker(port, queue, args.num_concurrent_streams, 112 | args.num_requests_per_stream, args.num_rounds, args.message_size, 113 | args.load_type) 114 | 115 | processes = [] 116 | for worker_id in range(args.num_workers): 117 | process = multiprocessing.Process(target=target_fn, args=(worker_id,)) 118 | process.start() 119 | processes.append(process) 120 | 121 | for round_id in range(args.num_rounds): 122 | total_rps = 0 123 | latencies = [] 124 | for queue in queues: 125 | total_rps += queue.get() 126 | latencies.extend(queue.get()) 127 | avg_latency = 1000 * sum(latencies) / len(latencies) 128 | max_latency = 1000 * max(latencies) 129 | print(f"Round {round_id}, RPS: {total_rps}, avg latency: {avg_latency} ms, " 130 | f"max latency: {max_latency} ms") 131 | 132 | for queue in queues: 133 | queue.close() 134 | queue.join_thread() 135 | 136 | for process in processes: 137 | process.join() 138 | -------------------------------------------------------------------------------- /misc/h2load/latency_h2load.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | NUM_REQUESTS=10000 4 | NUM_CLIENTS=1 5 | NUM_STREAMS=1 6 | WINDOW_SIZE_LOG=20 7 | NUM_THREADS=1 8 | AUTHORITY="localhost:50055" 9 | 10 | h2load -n$NUM_REQUESTS -c$NUM_CLIENTS -m$NUM_STREAMS -w$WINDOW_SIZE_LOG -t$NUM_THREADS -d request.bin \ 11 | -H 'te: trailers' -H 'content-type: application/grpc+proto' \ 12 | http://$AUTHORITY/service/SayHelloToMany 13 | -------------------------------------------------------------------------------- /misc/h2load/old_logs/pypy6.curio.log.txt: -------------------------------------------------------------------------------- 1 | starting benchmark... 2 | spawning thread #0: 13 total client(s). 12500 total requests 3 | spawning thread #1: 13 total client(s). 12500 total requests 4 | spawning thread #2: 13 total client(s). 12500 total requests 5 | spawning thread #3: 13 total client(s). 12500 total requests 6 | spawning thread #4: 12 total client(s). 12500 total requests 7 | spawning thread #5: 12 total client(s). 12500 total requests 8 | spawning thread #6: 12 total client(s). 12500 total requests 9 | spawning thread #7: 12 total client(s). 12500 total requests 10 | Application protocol: h2c 11 | progress: 10% done 12 | progress: 20% done 13 | progress: 30% done 14 | progress: 40% done 15 | progress: 50% done 16 | progress: 60% done 17 | progress: 70% done 18 | progress: 80% done 19 | progress: 90% done 20 | progress: 100% done 21 | 22 | finished in 11.65s, 8584.18 req/s, 403.07KB/s 23 | requests: 100000 total, 100000 started, 100000 done, 100000 succeeded, 0 failed, 0 errored, 0 timeout 24 | status codes: 100000 2xx, 0 3xx, 0 4xx, 0 5xx 25 | traffic: 4.59MB (4808200) total, 196.97KB (201700) headers (space savings 95.42%), 1.72MB (1800000) data 26 | min max mean sd +/- sd 27 | time for request: 243.54ms 2.04s 1.10s 398.62ms 76.30% 28 | time for connect: 171us 36.02ms 6.01ms 6.71ms 87.00% 29 | time to 1st byte: 1.91s 2.06s 1.99s 42.59ms 58.00% 30 | req/s : 85.30 90.35 88.18 1.83 52.00% 31 | -------------------------------------------------------------------------------- /misc/h2load/old_logs/pypy6.trio.0.3.log.txt: -------------------------------------------------------------------------------- 1 | starting benchmark... 2 | spawning thread #0: 13 total client(s). 12500 total requests 3 | spawning thread #1: 13 total client(s). 12500 total requests 4 | spawning thread #2: 13 total client(s). 12500 total requests 5 | spawning thread #3: 13 total client(s). 12500 total requests 6 | spawning thread #4: 12 total client(s). 12500 total requests 7 | spawning thread #5: 12 total client(s). 12500 total requests 8 | spawning thread #6: 12 total client(s). 12500 total requests 9 | spawning thread #7: 12 total client(s). 12500 total requests 10 | Application protocol: h2c 11 | progress: 10% done 12 | progress: 20% done 13 | progress: 30% done 14 | progress: 40% done 15 | progress: 50% done 16 | progress: 60% done 17 | progress: 70% done 18 | progress: 80% done 19 | progress: 90% done 20 | progress: 100% done 21 | 22 | finished in 46.62s, 2145.00 req/s, 100.72KB/s 23 | requests: 100000 total, 100000 started, 100000 done, 100000 succeeded, 0 failed, 0 errored, 0 timeout 24 | status codes: 100000 2xx, 0 3xx, 0 4xx, 0 5xx 25 | traffic: 4.59MB (4808200) total, 196.97KB (201700) headers (space savings 95.42%), 1.72MB (1800000) data 26 | min max mean sd +/- sd 27 | time for request: 35.62ms 6.62s 4.42s 826.95ms 85.26% 28 | time for connect: 177us 30.14ms 5.55ms 7.68ms 91.00% 29 | time to 1st byte: 59.87ms 1.90s 972.50ms 578.10ms 56.00% 30 | req/s : 21.31 22.42 21.87 0.50 42.00% 31 | -------------------------------------------------------------------------------- /misc/h2load/old_logs/pypy6.trio.0.4.log.txt: -------------------------------------------------------------------------------- 1 | starting benchmark... 2 | spawning thread #0: 13 total client(s). 12500 total requests 3 | spawning thread #1: 13 total client(s). 12500 total requests 4 | spawning thread #2: 13 total client(s). 12500 total requests 5 | spawning thread #3: 13 total client(s). 12500 total requests 6 | spawning thread #4: 12 total client(s). 12500 total requests 7 | spawning thread #5: 12 total client(s). 12500 total requests 8 | spawning thread #6: 12 total client(s). 12500 total requests 9 | spawning thread #7: 12 total client(s). 12500 total requests 10 | Application protocol: h2c 11 | progress: 10% done 12 | progress: 20% done 13 | progress: 30% done 14 | progress: 40% done 15 | progress: 50% done 16 | progress: 60% done 17 | progress: 70% done 18 | progress: 80% done 19 | progress: 90% done 20 | progress: 100% done 21 | 22 | finished in 45.13s, 2215.93 req/s, 104.05KB/s 23 | requests: 100000 total, 100000 started, 100000 done, 100000 succeeded, 0 failed, 0 errored, 0 timeout 24 | status codes: 100000 2xx, 0 3xx, 0 4xx, 0 5xx 25 | traffic: 4.59MB (4808200) total, 196.97KB (201700) headers (space savings 95.42%), 1.72MB (1800000) data 26 | min max mean sd +/- sd 27 | time for request: 38.30ms 6.20s 4.28s 846.40ms 87.72% 28 | time for connect: 177us 35.26ms 4.27ms 6.86ms 94.00% 29 | time to 1st byte: 48.94ms 1.34s 588.37ms 385.20ms 57.00% 30 | req/s : 21.92 23.22 22.58 0.54 46.00% 31 | 32 | -------------------------------------------------------------------------------- /misc/h2load/old_logs/python.log.txt: -------------------------------------------------------------------------------- 1 | finished in 178.74s, 559.46 req/s, 26.27KB/s 2 | requests: 100000 total, 100000 started, 100000 done, 100000 succeeded, 0 failed, 0 errored, 0 timeout 3 | status codes: 100000 2xx, 0 3xx, 0 4xx, 0 5xx 4 | traffic: 4.59MB (4808200) total, 196.97KB (201700) headers (space savings 95.42%), 1.72MB (1800000) data 5 | min max mean sd +/- sd 6 | time for request: 332.73ms 23.50s 16.98s 3.50s 87.67% 7 | time for connect: 174us 66.92ms 17.18ms 22.85ms 75.00% 8 | time to 1st byte: 388.39ms 5.57s 2.53s 1.59s 59.00% 9 | req/s : 5.53 5.85 5.69 0.14 39.00% 10 | 11 | -------------------------------------------------------------------------------- /misc/h2load/request.bin: -------------------------------------------------------------------------------- 1 |  2 | World -------------------------------------------------------------------------------- /misc/h2load/run_h2load.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | NUM_REQUESTS=1000000 4 | NUM_CLIENTS=100 5 | NUM_STREAMS=100 6 | WINDOW_SIZE_LOG=24 7 | NUM_THREADS=8 8 | AUTHORITY="192.168.1.10:50055" 9 | AUTHORITY="localhost:50055" 10 | 11 | h2load -n$NUM_REQUESTS -c$NUM_CLIENTS -m$NUM_STREAMS -w$WINDOW_SIZE_LOG -t$NUM_THREADS -d request.bin \ 12 | -H 'te: trailers' -H 'content-type: application/grpc+proto' \ 13 | http://$AUTHORITY/Greeter/SayHello 14 | -------------------------------------------------------------------------------- /misc/pypy_tests/bytearray_perf_test.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | class ByteBuffer: 5 | def __init__(self, chunk_size=65536): 6 | self._deque = collections.deque([bytearray()]) 7 | self._chunk_size = chunk_size 8 | self._size = 0 9 | 10 | def append(self, data): 11 | pos = 0 12 | while pos < len(data): 13 | data_to_write = min(self._chunk_size - len(self._deque[-1]), len(data) - pos) 14 | self._deque[-1].extend(data[pos:pos + data_to_write]) 15 | if len(self._deque[-1]) == self._chunk_size: 16 | self._deque.append(bytearray()) 17 | pos += data_to_write 18 | self._size += data_to_write 19 | 20 | def popleft(self, amount): 21 | if amount > self._size: 22 | raise ValueError("Trying to extract {} bytes from ByteBuffer of length {}".format( 23 | amount, self._size)) 24 | data = bytearray() 25 | while amount > 0: 26 | data_to_read = min(amount, len(self._deque[0])) 27 | data.extend(self._deque[0][:data_to_read]) 28 | self._deque[0] = self._deque[0][data_to_read:] 29 | if len(self._deque[0]) == 0 and len(self._deque) > 1: 30 | self._deque.popleft() 31 | amount -= data_to_read 32 | self._size -= data_to_read 33 | return bytes(data) 34 | 35 | def __len__(self): 36 | return self._size 37 | 38 | 39 | class ByteBufferV2: 40 | def __init__(self): 41 | self._deque = collections.deque() 42 | self._size = 0 43 | 44 | def append(self, data): 45 | if not isinstance(data, bytes): 46 | raise ValueError("Expected bytes") 47 | if data: 48 | self._deque.append(data) 49 | self._size += len(data) 50 | 51 | def popleft_v2(self, amount): 52 | if amount > self._size: 53 | raise ValueError("Trying to extract {} bytes from ByteBuffer of length {}".format( 54 | amount, self._size)) 55 | data = [] 56 | while amount > 0: 57 | next_element = self._deque[0] 58 | if amount >= len(next_element): 59 | self._size -= len(next_element) 60 | amount -= len(next_element) 61 | data.append(self._deque.popleft()) 62 | else: 63 | data.append(next_element[:amount]) 64 | self._deque[0] = next_element[amount:] 65 | self._size -= amount 66 | amount = 0 67 | return b"".join(data) 68 | 69 | def popleft(self, amount): 70 | if amount > self._size: 71 | raise ValueError("Trying to extract {} bytes from ByteBuffer of length {}".format( 72 | amount, self._size)) 73 | data = [] 74 | while amount > 0: 75 | next_element = self._deque[0] 76 | bytes_to_read = min(amount, len(next_element)) 77 | data.append(next_element[:bytes_to_read]) 78 | self._deque[0] = next_element[bytes_to_read:] 79 | amount -= bytes_to_read 80 | self._size -= bytes_to_read 81 | if len(self._deque[0]) == 0: 82 | self._deque.popleft() 83 | return b"".join(data) 84 | 85 | def __len__(self): 86 | return self._size 87 | 88 | 89 | def main(): 90 | from purerpc.grpclib.buffers import ByteBuffer 91 | b = b"\x00" * 50 92 | x = ByteBuffer() 93 | for i in range(500000): 94 | for j in range(50): 95 | x.append(b) 96 | x.popleft(2000) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.8 3 | # by the following command: 4 | # 5 | # pip-compile --extra=dev --output-file=requirements_test.txt setup.py 6 | # 7 | anyio==3.5.0 8 | # via purerpc (setup.py) 9 | async-generator==1.10 10 | # via 11 | # purerpc (setup.py) 12 | # trio 13 | attrs==21.4.0 14 | # via 15 | # outcome 16 | # pytest 17 | # trio 18 | build==1.2.1 19 | # via pip-tools 20 | cffi==1.15.0 21 | # via cryptography 22 | click==8.1.2 23 | # via pip-tools 24 | cryptography==36.0.2 25 | # via trustme 26 | exceptiongroup==1.2.0 ; python_version < "3.11" 27 | # via purerpc (setup.py) 28 | grpcio==1.62.1 29 | # via 30 | # grpcio-tools 31 | # purerpc (setup.py) 32 | grpcio-tools==1.62.1 33 | # via purerpc (setup.py) 34 | h2==3.2.0 35 | # via purerpc (setup.py) 36 | hpack==3.0.0 37 | # via h2 38 | hyperframe==5.2.0 39 | # via h2 40 | idna==3.3 41 | # via 42 | # anyio 43 | # trio 44 | # trustme 45 | importlib-metadata==4.11.3 46 | # via build 47 | iniconfig==1.1.1 48 | # via pytest 49 | outcome==1.1.0 50 | # via trio 51 | packaging==21.3 52 | # via 53 | # build 54 | # pytest 55 | pip-tools==7.4.1 56 | # via purerpc (setup.py) 57 | pluggy==1.0.0 58 | # via pytest 59 | protobuf==4.25.3 60 | # via grpcio-tools 61 | py==1.11.0 62 | # via pytest 63 | pycparser==2.21 64 | # via cffi 65 | pyparsing==3.0.8 66 | # via packaging 67 | pyproject-hooks==1.0.0 68 | # via 69 | # build 70 | # pip-tools 71 | pytest==7.1.1 72 | # via purerpc (setup.py) 73 | python-forge==18.6.0 74 | # via purerpc (setup.py) 75 | sniffio==1.2.0 76 | # via 77 | # anyio 78 | # trio 79 | sortedcontainers==2.4.0 80 | # via trio 81 | tblib==1.7.0 82 | # via purerpc (setup.py) 83 | tomli==2.0.1 84 | # via 85 | # build 86 | # pip-tools 87 | # pyproject-hooks 88 | # pytest 89 | trio==0.20.0 90 | # via purerpc (setup.py) 91 | trustme==0.9.0 92 | # via purerpc (setup.py) 93 | uvloop==0.19.0 ; platform_system != "Windows" 94 | # via purerpc (setup.py) 95 | wheel==0.37.1 96 | # via pip-tools 97 | zipp==3.8.0 98 | # via importlib-metadata 99 | 100 | # The following packages are considered to be unsafe in a requirements file: 101 | # pip 102 | # setuptools 103 | -------------------------------------------------------------------------------- /requirements_test_pypy.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.7 3 | # To update, run: 4 | # 5 | # pip-compile --extra=dev,grpc-pypy --output-file=requirements_test_pypy.txt setup.py 6 | # 7 | anyio==3.5.0 8 | # via purerpc (setup.py) 9 | async-generator==1.10 10 | # via 11 | # purerpc (setup.py) 12 | # trio 13 | attrs==21.4.0 14 | # via 15 | # outcome 16 | # pytest 17 | # trio 18 | cffi==1.15.0 19 | # via cryptography 20 | click==8.1.2 21 | # via pip-tools 22 | cryptography==36.0.2 23 | # via trustme 24 | grpcio==1.26.0 25 | # via 26 | # grpcio-tools 27 | # purerpc (setup.py) 28 | grpcio-tools==1.26.0 29 | # via purerpc (setup.py) 30 | h2==3.2.0 31 | # via purerpc (setup.py) 32 | hpack==3.0.0 33 | # via h2 34 | hyperframe==5.2.0 35 | # via h2 36 | idna==3.3 37 | # via 38 | # anyio 39 | # trio 40 | # trustme 41 | importlib-metadata==4.11.3 42 | # via 43 | # click 44 | # pep517 45 | # pluggy 46 | # pytest 47 | iniconfig==1.1.1 48 | # via pytest 49 | outcome==1.1.0 50 | # via trio 51 | packaging==21.3 52 | # via pytest 53 | pep517==0.12.0 54 | # via pip-tools 55 | pip-tools==6.6.0 56 | # via purerpc (setup.py) 57 | pluggy==1.0.0 58 | # via pytest 59 | protobuf==3.20.0 60 | # via grpcio-tools 61 | py==1.11.0 62 | # via pytest 63 | pycparser==2.21 64 | # via cffi 65 | pyparsing==3.0.8 66 | # via packaging 67 | pytest==7.1.1 68 | # via purerpc (setup.py) 69 | python-forge==18.6.0 70 | # via purerpc (setup.py) 71 | six==1.16.0 72 | # via grpcio 73 | sniffio==1.2.0 74 | # via 75 | # anyio 76 | # trio 77 | sortedcontainers==2.4.0 78 | # via trio 79 | tblib==1.7.0 80 | # via purerpc (setup.py) 81 | tomli==2.0.1 82 | # via 83 | # pep517 84 | # pytest 85 | trio==0.20.0 86 | # via purerpc (setup.py) 87 | trustme==0.9.0 88 | # via purerpc (setup.py) 89 | typing-extensions==4.1.1 90 | # via 91 | # anyio 92 | # importlib-metadata 93 | uvloop==0.16.0 ; platform_system != "Windows" 94 | # via purerpc (setup.py) 95 | wheel==0.37.1 96 | # via pip-tools 97 | zipp==3.8.0 98 | # via 99 | # importlib-metadata 100 | # pep517 101 | 102 | # The following packages are considered to be unsafe in a requirements file: 103 | # pip 104 | # setuptools 105 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = -s -vvv 6 | testpaths = tests -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from setuptools import setup, find_packages 4 | 5 | exec(open("src/purerpc/_version.py", encoding="utf-8").read()) 6 | 7 | 8 | def read(*names, **kwargs): 9 | with open(os.path.join(os.path.dirname(__file__), *names), "r") as fin: 10 | return fin.read() 11 | 12 | 13 | def main(): 14 | console_scripts = ['protoc-gen-purerpc=purerpc.protoc_plugin.plugin:main'] 15 | setup( 16 | name="purerpc", 17 | version=__version__, 18 | license="Apache License Version 2.0", 19 | description=("Native, async Python gRPC client and server implementation " 20 | "supporting asyncio, uvloop, trio"), 21 | long_description=( 22 | re.compile(r'\bstart-badges\b.*\bend-badges\b', re.M | re.S).sub('', read('README.md')) 23 | ), 24 | long_description_content_type='text/markdown', 25 | author="Andrew Stepanov", 26 | url="https://github.com/python-trio/purerpc", 27 | classifiers=[ 28 | "Development Status :: 4 - Beta", 29 | "Intended Audience :: Developers", 30 | "Intended Audience :: Telecommunications Industry", 31 | "License :: OSI Approved :: Apache Software License", 32 | "Operating System :: MacOS", 33 | "Operating System :: MacOS :: MacOS X", 34 | "Operating System :: POSIX :: BSD", 35 | "Operating System :: POSIX :: Linux", 36 | "Programming Language :: Python", 37 | "Programming Language :: Python :: 3 :: Only", 38 | "Programming Language :: Python :: 3.8", 39 | "Programming Language :: Python :: 3.9", 40 | "Programming Language :: Python :: 3.10", 41 | "Programming Language :: Python :: 3.11", 42 | "Programming Language :: Python :: 3.12", 43 | "Programming Language :: Python :: Implementation :: CPython", 44 | "Programming Language :: Python :: Implementation :: PyPy", 45 | "Framework :: AsyncIO", 46 | "Framework :: Trio", 47 | "Topic :: Internet", 48 | "Topic :: Internet :: WWW/HTTP", 49 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers", 50 | "Topic :: Software Development :: Libraries", 51 | "Topic :: Software Development :: Code Generators", 52 | "Topic :: Software Development :: Libraries :: Python Modules", 53 | "Topic :: System :: Networking", 54 | ], 55 | keywords=[ 56 | "async", "await", "grpc", "pure-python", "pypy", "network", 57 | "rpc", "http2", 58 | ], 59 | packages=find_packages('src'), 60 | package_dir={'': 'src'}, 61 | test_suite="tests", 62 | python_requires=">=3.7", 63 | install_requires=[ 64 | "h2>=3.1.1,<4", 65 | "anyio>=3.0.0", 66 | "async_generator>=1.10", # for aclosing() only (Python < 3.10) 67 | ], 68 | entry_points={ 69 | "console_scripts": console_scripts, 70 | }, 71 | extras_require={ 72 | 'grpc': [ 73 | "grpcio-tools", 74 | ], 75 | 'grpc-pypy': [ 76 | "grpcio<=1.26", 77 | "grpcio-tools<=1.26", 78 | ], 79 | 'dev': [ 80 | "exceptiongroup; python_version<'3.11'", 81 | "pytest", 82 | "grpcio", 83 | "grpcio-tools", 84 | "uvloop; platform_system!='Windows'", 85 | "tblib>=1.3.2", 86 | "trio>=0.11", 87 | "pip-tools>=6.3.1", 88 | "python-forge>=18.6", 89 | "trustme", 90 | ], 91 | }, 92 | ) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /src/purerpc/__init__.py: -------------------------------------------------------------------------------- 1 | from purerpc.client import insecure_channel, secure_channel, Client 2 | from purerpc.server import Service, Servicer, Server 3 | from purerpc.rpc import Cardinality, RPCSignature, Stream 4 | from purerpc.grpclib.status import Status, StatusCode 5 | from purerpc.grpclib.exceptions import * 6 | from purerpc.utils import run 7 | from purerpc._version import __version__ 8 | -------------------------------------------------------------------------------- /src/purerpc/_version.py: -------------------------------------------------------------------------------- 1 | # This file is imported from __init__.py and exec'd from setup.py 2 | 3 | __version__ = "0.9.0-dev" 4 | -------------------------------------------------------------------------------- /src/purerpc/client.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from contextlib import AsyncExitStack 3 | 4 | import anyio 5 | 6 | from purerpc.grpc_proto import GRPCProtoSocket 7 | from purerpc.grpclib.config import GRPCConfiguration 8 | from purerpc.rpc import RPCSignature, Cardinality 9 | from purerpc.wrappers import ClientStubUnaryUnary, ClientStubStreamStream, ClientStubUnaryStream, \ 10 | ClientStubStreamUnary 11 | 12 | 13 | class _Channel(AsyncExitStack): 14 | def __init__(self, host, port, ssl_context=None): 15 | super().__init__() 16 | self._host = host 17 | self._port = port 18 | self._ssl_context = ssl_context 19 | self._grpc_socket = None 20 | 21 | async def __aenter__(self): 22 | await super().__aenter__() # Does nothing 23 | socket = await anyio.connect_tcp(self._host, self._port, 24 | ssl_context=self._ssl_context, 25 | tls=self._ssl_context is not None, 26 | tls_standard_compatible=False) 27 | config = GRPCConfiguration(client_side=True) 28 | self._grpc_socket = await self.enter_async_context(GRPCProtoSocket(config, socket)) 29 | return self 30 | 31 | 32 | def insecure_channel(host, port): 33 | return _Channel(host, port) 34 | 35 | def secure_channel(host, port, ssl_context): 36 | return _Channel(host, port, ssl_context) 37 | 38 | 39 | class Client: 40 | def __init__(self, service_name: str, channel: _Channel): 41 | self.service_name = service_name 42 | self.channel = channel 43 | 44 | async def rpc(self, method_name: str, request_type, response_type, metadata=None): 45 | message_type = request_type.DESCRIPTOR.full_name 46 | if metadata is None: 47 | metadata = () 48 | stream = await self.channel._grpc_socket.start_request("http", self.service_name, 49 | method_name, message_type, 50 | "{}:{}".format(self.channel._host, 51 | self.channel._port), 52 | custom_metadata=metadata) 53 | stream.expect_message_type(response_type) 54 | return stream 55 | 56 | def get_method_stub(self, method_name: str, signature: RPCSignature): 57 | stream_fn = functools.partial(self.rpc, method_name, signature.request_type, 58 | signature.response_type) 59 | if signature.cardinality == Cardinality.STREAM_STREAM: 60 | return ClientStubStreamStream(stream_fn) 61 | elif signature.cardinality == Cardinality.UNARY_STREAM: 62 | return ClientStubUnaryStream(stream_fn) 63 | elif signature.cardinality == Cardinality.STREAM_UNARY: 64 | return ClientStubStreamUnary(stream_fn) 65 | else: 66 | return ClientStubUnaryUnary(stream_fn) 67 | -------------------------------------------------------------------------------- /src/purerpc/grpc_proto.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from purerpc.grpclib.events import MessageReceived, RequestEnded, ResponseEnded 4 | 5 | from .grpc_socket import GRPCStream, GRPCSocket 6 | 7 | 8 | class GRPCProtoStream(GRPCStream): 9 | def expect_message_type(self, message_type): 10 | self._incoming_message_type = message_type 11 | 12 | async def send_message(self, message): 13 | return await super()._send(message.SerializeToString()) 14 | 15 | async def receive_event(self): 16 | event = await super()._receive() 17 | if isinstance(event, MessageReceived) and hasattr(self, "_incoming_message_type"): 18 | binary_data = event.data 19 | event.data = self._incoming_message_type() 20 | event.data.ParseFromString(binary_data) 21 | return event 22 | 23 | async def receive_message(self): 24 | event = await self.receive_event() 25 | if isinstance(event, RequestEnded) or isinstance(event, ResponseEnded): 26 | return None 27 | elif isinstance(event, MessageReceived): 28 | return event.data 29 | else: 30 | return await self.receive_message() 31 | 32 | async def start_response(self, content_type_suffix="", custom_metadata=()): 33 | return await super().start_response( 34 | content_type_suffix if content_type_suffix else "+proto", 35 | custom_metadata) 36 | 37 | 38 | class GRPCProtoSocket(GRPCSocket): 39 | StreamClass = GRPCProtoStream 40 | 41 | async def start_request(self, scheme: str, service_name: str, method_name: str, 42 | message_type=None, authority=None, timeout: datetime.timedelta = None, 43 | content_type_suffix="", custom_metadata=()): 44 | return await super().start_request( 45 | scheme, service_name, method_name, message_type, authority, timeout, 46 | content_type_suffix if content_type_suffix else "+proto", custom_metadata 47 | ) 48 | -------------------------------------------------------------------------------- /src/purerpc/grpc_socket.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import enum 3 | import socket 4 | import datetime 5 | from contextlib import AsyncExitStack 6 | from typing import Dict 7 | 8 | import anyio 9 | import anyio.abc 10 | from purerpc.utils import is_darwin, is_windows 11 | from purerpc.grpclib.exceptions import ProtocolError 12 | 13 | from .grpclib.connection import GRPCConfiguration, GRPCConnection 14 | from .grpclib.events import RequestReceived, RequestEnded, ResponseEnded, MessageReceived, WindowUpdated 15 | from .grpclib.buffers import MessageWriteBuffer 16 | from .grpclib.exceptions import StreamClosedError 17 | 18 | 19 | class SocketWrapper(AsyncExitStack): 20 | def __init__(self, grpc_connection: GRPCConnection, stream: anyio.abc.SocketStream): 21 | super().__init__() 22 | self._set_socket_options(stream) 23 | self._stream = stream 24 | self._grpc_connection = grpc_connection 25 | self._flush_event = anyio.Event() 26 | self._running = True 27 | 28 | async def __aenter__(self): 29 | await super().__aenter__() 30 | task_group = await self.enter_async_context(anyio.create_task_group()) 31 | task_group.start_soon(self._writer_thread) 32 | 33 | async def callback(): 34 | self._running = False 35 | self._flush_event.set() 36 | 37 | self.push_async_callback(callback) 38 | return self 39 | 40 | @staticmethod 41 | def _set_socket_options(stream: anyio.abc.SocketStream): 42 | sock = stream.extra(anyio.abc.SocketAttribute.raw_socket) 43 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 44 | if hasattr(socket, "TCP_KEEPIDLE"): 45 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 300) 46 | elif is_darwin(): 47 | # Darwin specific option 48 | TCP_KEEPALIVE = 16 49 | sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, 300) 50 | if not is_windows(): 51 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30) 52 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) 53 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 54 | 55 | async def _writer_thread(self): 56 | while True: 57 | data = self._grpc_connection.data_to_send() 58 | if data: 59 | await self._stream.send(data) 60 | elif self._running: 61 | await self._flush_event.wait() 62 | self._flush_event = anyio.Event() 63 | else: 64 | return 65 | 66 | async def flush(self): 67 | """This maybe called from different threads.""" 68 | self._flush_event.set() 69 | 70 | async def recv(self, buffer_size: int): 71 | """This may only be called from single thread.""" 72 | return await self._stream.receive(buffer_size) 73 | 74 | 75 | class GRPCStreamState(enum.Enum): 76 | OPEN = 1 77 | HALF_CLOSED_REMOTE = 2 78 | HALF_CLOSED_LOCAL = 3 79 | CLOSED = 4 80 | 81 | 82 | class GRPCStream: 83 | def __init__(self, grpc_connection: GRPCConnection, stream_id: int, socket: SocketWrapper, 84 | grpc_socket: "GRPCSocket"): 85 | self._stream_id = stream_id 86 | self._grpc_connection = grpc_connection 87 | self._grpc_socket = grpc_socket 88 | self._socket = socket 89 | self._flow_control_update_event = anyio.Event() 90 | # TODO: find a reasonable buffer size, or expose it in the API 91 | self._incoming_events = anyio.create_memory_object_stream(max_buffer_size=sys.maxsize) # (send, receive) 92 | self._response_started = False 93 | self._state = GRPCStreamState.OPEN 94 | self._start_stream_event = None 95 | self._end_stream_event = None 96 | 97 | @property 98 | def state(self): 99 | return self._state 100 | 101 | @property 102 | def start_stream_event(self): 103 | return self._start_stream_event 104 | 105 | @property 106 | def end_stream_event(self): 107 | return self._end_stream_event 108 | 109 | @property 110 | def stream_id(self): 111 | return self._stream_id 112 | 113 | @property 114 | def client_side(self): 115 | return self._grpc_connection.config.client_side 116 | 117 | @property 118 | def debug_prefix(self): 119 | return "[CLIENT] " if self.client_side else "[SERVER] " 120 | 121 | def _close_remote(self): 122 | if self._state == GRPCStreamState.OPEN: 123 | self._state = GRPCStreamState.HALF_CLOSED_REMOTE 124 | elif self._state == GRPCStreamState.HALF_CLOSED_LOCAL: 125 | self._state = GRPCStreamState.CLOSED 126 | del self._grpc_socket._streams[self._stream_id] 127 | 128 | def _close_local(self): 129 | if self._state == GRPCStreamState.OPEN: 130 | self._state = GRPCStreamState.HALF_CLOSED_LOCAL 131 | elif self._state == GRPCStreamState.HALF_CLOSED_REMOTE: 132 | self._state = GRPCStreamState.CLOSED 133 | del self._grpc_socket._streams[self._stream_id] 134 | 135 | async def _set_flow_control_update(self): 136 | self._flow_control_update_event.set() 137 | 138 | async def _wait_flow_control_update(self): 139 | await self._flow_control_update_event.wait() 140 | self._flow_control_update_event = anyio.Event() 141 | 142 | async def _send(self, message: bytes, compress=False): 143 | message_write_buffer = MessageWriteBuffer(self._grpc_connection.config.message_encoding, 144 | self._grpc_connection.config.max_message_length) 145 | message_write_buffer.write_message(message, compress) 146 | while message_write_buffer: 147 | window_size = self._grpc_connection.flow_control_window(self._stream_id) 148 | if window_size <= 0: 149 | await self._wait_flow_control_update() 150 | continue 151 | num_data_to_send = min(window_size, len(message_write_buffer)) 152 | data = message_write_buffer.data_to_send(num_data_to_send) 153 | self._grpc_connection.send_data(self._stream_id, data) 154 | await self._socket.flush() 155 | 156 | async def _receive(self): 157 | event = await self._incoming_events[1].receive() 158 | if isinstance(event, MessageReceived): 159 | self._grpc_connection.acknowledge_received_data(self._stream_id, 160 | event.flow_controlled_length) 161 | await self._socket.flush() 162 | elif isinstance(event, RequestEnded) or isinstance(event, ResponseEnded): 163 | assert self._end_stream_event is None 164 | self._end_stream_event = event 165 | else: 166 | assert self._start_stream_event is None 167 | self._start_stream_event = event 168 | return event 169 | 170 | async def close(self, status=None, content_type_suffix="", custom_metadata=()): 171 | if self.client_side and (status or custom_metadata): 172 | raise ValueError("Client side streams cannot be closed with non-default arguments") 173 | if self._state in (GRPCStreamState.HALF_CLOSED_LOCAL, GRPCStreamState.CLOSED): 174 | raise TypeError("Closing already closed stream") 175 | self._close_local() 176 | if self.client_side: 177 | try: 178 | self._grpc_connection.end_request(self._stream_id) 179 | except StreamClosedError: 180 | # Remote end already closed connection, do nothing here 181 | pass 182 | elif self._response_started: 183 | self._grpc_connection.end_response(self._stream_id, status, custom_metadata) 184 | else: 185 | self._grpc_connection.respond_status(self._stream_id, status, 186 | content_type_suffix, custom_metadata) 187 | await self._socket.flush() 188 | 189 | async def start_response(self, content_type_suffix="", custom_metadata=()): 190 | if self.client_side: 191 | raise ValueError("Cannot start response on client-side socket") 192 | self._grpc_connection.start_response(self._stream_id, content_type_suffix, custom_metadata) 193 | self._response_started = True 194 | await self._socket.flush() 195 | 196 | 197 | # TODO: this name is not correct, should be something like GRPCConnection (but this name is already 198 | # occupied) 199 | class GRPCSocket(AsyncExitStack): 200 | StreamClass = GRPCStream 201 | 202 | def __init__(self, config: GRPCConfiguration, sock, 203 | receive_buffer_size=1024*1024): 204 | super().__init__() 205 | self._grpc_connection = GRPCConnection(config=config) 206 | self._socket = SocketWrapper(self._grpc_connection, sock) 207 | self._receive_buffer_size = receive_buffer_size 208 | self._streams = {} # type: Dict[int, GRPCStream] 209 | 210 | async def __aenter__(self): 211 | await super().__aenter__() 212 | self._socket = await self.enter_async_context(self._socket) 213 | self._grpc_connection.initiate_connection() 214 | await self._socket.flush() 215 | if self.client_side: 216 | task_group = await self.enter_async_context(anyio.create_task_group()) 217 | self.callback(task_group.cancel_scope.cancel) 218 | task_group.start_soon(self._reader_thread) 219 | return self 220 | 221 | @property 222 | def client_side(self): 223 | return self._grpc_connection.config.client_side 224 | 225 | def _stream_ctor(self, stream_id): 226 | return self.StreamClass(self._grpc_connection, stream_id, self._socket, self) 227 | 228 | def _allocate_stream(self, stream_id): 229 | self._streams[stream_id] = self._stream_ctor(stream_id) 230 | return self._streams[stream_id] 231 | 232 | async def _listen(self): 233 | while True: 234 | try: 235 | data = await self._socket.recv(self._receive_buffer_size) 236 | # TODO: Not too confident that BrokenResourceError should be treated 237 | # the same as EndOfStream (maybe the handler wants to know?), but it's 238 | # here for parity with anyio 1.x behavior. 239 | except (anyio.EndOfStream, anyio.BrokenResourceError): 240 | return 241 | events = self._grpc_connection.receive_data(data) 242 | await self._socket.flush() 243 | for event in events: 244 | if isinstance(event, WindowUpdated): 245 | if event.stream_id == 0: 246 | for stream in self._streams.values(): 247 | await stream._set_flow_control_update() 248 | elif event.stream_id in self._streams: 249 | await self._streams[event.stream_id]._set_flow_control_update() 250 | continue 251 | elif isinstance(event, RequestReceived): 252 | self._allocate_stream(event.stream_id) 253 | 254 | await self._streams[event.stream_id]._incoming_events[0].send(event) 255 | 256 | if isinstance(event, RequestReceived): 257 | yield self._streams[event.stream_id] 258 | elif isinstance(event, ResponseEnded) or isinstance(event, RequestEnded): 259 | self._streams[event.stream_id]._close_remote() 260 | 261 | async def _reader_thread(self): 262 | async for _ in self._listen(): 263 | raise ProtocolError("Received request on client end") 264 | 265 | async def listen(self): 266 | if self.client_side: 267 | raise ValueError("Cannot listen client-side socket") 268 | 269 | async for value in self._listen(): 270 | yield value 271 | 272 | async def start_request(self, scheme: str, service_name: str, method_name: str, 273 | message_type=None, authority=None, timeout: datetime.timedelta=None, 274 | content_type_suffix="", custom_metadata=()): 275 | if not self.client_side: 276 | raise ValueError("Cannot start request on server-side socket") 277 | stream_id = self._grpc_connection.get_next_available_stream_id() 278 | stream = self._allocate_stream(stream_id) 279 | self._grpc_connection.start_request(stream_id, scheme, service_name, method_name, 280 | message_type, authority, timeout, 281 | content_type_suffix, custom_metadata) 282 | await self._socket.flush() 283 | return stream 284 | -------------------------------------------------------------------------------- /src/purerpc/grpclib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/python-trio/purerpc/a3c17dd885d8f36bcf3a78c7506e35f7bc33cccc/src/purerpc/grpclib/__init__.py -------------------------------------------------------------------------------- /src/purerpc/grpclib/buffers.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import collections 3 | from .exceptions import UnsupportedMessageEncodingError, MessageTooLargeError 4 | 5 | 6 | class ByteBuffer: 7 | def __init__(self): 8 | self._deque = collections.deque() 9 | self._size = 0 10 | self._flow_controlled_length = 0 11 | 12 | def append(self, data, flow_controlled_length=None): 13 | if not isinstance(data, bytes) and not isinstance(data, bytearray): 14 | raise ValueError("Expected bytes") 15 | if flow_controlled_length is None: 16 | flow_controlled_length = len(data) 17 | if flow_controlled_length < len(data): 18 | raise ValueError("flow_controlled_length should be >= len(data)") 19 | self._deque.append((data, flow_controlled_length)) 20 | self._size += len(data) 21 | 22 | def popleft_flowcontrol(self, amount): 23 | if amount > self._size: 24 | raise ValueError("Trying to extract {} bytes from ByteBuffer of length {}".format( 25 | amount, self._size)) 26 | data = [] 27 | flow_controlled_length = 0 28 | while amount > 0: 29 | next_element = self._deque[0][0] 30 | next_element_flow_controlled_length = self._deque[0][1] 31 | if amount >= len(next_element): 32 | self._size -= len(next_element) 33 | self._flow_controlled_length -= next_element_flow_controlled_length 34 | amount -= len(next_element) 35 | flow_controlled_length += next_element_flow_controlled_length 36 | data.append(next_element) 37 | self._deque.popleft() 38 | else: 39 | data.append(next_element[:amount]) 40 | self._deque[0] = (next_element[amount:], 41 | next_element_flow_controlled_length - amount) 42 | self._size -= amount 43 | self._flow_controlled_length -= amount 44 | flow_controlled_length += amount 45 | amount = 0 46 | return b"".join(data), flow_controlled_length 47 | 48 | def popleft(self, amount): 49 | return self.popleft_flowcontrol(amount)[0] 50 | 51 | def __len__(self): 52 | return self._size 53 | 54 | @property 55 | def flow_controlled_length(self): 56 | return self._flow_controlled_length 57 | 58 | @property 59 | def length(self): 60 | return len(self) 61 | 62 | 63 | class MessageReadBuffer: 64 | def __init__(self, message_encoding=None, max_message_length=4*1024*1024): 65 | self._message_encoding = message_encoding 66 | self._max_message_length = max_message_length 67 | 68 | self._buffer = ByteBuffer() 69 | self._messages = collections.deque() 70 | 71 | # MessageReadBuffer parser state 72 | self._compressed_flag = None 73 | self._message_length = None 74 | self._flow_controlled_length = None 75 | 76 | def data_received(self, data: bytes, flow_controlled_length=None): 77 | self._buffer.append(data, flow_controlled_length) 78 | self._process_new_messages() 79 | 80 | def decompress(self, data): 81 | if self._message_encoding == "gzip" or self._message_encoding == "deflate": 82 | import zlib 83 | return zlib.decompress(data) 84 | elif self._message_encoding == "snappy": 85 | import snappy 86 | return snappy.decompress(data) 87 | else: 88 | raise UnsupportedMessageEncodingError( 89 | "Unsupported compression: {}".format(self._message_encoding)) 90 | 91 | def _parse_one_message(self): 92 | # either compressed_flag = message_length = flow_controlled_length = None and 93 | # compressed_flag, message_length are the next elements in self._buffer, or they are all 94 | # not None, and the next element in self._buffer is data 95 | if self._message_length is None: 96 | if len(self._buffer) < 5: 97 | return None, 0 98 | message_header, self._flow_controlled_length = self._buffer.popleft_flowcontrol(5) 99 | self._compressed_flag, self._message_length = struct.unpack('>?I', message_header) 100 | if self._message_length > self._max_message_length: 101 | # Even after the error is raised, the state is not corrupted, and parsing 102 | # can be safely resumed 103 | raise MessageTooLargeError( 104 | "Received message larger than max: {message_length} > {max_message_length}".format( 105 | message_length=self._message_length, 106 | max_message_length=self._max_message_length, 107 | ) 108 | ) 109 | if len(self._buffer) < self._message_length: 110 | return None, 0 111 | data, flow_controlled_length = self._buffer.popleft_flowcontrol(self._message_length) 112 | flow_controlled_length += self._flow_controlled_length 113 | if self._compressed_flag: 114 | data = self.decompress(data) 115 | self._compressed_flag = self._message_length = self._flow_controlled_length = None 116 | return data, flow_controlled_length 117 | 118 | def _process_new_messages(self): 119 | while True: 120 | result, flow_controlled_length = self._parse_one_message() 121 | if result is not None: 122 | self._messages.append((result, flow_controlled_length)) 123 | else: 124 | return 125 | 126 | def __len__(self): 127 | return len(self._messages) 128 | 129 | def read_all_complete_messages(self): 130 | messages = [message for message, _ in self._messages] 131 | self._messages = collections.deque() 132 | return messages 133 | 134 | def read_all_complete_messages_flowcontrol(self): 135 | messages = self._messages 136 | self._messages = collections.deque() 137 | return messages 138 | 139 | def read_message(self): 140 | return self._messages.popleft()[0] 141 | 142 | def read_message_flowcontrol(self): 143 | return self._messages.popleft() 144 | 145 | 146 | class MessageWriteBuffer: 147 | def __init__(self, message_encoding=None, max_message_length=4*1024*1024): 148 | self._buffer = ByteBuffer() 149 | self._message_encoding = message_encoding 150 | self._max_message_length = max_message_length 151 | 152 | def compress(self, data): 153 | if self._message_encoding == "gzip" or self._message_encoding == "deflate": 154 | import zlib 155 | return zlib.compress(data) 156 | elif self._message_encoding == "snappy": 157 | import snappy 158 | return snappy.compress(data) 159 | else: 160 | raise UnsupportedMessageEncodingError( 161 | "Unsupported compression: {}".format(self._message_encoding)) 162 | 163 | def write_message(self, data: bytes, compress=False): 164 | if compress: 165 | data = self.compress(data) 166 | if len(data) > self._max_message_length: 167 | raise MessageTooLargeError( 168 | "Trying to send message larger than max: {message_length} > {max_message_length}".format( 169 | message_length=len(data), 170 | max_message_length=self._max_message_length, 171 | ) 172 | ) 173 | self._buffer.append(struct.pack('>?I', compress, len(data))) 174 | self._buffer.append(data) 175 | 176 | def data_to_send(self, amount): 177 | return self._buffer.popleft(amount) 178 | 179 | def __len__(self): 180 | return len(self._buffer) 181 | -------------------------------------------------------------------------------- /src/purerpc/grpclib/config.py: -------------------------------------------------------------------------------- 1 | class GRPCConfiguration: 2 | def __init__(self, client_side: bool, server_string=None, user_agent=None, 3 | message_encoding=None, message_accept_encoding=None, 4 | max_message_length=33554432): 5 | self._client_side = client_side 6 | if client_side and server_string is not None: 7 | raise ValueError("Passed client_side=True and server_string at the same time") 8 | if not client_side and user_agent is not None: 9 | raise ValueError("Passed user_agent put didn't pass client_side=True") 10 | self._server_string = server_string 11 | self._user_agent = user_agent 12 | self._max_message_length = max_message_length 13 | self._message_encoding = message_encoding 14 | 15 | # TODO: this does not need to be passed in config, may be just a single global string 16 | # with all encodings supported by grpclib 17 | if message_accept_encoding is not None: 18 | self._message_accept_encoding = ",".join(message_accept_encoding) 19 | else: 20 | self._message_accept_encoding = None 21 | 22 | @property 23 | def client_side(self): 24 | return self._client_side 25 | 26 | @property 27 | def server_string(self): 28 | return self._server_string 29 | 30 | @property 31 | def user_agent(self): 32 | return self._user_agent 33 | 34 | @property 35 | def message_encoding(self): 36 | return self._message_encoding 37 | 38 | @property 39 | def message_accept_encoding(self): 40 | return self._message_accept_encoding 41 | 42 | @property 43 | def max_message_length(self): 44 | return self._max_message_length 45 | -------------------------------------------------------------------------------- /src/purerpc/grpclib/events.py: -------------------------------------------------------------------------------- 1 | import urllib.parse 2 | import datetime 3 | 4 | from .headers import HeaderDict 5 | from .exceptions import ProtocolError 6 | from .status import Status 7 | 8 | 9 | class Event: 10 | pass 11 | 12 | 13 | class WindowUpdated(Event): 14 | def __init__(self, stream_id, delta): 15 | self.stream_id = stream_id 16 | self.delta = delta 17 | 18 | def __repr__(self): 19 | return "" % ( 20 | self.stream_id, self.delta 21 | ) 22 | 23 | 24 | 25 | class RequestReceived(Event): 26 | def __init__(self, stream_id: int, scheme: str, service_name: str, method_name: str, 27 | content_type: str): 28 | self.stream_id = stream_id 29 | self.scheme = scheme 30 | self.service_name = service_name 31 | self.method_name = method_name 32 | self.content_type = content_type 33 | self.authority = None 34 | self.timeout = None 35 | self.message_type = None 36 | self.message_encoding = None 37 | self.message_accept_encoding = None 38 | self.user_agent = None 39 | self.custom_metadata = () 40 | 41 | @staticmethod 42 | def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: HeaderDict): 43 | if headers.pop(":method") != "POST": 44 | raise ProtocolError("Unsupported method {}".format(headers[":method"])) 45 | 46 | scheme = headers.pop(":scheme") 47 | if scheme not in ["http", "https"]: 48 | raise ProtocolError("Scheme should be either http or https") 49 | 50 | if headers[":path"].startswith("/"): 51 | service_name, method_name = headers.pop(":path")[1:].split("/") 52 | else: 53 | raise ProtocolError("Path should be //") 54 | 55 | if "te" not in headers or headers["te"] != "trailers": 56 | raise ProtocolError("te header not found or not equal to 'trailers', " 57 | "using incompatible proxy?") 58 | else: 59 | headers.pop("te") 60 | 61 | content_type = headers.pop("content-type") 62 | if not content_type.startswith("application/grpc"): 63 | raise ProtocolError("Content type should start with application/grpc") 64 | 65 | event = RequestReceived(stream_id, scheme, service_name, method_name, content_type) 66 | 67 | if ":authority" in headers: 68 | event.authority = headers.pop(":authority") 69 | 70 | if "grpc-timeout" in headers: 71 | timeout_string = headers.pop("grpc-timeout") 72 | timeout_value, timeout_unit = int(timeout_string[:-1]), timeout_string[-1:] 73 | if timeout_unit == "H": 74 | event.timeout = datetime.timedelta(hours=timeout_value) 75 | elif timeout_unit == "M": 76 | event.timeout = datetime.timedelta(minutes=timeout_value) 77 | elif timeout_unit == "S": 78 | event.timeout = datetime.timedelta(seconds=timeout_value) 79 | elif timeout_unit == "m": 80 | event.timeout = datetime.timedelta(milliseconds=timeout_value) 81 | elif timeout_unit == "u": 82 | event.timeout = datetime.timedelta(microseconds=timeout_value) 83 | elif timeout_unit == "n": 84 | event.timeout = datetime.timedelta(microseconds=timeout_value / 1000) 85 | else: 86 | raise ProtocolError("Unknown timeout unit: {}".format(timeout_unit)) 87 | 88 | if "grpc-encoding" in headers: 89 | event.message_encoding = headers.pop("grpc-encoding") 90 | 91 | if "grpc-accept-encoding" in headers: 92 | event.message_accept_encoding = headers.pop("grpc-accept-encoding").split(",") 93 | 94 | if "user-agent" in headers: 95 | event.user_agent = headers.pop("user-agent") 96 | 97 | if "grpc-message-type" in headers: 98 | event.message_type = headers.pop("grpc-message-type") 99 | 100 | event.custom_metadata = tuple(header for header_name in list(headers.keys()) 101 | for header in headers.extract_headers(header_name)) 102 | return event 103 | 104 | def __repr__(self): 105 | fmt_string = ("") 107 | return fmt_string.format( 108 | stream_id=self.stream_id, 109 | service_name=self.service_name, 110 | method_name=self.method_name, 111 | ) 112 | 113 | 114 | class MessageReceived(Event): 115 | def __init__(self, stream_id: int, data: bytes, flow_controlled_length: int): 116 | self.stream_id = stream_id 117 | self.data = data 118 | self.flow_controlled_length = flow_controlled_length 119 | 120 | def __repr__(self): 121 | fmt_string= ("") 123 | return fmt_string.format( 124 | stream_id=self.stream_id, 125 | flow_controlled_length=self.flow_controlled_length, 126 | ) 127 | 128 | 129 | class RequestEnded(Event): 130 | def __init__(self, stream_id: int): 131 | self.stream_id = stream_id 132 | 133 | def __repr__(self): 134 | fmt_string = "" 135 | return fmt_string.format( 136 | stream_id=self.stream_id, 137 | ) 138 | 139 | 140 | class ResponseReceived(Event): 141 | def __init__(self, stream_id: int, content_type: str): 142 | self.stream_id = stream_id 143 | self.content_type = content_type 144 | self.message_encoding = None 145 | self.message_accept_encoding = None 146 | self.custom_metadata = () 147 | 148 | @staticmethod 149 | def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: HeaderDict): 150 | if int(headers.pop(":status")) != 200: 151 | raise ProtocolError("http status is not 200") 152 | 153 | content_type = headers.pop("content-type") 154 | if not content_type.startswith("application/grpc"): 155 | raise ProtocolError("Content type should start with application/grpc") 156 | 157 | event = ResponseReceived(stream_id, content_type) 158 | 159 | if "grpc-encoding" in headers: 160 | event.message_encoding = headers.pop("grpc-encoding") 161 | 162 | if "grpc-accept-encoding" in headers: 163 | event.message_accept_encoding = headers.pop("grpc-accept-encoding").split(",") 164 | 165 | event.custom_metadata = tuple(header for header_name in list(headers.keys()) 166 | for header in headers.extract_headers(header_name)) 167 | return event 168 | 169 | def __repr__(self): 170 | fmt_string = "" 171 | return fmt_string.format( 172 | stream_id=self.stream_id, 173 | content_type=self.content_type, 174 | ) 175 | 176 | 177 | class ResponseEnded(Event): 178 | def __init__(self, stream_id: int, status: Status): 179 | self.stream_id = stream_id 180 | self.status = status 181 | self.custom_metadata = () 182 | 183 | @staticmethod 184 | def parse_from_stream_id_and_headers_destructive(stream_id: int, headers: HeaderDict): 185 | if "grpc-status" not in headers: 186 | raise ProtocolError("Expected grpc-status in trailers") 187 | 188 | status_code = int(headers.pop("grpc-status")) 189 | if "grpc-message" in headers: 190 | status_message = urllib.parse.unquote(headers.pop("grpc-message")) 191 | else: 192 | status_message = "" 193 | 194 | event = ResponseEnded(stream_id, Status(status_code, status_message)) 195 | 196 | event.custom_metadata = tuple(header for header_name in list(headers.keys()) 197 | for header in headers.extract_headers(header_name)) 198 | return event 199 | 200 | def __repr__(self): 201 | fmt_string = "" 202 | return fmt_string.format( 203 | stream_id=self.stream_id, 204 | status=self.status, 205 | ) 206 | -------------------------------------------------------------------------------- /src/purerpc/grpclib/exceptions.py: -------------------------------------------------------------------------------- 1 | from .status import Status, StatusCode 2 | 3 | 4 | class GRPCError(Exception): 5 | pass 6 | 7 | 8 | class StreamClosedError(GRPCError): 9 | def __init__(self, stream_id, error_code): 10 | self.stream_id = stream_id 11 | self.error_code = error_code 12 | 13 | 14 | class ProtocolError(GRPCError): 15 | pass 16 | 17 | 18 | class MessageTooLargeError(ProtocolError): 19 | pass 20 | 21 | 22 | class UnsupportedMessageEncodingError(ProtocolError): 23 | pass 24 | 25 | 26 | class RpcFailedError(GRPCError): 27 | def __init__(self, status): 28 | super().__init__("RPC failed with status {status}".format(status=status)) 29 | self._status = status 30 | 31 | @property 32 | def status(self): 33 | return self._status 34 | 35 | 36 | class CancelledError(RpcFailedError): 37 | def __init__(self, message=""): 38 | super().__init__(Status(StatusCode.CANCELLED, message)) 39 | 40 | 41 | class UnknownError(RpcFailedError): 42 | def __init__(self, message=""): 43 | super().__init__(Status(StatusCode.UNKNOWN, message)) 44 | 45 | 46 | class InvalidArgumentError(RpcFailedError): 47 | def __init__(self, message=""): 48 | super().__init__(Status(StatusCode.INVALID_ARGUMENT, message)) 49 | 50 | 51 | class DeadlineExceededError(RpcFailedError): 52 | def __init__(self, message=""): 53 | super().__init__(Status(StatusCode.DEADLINE_EXCEEDED, message)) 54 | 55 | 56 | class NotFoundError(RpcFailedError): 57 | def __init__(self, message=""): 58 | super().__init__(Status(StatusCode.NOT_FOUND, message)) 59 | 60 | 61 | class AlreadyExistsError(RpcFailedError): 62 | def __init__(self, message=""): 63 | super().__init__(Status(StatusCode.ALREADY_EXISTS, message)) 64 | 65 | 66 | class PermissionDeniedError(RpcFailedError): 67 | def __init__(self, message=""): 68 | super().__init__(Status(StatusCode.PERMISSION_DENIED, message)) 69 | 70 | 71 | class ResourceExhaustedError(RpcFailedError): 72 | def __init__(self, message=""): 73 | super().__init__(Status(StatusCode.RESOURCE_EXHAUSTED, message)) 74 | 75 | 76 | class FailedPreconditionError(RpcFailedError): 77 | def __init__(self, message=""): 78 | super().__init__(Status(StatusCode.FAILED_PRECONDITION, message)) 79 | 80 | 81 | class AbortedError(RpcFailedError): 82 | def __init__(self, message=""): 83 | super().__init__(Status(StatusCode.ABORTED, message)) 84 | 85 | 86 | class OutOfRangeError(RpcFailedError): 87 | def __init__(self, message=""): 88 | super().__init__(Status(StatusCode.OUT_OF_RANGE, message)) 89 | 90 | 91 | class UnimplementedError(RpcFailedError): 92 | def __init__(self, message=""): 93 | super().__init__(Status(StatusCode.UNIMPLEMENTED, message)) 94 | 95 | 96 | class InternalError(RpcFailedError): 97 | def __init__(self, message=""): 98 | super().__init__(Status(StatusCode.INTERNAL, message)) 99 | 100 | 101 | class UnavailableError(RpcFailedError): 102 | def __init__(self, message=""): 103 | super().__init__(Status(StatusCode.UNAVAILABLE, message)) 104 | 105 | 106 | class DataLossError(RpcFailedError): 107 | def __init__(self, message=""): 108 | super().__init__(Status(StatusCode.DATA_LOSS, message)) 109 | 110 | 111 | class UnauthenticatedError(RpcFailedError): 112 | def __init__(self, message=""): 113 | super().__init__(Status(StatusCode.UNAUTHENTICATED, message)) 114 | 115 | 116 | def raise_status(status: Status): 117 | if status.status_code == StatusCode.CANCELLED: 118 | raise CancelledError(status.status_message) 119 | elif status.status_code == StatusCode.UNKNOWN: 120 | raise UnknownError(status.status_message) 121 | elif status.status_code == StatusCode.INVALID_ARGUMENT: 122 | raise InvalidArgumentError(status.status_message) 123 | elif status.status_code == StatusCode.DEADLINE_EXCEEDED: 124 | raise DeadlineExceededError(status.status_message) 125 | elif status.status_code == StatusCode.NOT_FOUND: 126 | raise NotFoundError(status.status_message) 127 | elif status.status_code == StatusCode.ALREADY_EXISTS: 128 | raise AlreadyExistsError(status.status_message) 129 | elif status.status_code == StatusCode.PERMISSION_DENIED: 130 | raise PermissionDeniedError(status.status_message) 131 | elif status.status_code == StatusCode.RESOURCE_EXHAUSTED: 132 | raise ResourceExhaustedError(status.status_message) 133 | elif status.status_code == StatusCode.FAILED_PRECONDITION: 134 | raise FailedPreconditionError(status.status_message) 135 | elif status.status_code == StatusCode.ABORTED: 136 | raise AbortedError(status.status_message) 137 | elif status.status_code == StatusCode.OUT_OF_RANGE: 138 | raise OutOfRangeError(status.status_message) 139 | elif status.status_code == StatusCode.UNIMPLEMENTED: 140 | raise UnimplementedError(status.status_message) 141 | elif status.status_code == StatusCode.INTERNAL: 142 | raise InternalError(status.status_message) 143 | elif status.status_code == StatusCode.UNAVAILABLE: 144 | raise UnavailableError(status.status_message) 145 | elif status.status_code == StatusCode.DATA_LOSS: 146 | raise DataLossError(status.status_message) 147 | elif status.status_code == StatusCode.UNAUTHENTICATED: 148 | raise UnauthenticatedError(status.status_message) 149 | -------------------------------------------------------------------------------- /src/purerpc/grpclib/headers.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import collections 3 | 4 | 5 | class HeaderDict(collections.OrderedDict): 6 | def __init__(self, values): 7 | super().__init__() 8 | for key, value in values: 9 | if key not in self: 10 | self[key] = [value] 11 | else: 12 | self[key].append(value) 13 | for key in self: 14 | if len(self[key]) == 1: 15 | self[key] = self[key][0] 16 | 17 | def extract_headers(self, header_name: str): 18 | """Returns all headers with name == header_name as list of tuples (name, value)""" 19 | if header_name.startswith("grpc-"): 20 | return () 21 | else: 22 | value = self.pop(header_name) 23 | is_binary = header_name.endswith("-bin") 24 | 25 | if not isinstance(value, list): 26 | value_list = [value] 27 | else: 28 | value_list = value 29 | 30 | if is_binary: 31 | return ((header_name, b64decode(value)) for value_sublist in value_list for value 32 | in value_sublist.split(",")) 33 | else: 34 | return ((header_name, value) for value in value_list) 35 | 36 | 37 | def sanitize_headers(headers): 38 | for name, value in headers: 39 | if isinstance(value, bytes) and not name.endswith("-bin"): 40 | raise ValueError("Got binary value for header name '{}', but name does not end " 41 | "with '-bin' suffix".format(name)) 42 | if name.startswith("grpc-"): 43 | raise ValueError("Got header with name '{}', but custom metadata headers should " 44 | "not start with 'grpc-' prefix".format(name)) 45 | if name.endswith("-bin"): 46 | yield name, b64encode(value) 47 | else: 48 | yield name, value 49 | 50 | 51 | def b64decode(data: str) -> bytes: 52 | # Apply missing padding 53 | missing_padding = len(data) % 4 54 | if missing_padding: 55 | data += "=" * (4 - missing_padding) 56 | return base64.b64decode(data) 57 | 58 | 59 | def b64encode(data: bytes) -> str: 60 | return base64.b64encode(data).rstrip(b"=").decode("utf-8") 61 | -------------------------------------------------------------------------------- /src/purerpc/grpclib/status.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2014 gRPC authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import enum 17 | 18 | 19 | class StatusCode(enum.Enum): 20 | # OK is returned on success. 21 | OK = 0 22 | 23 | # Canceled indicates the operation was canceled (typically by the caller). 24 | CANCELLED = 1 25 | 26 | # Unknown error. An example of where this error may be returned is 27 | # if a Status value received from another address space belongs to 28 | # an error-space that is not known in this address space. Also 29 | # errors raised by APIs that do not return enough error information 30 | # may be converted to this error. 31 | UNKNOWN = 2 32 | 33 | # InvalidArgument indicates client specified an invalid argument. 34 | # Note that this differs from FailedPrecondition. It indicates arguments 35 | # that are problematic regardless of the state of the system 36 | # (e.g., a malformed file name). 37 | INVALID_ARGUMENT = 3 38 | 39 | # DeadlineExceeded means operation expired before completion. 40 | # For operations that change the state of the system, this error may be 41 | # returned even if the operation has completed successfully. For 42 | # example, a successful response from a server could have been delayed 43 | # long enough for the deadline to expire. 44 | DEADLINE_EXCEEDED = 4 45 | 46 | # NotFound means some requested entity (e.g., file or directory) was 47 | # not found. 48 | NOT_FOUND = 5 49 | 50 | # AlreadyExists means an attempt to create an entity failed because one 51 | # already exists. 52 | ALREADY_EXISTS = 6 53 | 54 | # PermissionDenied indicates the caller does not have permission to 55 | # execute the specified operation. It must not be used for rejections 56 | # caused by exhausting some resource (use ResourceExhausted 57 | # instead for those errors). It must not be 58 | # used if the caller cannot be identified (use Unauthenticated 59 | # instead for those errors). 60 | PERMISSION_DENIED = 7 61 | 62 | # ResourceExhausted indicates some resource has been exhausted, perhaps 63 | # a per-user quota, or perhaps the entire file system is out of space. 64 | RESOURCE_EXHAUSTED = 8 65 | 66 | # FailedPrecondition indicates operation was rejected because the 67 | # system is not in a state required for the operation's execution. 68 | # For example, directory to be deleted may be non-empty, an rmdir 69 | # operation is applied to a non-directory, etc. 70 | # 71 | # A litmus test that may help a service implementor in deciding 72 | # between FailedPrecondition, Aborted, and Unavailable: 73 | # (a) Use Unavailable if the client can retry just the failing call. 74 | # (b) Use Aborted if the client should retry at a higher-level 75 | # (e.g., restarting a read-modify-write sequence). 76 | # (c) Use FailedPrecondition if the client should not retry until 77 | # the system state has been explicitly fixed. E.g., if an "rmdir" 78 | # fails because the directory is non-empty, FailedPrecondition 79 | # should be returned since the client should not retry unless 80 | # they have first fixed up the directory by deleting files from it. 81 | # (d) Use FailedPrecondition if the client performs conditional 82 | # REST Get/Update/Delete on a resource and the resource on the 83 | # server does not match the condition. E.g., conflicting 84 | # read-modify-write on the same resource. 85 | FAILED_PRECONDITION = 9 86 | 87 | # Aborted indicates the operation was aborted, typically due to a 88 | # concurrency issue like sequencer check failures, transaction aborts, 89 | # etc. 90 | # 91 | # See litmus test above for deciding between FailedPrecondition, 92 | # Aborted, and Unavailable. 93 | ABORTED = 10 94 | 95 | # OutOfRange means operation was attempted past the valid range. 96 | # E.g., seeking or reading past end of file. 97 | # 98 | # Unlike InvalidArgument, this error indicates a problem that may 99 | # be fixed if the system state changes. For example, a 32-bit file 100 | # system will generate InvalidArgument if asked to read at an 101 | # offset that is not in the range [0,2^32-1], but it will generate 102 | # OutOfRange if asked to read from an offset past the current 103 | # file size. 104 | # 105 | # There is a fair bit of overlap between FailedPrecondition and 106 | # OutOfRange. We recommend using OutOfRange (the more specific 107 | # error) when it applies so that callers who are iterating through 108 | # a space can easily look for an OutOfRange error to detect when 109 | # they are done. 110 | OUT_OF_RANGE = 11 111 | 112 | # Unimplemented indicates operation is not implemented or not 113 | # supported/enabled in this service. 114 | UNIMPLEMENTED = 12 115 | 116 | # Internal errors. Means some invariants expected by underlying 117 | # system has been broken. If you see one of these errors, 118 | # something is very broken. 119 | INTERNAL = 13 120 | 121 | # Unavailable indicates the service is currently unavailable. 122 | # This is a most likely a transient condition and may be corrected 123 | # by retrying with a backoff. 124 | # 125 | # See litmus test above for deciding between FailedPrecondition, 126 | # Aborted, and Unavailable. 127 | UNAVAILABLE = 14 128 | 129 | # DataLoss indicates unrecoverable data loss or corruption. 130 | DATA_LOSS = 15 131 | 132 | # Unauthenticated indicates the request does not have valid 133 | # authentication credentials for the operation. 134 | UNAUTHENTICATED = 16 135 | 136 | 137 | class Status: 138 | def __init__(self, status_code_or_int, status_message=""): 139 | if isinstance(status_code_or_int, StatusCode): 140 | self._status_code = status_code_or_int 141 | self._int_value = status_code_or_int.value 142 | else: 143 | try: 144 | self._status_code = StatusCode(status_code_or_int) 145 | except ValueError: 146 | self._status_code = StatusCode.UNKNOWN 147 | self._int_value = status_code_or_int 148 | self._status_message = status_message 149 | 150 | @property 151 | def status_code(self): 152 | return self._status_code 153 | 154 | @property 155 | def int_value(self): 156 | return self._int_value 157 | 158 | @property 159 | def status_message(self): 160 | return self._status_message 161 | 162 | def __str__(self): 163 | suffix = ("" if not self.status_message else ": " + self.status_message) 164 | if self.status_code != StatusCode.UNKNOWN: 165 | prefix = self.status_code.name 166 | else: 167 | prefix = "{name} ({value})".format(name=self.status_code.name, value=self.int_value) 168 | return prefix + suffix 169 | 170 | def __repr__(self): 171 | fmt_string = "Status({name}, code={code}, message={message})" 172 | return fmt_string.format( 173 | name=self.status_code.name, 174 | code=self.int_value, 175 | message=repr(self.status_message) 176 | ) 177 | -------------------------------------------------------------------------------- /src/purerpc/protoc_plugin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/python-trio/purerpc/a3c17dd885d8f36bcf3a78c7506e35f7bc33cccc/src/purerpc/protoc_plugin/__init__.py -------------------------------------------------------------------------------- /src/purerpc/protoc_plugin/plugin.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import itertools 3 | 4 | from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest 5 | from google.protobuf.compiler.plugin_pb2 import CodeGeneratorResponse 6 | from google.protobuf import descriptor_pb2 7 | from purerpc import Cardinality 8 | 9 | 10 | def generate_import_statement(proto_name): 11 | module_path = proto_name[:-len(".proto")].replace("-", "_").replace("/", ".") + "_pb2" 12 | alias = get_python_module_alias(proto_name) 13 | if "." in module_path: 14 | # produces import statements in line with grpcio so that tools for 15 | # postprocessing import statements work with purerpc as well 16 | # example: `from foo.bar import zap_pb2 as foo_dot_bar_dot_zap__pb2` 17 | parent_modules, sub_module = module_path.rsplit(".", 1) 18 | return "from " + parent_modules + " import " + sub_module + " as " + alias 19 | else: 20 | return "import " + module_path + " as " + alias 21 | 22 | 23 | def get_python_module_alias(proto_name): 24 | package_name = proto_name[:-len(".proto")] 25 | return package_name.replace("/", "_dot_").replace("-", "_") + "__pb2" 26 | 27 | 28 | def simple_type(type_): 29 | simple_type = type_.split(".")[-1] 30 | return simple_type 31 | 32 | 33 | def get_python_type(proto_name, proto_type): 34 | if proto_type.startswith("."): 35 | return get_python_module_alias(proto_name) + "." + simple_type(proto_type) 36 | else: 37 | return proto_type 38 | 39 | 40 | def generate_single_proto(proto_file: descriptor_pb2.FileDescriptorProto, 41 | proto_for_entity): 42 | lines = ["import purerpc"] 43 | lines.append(generate_import_statement(proto_file.name)) 44 | for dep_module in proto_file.dependency: 45 | lines.append(generate_import_statement(dep_module)) 46 | for service in proto_file.service: 47 | if proto_file.package: 48 | fully_qualified_service_name = proto_file.package + "." + service.name 49 | else: 50 | fully_qualified_service_name = service.name 51 | 52 | lines.append("\n\nclass {service_name}Servicer(purerpc.Servicer):".format(service_name=service.name)) 53 | for method in service.method: 54 | plural_suffix = "s" if method.client_streaming else "" 55 | fmt_string = (" async def {method_name}(self, input_message{plural_suffix}):\n" 56 | " raise NotImplementedError()\n") 57 | lines.append(fmt_string.format( 58 | method_name=method.name, 59 | plural_suffix=plural_suffix, 60 | )) 61 | fmt_string = (" @property\n" 62 | " def service(self) -> purerpc.Service:\n" 63 | " service_obj = purerpc.Service(\n" 64 | " \"{fully_qualified_service_name}\"\n" 65 | " )") 66 | lines.append(fmt_string.format(fully_qualified_service_name=fully_qualified_service_name)) 67 | for method in service.method: 68 | input_proto = proto_for_entity[method.input_type] 69 | output_proto = proto_for_entity[method.output_type] 70 | cardinality = Cardinality.get_cardinality_for(request_stream=method.client_streaming, 71 | response_stream=method.server_streaming) 72 | fmt_string = (" service_obj.add_method(\n" 73 | " \"{method_name}\",\n" 74 | " self.{method_name},\n" 75 | " purerpc.RPCSignature(\n" 76 | " purerpc.{cardinality},\n" 77 | " {input_type},\n" 78 | " {output_type},\n" 79 | " )\n" 80 | " )") 81 | lines.append(fmt_string.format( 82 | method_name=method.name, 83 | cardinality=cardinality, 84 | input_type=get_python_type(input_proto, method.input_type), 85 | output_type=get_python_type(output_proto, method.output_type), 86 | )) 87 | lines.append(" return service_obj\n\n") 88 | 89 | fmt_string = ("class {service_name}Stub:\n" 90 | " def __init__(self, channel):\n" 91 | " self._client = purerpc.Client(\n" 92 | " \"{fully_qualified_service_name}\",\n" 93 | " channel\n" 94 | " )") 95 | lines.append(fmt_string.format( 96 | service_name=service.name, 97 | fully_qualified_service_name=fully_qualified_service_name 98 | )) 99 | for method in service.method: 100 | input_proto = proto_for_entity[method.input_type] 101 | output_proto = proto_for_entity[method.output_type] 102 | cardinality = Cardinality.get_cardinality_for(request_stream=method.client_streaming, 103 | response_stream=method.server_streaming) 104 | fmt_string = (" self.{method_name} = self._client.get_method_stub(\n" 105 | " \"{method_name}\",\n" 106 | " purerpc.RPCSignature(\n" 107 | " purerpc.{cardinality},\n" 108 | " {input_type},\n" 109 | " {output_type},\n" 110 | " )\n" 111 | " )") 112 | lines.append(fmt_string.format( 113 | method_name=method.name, 114 | cardinality=cardinality, 115 | input_type=get_python_type(input_proto, method.input_type), 116 | output_type=get_python_type(output_proto, method.output_type), 117 | )) 118 | 119 | return "\n".join(lines) 120 | 121 | 122 | def main(): 123 | request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) 124 | 125 | files_to_generate = set(request.file_to_generate) 126 | 127 | response = CodeGeneratorResponse() 128 | proto_for_entity = dict() 129 | for proto_file in request.proto_file: 130 | package_name = proto_file.package 131 | for named_entity in itertools.chain(proto_file.message_type, proto_file.enum_type, 132 | proto_file.service, proto_file.extension): 133 | if package_name: 134 | fully_qualified_name = ".".join(["", package_name, named_entity.name]) 135 | else: 136 | fully_qualified_name = "." + named_entity.name 137 | proto_for_entity[fully_qualified_name] = proto_file.name 138 | for proto_file in request.proto_file: 139 | if proto_file.name in files_to_generate: 140 | out = response.file.add() 141 | out.name = proto_file.name.replace('-', "_").replace('.proto', "_grpc.py") 142 | out.content = generate_single_proto(proto_file, proto_for_entity) 143 | sys.stdout.buffer.write(response.SerializeToString()) 144 | -------------------------------------------------------------------------------- /src/purerpc/rpc.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import typing 3 | import collections 4 | import collections.abc 5 | 6 | 7 | Stream = typing.AsyncIterator 8 | 9 | 10 | class Cardinality(enum.Enum): 11 | UNARY_UNARY = 0 12 | UNARY_STREAM = 1 13 | STREAM_UNARY = 2 14 | STREAM_STREAM = 3 15 | 16 | @staticmethod 17 | def get_cardinality_for(*, request_stream, response_stream): 18 | if request_stream and response_stream: 19 | return Cardinality.STREAM_STREAM 20 | elif request_stream and not response_stream: 21 | return Cardinality.STREAM_UNARY 22 | elif not request_stream and response_stream: 23 | return Cardinality.UNARY_STREAM 24 | else: 25 | return Cardinality.UNARY_UNARY 26 | 27 | 28 | class RPCSignature: 29 | def __init__(self, cardinality: Cardinality, request_type, response_type): 30 | self._cardinality = cardinality 31 | self._request_type = request_type 32 | self._response_type = response_type 33 | 34 | @property 35 | def cardinality(self): 36 | return self._cardinality 37 | 38 | @property 39 | def request_type(self): 40 | return self._request_type 41 | 42 | @property 43 | def response_type(self): 44 | return self._response_type 45 | 46 | @staticmethod 47 | def from_annotations(request_annotation, response_annotation): 48 | if (hasattr(request_annotation, "__origin__") and 49 | issubclass(request_annotation.__origin__, collections.abc.AsyncIterator)): 50 | request_type = request_annotation.__args__[0] 51 | request_stream = True 52 | else: 53 | request_type = request_annotation 54 | request_stream = False 55 | 56 | if (hasattr(response_annotation, "__origin__") and 57 | issubclass(response_annotation.__origin__, collections.abc.AsyncIterator)): 58 | response_type = response_annotation.__args__[0] 59 | response_stream = True 60 | else: 61 | response_type = response_annotation 62 | response_stream = False 63 | cardinality = Cardinality.get_cardinality_for(request_stream=request_stream, 64 | response_stream=response_stream) 65 | return RPCSignature(cardinality, request_type, response_type) 66 | -------------------------------------------------------------------------------- /src/purerpc/server.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import collections 4 | import functools 5 | 6 | import logging 7 | from contextlib import asynccontextmanager, AsyncExitStack 8 | 9 | import anyio 10 | import anyio.abc 11 | from anyio import TASK_STATUS_IGNORED 12 | from anyio.streams.tls import TLSListener 13 | 14 | from .grpclib.events import RequestReceived 15 | from .grpclib.status import Status, StatusCode 16 | from .grpclib.exceptions import RpcFailedError 17 | from .utils import run as purerpc_run 18 | from purerpc.grpc_proto import GRPCProtoStream, GRPCProtoSocket 19 | from purerpc.rpc import RPCSignature, Cardinality 20 | from purerpc.wrappers import call_server_unary_unary, \ 21 | call_server_unary_stream, call_server_stream_unary, call_server_stream_stream 22 | 23 | from .grpclib.connection import GRPCConfiguration 24 | 25 | log = logging.getLogger(__name__) 26 | 27 | BoundRPCMethod = collections.namedtuple("BoundRPCMethod", ["method_fn", "signature"]) 28 | 29 | 30 | class Service: 31 | def __init__(self, name): 32 | self.name = name 33 | self.methods = {} 34 | 35 | def add_method(self, method_name: str, method_fn, rpc_signature: RPCSignature, 36 | method_signature: inspect.Signature = None): 37 | if method_signature is None: 38 | method_signature = inspect.signature(method_fn) 39 | if len(method_signature.parameters) == 1: 40 | def method_fn_with_headers(arg, request): 41 | return method_fn(arg) 42 | elif len(method_signature.parameters) == 2: 43 | if list(method_signature.parameters.values())[1].name == "request": 44 | method_fn_with_headers = method_fn 45 | else: 46 | raise ValueError("Expected second parameter 'request'") 47 | else: 48 | raise ValueError("Expected method_fn to have exactly one or two parameters") 49 | self.methods[method_name] = BoundRPCMethod(method_fn_with_headers, rpc_signature) 50 | 51 | def rpc(self, method_name): 52 | def decorator(func): 53 | signature = inspect.signature(func) 54 | if signature.return_annotation == signature.empty: 55 | raise ValueError("Only annotated methods can be used with Service.rpc() decorator") 56 | if len(signature.parameters) not in (1, 2): 57 | raise ValueError("Only functions with one or two parameters can be used with " 58 | "Service.rpc() decorator") 59 | parameter = next(iter(signature.parameters.values())) 60 | if parameter.annotation == parameter.empty: 61 | raise ValueError("Only annotated methods can be used with Service.rpc() decorator") 62 | 63 | rpc_signature = RPCSignature.from_annotations(parameter.annotation, 64 | signature.return_annotation) 65 | self.add_method(method_name, func, rpc_signature, method_signature=signature) 66 | return func 67 | 68 | return decorator 69 | 70 | 71 | class Servicer: 72 | @property 73 | def service(self) -> Service: 74 | raise NotImplementedError() 75 | 76 | 77 | @asynccontextmanager 78 | async def _service_wrapper(service=None, setup_fn=None, teardown_fn=None): 79 | if setup_fn is not None: 80 | yield await setup_fn() 81 | else: 82 | yield service 83 | 84 | if teardown_fn is not None: 85 | await teardown_fn() 86 | 87 | 88 | class Server: 89 | def __init__(self, port=50055, ssl_context=None): 90 | self.port = port 91 | self._ssl_context = ssl_context 92 | self.services = {} 93 | self._connection_count = 0 94 | self._exception_count = 0 95 | 96 | def add_service(self, service=None, context_manager=None, setup_fn=None, teardown_fn=None, name=None): 97 | if (service is not None) + (context_manager is not None) + (setup_fn is not None) != 1: 98 | raise ValueError("Precisely one of service, context_manager or setup_fn should be set") 99 | if name is None: 100 | if hasattr(service, "name"): 101 | name = service.name 102 | elif hasattr(context_manager, "name"): 103 | name = context_manager.name 104 | elif hasattr(setup_fn, "name"): 105 | name = setup_fn.name 106 | else: 107 | raise ValueError("Could not infer name, please provide 'name' argument to this function call " 108 | "or define 'name' attribute on service, context_manager or setup_fn") 109 | if service is not None: 110 | self.services[name] = _service_wrapper(service=service) 111 | elif context_manager is not None: 112 | self.services[name] = context_manager 113 | elif setup_fn is not None: 114 | self.services[name] = _service_wrapper(setup_fn=setup_fn, teardown_fn=teardown_fn) 115 | else: 116 | raise ValueError("Shouldn't have happened") 117 | 118 | async def serve_async(self, *, task_status=TASK_STATUS_IGNORED): 119 | """Run the grpc server 120 | 121 | The task_status protocol lets the caller know when the server is 122 | listening, and yields the port number (same given to Server constructor). 123 | """ 124 | 125 | # TODO: resource usage warning 126 | async with AsyncExitStack() as stack: 127 | tcp_server = await anyio.create_tcp_listener(local_port=self.port, reuse_port=True) 128 | # read the resulting port, in case it was 0 129 | self.port = tcp_server.extra(anyio.abc.SocketAttribute.local_port) 130 | if self._ssl_context: 131 | tcp_server = TLSListener(tcp_server, self._ssl_context, 132 | standard_compatible=False) 133 | task_status.started(self.port) 134 | 135 | services_dict = {} 136 | for key, value in self.services.items(): 137 | services_dict[key] = await stack.enter_async_context(value) 138 | 139 | await tcp_server.serve(ConnectionHandler(services_dict, self)) 140 | 141 | def serve(self, backend=None): 142 | """ 143 | DEPRECATED - use serve_async() instead 144 | 145 | This function runs an entire async event loop (there can only be one 146 | per thread), and there is no way to know when the server is ready for 147 | connections. 148 | """ 149 | purerpc_run(self.serve_async, backend=backend) 150 | 151 | 152 | class ConnectionHandler: 153 | RECEIVE_BUFFER_SIZE = 65536 154 | 155 | def __init__(self, services: dict, server: Server): 156 | self.config = GRPCConfiguration(client_side=False) 157 | self.services = services 158 | self._server = server 159 | 160 | async def request_received(self, stream: GRPCProtoStream): 161 | try: 162 | await stream.start_response() 163 | event = await stream.receive_event() 164 | 165 | if not isinstance(event, RequestReceived): 166 | await stream.close(Status(StatusCode.INTERNAL, status_message="Expected headers")) 167 | return 168 | 169 | try: 170 | service = self.services[event.service_name] 171 | except KeyError: 172 | await stream.close(Status( 173 | StatusCode.UNIMPLEMENTED, 174 | status_message="Service {service_name} is not implemented".format(service_name=event.service_name) 175 | )) 176 | return 177 | 178 | try: 179 | bound_rpc_method = service.methods[event.method_name] 180 | except KeyError: 181 | await stream.close(Status( 182 | StatusCode.UNIMPLEMENTED, 183 | status_message="Method {method_name} is not implemented in service {service_name}".format( 184 | method_name=event.method_name, 185 | service_name=event.service_name 186 | ) 187 | )) 188 | return 189 | 190 | # TODO: Should at least pass through GeneratorExit 191 | try: 192 | method_fn = functools.partial(bound_rpc_method.method_fn, request=event) 193 | cardinality = bound_rpc_method.signature.cardinality 194 | stream.expect_message_type(bound_rpc_method.signature.request_type) 195 | if cardinality == Cardinality.STREAM_STREAM: 196 | await call_server_stream_stream(method_fn, stream) 197 | elif cardinality == Cardinality.UNARY_STREAM: 198 | await call_server_unary_stream(method_fn, stream) 199 | elif cardinality == Cardinality.STREAM_UNARY: 200 | await call_server_stream_unary(method_fn, stream) 201 | else: 202 | await call_server_unary_unary(method_fn, stream) 203 | except RpcFailedError as error: 204 | await stream.close(error.status) 205 | except: 206 | # TODO: limit catch to Exception, so async cancel can propagate 207 | log.warning("Got exception while writing response stream", 208 | exc_info=log.getEffectiveLevel() == logging.DEBUG) 209 | await stream.close(Status(StatusCode.CANCELLED, status_message=repr(sys.exc_info()))) 210 | except: 211 | # TODO: limit catch to Exception, so async cancel can propagate 212 | log.warning("Got exception in request_received", 213 | exc_info=log.getEffectiveLevel() == logging.DEBUG) 214 | 215 | async def __call__(self, stream_: anyio.abc.SocketStream): 216 | # TODO: Should at least pass through GeneratorExit 217 | self._server._connection_count += 1 218 | try: 219 | async with GRPCProtoSocket(self.config, stream_) as grpc_socket: 220 | # TODO: resource usage warning 221 | # TODO: TaskGroup() uses a lot of memory if the connection is kept for a long time 222 | # TODO: do we really need it here? 223 | async with anyio.create_task_group() as task_group: 224 | async for stream in grpc_socket.listen(): 225 | task_group.start_soon(self.request_received, stream) 226 | except: 227 | # TODO: limit catch to Exception, so async cancel can propagate 228 | # TODO: migrate off this broad catching of exceptions. The library 229 | # user should decide the policy. 230 | log.warning("Got exception in main dispatch loop", 231 | exc_info=log.getEffectiveLevel() == logging.DEBUG) 232 | self._server._exception_count += 1 233 | -------------------------------------------------------------------------------- /src/purerpc/test_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import collections 3 | import subprocess 4 | import multiprocessing 5 | import tempfile 6 | import shutil 7 | import os 8 | import sys 9 | import inspect 10 | import importlib 11 | import concurrent.futures 12 | import contextlib 13 | import time 14 | import random 15 | import string 16 | from multiprocessing.connection import Connection 17 | 18 | from tblib import pickling_support 19 | pickling_support.install() 20 | 21 | import forge 22 | import anyio 23 | from async_generator import aclosing 24 | 25 | # work around pickle issue on macOS 26 | if sys.platform == 'darwin': 27 | multiprocessing = multiprocessing.get_context('fork') 28 | 29 | 30 | @contextlib.contextmanager 31 | def compile_temp_proto(*relative_proto_paths): 32 | modules = [] 33 | with tempfile.TemporaryDirectory() as temp_dir: 34 | sys.path.insert(0, temp_dir) 35 | try: 36 | for relative_proto_path in relative_proto_paths: 37 | proto_path = os.path.join(os.path.dirname( 38 | inspect.currentframe().f_back.f_back.f_globals['__file__']), 39 | relative_proto_path) 40 | proto_filename = os.path.basename(proto_path) 41 | proto_temp_path = os.path.join(temp_dir, proto_filename) 42 | shutil.copyfile(proto_path, proto_temp_path) 43 | for relative_proto_path in relative_proto_paths: 44 | proto_filename = os.path.basename(relative_proto_path) 45 | proto_temp_path = os.path.join(temp_dir, proto_filename) 46 | cmdline = [sys.executable, '-m', 'grpc_tools.protoc', 47 | '--python_out=.', '--purerpc_out=.', '--grpc_python_out=.', 48 | '-I' + temp_dir, proto_temp_path] 49 | subprocess.check_call(cmdline, cwd=temp_dir) 50 | 51 | pb2_module_name = proto_filename.replace(".proto", "_pb2") 52 | pb2_grpc_module_name = proto_filename.replace(".proto", "_pb2_grpc") 53 | grpc_module_name = proto_filename.replace(".proto", "_grpc") 54 | 55 | pb2_module = importlib.import_module(pb2_module_name) 56 | pb2_grpc_module = importlib.import_module(pb2_grpc_module_name) 57 | grpc_module = importlib.import_module(grpc_module_name) 58 | modules.extend((pb2_module, pb2_grpc_module, grpc_module)) 59 | yield modules 60 | finally: 61 | sys.path.remove(temp_dir) 62 | 63 | 64 | _WrappedResult = collections.namedtuple("_WrappedResult", ("result", "exc_info")) 65 | 66 | 67 | def _wrap_gen_in_process(conn: Connection): 68 | def decorator(gen): 69 | @functools.wraps(gen) 70 | def new_func(*args, **kwargs): 71 | try: 72 | for elem in gen(*args, **kwargs): 73 | conn.send(_WrappedResult(result=elem, exc_info=None)) 74 | except: 75 | conn.send(_WrappedResult(result=None, exc_info=sys.exc_info())) 76 | finally: 77 | conn.close() 78 | return new_func 79 | return decorator 80 | 81 | 82 | async def async_iterable_to_list(async_iterable): 83 | result = [] 84 | async with aclosing(async_iterable) as async_iterable: 85 | async for value in async_iterable: 86 | result.append(value) 87 | return result 88 | 89 | 90 | def random_payload(min_size=1000, max_size=100000): 91 | return "".join( 92 | random.choices(string.ascii_letters, k=random.randint(min_size, max_size)) 93 | ) 94 | 95 | 96 | @contextlib.contextmanager 97 | def _run_context_manager_generator_in_process(cm_gen): 98 | parent_conn, child_conn = multiprocessing.Pipe(duplex=False) 99 | target_fn = _wrap_gen_in_process(child_conn)(cm_gen) 100 | 101 | process = multiprocessing.Process(target=target_fn) 102 | process.start() 103 | try: 104 | wrapped_result = parent_conn.recv() 105 | if wrapped_result.exc_info is not None: 106 | raise wrapped_result.exc_info[0].with_traceback(*wrapped_result.exc_info[1:]) 107 | else: 108 | yield wrapped_result.result 109 | finally: 110 | try: 111 | if parent_conn.poll(): 112 | exc_info = parent_conn.recv().exc_info 113 | if exc_info is not None: 114 | raise exc_info[0].with_traceback(*exc_info[1:]) 115 | finally: 116 | process.terminate() 117 | process.join() 118 | parent_conn.close() 119 | 120 | 121 | def _run_purerpc_service_in_process(service, ssl_context=None): 122 | # TODO: there is no reason to run the server as a separate process... 123 | # just use serve_async(). This synchronous cm has timing problems, 124 | # because the server may not be listening before yielding to the body. 125 | 126 | def target_fn(): 127 | import purerpc 128 | import socket 129 | 130 | # Grab an ephemeral port in advance, because we need to yield the port 131 | # before blocking on serve()... 132 | with socket.socket() as sock: 133 | sock.bind(('127.0.0.1', 0)) 134 | port = sock.getsockname()[1] 135 | 136 | server = purerpc.Server(port=port, ssl_context=ssl_context) 137 | server.add_service(service) 138 | yield port 139 | server.serve() 140 | 141 | # async def sleep_10_seconds_then_die(): 142 | # await anyio.sleep(20) 143 | # raise ValueError 144 | # 145 | # async def main(): 146 | # async with anyio.create_task_group() as tg: 147 | # tg.start_soon(server.serve_async) 148 | # tg.start_soon(sleep_10_seconds_then_die) 149 | # 150 | # import cProfile 151 | # cProfile.runctx("purerpc_run(main)", globals(), locals(), sort="tottime") 152 | 153 | return _run_context_manager_generator_in_process(target_fn) 154 | 155 | 156 | @contextlib.contextmanager 157 | def run_purerpc_service_in_process(service, ssl_context=None): 158 | with _run_purerpc_service_in_process(service, ssl_context=ssl_context) as port: 159 | # work around API issue, giving server a chance to listen 160 | time.sleep(.05) 161 | yield port 162 | 163 | 164 | # TODO: remove grpcio dependency from tests. There is no reason to unit test 165 | # grpc project's code, and it's blocking pypy support. 166 | def run_grpc_service_in_process(add_handler_fn): 167 | def target_fn(): 168 | import grpc 169 | server = grpc.server(concurrent.futures.ThreadPoolExecutor(max_workers=1)) 170 | port = server.add_insecure_port('[::]:0') 171 | add_handler_fn(server) 172 | server.start() 173 | yield port 174 | while True: 175 | time.sleep(60) 176 | return _run_context_manager_generator_in_process(target_fn) 177 | 178 | 179 | def run_tests_in_workers(*, target, num_workers): 180 | parent_conn, child_conn = multiprocessing.Pipe(duplex=False) 181 | 182 | @_wrap_gen_in_process(child_conn) 183 | def target_fn(): 184 | target() 185 | yield 186 | 187 | processes = [multiprocessing.Process(target=target_fn) for _ in range(num_workers)] 188 | for process in processes: 189 | process.start() 190 | 191 | try: 192 | for _ in range(num_workers): 193 | wrapped_result = parent_conn.recv() 194 | if wrapped_result.exc_info is not None: 195 | raise wrapped_result.exc_info[0].with_traceback(*wrapped_result.exc_info[1:]) 196 | finally: 197 | parent_conn.close() 198 | for process in processes: 199 | process.join() 200 | 201 | 202 | def grpc_client_parallelize(num_workers): 203 | def decorator(func): 204 | @functools.wraps(func) 205 | def new_func(*args, **kwargs): 206 | def target(): 207 | func(*args, **kwargs) 208 | run_tests_in_workers(target=target, num_workers=num_workers) 209 | 210 | new_func.__parallelized__ = True 211 | return new_func 212 | return decorator 213 | 214 | 215 | def purerpc_client_parallelize(num_tasks): 216 | def decorator(corofunc): 217 | if not inspect.iscoroutinefunction(corofunc): 218 | raise TypeError("Expected coroutine function") 219 | 220 | @functools.wraps(corofunc) 221 | async def new_corofunc(**kwargs): 222 | async with anyio.create_task_group() as tg: 223 | for _ in range(num_tasks): 224 | tg.start_soon(functools.partial(corofunc, **kwargs)) 225 | return new_corofunc 226 | return decorator 227 | 228 | 229 | def grpc_channel(port_fixture_name, channel_arg_name="channel"): 230 | def decorator(func): 231 | if hasattr(func, "__parallelized__") and func.__parallelized__: 232 | raise TypeError("Cannot pass gRPC channel to already parallelized test, grpc_client_parallelize should " 233 | "be the last decorator in chain") 234 | 235 | @forge.compose( 236 | forge.copy(func), 237 | forge.modify(channel_arg_name, name=port_fixture_name, interface_name="port_fixture_value"), 238 | ) 239 | def new_func(*, port_fixture_value, **kwargs): 240 | import grpc 241 | with grpc.insecure_channel('127.0.0.1:{}'.format(port_fixture_value)) as channel: 242 | func(**kwargs, channel=channel) 243 | 244 | return new_func 245 | return decorator 246 | 247 | 248 | def purerpc_channel(port_fixture_name, channel_arg_name="channel"): 249 | def decorator(corofunc): 250 | if not inspect.iscoroutinefunction(corofunc): 251 | raise TypeError("Expected coroutine function") 252 | 253 | @forge.compose( 254 | forge.copy(corofunc), 255 | forge.modify(channel_arg_name, name=port_fixture_name, interface_name="port_fixture_value"), 256 | ) 257 | async def new_corofunc(*, port_fixture_value, **kwargs): 258 | import purerpc 259 | async with purerpc.insecure_channel("127.0.0.1", port_fixture_value) as channel: 260 | await corofunc(**kwargs, channel=channel) 261 | 262 | return new_corofunc 263 | return decorator 264 | -------------------------------------------------------------------------------- /src/purerpc/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | 5 | import anyio 6 | 7 | _log = logging.getLogger(__name__) 8 | 9 | 10 | def is_linux(): 11 | return platform.system() == "Linux" 12 | 13 | 14 | def is_darwin(): 15 | return platform.system() == "Darwin" 16 | 17 | 18 | def is_windows(): 19 | return platform.system() == "Windows" 20 | 21 | 22 | def run(func, *args, backend=None, backend_options=None): 23 | """wrapper for anyio.run() with some purerpc-specific conventions 24 | 25 | * if `backend` is None, read it from PURERPC_BACKEND environment variable, 26 | (still defaulting to asyncio) 27 | * allow "uvloop" as a value of `backend` (normally uvloop needs to 28 | be specified via `backend_options` under asyncio) 29 | * if uvloop is selected, raise ModuleNotFoundError if uvloop isn't installed 30 | """ 31 | 32 | if backend is None: 33 | backend = os.getenv("PURERPC_BACKEND", "asyncio") 34 | _log.info("purerpc.run() selected {} backend".format(backend)) 35 | if backend == "uvloop": 36 | backend = "asyncio" 37 | options = dict(use_uvloop=True) 38 | if backend_options is None: 39 | backend_options = options 40 | else: 41 | backend_options.update(options) 42 | if backend == "asyncio" and backend_options and backend_options.get('use_uvloop'): 43 | # Since anyio.run() will silently fall back when uvloop isn't available, 44 | # make the requirement explicit. 45 | import uvloop 46 | return anyio.run(func, *args, backend=backend, backend_options=backend_options) 47 | -------------------------------------------------------------------------------- /src/purerpc/wrappers.py: -------------------------------------------------------------------------------- 1 | import anyio 2 | from contextlib import asynccontextmanager 3 | from async_generator import aclosing 4 | 5 | from .grpclib.exceptions import ProtocolError, raise_status 6 | from .grpclib.status import Status, StatusCode 7 | from purerpc.grpc_proto import GRPCProtoStream 8 | from purerpc.grpclib.events import ResponseEnded 9 | 10 | 11 | async def extract_message_from_singleton_stream(stream): 12 | msg = await stream.receive_message() 13 | if msg is None: 14 | event = stream.end_stream_event 15 | if isinstance(event, ResponseEnded): 16 | raise_status(event.status) 17 | raise ProtocolError("Expected one message, got zero") 18 | if await stream.receive_message() is not None: 19 | raise ProtocolError("Expected one message, got multiple") 20 | return msg 21 | 22 | 23 | async def stream_to_async_iterator(stream: GRPCProtoStream): 24 | while True: 25 | msg = await stream.receive_message() 26 | if msg is None: 27 | event = stream.end_stream_event 28 | if isinstance(event, ResponseEnded): 29 | raise_status(event.status) 30 | return 31 | yield msg 32 | 33 | 34 | async def send_multiple_messages_server(stream, aiter): 35 | async with aclosing(aiter) as aiter: 36 | async for message in aiter: 37 | await stream.send_message(message) 38 | await stream.close(Status(StatusCode.OK)) 39 | 40 | 41 | async def send_single_message_server(stream, message): 42 | await stream.send_message(message) 43 | await stream.close(Status(StatusCode.OK)) 44 | 45 | 46 | async def send_multiple_messages_client(stream, aiter): 47 | try: 48 | async with aclosing(aiter) as aiter: 49 | async for message in aiter: 50 | await stream.send_message(message) 51 | finally: 52 | await stream.close() 53 | 54 | 55 | async def send_single_message_client(stream, message): 56 | try: 57 | await stream.send_message(message) 58 | finally: 59 | await stream.close() 60 | 61 | 62 | async def call_server_unary_unary(func, stream): 63 | msg = await extract_message_from_singleton_stream(stream) 64 | await send_single_message_server(stream, await func(msg)) 65 | 66 | 67 | async def call_server_unary_stream(func, stream): 68 | msg = await extract_message_from_singleton_stream(stream) 69 | await send_multiple_messages_server(stream, func(msg)) 70 | 71 | 72 | async def call_server_stream_unary(func, stream): 73 | input_message_stream = stream_to_async_iterator(stream) 74 | await send_single_message_server(stream, await func(input_message_stream)) 75 | 76 | 77 | async def call_server_stream_stream(func, stream): 78 | input_message_stream = stream_to_async_iterator(stream) 79 | await send_multiple_messages_server(stream, func(input_message_stream)) 80 | 81 | 82 | class ClientStub: 83 | def __init__(self, stream_fn): 84 | self._stream_fn = stream_fn 85 | 86 | 87 | class ClientStubUnaryUnary(ClientStub): 88 | async def __call__(self, message, *, metadata=None): 89 | stream = await self._stream_fn(metadata=metadata) 90 | await send_single_message_client(stream, message) 91 | return await extract_message_from_singleton_stream(stream) 92 | 93 | 94 | class ClientStubUnaryStream(ClientStub): 95 | async def __call__(self, message, *, metadata=None): 96 | stream = await self._stream_fn(metadata=metadata) 97 | await send_single_message_client(stream, message) 98 | async for value in stream_to_async_iterator(stream): 99 | yield value 100 | 101 | 102 | class ClientStubStreamUnary(ClientStub): 103 | async def __call__(self, message_aiter, *, metadata=None): 104 | stream = await self._stream_fn(metadata=metadata) 105 | async with anyio.create_task_group() as task_group: 106 | task_group.start_soon(send_multiple_messages_client, stream, message_aiter) 107 | return await extract_message_from_singleton_stream(stream) 108 | 109 | 110 | class ClientStubStreamStream(ClientStub): 111 | @asynccontextmanager 112 | async def call_aiter(self, message_aiter, metadata): 113 | stream = await self._stream_fn(metadata=metadata) 114 | 115 | async with anyio.create_task_group() as task_group: 116 | task_group.start_soon(send_multiple_messages_client, stream, message_aiter) 117 | yield stream_to_async_iterator(stream) 118 | 119 | async def call_stream(self, metadata): 120 | return await self._stream_fn(metadata=metadata) 121 | 122 | def __call__(self, message_aiter=None, *, metadata=None): 123 | if message_aiter is None: 124 | return self.call_stream(metadata) 125 | else: 126 | return self.call_aiter(message_aiter, metadata) 127 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/python-trio/purerpc/a3c17dd885d8f36bcf3a78c7506e35f7bc33cccc/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pytest 4 | 5 | from purerpc.test_utils import compile_temp_proto 6 | 7 | pytestmark = pytest.mark.anyio 8 | 9 | 10 | @pytest.fixture(params=[ 11 | pytest.param(('trio'), id='trio'), 12 | pytest.param(('asyncio'), id='asyncio'), 13 | pytest.param(('asyncio', dict(use_uvloop=True)), id='uvloop', 14 | marks=[pytest.mark.skipif(platform.system() == 'Windows', 15 | reason='uvloop not supported on Windows')]), 16 | ]) 17 | def anyio_backend(request): 18 | return request.param 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def greeter_modules(): 23 | with compile_temp_proto("data/greeter.proto") as modules: 24 | yield modules 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def greeter_pb2(greeter_modules): 29 | return greeter_modules[0] 30 | 31 | 32 | @pytest.fixture(scope="session") 33 | def greeter_pb2_grpc(greeter_modules): 34 | return greeter_modules[1] 35 | 36 | 37 | @pytest.fixture(scope="session") 38 | def greeter_grpc(greeter_modules): 39 | return greeter_modules[2] 40 | 41 | 42 | @pytest.fixture(scope="session") 43 | def echo_modules(): 44 | with compile_temp_proto("data/echo.proto") as modules: 45 | yield modules 46 | 47 | 48 | @pytest.fixture(scope="session") 49 | def echo_pb2(echo_modules): 50 | return echo_modules[0] 51 | 52 | 53 | @pytest.fixture(scope="session") 54 | def echo_pb2_grpc(echo_modules): 55 | return echo_modules[1] 56 | 57 | 58 | @pytest.fixture(scope="session") 59 | def echo_grpc(echo_modules): 60 | return echo_modules[2] 61 | -------------------------------------------------------------------------------- /tests/data/echo.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | service Echo { 4 | rpc Echo (EchoRequest) returns (EchoReply) {} 5 | rpc EchoTwoTimes (EchoRequest) returns (stream EchoReply) {} 6 | rpc EchoEachTime (stream EchoRequest) returns (stream EchoReply) {} 7 | rpc EchoLast (stream EchoRequest) returns (EchoReply) {} 8 | rpc EchoLastV2 (stream EchoRequest) returns (stream EchoReply) {} 9 | } 10 | 11 | message EchoRequest { 12 | string data = 1; 13 | } 14 | 15 | message EchoReply { 16 | string data = 1; 17 | } 18 | -------------------------------------------------------------------------------- /tests/data/greeter.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | service Greeter { 4 | rpc SayHello (HelloRequest) returns (HelloReply) {} 5 | rpc SayHelloGoodbye (HelloRequest) returns (stream HelloReply) {} 6 | rpc SayHelloToMany (stream HelloRequest) returns (stream HelloReply) {} 7 | rpc SayHelloToManyAtOnce (stream HelloRequest) returns (HelloReply) {} 8 | } 9 | 10 | message HelloRequest { 11 | string name = 1; 12 | } 13 | 14 | message HelloReply { 15 | string message = 1; 16 | } 17 | -------------------------------------------------------------------------------- /tests/data/test_package_names/A.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package packageA; 3 | 4 | message Message { 5 | string value = 1; 6 | } 7 | -------------------------------------------------------------------------------- /tests/data/test_package_names/B.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package packageB; 3 | 4 | message Message { 5 | string value = 1; 6 | } 7 | -------------------------------------------------------------------------------- /tests/data/test_package_names/C.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "A.proto"; 4 | import "B.proto"; 5 | 6 | service Service { 7 | rpc Method(packageA.Message) returns (packageB.Message); 8 | } 9 | -------------------------------------------------------------------------------- /tests/exceptiongroups.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from contextlib import contextmanager 3 | import sys 4 | 5 | if sys.version_info < (3, 11): 6 | from exceptiongroup import ExceptionGroup 7 | 8 | 9 | def _unroll_exceptions( 10 | exceptions: Iterable[Exception] 11 | ) -> Iterable[Exception]: 12 | res: list[Exception] = [] 13 | for exc in exceptions: 14 | if isinstance(exc, ExceptionGroup): 15 | res.extend(_unroll_exceptions(exc.exceptions)) 16 | 17 | else: 18 | res.append(exc) 19 | return res 20 | 21 | 22 | @contextmanager 23 | def unwrap_exceptiongroups_single(): 24 | try: 25 | yield 26 | except ExceptionGroup as e: 27 | exceptions = _unroll_exceptions(e.exceptions) 28 | 29 | assert len(exceptions) == 1, "Exception group contains multiple exceptions" 30 | 31 | raise exceptions[0] 32 | -------------------------------------------------------------------------------- /tests/test_buffers.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | import random 3 | import struct 4 | 5 | import pytest 6 | 7 | from purerpc.grpclib.buffers import ByteBuffer, MessageReadBuffer 8 | 9 | byte_buffer = pytest.fixture(lambda: ByteBuffer()) 10 | byte_array = pytest.fixture(lambda: bytearray()) 11 | 12 | 13 | def test_byte_buffer_random(byte_buffer, byte_array): 14 | for i in range(1000): 15 | data = bytes(random.randint(0, 255) for _ in range(random.randint(0, 100))) 16 | byte_buffer.append(data) 17 | byte_array.extend(data) 18 | assert len(byte_buffer) == len(byte_array) 19 | num_elements = min(random.randint(0, 100), len(byte_buffer)) 20 | assert byte_array[:num_elements] == byte_buffer.popleft(num_elements) 21 | byte_array = byte_array[num_elements:] 22 | 23 | 24 | def test_byte_buffer_large_reads(byte_buffer, byte_array): 25 | for i in range(1000): 26 | for j in range(100): 27 | data = bytes([(i + j) % 256]) 28 | byte_buffer.append(data) 29 | byte_array.extend(data) 30 | assert len(byte_array) == len(byte_buffer) 31 | num_elements = min(random.randint(0, 100), len(byte_buffer)) 32 | assert byte_array[:num_elements] == byte_buffer.popleft(num_elements) 33 | byte_array = byte_array[num_elements:] 34 | 35 | 36 | def test_byte_buffer_large_writes(byte_buffer, byte_array): 37 | data = bytes(range(256)) * 10 38 | for i in range(250): 39 | byte_buffer.append(data) 40 | byte_array.extend(data) 41 | for j in range(10): 42 | assert len(byte_array) == len(byte_buffer) 43 | num_elements = min(random.randint(0, 100), len(byte_buffer)) 44 | assert byte_array[:num_elements] == byte_buffer.popleft(num_elements) 45 | byte_array = byte_array[num_elements:] 46 | 47 | 48 | def test_message_read_buffer(byte_array): 49 | for i in range(100): 50 | data = bytes(range(i)) 51 | compress_flag = False 52 | if i % 2: 53 | data = zlib.compress(data) 54 | compress_flag = True 55 | byte_array.extend(struct.pack('>?I', compress_flag, len(data))) 56 | byte_array.extend(data) 57 | 58 | read_buffer = MessageReadBuffer(message_encoding="gzip") 59 | messages = [] 60 | while byte_array: 61 | if random.choice([True, False]): 62 | num_bytes = random.randint(0, 50) 63 | read_buffer.data_received(bytes(byte_array[:num_bytes])) 64 | byte_array = byte_array[num_bytes:] 65 | else: 66 | messages.extend(read_buffer.read_all_complete_messages()) 67 | messages.extend(read_buffer.read_all_complete_messages()) 68 | 69 | assert len(messages) == 100 70 | for idx, message in enumerate(messages): 71 | assert message == bytes(range(idx)) 72 | -------------------------------------------------------------------------------- /tests/test_echo.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import ssl 3 | 4 | import anyio 5 | import pytest 6 | import trustme 7 | 8 | import purerpc 9 | from purerpc.test_utils import run_purerpc_service_in_process, run_grpc_service_in_process, \ 10 | async_iterable_to_list, random_payload, grpc_client_parallelize, purerpc_channel, purerpc_client_parallelize, grpc_channel 11 | from .exceptiongroups import unwrap_exceptiongroups_single 12 | 13 | pytestmark = pytest.mark.anyio 14 | 15 | 16 | @pytest.fixture(scope='module') 17 | def ca(): 18 | return trustme.CA() 19 | 20 | 21 | @pytest.fixture(scope='module') 22 | def server_ssl_context(ca): 23 | server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 24 | ca.issue_cert('127.0.0.1').configure_cert(server_context) 25 | return server_context 26 | 27 | 28 | @pytest.fixture(scope='module') 29 | def client_ssl_context(ca): 30 | client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 31 | ca.configure_trust(client_context) 32 | return client_context 33 | 34 | 35 | def make_servicer(echo_pb2, echo_grpc): 36 | class Servicer(echo_grpc.EchoServicer): 37 | async def Echo(self, message): 38 | return echo_pb2.EchoReply(data=message.data) 39 | 40 | async def EchoTwoTimes(self, message): 41 | yield echo_pb2.EchoReply(data=message.data) 42 | yield echo_pb2.EchoReply(data=message.data) 43 | 44 | async def EchoEachTime(self, messages): 45 | async for message in messages: 46 | yield echo_pb2.EchoReply(data=message.data) 47 | 48 | async def EchoLast(self, messages): 49 | data = [] 50 | async for message in messages: 51 | data.append(message.data) 52 | return echo_pb2.EchoReply(data="".join(data)) 53 | 54 | async def EchoLastV2(self, messages): 55 | data = [] 56 | async for message in messages: 57 | data.append(message.data) 58 | yield echo_pb2.EchoReply(data="".join(data)) 59 | 60 | return Servicer 61 | 62 | 63 | @pytest.fixture(scope="module") 64 | def purerpc_echo_port(echo_pb2, echo_grpc): 65 | Servicer = make_servicer(echo_pb2, echo_grpc) 66 | with run_purerpc_service_in_process(Servicer().service) as port: 67 | yield port 68 | 69 | 70 | @pytest.fixture(scope="module") 71 | def purerpc_echo_port_ssl(echo_pb2, echo_grpc, server_ssl_context): 72 | Servicer = make_servicer(echo_pb2, echo_grpc) 73 | with run_purerpc_service_in_process(Servicer().service, 74 | ssl_context=server_ssl_context) as port: 75 | yield port 76 | 77 | 78 | @pytest.fixture(scope="module") 79 | def grpc_echo_port(echo_pb2, echo_pb2_grpc): 80 | class Servicer(echo_pb2_grpc.EchoServicer): 81 | def Echo(self, message, context): 82 | return echo_pb2.EchoReply(data=message.data) 83 | 84 | def EchoTwoTimes(self, message, context): 85 | yield echo_pb2.EchoReply(data=message.data) 86 | yield echo_pb2.EchoReply(data=message.data) 87 | 88 | def EchoEachTime(self, messages, context): 89 | for message in messages: 90 | yield echo_pb2.EchoReply(data=message.data) 91 | 92 | def EchoLast(self, messages, context): 93 | data = [] 94 | for message in messages: 95 | data.append(message.data) 96 | return echo_pb2.EchoReply(data="".join(data)) 97 | 98 | def EchoLastV2(self, messages, context): 99 | data = [] 100 | for message in messages: 101 | data.append(message.data) 102 | yield echo_pb2.EchoReply(data="".join(data)) 103 | 104 | with run_grpc_service_in_process(functools.partial( 105 | echo_pb2_grpc.add_EchoServicer_to_server, Servicer())) as port: 106 | yield port 107 | 108 | 109 | @pytest.fixture(scope="module", 110 | params=["purerpc_echo_port", "grpc_echo_port"]) 111 | def echo_port(request): 112 | return request.getfixturevalue(request.param) 113 | 114 | 115 | @purerpc_channel("echo_port") 116 | @purerpc_client_parallelize(50) 117 | async def test_purerpc_client_large_payload_many_streams(echo_pb2, echo_grpc, channel): 118 | stub = echo_grpc.EchoStub(channel) 119 | data = "World" * 20000 120 | assert (await stub.Echo(echo_pb2.EchoRequest(data=data))).data == data 121 | 122 | 123 | @purerpc_channel("echo_port") 124 | async def test_purerpc_client_large_payload_one_stream(echo_pb2, echo_grpc, channel): 125 | stub = echo_grpc.EchoStub(channel) 126 | data = "World" * 20000 127 | assert (await stub.Echo(echo_pb2.EchoRequest(data=data))).data == data 128 | 129 | 130 | @grpc_client_parallelize(50) 131 | @grpc_channel("echo_port") 132 | def test_grpc_client_large_payload(echo_pb2, echo_pb2_grpc, channel): 133 | stub = echo_pb2_grpc.EchoStub(channel) 134 | data = "World" * 20000 135 | assert stub.Echo(echo_pb2.EchoRequest(data=data)).data == data 136 | 137 | 138 | @purerpc_channel("echo_port") 139 | @purerpc_client_parallelize(20) 140 | async def test_purerpc_client_random_payload(echo_pb2, echo_grpc, channel): 141 | stub = echo_grpc.EchoStub(channel) 142 | data = random_payload() 143 | 144 | async def gen(): 145 | for _ in range(4): 146 | yield echo_pb2.EchoRequest(data=data) 147 | 148 | assert (await stub.Echo(echo_pb2.EchoRequest(data=data))).data == data 149 | assert [response.data for response in await async_iterable_to_list( 150 | stub.EchoTwoTimes(echo_pb2.EchoRequest(data=data)))] == [data] * 2 151 | assert (await stub.EchoLast(gen())).data == data * 4 152 | async with stub.EchoEachTime(gen()) as aiter: 153 | assert [response.data for response in await async_iterable_to_list(aiter)] == [ 154 | data 155 | ] * 4 156 | 157 | 158 | @purerpc_channel("echo_port") 159 | @purerpc_client_parallelize(10) 160 | async def test_purerpc_client_deadlock(echo_pb2, echo_grpc, channel): 161 | stub = echo_grpc.EchoStub(channel) 162 | data = random_payload(min_size=32000, max_size=64000) 163 | 164 | async def gen(): 165 | for _ in range(20): 166 | yield echo_pb2.EchoRequest(data=data) 167 | 168 | async with stub.EchoLastV2(gen()) as aiter: 169 | assert [response.data for response in await async_iterable_to_list(aiter)] == [ 170 | data * 20 171 | ] 172 | 173 | 174 | async def test_purerpc_ssl(echo_pb2, echo_grpc, purerpc_echo_port_ssl, client_ssl_context): 175 | async with purerpc.secure_channel("127.0.0.1", purerpc_echo_port_ssl, 176 | ssl_context=client_ssl_context) as channel: 177 | stub = echo_grpc.EchoStub(channel) 178 | data = random_payload(min_size=32000, max_size=64000) 179 | 180 | async def gen(): 181 | for _ in range(20): 182 | yield echo_pb2.EchoRequest(data=data) 183 | 184 | async with stub.EchoLastV2(gen()) as aiter: 185 | assert [ 186 | response.data for response in await async_iterable_to_list(aiter) 187 | ] == [data * 20] 188 | 189 | 190 | async def test_purerpc_client_disconnect(echo_pb2, echo_grpc): 191 | # when the client disconnects, the server should not log an exception 192 | # 193 | # NOTE: This test demonstrates a client/server test without multiprocessing or 194 | # fixture acrobatics. 195 | 196 | async with anyio.create_task_group() as tg: 197 | # server 198 | Servicer = make_servicer(echo_pb2, echo_grpc) 199 | server = purerpc.Server(port=0) 200 | server.add_service(Servicer().service) 201 | port = await tg.start(server.serve_async) 202 | 203 | # client 204 | with pytest.raises(anyio.ClosedResourceError), unwrap_exceptiongroups_single(): 205 | async with purerpc.insecure_channel("localhost", port) as channel: 206 | stub = echo_grpc.EchoStub(channel) 207 | 208 | data = 'hello' 209 | assert (await stub.Echo(echo_pb2.EchoRequest(data=data))).data == data 210 | 211 | # close the sending stream, inducing EndOfStream on the server 212 | await channel._grpc_socket._socket._stream.aclose() 213 | await anyio.wait_all_tasks_blocked() 214 | 215 | assert server._connection_count == 1 216 | assert server._exception_count == 0 217 | tg.cancel_scope.cancel() 218 | -------------------------------------------------------------------------------- /tests/test_errors.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import pytest 4 | 5 | import purerpc 6 | from purerpc.test_utils import run_purerpc_service_in_process, run_grpc_service_in_process, grpc_channel, \ 7 | grpc_client_parallelize, purerpc_channel 8 | from .exceptiongroups import unwrap_exceptiongroups_single 9 | 10 | pytestmark = pytest.mark.anyio 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def purerpc_port(greeter_pb2, greeter_grpc): 15 | class Servicer(greeter_grpc.GreeterServicer): 16 | async def SayHello(self, message): 17 | raise ValueError("oops my bad") 18 | 19 | async def SayHelloToMany(self, messages): 20 | idx = 1 21 | async for _ in messages: 22 | yield greeter_pb2.HelloReply(message=str(idx)) 23 | if idx == 7: 24 | raise ValueError("Lucky 7") 25 | idx += 1 26 | with run_purerpc_service_in_process(Servicer().service) as port: 27 | yield port 28 | 29 | 30 | @pytest.fixture(scope="module") 31 | def grpc_port(greeter_pb2, greeter_pb2_grpc): 32 | class Servicer(greeter_pb2_grpc.GreeterServicer): 33 | def SayHello(self, message, context): 34 | raise ValueError("oops my bad") 35 | 36 | def SayHelloToMany(self, messages, context): 37 | idx = 1 38 | for _ in messages: 39 | yield greeter_pb2.HelloReply(message=str(idx)) 40 | if idx == 7: 41 | raise ValueError("Lucky 7") 42 | idx += 1 43 | 44 | with run_grpc_service_in_process(functools.partial( 45 | greeter_pb2_grpc.add_GreeterServicer_to_server, Servicer())) as port: 46 | yield port 47 | 48 | 49 | @pytest.fixture(scope="module", 50 | params=["purerpc_port", "grpc_port"]) 51 | def port(request): 52 | return request.getfixturevalue(request.param) 53 | 54 | 55 | @grpc_client_parallelize(1) 56 | @grpc_channel("purerpc_port") 57 | def test_errors_grpc_client(greeter_pb2, greeter_pb2_grpc, channel): 58 | stub = greeter_pb2_grpc.GreeterStub(channel) 59 | with pytest.raises(BaseException, match=r"oops my bad"): 60 | stub.SayHello(greeter_pb2.HelloRequest(name="World")) 61 | 62 | with pytest.raises(BaseException, match=r"Lucky 7"): 63 | for _ in stub.SayHelloToMany(greeter_pb2.HelloRequest() for _ in range(10)): 64 | pass 65 | 66 | 67 | @purerpc_channel("port") 68 | async def test_errors_purerpc_client(greeter_pb2, greeter_grpc, channel): 69 | async def generator(): 70 | for _ in range(7): 71 | yield greeter_pb2.HelloRequest() 72 | 73 | stub = greeter_grpc.GreeterStub(channel) 74 | with pytest.raises(purerpc.RpcFailedError, match=r"oops my bad"): 75 | await stub.SayHello(greeter_pb2.HelloRequest(name="World")) 76 | 77 | async with stub.SayHelloToMany(generator()) as aiter: 78 | for _ in range(7): 79 | await aiter.__anext__() 80 | with pytest.raises(purerpc.RpcFailedError, match=r"Lucky 7"): 81 | await aiter.__anext__() 82 | 83 | 84 | @purerpc_channel("port") 85 | async def test_errors_purerpc_async_generator(greeter_pb2, greeter_grpc, channel): 86 | async def generator(): 87 | for _ in range(7): 88 | yield greeter_pb2.HelloRequest() 89 | 90 | stub = greeter_grpc.GreeterStub(channel) 91 | 92 | with pytest.raises(ValueError, match="oops"), unwrap_exceptiongroups_single(): 93 | async with stub.SayHelloToMany(generator()) as aiter: 94 | async for resp in aiter: 95 | if resp.message == "2": 96 | raise ValueError("oops") 97 | -------------------------------------------------------------------------------- /tests/test_greeter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import pytest 4 | 5 | import purerpc 6 | 7 | from purerpc.test_utils import ( 8 | run_purerpc_service_in_process, run_grpc_service_in_process, async_iterable_to_list, grpc_client_parallelize, 9 | purerpc_channel, purerpc_client_parallelize, grpc_channel 10 | ) 11 | 12 | pytestmark = pytest.mark.anyio 13 | 14 | 15 | def name_generator(greeter_pb2): 16 | names = ('Foo', 'Bar', 'Bat', 'Baz') 17 | for name in names: 18 | yield greeter_pb2.HelloRequest(name=name) 19 | 20 | 21 | async def async_name_generator(greeter_pb2): 22 | for request in name_generator(greeter_pb2): 23 | yield request 24 | 25 | 26 | @pytest.fixture(scope="module") 27 | def purerpc_codegen_greeter_port(greeter_pb2, greeter_grpc): 28 | class Servicer(greeter_grpc.GreeterServicer): 29 | async def SayHello(self, message): 30 | return greeter_pb2.HelloReply(message="Hello, " + message.name) 31 | 32 | async def SayHelloGoodbye(self, message): 33 | yield greeter_pb2.HelloReply(message="Hello, " + message.name) 34 | yield greeter_pb2.HelloReply(message="Goodbye, " + message.name) 35 | 36 | async def SayHelloToManyAtOnce(self, messages): 37 | names = [] 38 | async for message in messages: 39 | names.append(message.name) 40 | return greeter_pb2.HelloReply(message="Hello, " + ", ".join(names)) 41 | 42 | async def SayHelloToMany(self, messages): 43 | async for message in messages: 44 | yield greeter_pb2.HelloReply(message="Hello, " + message.name) 45 | 46 | with run_purerpc_service_in_process(Servicer().service) as port: 47 | yield port 48 | 49 | 50 | @pytest.fixture(scope="module") 51 | def purerpc_simple_greeter_port(greeter_pb2): 52 | service = purerpc.Service("Greeter") 53 | 54 | @service.rpc("SayHello") 55 | async def say_hello(message: greeter_pb2.HelloRequest) -> greeter_pb2.HelloReply: 56 | return greeter_pb2.HelloReply(message="Hello, " + message.name) 57 | 58 | @service.rpc("SayHelloGoodbye") 59 | async def say_hello_goodbye(message: greeter_pb2.HelloRequest) -> purerpc.Stream[greeter_pb2.HelloReply]: 60 | yield greeter_pb2.HelloReply(message="Hello, " + message.name) 61 | yield greeter_pb2.HelloReply(message="Goodbye, " + message.name) 62 | 63 | @service.rpc("SayHelloToManyAtOnce") 64 | async def say_hello_to_many_at_once(messages: purerpc.Stream[greeter_pb2.HelloRequest]) -> greeter_pb2.HelloReply: 65 | names = [] 66 | async for message in messages: 67 | names.append(message.name) 68 | return greeter_pb2.HelloReply(message="Hello, " + ', '.join(names)) 69 | 70 | @service.rpc("SayHelloToMany") 71 | async def say_hello_to_many(messages: purerpc.Stream[greeter_pb2.HelloRequest]) -> purerpc.Stream[greeter_pb2.HelloReply]: 72 | async for message in messages: 73 | yield greeter_pb2.HelloReply(message="Hello, " + message.name) 74 | 75 | 76 | with run_purerpc_service_in_process(service) as port: 77 | yield port 78 | 79 | 80 | @pytest.fixture(scope="module") 81 | def grpc_greeter_port(greeter_pb2, greeter_pb2_grpc): 82 | class Servicer(greeter_pb2_grpc.GreeterServicer): 83 | def SayHello(self, message, context): 84 | return greeter_pb2.HelloReply(message="Hello, " + message.name) 85 | 86 | def SayHelloGoodbye(self, message, context): 87 | yield greeter_pb2.HelloReply(message="Hello, " + message.name) 88 | yield greeter_pb2.HelloReply(message="Goodbye, " + message.name) 89 | 90 | def SayHelloToMany(self, messages, context): 91 | for message in messages: 92 | yield greeter_pb2.HelloReply(message="Hello, " + message.name) 93 | 94 | def SayHelloToManyAtOnce(self, messages, context): 95 | names = [] 96 | for message in messages: 97 | names.append(message.name) 98 | return greeter_pb2.HelloReply(message="Hello, " + ", ".join(names)) 99 | 100 | with run_grpc_service_in_process(functools.partial( 101 | greeter_pb2_grpc.add_GreeterServicer_to_server, Servicer())) as port: 102 | yield port 103 | 104 | 105 | @pytest.fixture(scope="module", 106 | params=["purerpc_codegen_greeter_port", "purerpc_simple_greeter_port"]) 107 | def purerpc_greeter_port(request): 108 | return request.getfixturevalue(request.param) 109 | 110 | 111 | @pytest.fixture(scope="module", 112 | params=["purerpc_codegen_greeter_port", "purerpc_simple_greeter_port", "grpc_greeter_port"]) 113 | def greeter_port(request): 114 | return request.getfixturevalue(request.param) 115 | 116 | 117 | @grpc_client_parallelize(50) 118 | @grpc_channel("purerpc_greeter_port") 119 | def test_grpc_client_parallel(greeter_pb2, greeter_pb2_grpc, channel): 120 | stub = greeter_pb2_grpc.GreeterStub(channel) 121 | assert stub.SayHello(greeter_pb2.HelloRequest(name="World")).message == "Hello, World" 122 | assert [response.message for response in 123 | stub.SayHelloGoodbye(greeter_pb2.HelloRequest(name="World"))] == ["Hello, World", "Goodbye, World"] 124 | assert stub.SayHelloToManyAtOnce(name_generator(greeter_pb2)).message == "Hello, Foo, Bar, Bat, Baz" 125 | assert [response.message for response in stub.SayHelloToMany(name_generator(greeter_pb2))] == \ 126 | ["Hello, Foo", "Hello, Bar", "Hello, Bat", "Hello, Baz"] 127 | 128 | 129 | @purerpc_channel("greeter_port") 130 | @purerpc_client_parallelize(50) 131 | async def test_purerpc_stub_client_parallel(greeter_pb2, greeter_grpc, channel): 132 | stub = greeter_grpc.GreeterStub(channel) 133 | assert (await stub.SayHello(greeter_pb2.HelloRequest(name="World"))).message == "Hello, World" 134 | assert [response.message for response in await async_iterable_to_list( 135 | stub.SayHelloGoodbye(greeter_pb2.HelloRequest(name="World")))] == ["Hello, World", "Goodbye, World"] 136 | assert (await stub.SayHelloToManyAtOnce(async_name_generator(greeter_pb2))).message == "Hello, Foo, Bar, Bat, Baz" 137 | async with stub.SayHelloToMany(async_name_generator(greeter_pb2)) as aiter: 138 | assert [ 139 | response.message for response in await async_iterable_to_list(aiter) 140 | ] == ["Hello, Foo", "Hello, Bar", "Hello, Bat", "Hello, Baz"] 141 | 142 | 143 | @purerpc_channel("greeter_port") 144 | @purerpc_client_parallelize(50) 145 | async def test_purerpc_stream_client_parallel(greeter_pb2, channel): 146 | async def test_say_hello(client): 147 | stream = await client.rpc("SayHello", greeter_pb2.HelloRequest, greeter_pb2.HelloReply) 148 | await stream.send_message(greeter_pb2.HelloRequest(name="World")) 149 | await stream.close() 150 | assert (await stream.receive_message()).message == "Hello, World" 151 | assert await stream.receive_message() is None 152 | 153 | async def test_say_hello_goodbye(client): 154 | stream = await client.rpc("SayHelloGoodbye", greeter_pb2.HelloRequest, greeter_pb2.HelloReply) 155 | await stream.send_message(greeter_pb2.HelloRequest(name="World")) 156 | await stream.close() 157 | assert (await stream.receive_message()).message == "Hello, World" 158 | assert (await stream.receive_message()).message == "Goodbye, World" 159 | assert await stream.receive_message() is None 160 | 161 | async def test_say_hello_to_many(client): 162 | stream = await client.rpc("SayHelloToMany", greeter_pb2.HelloRequest, greeter_pb2.HelloReply) 163 | await stream.send_message(greeter_pb2.HelloRequest(name="Foo")) 164 | assert (await stream.receive_message()).message == "Hello, Foo" 165 | await stream.send_message(greeter_pb2.HelloRequest(name="Bar")) 166 | assert (await stream.receive_message()).message == "Hello, Bar" 167 | await stream.send_message(greeter_pb2.HelloRequest(name="Baz")) 168 | await stream.send_message(greeter_pb2.HelloRequest(name="World")) 169 | assert (await stream.receive_message()).message == "Hello, Baz" 170 | assert (await stream.receive_message()).message == "Hello, World" 171 | await stream.close() 172 | assert await stream.receive_message() is None 173 | 174 | async def test_say_hello_to_many_at_once(client): 175 | stream = await client.rpc("SayHelloToManyAtOnce", greeter_pb2.HelloRequest, greeter_pb2.HelloReply) 176 | await stream.send_message(greeter_pb2.HelloRequest(name="Foo")) 177 | await stream.send_message(greeter_pb2.HelloRequest(name="Bar")) 178 | await stream.send_message(greeter_pb2.HelloRequest(name="Baz")) 179 | await stream.send_message(greeter_pb2.HelloRequest(name="World")) 180 | await stream.close() 181 | assert (await stream.receive_message()).message == "Hello, Foo, Bar, Baz, World" 182 | assert await stream.receive_message() is None 183 | 184 | client = purerpc.Client("Greeter", channel) 185 | await test_say_hello(client) 186 | await test_say_hello_goodbye(client) 187 | await test_say_hello_to_many(client) 188 | await test_say_hello_to_many_at_once(client) 189 | -------------------------------------------------------------------------------- /tests/test_metadata.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import pickle 3 | import base64 4 | 5 | import pytest 6 | 7 | from purerpc.test_utils import (run_purerpc_service_in_process, run_grpc_service_in_process, grpc_client_parallelize, 8 | grpc_channel, purerpc_channel) 9 | 10 | pytestmark = pytest.mark.anyio 11 | 12 | 13 | METADATA = ( 14 | ("name", "World"), 15 | ("name", "World2"), 16 | ("name-bin", b"1234"), 17 | ("name-bin", b"123"), 18 | ("true-bin", b"\x00\x00") 19 | ) 20 | 21 | 22 | @pytest.fixture(scope="module") 23 | def purerpc_port(greeter_pb2, greeter_grpc): 24 | class Servicer(greeter_grpc.GreeterServicer): 25 | async def SayHello(self, message, request): 26 | return greeter_pb2.HelloReply(message=base64.b64encode(pickle.dumps( 27 | request.custom_metadata))) 28 | 29 | with run_purerpc_service_in_process(Servicer().service) as port: 30 | yield port 31 | 32 | 33 | @pytest.fixture(scope="module") 34 | def grpc_port(greeter_pb2, greeter_pb2_grpc): 35 | class Servicer(greeter_pb2_grpc.GreeterServicer): 36 | def SayHello(self, message, context): 37 | metadata = [] 38 | for key, value in context.invocation_metadata(): 39 | metadata.append((key, value)) 40 | metadata = tuple(metadata) 41 | return greeter_pb2.HelloReply(message=base64.b64encode(pickle.dumps(metadata))) 42 | 43 | with run_grpc_service_in_process(functools.partial( 44 | greeter_pb2_grpc.add_GreeterServicer_to_server, Servicer())) as port: 45 | yield port 46 | 47 | 48 | @grpc_client_parallelize(1) 49 | @grpc_channel("purerpc_port") 50 | def test_metadata_grpc_client(greeter_pb2, greeter_pb2_grpc, channel): 51 | stub = greeter_pb2_grpc.GreeterStub(channel) 52 | response = stub.SayHello(greeter_pb2.HelloRequest(name="World"), metadata=METADATA) 53 | 54 | received_metadata = pickle.loads(base64.b64decode(response.message)) 55 | # remove artifact of older grpcio versions 56 | if received_metadata[-1][0] == "accept-encoding": 57 | received_metadata = received_metadata[:-1] 58 | assert METADATA == received_metadata 59 | 60 | 61 | @purerpc_channel("grpc_port") 62 | async def test_metadata_grpc_server_purerpc_client(greeter_pb2, greeter_grpc, channel): 63 | stub = greeter_grpc.GreeterStub(channel) 64 | response = await stub.SayHello(greeter_pb2.HelloRequest(name="World"), metadata=METADATA) 65 | 66 | received_metadata = pickle.loads(base64.b64decode(response.message)) 67 | assert received_metadata[0][0] == "grpc-message-type" 68 | received_metadata = received_metadata[1:] 69 | assert METADATA == received_metadata 70 | 71 | 72 | @purerpc_channel("purerpc_port") 73 | async def test_metadata_purerpc_server_purerpc_client(greeter_pb2, greeter_grpc, channel): 74 | stub = greeter_grpc.GreeterStub(channel) 75 | response = await stub.SayHello(greeter_pb2.HelloRequest(name="World"), metadata=METADATA) 76 | 77 | received_metadata = pickle.loads(base64.b64decode(response.message)) 78 | assert METADATA == received_metadata 79 | -------------------------------------------------------------------------------- /tests/test_protoc_plugin.py: -------------------------------------------------------------------------------- 1 | import purerpc 2 | import unittest.mock 3 | import purerpc.server 4 | from purerpc.test_utils import compile_temp_proto 5 | 6 | 7 | def test_plugin(greeter_grpc): 8 | assert "GreeterServicer" in dir(greeter_grpc) 9 | assert "GreeterStub" in dir(greeter_grpc) 10 | 11 | GreeterServicer = greeter_grpc.GreeterServicer 12 | assert issubclass(GreeterServicer, purerpc.server.Servicer) 13 | assert "SayHello" in dir(GreeterServicer) 14 | assert callable(GreeterServicer.SayHello) 15 | assert "SayHelloToMany" in dir(GreeterServicer) 16 | 17 | assert callable(GreeterServicer.SayHelloToMany) 18 | assert "SayHelloGoodbye" in dir(GreeterServicer) 19 | assert callable(GreeterServicer.SayHelloGoodbye) 20 | assert "SayHelloToManyAtOnce" in dir(GreeterServicer) 21 | assert callable(GreeterServicer.SayHelloToManyAtOnce) 22 | assert isinstance(GreeterServicer().service, purerpc.Service) 23 | 24 | GreeterStub = greeter_grpc.GreeterStub 25 | channel = unittest.mock.MagicMock() 26 | greeter_stub = GreeterStub(channel) 27 | 28 | assert "SayHello" in dir(greeter_stub) 29 | assert callable(greeter_stub.SayHello) 30 | assert "SayHelloToMany" in dir(greeter_stub) 31 | assert callable(greeter_stub.SayHelloToMany) 32 | assert "SayHelloGoodbye" in dir(greeter_stub) 33 | assert callable(greeter_stub.SayHelloGoodbye) 34 | assert "SayHelloToManyAtOnce" in dir(greeter_stub) 35 | assert callable(greeter_stub.SayHelloToManyAtOnce) 36 | 37 | 38 | def test_package_names_and_imports(): 39 | with compile_temp_proto('data/test_package_names/A.proto', 40 | 'data/test_package_names/B.proto', 41 | 'data/test_package_names/C.proto'): 42 | # modules are imported by context manager 43 | # if there is no error then we are good. 44 | pass 45 | -------------------------------------------------------------------------------- /tests/test_server_http2.py: -------------------------------------------------------------------------------- 1 | import time 2 | import socket 3 | import contextlib 4 | 5 | import pytest 6 | import h2.config 7 | import h2.connection 8 | import h2.events 9 | import h2.settings 10 | 11 | import purerpc 12 | import purerpc.grpclib.connection 13 | 14 | from purerpc.test_utils import run_purerpc_service_in_process 15 | 16 | 17 | @pytest.fixture 18 | def dummy_server_port(): 19 | with run_purerpc_service_in_process(purerpc.Service("Greeter")) as port: 20 | # TODO: migrate to serve_async() to avoid timing problems 21 | time.sleep(0.1) 22 | yield port 23 | 24 | 25 | @contextlib.contextmanager 26 | def http2_client_connect(host, port): 27 | sock = socket.socket(socket.AF_INET) 28 | sock.connect((host, port)) 29 | config = h2.config.H2Configuration(client_side=True, header_encoding="utf-8") 30 | conn = h2.connection.H2Connection(config=config) 31 | conn.initiate_connection() 32 | sock.send(conn.data_to_send()) 33 | 34 | try: 35 | yield conn, sock 36 | finally: 37 | sock.close() 38 | 39 | 40 | def http2_receive_events(conn, sock): 41 | try: 42 | sock.settimeout(0.1) 43 | events = [] 44 | while True: 45 | try: 46 | data = sock.recv(4096) 47 | except socket.timeout: 48 | break 49 | if not data: 50 | break 51 | events.extend(conn.receive_data(data)) 52 | finally: 53 | sock.settimeout(None) 54 | return events 55 | 56 | 57 | def test_connection(dummy_server_port): 58 | ping_data = b"DEADBEEF" 59 | 60 | with http2_client_connect("localhost", dummy_server_port) as (conn, sock): 61 | conn.ping(ping_data) 62 | sock.send(conn.data_to_send()) 63 | 64 | events_received = http2_receive_events(conn, sock) 65 | event_types_received = list(map(type, events_received)) 66 | 67 | assert h2.events.RemoteSettingsChanged in event_types_received 68 | assert h2.events.WindowUpdated in event_types_received 69 | assert h2.events.SettingsAcknowledged in event_types_received 70 | assert h2.events.PingAckReceived in event_types_received 71 | 72 | for event in events_received: 73 | if isinstance(event, h2.events.RemoteSettingsChanged): 74 | assert (event.changed_settings[h2.settings.SettingCodes.MAX_FRAME_SIZE].new_value == 75 | purerpc.grpclib.connection.GRPCConnection.MAX_INBOUND_FRAME_SIZE) 76 | 77 | assert (event.changed_settings[h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS].new_value == 78 | purerpc.grpclib.connection.GRPCConnection.MAX_CONCURRENT_STREAMS) 79 | 80 | assert (event.changed_settings[h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE].new_value == 81 | purerpc.grpclib.connection.GRPCConnection.MAX_HEADER_LIST_SIZE) 82 | elif isinstance(event, h2.events.PingAckReceived): 83 | assert event.ping_data == ping_data 84 | 85 | time.sleep(0.2) 86 | 87 | conn.ping(ping_data) 88 | sock.send(conn.data_to_send()) 89 | 90 | events_received = http2_receive_events(conn, sock) 91 | event_types_received = list(map(type, events_received)) 92 | assert h2.events.PingAckReceived in event_types_received 93 | for event in events_received: 94 | if isinstance(event, h2.events.PingAckReceived): 95 | assert event.ping_data == ping_data 96 | -------------------------------------------------------------------------------- /tests/test_status_codes.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import pickle 3 | import base64 4 | import string 5 | import re 6 | 7 | import pytest 8 | 9 | import purerpc 10 | from purerpc.test_utils import run_purerpc_service_in_process, run_grpc_service_in_process, grpc_channel, \ 11 | grpc_client_parallelize, purerpc_channel 12 | 13 | pytestmark = pytest.mark.anyio 14 | 15 | 16 | def regex_and(first, second): 17 | return re.compile(r"(?=.*{first})(?=.*{second}).*".format(first=re.escape(first), second=re.escape(second)), 18 | flags=re.DOTALL) 19 | 20 | 21 | STATUS_CODES = [ 22 | (purerpc.CancelledError, "CANCELLED", "percent encoded message: %"), 23 | (purerpc.UnknownError, "UNKNOWN", "привет"), 24 | (purerpc.InvalidArgumentError, "INVALID_ARGUMENT", "\r\n"), 25 | (purerpc.DeadlineExceededError, "DEADLINE_EXCEEDED", string.printable), 26 | (purerpc.NotFoundError, "NOT_FOUND", "message:" + string.whitespace), 27 | (purerpc.AlreadyExistsError, "ALREADY_EXISTS", "detailed message"), 28 | (purerpc.PermissionDeniedError, "PERMISSION_DENIED", "detailed message"), 29 | (purerpc.ResourceExhaustedError, "RESOURCE_EXHAUSTED", "detailed message"), 30 | (purerpc.FailedPreconditionError, "FAILED_PRECONDITION", "detailed message"), 31 | (purerpc.AbortedError, "ABORTED", "detailed message"), 32 | (purerpc.OutOfRangeError, "OUT_OF_RANGE", "detailed message"), 33 | (purerpc.UnimplementedError, "UNIMPLEMENTED", "detailed message"), 34 | (purerpc.InternalError, "INTERNAL", "detailed message"), 35 | (purerpc.UnavailableError, "UNAVAILABLE", "detailed message"), 36 | (purerpc.DataLossError, "DATA_LOSS", "detailed message"), 37 | (purerpc.UnauthenticatedError, "UNAUTHENTICATED", "detailed message"), 38 | ] 39 | 40 | 41 | @pytest.fixture(scope="module") 42 | def purerpc_port(greeter_pb2): 43 | service = purerpc.Service("Greeter") 44 | 45 | @service.rpc("SayHello") 46 | async def say_hello(message: greeter_pb2.HelloRequest) -> greeter_pb2.HelloReply: 47 | status_code_tuple = pickle.loads(base64.b64decode(message.name)) 48 | raise status_code_tuple[0](status_code_tuple[2]) 49 | 50 | with run_purerpc_service_in_process(service) as port: 51 | yield port 52 | 53 | 54 | @pytest.fixture(scope="module") 55 | def grpc_port(greeter_pb2, greeter_pb2_grpc): 56 | class Servicer(greeter_pb2_grpc.GreeterServicer): 57 | def SayHello(self, message, context): 58 | import grpc 59 | status_code_tuple = pickle.loads(base64.b64decode(message.name)) 60 | context.abort(getattr(grpc.StatusCode, status_code_tuple[1]), status_code_tuple[2]) 61 | 62 | with run_grpc_service_in_process(functools.partial( 63 | greeter_pb2_grpc.add_GreeterServicer_to_server, Servicer())) as port: 64 | yield port 65 | 66 | 67 | @pytest.fixture(scope="module", 68 | params=["purerpc_port", "grpc_port"]) 69 | def port(request): 70 | return request.getfixturevalue(request.param) 71 | 72 | 73 | @pytest.fixture 74 | def purerpc_server_wrong_service_name_port(greeter_pb2): 75 | service = purerpc.Service("some_package.SomeWrongServiceName") 76 | 77 | @service.rpc("SayHello") 78 | async def say_hello(message: greeter_pb2.HelloRequest) -> greeter_pb2.HelloReply: 79 | return greeter_pb2.HelloReply(message="Hello, " + message.name) 80 | 81 | with run_purerpc_service_in_process(service) as port: 82 | yield port 83 | 84 | 85 | @pytest.fixture 86 | def purerpc_server_wrong_method_name_port(greeter_pb2): 87 | service = purerpc.Service("Greeter") 88 | 89 | @service.rpc("SomeOtherMethod") 90 | async def say_hello(message: greeter_pb2.HelloRequest) -> greeter_pb2.HelloReply: 91 | return greeter_pb2.HelloReply(message="Hello, " + message.name) 92 | 93 | with run_purerpc_service_in_process(service) as port: 94 | yield port 95 | 96 | 97 | @pytest.fixture 98 | def grpc_empty_servicer_port(greeter_pb2_grpc): 99 | class Servicer(greeter_pb2_grpc.GreeterServicer): 100 | pass 101 | 102 | with run_grpc_service_in_process(functools.partial( 103 | greeter_pb2_grpc.add_GreeterServicer_to_server, Servicer())) as port: 104 | yield port 105 | 106 | 107 | @grpc_client_parallelize(1) 108 | @grpc_channel("purerpc_server_wrong_service_name_port") 109 | def test_grpc_client_wrong_service_name(greeter_pb2, greeter_pb2_grpc, channel): 110 | stub = greeter_pb2_grpc.GreeterStub(channel) 111 | with pytest.raises(BaseException, match=r"not implemented"): 112 | stub.SayHello(greeter_pb2.HelloRequest(name="World")) 113 | 114 | 115 | @grpc_client_parallelize(1) 116 | @grpc_channel("purerpc_server_wrong_method_name_port") 117 | def test_grpc_client_wrong_method_name(greeter_pb2, greeter_pb2_grpc, channel): 118 | stub = greeter_pb2_grpc.GreeterStub(channel) 119 | with pytest.raises(BaseException, match=r"not implemented"): 120 | stub.SayHello(greeter_pb2.HelloRequest(name="World")) 121 | 122 | @purerpc_channel("grpc_empty_servicer_port") 123 | async def test_purerpc_client_empty_servicer(greeter_pb2, greeter_grpc, channel): 124 | stub = greeter_grpc.GreeterStub(channel) 125 | with pytest.raises(purerpc.UnimplementedError): 126 | await stub.SayHello(greeter_pb2.HelloRequest(name="World")) 127 | 128 | 129 | @pytest.mark.parametrize("status_code_tuple", STATUS_CODES) 130 | @grpc_client_parallelize(1) 131 | @grpc_channel("purerpc_port") 132 | def test_grpc_client_status_codes(status_code_tuple, greeter_pb2, greeter_pb2_grpc, channel): 133 | stub = greeter_pb2_grpc.GreeterStub(channel) 134 | with pytest.raises(BaseException, match=regex_and(status_code_tuple[1], status_code_tuple[2])): 135 | stub.SayHello(greeter_pb2.HelloRequest(name=base64.b64encode(pickle.dumps(status_code_tuple)))) 136 | 137 | 138 | @pytest.mark.parametrize("status_code_tuple", STATUS_CODES) 139 | @purerpc_channel("port") 140 | async def test_purerpc_client_status_codes(status_code_tuple, greeter_pb2, greeter_grpc, channel): 141 | purerpc_exception = status_code_tuple[0] 142 | stub = greeter_grpc.GreeterStub(channel) 143 | with pytest.raises(purerpc_exception, match=regex_and(status_code_tuple[1], status_code_tuple[2])): 144 | await stub.SayHello(greeter_pb2.HelloRequest(name=base64.b64encode(pickle.dumps(status_code_tuple)))) 145 | -------------------------------------------------------------------------------- /tests/test_test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pytest 5 | import traceback 6 | import multiprocessing 7 | 8 | from purerpc.test_utils import run_tests_in_workers, _run_context_manager_generator_in_process 9 | 10 | 11 | def test_run_tests_in_workers_error(): 12 | def target_fn(): 13 | def inner_2(): 14 | def inner_1(): 15 | raise ValueError("42") 16 | inner_1() 17 | inner_2() 18 | 19 | with pytest.raises(ValueError, match="42"): 20 | run_tests_in_workers(target=target_fn, num_workers=1) 21 | 22 | 23 | def test_run_tests_in_workers_error_traceback(): 24 | def target_fn(): 25 | def inner_2(): 26 | def inner_1(): 27 | raise ValueError("42") 28 | inner_1() 29 | inner_2() 30 | 31 | try: 32 | run_tests_in_workers(target=target_fn, num_workers=1) 33 | except ValueError: 34 | exc_type, exc_value, exc_traceback = sys.exc_info() 35 | lines = traceback.format_tb(exc_traceback) 36 | expected_traceback = ("target_fn", "inner_2", "inner_1") 37 | for expected_fname, line in zip(expected_traceback[::-1], lines[::-1]): 38 | assert expected_fname in line 39 | 40 | 41 | def test_run_tests_in_workers(): 42 | num_workers = 10 43 | queue = multiprocessing.Queue() 44 | def target_fn(): 45 | queue.put(os.getpid()) 46 | 47 | 48 | run_tests_in_workers(target=target_fn, num_workers=num_workers) 49 | pids = set() 50 | for _ in range(num_workers): 51 | pid = queue.get_nowait() 52 | pids.add(pid) 53 | assert len(pids) == num_workers 54 | 55 | 56 | def test_run_context_manager_generator_in_process(): 57 | def gen(): 58 | yield 42 59 | 60 | with _run_context_manager_generator_in_process(gen) as result: 61 | assert result == 42 62 | 63 | 64 | def test_run_context_manager_generator_in_process_error_before(): 65 | def gen(): 66 | raise ValueError("42") 67 | 68 | with pytest.raises(ValueError, match="42"): 69 | with _run_context_manager_generator_in_process(gen) as result: 70 | assert result == 42 71 | 72 | 73 | def test_run_context_manager_generator_in_process_error_after(): 74 | def gen(): 75 | yield 42 76 | raise ValueError("42") 77 | 78 | with pytest.raises(ValueError, match="42"): 79 | with _run_context_manager_generator_in_process(gen) as result: 80 | assert result == 42 81 | time.sleep(0.1) 82 | 83 | 84 | def test_run_context_manager_generator_in_process_error_traceback(): 85 | def gen(): 86 | def inner_2(): 87 | def inner_1(): 88 | raise ValueError("42") 89 | inner_1() 90 | inner_2() 91 | 92 | try: 93 | with _run_context_manager_generator_in_process(gen): 94 | pass 95 | except ValueError: 96 | exc_type, exc_value, exc_traceback = sys.exc_info() 97 | lines = traceback.format_tb(exc_traceback) 98 | expected_traceback = ("gen", "inner_2", "inner_1") 99 | for expected_fname, line in zip(expected_traceback[::-1], lines[::-1]): 100 | assert expected_fname in line 101 | --------------------------------------------------------------------------------