├── .gitignore ├── LICENSE ├── README.md ├── example ├── README.md ├── __init__.py ├── async_client.py ├── client.py ├── google │ └── protobuf │ │ ├── empty.proto │ │ ├── empty_connecpy.py │ │ ├── empty_pb2.py │ │ └── empty_pb2.pyi ├── grpc_client.py ├── grpc_server.py ├── haberdasher.proto ├── haberdasher_connecpy.py ├── haberdasher_pb2.py ├── haberdasher_pb2.pyi ├── haberdasher_pb2_grpc.py ├── requirements.txt ├── server.py ├── service.py ├── wsgi_server.py └── wsgi_service.py ├── protoc-gen-connecpy ├── README.md ├── generator │ ├── generator.go │ ├── generator_test.go │ ├── template.go │ └── template_test.go ├── go.mod ├── go.sum └── main.go ├── pyproject.toml ├── src └── connecpy │ ├── __init__.py │ ├── asgi.py │ ├── async_client.py │ ├── base.py │ ├── client.py │ ├── compression.py │ ├── context.py │ ├── cors.py │ ├── encoding.py │ ├── errors.py │ ├── exceptions.py │ ├── interceptor.py │ ├── server.py │ ├── shared_client.py │ └── wsgi.py ├── test └── test_cors.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Binaries for programs and plugins 107 | *.exe 108 | *.exe~ 109 | *.dll 110 | *.so 111 | *.dylib 112 | 113 | # Test binary, build with `go test -c` 114 | *.test 115 | 116 | # Output of the go coverage tool, specifically when used with LiteIDE 117 | *.out 118 | 119 | .vscode 120 | 121 | # generated Go binary 122 | protoc-gen-connecpy/protoc-gen-connecpy 123 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Connecpy 2 | 3 | Python implementation of [Connect Protocol](https://connectrpc.com/docs/protocol). 4 | 5 | This repo contains a protoc plugin that generates sever and client code and a pypi package with common implementation details. 6 | 7 | ## Installation 8 | 9 | You can install the protoc plugin to generate files by running the command: 10 | 11 | ```sh 12 | go install github.com/i2y/connecpy/protoc-gen-connecpy@latest 13 | ``` 14 | 15 | Additionally, please add the connecpy package to your project using your preferred package manager. For instance, with [uv](https://docs.astral.sh/uv/), use the command: 16 | 17 | 18 | ```sh 19 | uv add connecpy 20 | ``` 21 | 22 | or 23 | 24 | ```sh 25 | pip install connecpy 26 | ``` 27 | 28 | 29 | To run the server, you'll need one of the following: [Uvicorn](https://www.uvicorn.org/), [Daphne](https://github.com/django/daphne), or [Hypercorn](https://gitlab.com/pgjones/hypercorn). If your goal is to support both HTTP/1.1 and HTTP/2, you should opt for either Daphne or Hypercorn. Additionally, to test the server, you might need a client command, such as [buf](https://buf.build/docs/installation). 30 | 31 | 32 | ## Generate and run 33 | 34 | Use the protoc plugin to generate connecpy server and client code. 35 | 36 | ```sh 37 | protoc --python_out=./ --pyi_out=./ --connecpy_out=./ ./haberdasher.proto 38 | ``` 39 | 40 | ### Server code (ASGI) 41 | 42 | ```python 43 | # service.py 44 | import random 45 | 46 | from connecpy.exceptions import InvalidArgument 47 | from connecpy.context import ServiceContext 48 | 49 | from haberdasher_pb2 import Hat, Size 50 | 51 | 52 | class HaberdasherService(object): 53 | async def MakeHat(self, req: Size, ctx: ServiceContext) -> Hat: 54 | print("remaining_time: ", ctx.time_remaining()) 55 | if req.inches <= 0: 56 | raise InvalidArgument( 57 | argument="inches", error="I can't make a hat that small!" 58 | ) 59 | response = Hat( 60 | size=req.inches, 61 | color=random.choice(["white", "black", "brown", "red", "blue"]), 62 | ) 63 | if random.random() > 0.5: 64 | response.name = random.choice( 65 | ["bowler", "baseball cap", "top hat", "derby"] 66 | ) 67 | 68 | return response 69 | ``` 70 | 71 | ```python 72 | # server.py 73 | from connecpy import context 74 | from connecpy.asgi import ConnecpyASGIApp 75 | 76 | import haberdasher_connecpy 77 | from service import HaberdasherService 78 | 79 | service = haberdasher_connecpy.HaberdasherServer( 80 | service=HaberdasherService() 81 | ) 82 | app = ConnecpyASGIApp() 83 | app.add_service(service) 84 | ``` 85 | 86 | Run the server with 87 | ```sh 88 | uvicorn --port=3000 server:app 89 | ``` 90 | or 91 | 92 | ```sh 93 | daphne --port=3000 server:app 94 | ``` 95 | 96 | or 97 | 98 | ```sh 99 | hypercorn --bind :3000 server:app 100 | ``` 101 | 102 | ### Client code (Asyncronous) 103 | 104 | ```python 105 | # async_client.py 106 | import asyncio 107 | 108 | import httpx 109 | 110 | from connecpy.context import ClientContext 111 | from connecpy.exceptions import ConnecpyServerException 112 | 113 | import haberdasher_connecpy, haberdasher_pb2 114 | 115 | 116 | server_url = "http://localhost:3000" 117 | timeout_s = 5 118 | 119 | 120 | async def main(): 121 | session = httpx.AsyncClient( 122 | base_url=server_url, 123 | timeout=timeout_s, 124 | ) 125 | client = haberdasher_connecpy.AsyncHaberdasherClient(server_url, session=session) 126 | 127 | try: 128 | response = await client.MakeHat( 129 | ctx=ClientContext(), 130 | request=haberdasher_pb2.Size(inches=12), 131 | # Optionally provide a session per request 132 | # session=session, 133 | ) 134 | if not response.HasField("name"): 135 | print("We didn't get a name!") 136 | print(response) 137 | except ConnecpyServerException as e: 138 | print(e.code, e.message, e.to_dict()) 139 | finally: 140 | # Close the session (could also use a context manager) 141 | await session.aclose() 142 | 143 | 144 | if __name__ == "__main__": 145 | asyncio.run(main()) 146 | ``` 147 | 148 | Example output : 149 | ``` 150 | size: 12 151 | color: "black" 152 | name: "bowler" 153 | ``` 154 | 155 | ## Client code (Synchronous) 156 | 157 | ```python 158 | # client.py 159 | from connecpy.context import ClientContext 160 | from connecpy.exceptions import ConnecpyServerException 161 | 162 | import haberdasher_connecpy, haberdasher_pb2 163 | 164 | 165 | server_url = "http://localhost:3000" 166 | timeout_s = 5 167 | 168 | 169 | def main(): 170 | client = haberdasher_connecpy.HaberdasherClient(server_url, timeout=timeout_s) 171 | 172 | try: 173 | response = client.MakeHat( 174 | ctx=ClientContext(), 175 | request=haberdasher_pb2.Size(inches=12), 176 | ) 177 | if not response.HasField("name"): 178 | print("We didn't get a name!") 179 | print(response) 180 | except ConnecpyServerException as e: 181 | print(e.code, e.message, e.to_dict()) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | ``` 187 | 188 | ### Other clients 189 | 190 | Of course, you can use any HTTP client to make requests to a Connecpy server. For example, commands like `curl` or `buf curl` can be used, as well as HTTP client libraries such as `requests`, `httpx`, `aiohttp`, and others. The examples below use `curl` and `buf curl`. 191 | 192 | Content-Type: application/proto, HTTP/1.1 193 | ```sh 194 | buf curl --data '{"inches": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat --schema ./haberdasher.proto 195 | ``` 196 | 197 | On Windows, Content-Type: application/proto, HTTP/1.1 198 | ```sh 199 | buf curl --data '{\"inches\": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat --schema .\haberdasher.proto 200 | ``` 201 | 202 | Content-Type: application/proto, HTTP/2 203 | ```sh 204 | buf curl --data '{"inches": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat --http2-prior-knowledge --schema ./haberdasher.proto 205 | ``` 206 | 207 | On Windows, Content-Type: application/proto, HTTP/2 208 | ```sh 209 | buf curl --data '{\"inches\": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat --http2-prior-knowledge --schema .\haberdasher.proto 210 | ``` 211 | 212 | 213 | Content-Type: application/json, HTTP/1.1 214 | ```sh 215 | curl -X POST -H "Content-Type: application/json" -d '{"inches": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat 216 | ``` 217 | 218 | On Windows, Content-Type: application/json, HTTP/1.1 219 | ```sh 220 | curl -X POST -H "Content-Type: application/json" -d '{\"inches\": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat 221 | ``` 222 | 223 | Content-Type: application/json, HTTP/2 224 | ```sh 225 | curl --http2-prior-knowledge -X POST -H "Content-Type: application/json" -d '{"inches": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat 226 | ``` 227 | 228 | On Windows, Content-Type: application/json, HTTP/2 229 | ```sh 230 | curl --http2-prior-knowledge -X POST -H "Content-Type: application/json" -d '{\"inches\": 12}' -v http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat 231 | ``` 232 | 233 | ## WSGI Support 234 | 235 | Connecpy now provides WSGI support via the `ConnecpyWSGIApp`. This synchronous application adapts our service endpoints to the WSGI specification. It reads requests from the WSGI `environ`, processes POST requests, and returns responses using `start_response`. This enables integration with legacy WSGI servers and middleware. 236 | 237 | Please see the example in the [example directory](example/wsgi_server.py). 238 | 239 | ## Compression Support 240 | 241 | Connecpy supports various compression methods for both GET and POST requests/responses: 242 | 243 | - gzip 244 | - brotli (br) 245 | - zstandard (zstd) 246 | - identity (no compression) 247 | 248 | For GET requests, specify the compression method using the `compression` query parameter: 249 | ```sh 250 | curl "http://localhost:3000/service/method?compression=gzip&message=..." 251 | ``` 252 | 253 | For POST requests, use the `Content-Encoding` header: 254 | ```sh 255 | curl -H "Content-Encoding: br" -d '{"data": "..."}' http://localhost:3000/service/method 256 | ``` 257 | 258 | The compression is handled directly in the request handlers, ensuring consistent behavior across HTTP methods and frameworks (ASGI/WSGI). 259 | 260 | With Connecpy's compression features, you can automatically handle compressed requests and responses. Here are some examples: 261 | 262 | ### Server-side 263 | 264 | The compression handling is built into both ASGI and WSGI applications. You don't need any additional middleware configuration - it works out of the box! 265 | 266 | ### Client-side 267 | 268 | For synchronous clients: 269 | ```python 270 | from connecpy.context import ClientContext 271 | 272 | client = HaberdasherClient(server_url) 273 | response = client.MakeHat( 274 | ctx=ClientContext( 275 | headers={ 276 | "Content-Encoding": "br", # Use Brotli compression for request 277 | "Accept-Encoding": "gzip", # Accept gzip compressed response 278 | } 279 | ), 280 | request=request_obj, 281 | ) 282 | ``` 283 | 284 | For async clients: 285 | ```python 286 | async with httpx.AsyncClient() as session: 287 | client = AsyncHaberdasherClient(server_url, session=session) 288 | response = await client.MakeHat( 289 | ctx=ClientContext(), 290 | request=request_obj, 291 | headers={ 292 | "Content-Encoding": "zstd", # Use Zstandard compression for request 293 | "Accept-Encoding": "br", # Accept Brotli compressed response 294 | }, 295 | ) 296 | ``` 297 | 298 | Using GET requests with compression: 299 | ```python 300 | response = client.MakeHat( 301 | ctx=ClientContext(), 302 | request=request_obj, 303 | use_get=True, # Enable GET request (for methods marked with no_side_effects) 304 | params={ 305 | "compression": "gzip", # Use gzip compression for the message 306 | } 307 | ) 308 | ``` 309 | 310 | ### CORS Support 311 | 312 | Connecpy provides built-in CORS support via the `CORSMiddleware`. By default, it allows all origins and includes necessary Connect Protocol headers: 313 | 314 | ```python 315 | from connecpy.cors import CORSMiddleware 316 | 317 | app = ConnecpyASGIApp() 318 | app.add_service(service) 319 | app = CORSMiddleware(app) # Use default configuration 320 | ``` 321 | 322 | You can customize the CORS behavior using `CORSConfig`: 323 | 324 | ```python 325 | from connecpy.cors import CORSMiddleware, CORSConfig 326 | 327 | config = CORSConfig( 328 | allow_origin="https://your-domain.com", # Restrict allowed origins 329 | allow_methods=("POST", "GET", "OPTIONS"), # Customize allowed methods 330 | allow_headers=( # Customize allowed headers 331 | "Content-Type", 332 | "Connect-Protocol-Version", 333 | "X-Custom-Header", 334 | ), 335 | access_control_max_age=3600, # Set preflight cache duration 336 | ) 337 | 338 | app = CORSMiddleware(app, config=config) 339 | ``` 340 | 341 | The middleware handles both preflight requests (OPTIONS) and adds appropriate CORS headers to responses. 342 | 343 | ## Connect Protocol 344 | 345 | Connecpy protoc plugin generates the code based on [Connect Protocl](https://connectrpc.com/docs/protocol) from the `.proto` files. 346 | Currently, Connecpy supports only Unary RPCs using the POST HTTP method. Connecpy will support other types of RPCs as well, in the near future. 347 | 348 | ## Misc 349 | 350 | ### Server Path Prefix 351 | 352 | You can set server path prefix by passing `server_path_prefix` to `ConnecpyASGIApp` constructor. 353 | 354 | This example sets server path prefix to `/foo/bar`. 355 | ```python 356 | # server.py 357 | service = haberdasher_connecpy.HaberdasherServer( 358 | service=HaberdasherService(), 359 | server_path_prefix="/foo/bar", 360 | ) 361 | ``` 362 | 363 | ```python 364 | # async_client.py 365 | response = await client.MakeHat( 366 | ctx=ClientContext(), 367 | request=haberdasher_pb2.Size(inches=12), 368 | server_path_prefix="/foo/bar", 369 | ) 370 | ``` 371 | 372 | ### Interceptor (Server Side) 373 | 374 | ConnecpyASGIApp supports interceptors. You can add interceptors by passing `interceptors` to `ConnecpyASGIApp` constructor. 375 | AsyncConnecpyServerInterceptor 376 | 377 | ```python 378 | # server.py 379 | from typing import Any, Callable 380 | 381 | from connecpy import context 382 | from connecpy.asgi import ConnecpyASGIApp 383 | from connecpy.interceptor import AsyncConnecpyServerInterceptor 384 | 385 | import haberdasher_connecpy 386 | from service import HaberdasherService 387 | 388 | 389 | class MyInterceptor(AsyncConnecpyServerInterceptor): 390 | def __init__(self, msg): 391 | self._msg = msg 392 | 393 | async def intercept( 394 | self, 395 | method: Callable, 396 | request: Any, 397 | ctx: context.ServiceContext, 398 | method_name: str, 399 | ) -> Any: 400 | print("intercepting " + method_name + " with " + self._msg) 401 | return await method(request, ctx) 402 | 403 | 404 | my_interceptor_a = MyInterceptor("A") 405 | my_interceptor_b = MyInterceptor("B") 406 | 407 | service = haberdasher_connecpy.HaberdasherServer(service=HaberdasherService()) 408 | app = ConnecpyASGIApp( 409 | interceptors=(my_interceptor_a, my_interceptor_b), 410 | ) 411 | app.add_service(service) 412 | ``` 413 | 414 | Btw, `ConnecpyServerInterceptor`'s `intercept` method has compatible signature as `intercept` method of [grpc_interceptor.server.AsyncServerInterceptor](https://grpc-interceptor.readthedocs.io/en/latest/#async-server-interceptors), so you might be able to convert Connecpy interceptors to gRPC interceptors by just changing the import statement and the parent class. 415 | 416 | 417 | ### gRPC Compatibility 418 | In Connecpy, unlike connect-go, it is not possible to simultaneously support both gRPC and Connect RPC on the same server and port. In addition to it, Connecpy itself doesn't support gRPC. However, implementing a gRPC server using the same service code used for Connecpy server is feasible, as shown below. This is possible because the type signature of the service class in Connecpy is compatible with type signature gRPC farmework requires. 419 | The example below uses [grpc.aio](https://grpc.github.io/grpc/python/grpc_asyncio.html) and there are in [example dicrectory](example/README.md). 420 | 421 | 422 | ```python 423 | # grpc_server.py 424 | import asyncio 425 | 426 | from grpc.aio import server 427 | 428 | import haberdasher_pb2_grpc 429 | 430 | # same service.py as the one used in previous server.py 431 | from service import HaberdasherService 432 | 433 | host = "localhost:50051" 434 | 435 | 436 | async def main(): 437 | s = server() 438 | haberdasher_pb2_grpc.add_HaberdasherServicer_to_server(HaberdasherService(), s) 439 | bound_port = s.add_insecure_port(host) 440 | print(f"localhost:{bound_port}") 441 | await s.start() 442 | await s.wait_for_termination() 443 | 444 | 445 | if __name__ == "__main__": 446 | asyncio.run(main()) 447 | ``` 448 | 449 | ```python 450 | # grpc_client.py 451 | import asyncio 452 | 453 | from grpc.aio import insecure_channel 454 | 455 | import haberdasher_pb2 456 | import haberdasher_pb2_grpc 457 | 458 | 459 | target = "localhost:50051" 460 | 461 | 462 | async def main(): 463 | channel = insecure_channel(target) 464 | stub = haberdasher_pb2_grpc.HaberdasherStub(channel) 465 | request = haberdasher_pb2.Size(inches=12) 466 | response = await stub.MakeHat(request) 467 | print(response) 468 | 469 | 470 | if __name__ == "__main__": 471 | asyncio.run(main()) 472 | ``` 473 | 474 | ### Message Body Length 475 | 476 | Currently, message body length limit is set to 100kb, you can override this by passing `max_receive_message_length` to `ConnecpyASGIApp` constructor. 477 | 478 | ```python 479 | # this sets max message length to be 10 bytes 480 | app = ConnecpyASGIApp(max_receive_message_length=10) 481 | 482 | ``` 483 | 484 | ## Standing on the shoulders of giants 485 | 486 | The initial version (1.0.0) of this software was created by modifying https://github.com/verloop/twirpy at January 4, 2024, so that it supports Connect Protocol. Therefore, this software is also licensed under Unlicense same as twirpy. 487 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # Connecpy Examples 2 | 3 | ## Requirements 4 | You can install requirements for running the examples with executing the following command on the root directory of this repository: 5 | ```sh 6 | uv sync 7 | ``` 8 | 9 | ## Running server (ASGI) 10 | You can run the Connecpy ASGI application using servers such as uvicorn. 11 | Example uvicorn command : `uv run uvicorn --port 3000 server:app` 12 | 13 | ## Running server (WSGI) 14 | You can run the Connecpy WSGI application using the following command : 15 | `uv run wsgi_server.py` 16 | 17 | ## Running client 18 | After server has started, you can make request using the example client with the following command : 19 | `uv run async_client.py` or `uv run client.py` 20 | 21 | Example output : 22 | ``` 23 | size: 12 24 | color: "black" 25 | name: "bowler" 26 | ``` 27 | 28 | ## Generated code 29 | The example server and client makes use of code generated by `protoc-gen-connecpy` plugin. You can find the code in `.` folder. To generate the code yourself for `haberdasher`, use the steps given below: 30 | 1. Install `protoc-gen-connecpy` plugin using steps given [here](/README.md) 31 | 2. Generate code for `haberdasher` using connecpy plugin : 32 | `uv run python -m grpc_tools.protoc --python_out=. --pyi_out=. --connecpy_out=. -I . haberdasher.proto` 33 | or 34 | `protoc --python_out=. --pyi_out=. --connecpy_out=. -I . haberdasher.proto` 35 | - python_out : The directory where generated Protobuf Python code needs to be saved. 36 | - pyi_out : The directory where generated Python stub code for Protobuf Messages needs to be saved. 37 | - connecpy_out : The directory where generated Connecpy Python server and client code needs to be saved. 38 | 39 | ## Making Compressed Requests 40 | 41 | After starting the server, you can test compression in several ways: 42 | 43 | ### Using the Example Client 44 | The example clients demonstrate how to use compression: 45 | 46 | ```sh 47 | # Async client with Brotli compression 48 | uv run async_client.py 49 | 50 | # Sync client with gzip compression 51 | uv run client.py 52 | ``` 53 | 54 | ### Using curl 55 | Test different compression methods: 56 | 57 | ```sh 58 | # POST request with Brotli compression 59 | curl -X POST \ 60 | -H "Content-Type: application/json" \ 61 | -H "Content-Encoding: br" \ 62 | -d '{"inches": 12}' \ 63 | http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat 64 | 65 | # GET request with gzip compression 66 | curl "http://localhost:3000/i2y.connecpy.example.Haberdasher/MakeHat?compression=gzip&message=eyJpbmNoZXMiOjEyfQ==&base64=1" 67 | ``` 68 | 69 | Note: For GET requests, the message should be base64 encoded when using binary formats like protobuf or compressed data. 70 | 71 | # gRPC Examples 72 | Connecpy does not support gRPC. However, you can use the same service code to run a gRPC server and client. 73 | `grpc_server.py` uses the same service code as `server.py`. 74 | 75 | 76 | ## Runnning gRPC server 77 | You can run the gRPC server using the following command : 78 | `uv run grpc_server.py` 79 | 80 | 81 | ## Running gRPC client 82 | After gRPC server has started, you can make request using the example client with the following command : 83 | `uv run grpc_client.py` 84 | 85 | ## Generated code 86 | `uv run python -m grpc_tools.protoc --python_out=. --pyi_out=. --grpc_python_out=. -I . haberdasher.proto` 87 | - python_out : The directory where generated Protobuf Python code needs to be saved. 88 | - pyi_out : The directory where generated Python stub code for Protobuf Messages needs to be saved. 89 | - grpc_python_out : The directory where generated gRPC Python server and client stub code needs to be saved. 90 | -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i2y/connecpy/6941d57ac24be366b19308a7dc223b239a4f06e7/example/__init__.py -------------------------------------------------------------------------------- /example/async_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | 5 | from connecpy.context import ClientContext 6 | from connecpy.exceptions import ConnecpyServerException 7 | 8 | import haberdasher_connecpy 9 | import haberdasher_pb2 10 | 11 | 12 | server_url = "http://localhost:3000" 13 | timeout_s = 5 14 | 15 | 16 | async def main(): 17 | async with httpx.AsyncClient( 18 | base_url=server_url, 19 | timeout=timeout_s, 20 | ) as session: 21 | client = haberdasher_connecpy.AsyncHaberdasherClient( 22 | server_url, session=session 23 | ) 24 | 25 | # Example 1: POST request with Zstandard compression 26 | try: 27 | response = await client.MakeHat( 28 | ctx=ClientContext(), 29 | request=haberdasher_pb2.Size(inches=12), 30 | headers={ 31 | "Content-Encoding": "zstd", # Request compression 32 | "Accept-Encoding": "br", # Response compression 33 | }, 34 | ) 35 | print("POST with Zstandard compression:", response) 36 | except ConnecpyServerException as e: 37 | print(e.code, e.message, e.to_dict()) 38 | 39 | # Example 2: GET request with Brotli compression 40 | try: 41 | response = await client.MakeHat( 42 | ctx=ClientContext( 43 | headers={ 44 | "Accept-Encoding": "zstd", # Response compression 45 | } 46 | ), 47 | request=haberdasher_pb2.Size(inches=8), 48 | use_get=True, # Enable GET request 49 | ) 50 | print("\nGET with Brotli compression:", response) 51 | except ConnecpyServerException as e: 52 | print(e.code, e.message, e.to_dict()) 53 | 54 | 55 | if __name__ == "__main__": 56 | asyncio.run(main()) 57 | -------------------------------------------------------------------------------- /example/client.py: -------------------------------------------------------------------------------- 1 | from connecpy.context import ClientContext 2 | from connecpy.exceptions import ConnecpyException, ConnecpyServerException 3 | 4 | import haberdasher_connecpy 5 | import haberdasher_pb2 6 | 7 | 8 | server_url = "http://localhost:3000" 9 | timeout_s = 5 10 | 11 | 12 | def create_large_request(): 13 | """Create a request with a large description to test compression.""" 14 | return haberdasher_pb2.Size( 15 | inches=12, 16 | description="A" * 2048, # Add a 2KB string to ensure compression is worthwhile 17 | ) 18 | 19 | 20 | def main(): 21 | client = haberdasher_connecpy.HaberdasherClient(server_url, timeout=timeout_s) 22 | 23 | # Example 1: POST request with gzip compression (large request) 24 | try: 25 | print("\nTesting POST request with gzip compression...") 26 | response = client.MakeHat( 27 | ctx=ClientContext( 28 | headers={ 29 | "Content-Encoding": "gzip", # Request compression 30 | "Accept-Encoding": "gzip", # Response compression 31 | } 32 | ), 33 | request=create_large_request(), 34 | ) 35 | print("POST with gzip compression successful:", response) 36 | except (ConnecpyException, ConnecpyServerException) as e: 37 | print("POST with gzip compression failed:", str(e)) 38 | 39 | # Example 2: POST request with brotli compression (large request) 40 | try: 41 | print("\nTesting POST request with brotli compression...") 42 | response = client.MakeHat( 43 | ctx=ClientContext( 44 | headers={ 45 | "Content-Encoding": "br", # Request compression 46 | "Accept-Encoding": "br", # Response compression 47 | } 48 | ), 49 | request=create_large_request(), 50 | ) 51 | print("POST with brotli compression successful:", response) 52 | except (ConnecpyException, ConnecpyServerException) as e: 53 | print("POST with brotli compression failed:", str(e)) 54 | 55 | # Example 3: GET request without compression 56 | try: 57 | print("\nTesting GET request without compression...") 58 | response = client.MakeHat( 59 | ctx=ClientContext(), # No compression headers 60 | request=haberdasher_pb2.Size(inches=8), # Small request 61 | use_get=True, 62 | ) 63 | print("GET without compression successful:", response) 64 | except (ConnecpyException, ConnecpyServerException) as e: 65 | print("GET without compression failed:", str(e)) 66 | 67 | # Example 4: GET request with ztstd compression (large request) 68 | try: 69 | print("\nTesting GET request with gzip compression...") 70 | response = client.MakeHat( 71 | ctx=ClientContext( 72 | headers={ 73 | "Accept-Encoding": "zstd", # Response compression 74 | "Content-Encoding": "zstd", # Request compression 75 | } 76 | ), 77 | request=create_large_request(), 78 | use_get=True, 79 | ) 80 | print("GET with zstd compression successful:", response) 81 | except (ConnecpyException, ConnecpyServerException) as e: 82 | print("GET with zstd compression failed:", str(e)) 83 | 84 | # Example 5: Test multiple accepted encodings 85 | try: 86 | print("\nTesting POST with multiple accepted encodings...") 87 | response = client.MakeHat( 88 | ctx=ClientContext( 89 | headers={ 90 | "Content-Encoding": "br", # Request compression 91 | "Accept-Encoding": "gzip, br, zstd", # Response compression (in order of preference) 92 | } 93 | ), 94 | request=create_large_request(), 95 | ) 96 | print("POST with multiple encodings successful:", response) 97 | except (ConnecpyException, ConnecpyServerException) as e: 98 | print("POST with multiple encodings failed:", str(e)) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /example/google/protobuf/empty.proto: -------------------------------------------------------------------------------- 1 | // Protocol Buffers - Google's data interchange format 2 | // Copyright 2008 Google Inc. All rights reserved. 3 | // https://developers.google.com/protocol-buffers/ 4 | // 5 | // Redistribution and use in source and binary forms, with or without 6 | // modification, are permitted provided that the following conditions are 7 | // met: 8 | // 9 | // * Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // * Redistributions in binary form must reproduce the above 12 | // copyright notice, this list of conditions and the following disclaimer 13 | // in the documentation and/or other materials provided with the 14 | // distribution. 15 | // * Neither the name of Google Inc. nor the names of its 16 | // contributors may be used to endorse or promote products derived from 17 | // this software without specific prior written permission. 18 | // 19 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | syntax = "proto3"; 32 | 33 | package google.protobuf; 34 | 35 | option go_package = "google.golang.org/protobuf/types/known/emptypb"; 36 | option java_package = "com.google.protobuf"; 37 | option java_outer_classname = "EmptyProto"; 38 | option java_multiple_files = true; 39 | option objc_class_prefix = "GPB"; 40 | option csharp_namespace = "Google.Protobuf.WellKnownTypes"; 41 | option cc_enable_arenas = true; 42 | 43 | // A generic empty message that you can re-use to avoid defining duplicated 44 | // empty messages in your APIs. A typical example is to use it as the request 45 | // or the response type of an API method. For instance: 46 | // 47 | // service Foo { 48 | // rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); 49 | // } 50 | // 51 | message Empty {} 52 | -------------------------------------------------------------------------------- /example/google/protobuf/empty_connecpy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by https://github.com/i2y/connecpy/protoc-gen-connecpy. DO NOT EDIT! 3 | # source: google/protobuf/empty.proto 4 | 5 | 6 | from google.protobuf import symbol_database 7 | 8 | _sym_db = symbol_database.Default() 9 | -------------------------------------------------------------------------------- /example/google/protobuf/empty_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: google/protobuf/empty.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import descriptor_pool as _descriptor_pool 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf.internal import builder as _builder 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 17 | b'\n\x1bgoogle/protobuf/empty.proto\x12\x0fgoogle.protobuf"\x07\n\x05\x45mptyB}\n\x13\x63om.google.protobufB\nEmptyProtoP\x01Z.google.golang.org/protobuf/types/known/emptypb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3' 18 | ) 19 | 20 | _globals = globals() 21 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 22 | _builder.BuildTopDescriptorsAndMessages( 23 | DESCRIPTOR, "google.protobuf.empty_pb2", _globals 24 | ) 25 | if _descriptor._USE_C_DESCRIPTORS == False: 26 | _globals["DESCRIPTOR"]._options = None 27 | _globals[ 28 | "DESCRIPTOR" 29 | ]._serialized_options = b"\n\023com.google.protobufB\nEmptyProtoP\001Z.google.golang.org/protobuf/types/known/emptypb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes" 30 | _globals["_EMPTY"]._serialized_start = 48 31 | _globals["_EMPTY"]._serialized_end = 55 32 | # @@protoc_insertion_point(module_scope) 33 | -------------------------------------------------------------------------------- /example/google/protobuf/empty_pb2.pyi: -------------------------------------------------------------------------------- 1 | from google.protobuf import descriptor as _descriptor 2 | from google.protobuf import message as _message 3 | 4 | DESCRIPTOR: _descriptor.FileDescriptor 5 | 6 | class Empty(_message.Message): 7 | __slots__ = () 8 | def __init__(self) -> None: ... 9 | -------------------------------------------------------------------------------- /example/grpc_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from grpc.aio import insecure_channel 4 | 5 | import haberdasher_pb2 6 | import haberdasher_pb2_grpc 7 | 8 | 9 | target = "localhost:50051" 10 | 11 | 12 | async def main(): 13 | channel = insecure_channel(target) 14 | stub = haberdasher_pb2_grpc.HaberdasherStub(channel) 15 | request = haberdasher_pb2.Size(inches=12) 16 | response = await stub.MakeHat(request) 17 | print(response) 18 | 19 | 20 | if __name__ == "__main__": 21 | asyncio.run(main()) 22 | -------------------------------------------------------------------------------- /example/grpc_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from grpc.aio import server 4 | 5 | import haberdasher_pb2_grpc 6 | from service import HaberdasherService 7 | 8 | 9 | host = "localhost:50051" 10 | 11 | 12 | async def main(): 13 | s = server() 14 | haberdasher_pb2_grpc.add_HaberdasherServicer_to_server(HaberdasherService(), s) 15 | bound_port = s.add_insecure_port(host) 16 | print(f"localhost:{bound_port}") 17 | await s.start() 18 | await s.wait_for_termination() 19 | 20 | 21 | if __name__ == "__main__": 22 | asyncio.run(main()) 23 | -------------------------------------------------------------------------------- /example/haberdasher.proto: -------------------------------------------------------------------------------- 1 | // Copied from https://github.com/twitchtv/connecpy/blob/v5.10.2/example/service.proto 2 | syntax = "proto3"; 3 | 4 | package i2y.connecpy.example; 5 | option go_package = "example"; 6 | 7 | // A Hat is a piece of headwear made by a Haberdasher. 8 | message Hat { 9 | // The size of a hat should always be in inches. 10 | int32 size = 1; 11 | 12 | // The color of a hat will never be 'invisible', but other than 13 | // that, anything is fair game. 14 | string color = 2; 15 | 16 | // The name of a hat is it's type. Like, 'bowler', or something. 17 | optional string name = 3; 18 | } 19 | 20 | // Size is passed when requesting a new hat to be made. It's always 21 | // measured in inches. 22 | message Size { 23 | int32 inches = 1; 24 | // Additional description or notes about the requested hat 25 | string description = 2; 26 | } 27 | 28 | // A Haberdasher makes hats for clients. 29 | service Haberdasher { 30 | // MakeHat produces a hat of mysterious, randomly-selected color! 31 | rpc MakeHat(Size) returns (Hat) { 32 | option idempotency_level = NO_SIDE_EFFECTS; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /example/haberdasher_connecpy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by https://github.com/i2y/connecpy/protoc-gen-connecpy. DO NOT EDIT! 3 | # source: haberdasher.proto 4 | 5 | from typing import Protocol, Union 6 | 7 | import httpx 8 | 9 | from connecpy.async_client import AsyncConnecpyClient 10 | from connecpy.base import Endpoint 11 | from connecpy.server import ConnecpyServer 12 | from connecpy.client import ConnecpyClient 13 | from connecpy.context import ClientContext, ServiceContext 14 | 15 | import haberdasher_pb2 as _pb2 16 | 17 | from google.protobuf import symbol_database 18 | 19 | _sym_db = symbol_database.Default() 20 | 21 | 22 | class Haberdasher(Protocol): 23 | async def MakeHat(self, req: _pb2.Size, ctx: ServiceContext) -> _pb2.Hat: ... 24 | 25 | 26 | class HaberdasherServer(ConnecpyServer): 27 | def __init__(self, *, service: Haberdasher, server_path_prefix=""): 28 | super().__init__() 29 | self._prefix = f"{server_path_prefix}/i2y.connecpy.example.Haberdasher" 30 | self._endpoints = { 31 | "MakeHat": Endpoint[_pb2.Size, _pb2.Hat]( 32 | service_name="Haberdasher", 33 | name="MakeHat", 34 | function=getattr(service, "MakeHat"), 35 | input=_pb2.Size, 36 | output=_pb2.Hat, 37 | allowed_methods=("GET", "POST"), 38 | ), 39 | } 40 | 41 | def serviceName(self): 42 | return "i2y.connecpy.example.Haberdasher" 43 | 44 | 45 | class HaberdasherSync(Protocol): 46 | def MakeHat(self, req: _pb2.Size, ctx: ServiceContext) -> _pb2.Hat: ... 47 | 48 | 49 | class HaberdasherServerSync(ConnecpyServer): 50 | def __init__(self, *, service: HaberdasherSync, server_path_prefix=""): 51 | super().__init__() 52 | self._prefix = f"{server_path_prefix}/i2y.connecpy.example.Haberdasher" 53 | self._endpoints = { 54 | "MakeHat": Endpoint[_pb2.Size, _pb2.Hat]( 55 | service_name="Haberdasher", 56 | name="MakeHat", 57 | function=getattr(service, "MakeHat"), 58 | input=_pb2.Size, 59 | output=_pb2.Hat, 60 | allowed_methods=("GET", "POST"), 61 | ), 62 | } 63 | 64 | def serviceName(self): 65 | return "i2y.connecpy.example.Haberdasher" 66 | 67 | 68 | class HaberdasherClient(ConnecpyClient): 69 | def MakeHat( 70 | self, 71 | *, 72 | request: _pb2.Size, 73 | ctx: ClientContext, 74 | server_path_prefix: str = "", 75 | use_get: bool = False, 76 | **kwargs, 77 | ) -> _pb2.Hat: 78 | method = "GET" if use_get else "POST" 79 | return self._make_request( 80 | url=f"{server_path_prefix}/i2y.connecpy.example.Haberdasher/MakeHat", 81 | ctx=ctx, 82 | request=request, 83 | response_obj=_pb2.Hat, 84 | method=method, 85 | **kwargs, 86 | ) 87 | 88 | 89 | class AsyncHaberdasherClient(AsyncConnecpyClient): 90 | async def MakeHat( 91 | self, 92 | *, 93 | request: _pb2.Size, 94 | ctx: ClientContext, 95 | server_path_prefix: str = "", 96 | session: Union[httpx.AsyncClient, None] = None, 97 | use_get: bool = False, 98 | **kwargs, 99 | ) -> _pb2.Hat: 100 | method = "GET" if use_get else "POST" 101 | return await self._make_request( 102 | url=f"{server_path_prefix}/i2y.connecpy.example.Haberdasher/MakeHat", 103 | ctx=ctx, 104 | request=request, 105 | response_obj=_pb2.Hat, 106 | method=method, 107 | session=session, 108 | **kwargs, 109 | ) 110 | -------------------------------------------------------------------------------- /example/haberdasher_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # NO CHECKED-IN PROTOBUF GENCODE 4 | # source: haberdasher.proto 5 | # Protobuf Python Version: 5.29.1 6 | """Generated protocol buffer code.""" 7 | 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import descriptor_pool as _descriptor_pool 10 | from google.protobuf import runtime_version as _runtime_version 11 | from google.protobuf import symbol_database as _symbol_database 12 | from google.protobuf.internal import builder as _builder 13 | 14 | _runtime_version.ValidateProtobufRuntimeVersion( 15 | _runtime_version.Domain.PUBLIC, 5, 29, 1, "", "haberdasher.proto" 16 | ) 17 | # @@protoc_insertion_point(imports) 18 | 19 | _sym_db = _symbol_database.Default() 20 | 21 | 22 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 23 | b'\n\x11haberdasher.proto\x12\x14i2y.connecpy.example">\n\x03Hat\x12\x0c\n\x04size\x18\x01 \x01(\x05\x12\r\n\x05\x63olor\x18\x02 \x01(\t\x12\x11\n\x04name\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x07\n\x05_name"+\n\x04Size\x12\x0e\n\x06inches\x18\x01 \x01(\x05\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t2T\n\x0bHaberdasher\x12\x45\n\x07MakeHat\x12\x1a.i2y.connecpy.example.Size\x1a\x19.i2y.connecpy.example.Hat"\x03\x90\x02\x01\x42\tZ\x07\x65xampleb\x06proto3' 24 | ) 25 | 26 | _globals = globals() 27 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 28 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "haberdasher_pb2", _globals) 29 | if not _descriptor._USE_C_DESCRIPTORS: 30 | _globals["DESCRIPTOR"]._loaded_options = None 31 | _globals["DESCRIPTOR"]._serialized_options = b"Z\007example" 32 | _globals["_HABERDASHER"].methods_by_name["MakeHat"]._loaded_options = None 33 | _globals["_HABERDASHER"].methods_by_name[ 34 | "MakeHat" 35 | ]._serialized_options = b"\220\002\001" 36 | _globals["_HAT"]._serialized_start = 43 37 | _globals["_HAT"]._serialized_end = 105 38 | _globals["_SIZE"]._serialized_start = 107 39 | _globals["_SIZE"]._serialized_end = 150 40 | _globals["_HABERDASHER"]._serialized_start = 152 41 | _globals["_HABERDASHER"]._serialized_end = 236 42 | # @@protoc_insertion_point(module_scope) 43 | -------------------------------------------------------------------------------- /example/haberdasher_pb2.pyi: -------------------------------------------------------------------------------- 1 | from google.protobuf import descriptor as _descriptor 2 | from google.protobuf import message as _message 3 | from typing import ClassVar as _ClassVar, Optional as _Optional 4 | 5 | DESCRIPTOR: _descriptor.FileDescriptor 6 | 7 | class Hat(_message.Message): 8 | __slots__ = ("size", "color", "name") 9 | SIZE_FIELD_NUMBER: _ClassVar[int] 10 | COLOR_FIELD_NUMBER: _ClassVar[int] 11 | NAME_FIELD_NUMBER: _ClassVar[int] 12 | size: int 13 | color: str 14 | name: str 15 | def __init__( 16 | self, 17 | size: _Optional[int] = ..., 18 | color: _Optional[str] = ..., 19 | name: _Optional[str] = ..., 20 | ) -> None: ... 21 | 22 | class Size(_message.Message): 23 | __slots__ = ("inches", "description") 24 | INCHES_FIELD_NUMBER: _ClassVar[int] 25 | DESCRIPTION_FIELD_NUMBER: _ClassVar[int] 26 | inches: int 27 | description: str 28 | def __init__( 29 | self, inches: _Optional[int] = ..., description: _Optional[str] = ... 30 | ) -> None: ... 31 | -------------------------------------------------------------------------------- /example/haberdasher_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | 4 | import grpc 5 | 6 | import haberdasher_pb2 as haberdasher__pb2 7 | 8 | 9 | class HaberdasherStub(object): 10 | """A Haberdasher makes hats for clients.""" 11 | 12 | def __init__(self, channel): 13 | """Constructor. 14 | 15 | Args: 16 | channel: A grpc.Channel. 17 | """ 18 | self.MakeHat = channel.unary_unary( 19 | "/i2y.connecpy.example.Haberdasher/MakeHat", 20 | request_serializer=haberdasher__pb2.Size.SerializeToString, 21 | response_deserializer=haberdasher__pb2.Hat.FromString, 22 | ) 23 | 24 | 25 | class HaberdasherServicer(object): 26 | """A Haberdasher makes hats for clients.""" 27 | 28 | def MakeHat(self, request, context): 29 | """MakeHat produces a hat of mysterious, randomly-selected color!""" 30 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 31 | context.set_details("Method not implemented!") 32 | raise NotImplementedError("Method not implemented!") 33 | 34 | 35 | def add_HaberdasherServicer_to_server(servicer, server): 36 | rpc_method_handlers = { 37 | "MakeHat": grpc.unary_unary_rpc_method_handler( 38 | servicer.MakeHat, 39 | request_deserializer=haberdasher__pb2.Size.FromString, 40 | response_serializer=haberdasher__pb2.Hat.SerializeToString, 41 | ), 42 | } 43 | generic_handler = grpc.method_handlers_generic_handler( 44 | "i2y.connecpy.example.Haberdasher", rpc_method_handlers 45 | ) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | 48 | 49 | # This class is part of an EXPERIMENTAL API. 50 | class Haberdasher(object): 51 | """A Haberdasher makes hats for clients.""" 52 | 53 | @staticmethod 54 | def MakeHat( 55 | request, 56 | target, 57 | options=(), 58 | channel_credentials=None, 59 | call_credentials=None, 60 | insecure=False, 61 | compression=None, 62 | wait_for_ready=None, 63 | timeout=None, 64 | metadata=None, 65 | ): 66 | return grpc.experimental.unary_unary( 67 | request, 68 | target, 69 | "/i2y.connecpy.example.Haberdasher/MakeHat", 70 | haberdasher__pb2.Size.SerializeToString, 71 | haberdasher__pb2.Hat.FromString, 72 | options, 73 | channel_credentials, 74 | insecure, 75 | call_credentials, 76 | compression, 77 | wait_for_ready, 78 | timeout, 79 | metadata, 80 | ) 81 | -------------------------------------------------------------------------------- /example/requirements.txt: -------------------------------------------------------------------------------- 1 | uvicorn==0.23.2 2 | connecpy==1.0.0 3 | grpcio-tools==1.60.0 4 | -------------------------------------------------------------------------------- /example/server.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | from connecpy import context 4 | from connecpy.asgi import ConnecpyASGIApp 5 | from connecpy.cors import CORSMiddleware 6 | from connecpy.interceptor import AsyncConnecpyServerInterceptor 7 | 8 | import haberdasher_connecpy 9 | from service import HaberdasherService 10 | 11 | 12 | class MyInterceptor(AsyncConnecpyServerInterceptor): 13 | def __init__(self, msg): 14 | self._msg = msg 15 | 16 | async def intercept( 17 | self, 18 | method: Callable, 19 | request: Any, 20 | ctx: context.ServiceContext, 21 | method_name: str, 22 | ) -> Any: 23 | print("intercepting " + method_name + " with " + self._msg) 24 | return await method(request, ctx) 25 | 26 | 27 | my_interceptor_a = MyInterceptor("A") 28 | my_interceptor_b = MyInterceptor("B") 29 | 30 | service = haberdasher_connecpy.HaberdasherServer(service=HaberdasherService()) 31 | app = ConnecpyASGIApp( 32 | interceptors=(my_interceptor_a, my_interceptor_b), 33 | ) 34 | app.add_service(service) 35 | 36 | # Add CORS support with default configuration 37 | app = CORSMiddleware(app) 38 | -------------------------------------------------------------------------------- /example/service.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from connecpy.exceptions import InvalidArgument 4 | from connecpy.context import ServiceContext 5 | 6 | from haberdasher_pb2 import Hat, Size 7 | 8 | 9 | class HaberdasherService(object): 10 | async def MakeHat(self, req: Size, ctx: ServiceContext) -> Hat: 11 | print("remaining_time: ", ctx.time_remaining()) 12 | if req.inches <= 0: 13 | raise InvalidArgument( 14 | argument="inches", error="I can't make a hat that small!" 15 | ) 16 | response = Hat( 17 | size=req.inches, 18 | color=random.choice(["white", "black", "brown", "red", "blue"]), 19 | ) 20 | if random.random() > 0.5: 21 | response.name = random.choice( 22 | ["bowler", "baseball cap", "top hat", "derby"] 23 | ) 24 | 25 | return response 26 | -------------------------------------------------------------------------------- /example/wsgi_server.py: -------------------------------------------------------------------------------- 1 | from wsgiref.simple_server import make_server 2 | from connecpy.wsgi import ConnecpyWSGIApp 3 | from wsgi_service import HaberdasherService 4 | from haberdasher_connecpy import HaberdasherServerSync 5 | 6 | 7 | def main(): 8 | # Create synchronous service instance 9 | service = HaberdasherService() 10 | 11 | # Create server with service implementation 12 | server = HaberdasherServerSync(service=service) 13 | print(f"Created server with prefix: {server._prefix}") 14 | 15 | # Create WSGI application and add service 16 | app = ConnecpyWSGIApp() 17 | app.add_service(server) 18 | 19 | # Start WSGI server 20 | with make_server("", 3000, app) as httpd: 21 | print("Serving on port 3000...") 22 | try: 23 | httpd.serve_forever() 24 | except KeyboardInterrupt: 25 | print("\nShutting down server...") 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /example/wsgi_service.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from connecpy.exceptions import InvalidArgument 4 | from connecpy.context import ServiceContext 5 | 6 | from haberdasher_pb2 import Hat, Size 7 | 8 | 9 | class HaberdasherService(object): 10 | def MakeHat(self, req: Size, ctx: ServiceContext) -> Hat: 11 | print("remaining_time: ", ctx.time_remaining()) 12 | if req.inches <= 0: 13 | raise InvalidArgument( 14 | argument="inches", error="I can't make a hat that small!" 15 | ) 16 | response = Hat( 17 | size=req.inches, 18 | color=random.choice(["white", "black", "brown", "red", "blue"]), 19 | ) 20 | if random.random() > 0.5: 21 | response.name = random.choice( 22 | ["bowler", "baseball cap", "top hat", "derby"] 23 | ) 24 | 25 | return response 26 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/README.md: -------------------------------------------------------------------------------- 1 | # protoc-gen-connecpy 2 | protobuf plugin for generating Connecpy server and client code. 3 | 4 | ## Installing and using plugin 5 | 1. Make sure your [GO](https://golang.org/) environment, [Protoc](https://github.com/protocolbuffers/protobuf/releases/latest) compiler is properly setup. 6 | 2. Install the plugin : `go install` 7 | This will build the plugin and will be available at `$GOBIN` directory which is usually `$GOPATH/bin` 8 | 3. Generate code for `haberdasher.proto` using connecpy plugin : 9 | `protoc --python_out=./ --connecpy_out=./ haberdasher.proto` 10 | - python_out : The directory where generated Protobuf Python code needs to be saved. 11 | - connecpy_out : The directory where generated Connecpy Python server and client code needs to be saved. 12 | 13 | The compiler gives the error below if it's not able to find the plugin. 14 | 15 | ``` 16 | --connecpy_out: protoc-gen-connecpy: Plugin failed with status code 1. 17 | ``` 18 | 19 | In such cases, you can give absolute path to plugin, eg: `--plugin=protoc-gen-connecpy=$GOBIN/protoc-gen-connecpy` 20 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/generator/generator.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "path" 7 | "strings" 8 | 9 | "github.com/golang/protobuf/protoc-gen-go/descriptor" 10 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" 11 | "google.golang.org/protobuf/proto" 12 | ) 13 | 14 | func Generate(r *plugin.CodeGeneratorRequest) *plugin.CodeGeneratorResponse { 15 | resp := &plugin.CodeGeneratorResponse{} 16 | resp.SupportedFeatures = proto.Uint64(uint64(plugin.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)) 17 | 18 | files := r.GetFileToGenerate() 19 | if len(files) == 0 { 20 | resp.Error = proto.String("no files to generate") 21 | return resp 22 | } 23 | 24 | for _, fileName := range files { 25 | fd, err := getFileDescriptor(r.GetProtoFile(), fileName) 26 | if err != nil { 27 | resp.Error = proto.String("File[" + fileName + "][descriptor]: " + err.Error()) 28 | return resp 29 | } 30 | 31 | connecpyFile, err := GenerateConnecpyFile(fd) 32 | if err != nil { 33 | resp.Error = proto.String("File[" + fileName + "][generate]: " + err.Error()) 34 | return resp 35 | } 36 | resp.File = append(resp.File, connecpyFile) 37 | } 38 | return resp 39 | } 40 | 41 | func GenerateConnecpyFile(fd *descriptor.FileDescriptorProto) (*plugin.CodeGeneratorResponse_File, error) { 42 | name := fd.GetName() 43 | 44 | fileNameWithoutSuffix := strings.TrimSuffix(name, path.Ext(name)) 45 | moduleName := strings.Join(strings.Split(fileNameWithoutSuffix, "/"), ".") 46 | 47 | vars := ConnecpyTemplateVariables{ 48 | FileName: name, 49 | ModuleName: moduleName, 50 | } 51 | 52 | svcs := fd.GetService() 53 | packageName := fd.GetPackage() 54 | for _, svc := range svcs { 55 | connecpySvc := &ConnecpyService{ 56 | Name: svc.GetName(), 57 | Package: packageName, 58 | } 59 | 60 | for _, method := range svc.GetMethod() { 61 | idempotencyLevel := method.Options.GetIdempotencyLevel() 62 | noSideEffects := idempotencyLevel == descriptor.MethodOptions_NO_SIDE_EFFECTS 63 | connecpyMethod := &ConnecpyMethod{ 64 | Package: packageName, 65 | ServiceName: connecpySvc.Name, 66 | Name: method.GetName(), 67 | InputType: getSymbolName(method.GetInputType(), packageName), 68 | InputTypeForProtocol: getSymbolNameForProtocol(method.GetInputType(), packageName), 69 | OutputType: getSymbolName(method.GetOutputType(), packageName), 70 | OutputTypeForProtocol: getSymbolNameForProtocol(method.GetOutputType(), packageName), 71 | NoSideEffects: noSideEffects, 72 | } 73 | 74 | connecpySvc.Methods = append(connecpySvc.Methods, connecpyMethod) 75 | } 76 | vars.Services = append(vars.Services, connecpySvc) 77 | } 78 | 79 | var buf = &bytes.Buffer{} 80 | err := ConnecpyTemplate.Execute(buf, vars) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | resp := &plugin.CodeGeneratorResponse_File{ 86 | Name: proto.String(strings.TrimSuffix(name, path.Ext(name)) + "_connecpy.py"), 87 | Content: proto.String(buf.String()), 88 | } 89 | 90 | return resp, nil 91 | } 92 | 93 | func getLocalSymbolName(name string) string { 94 | parts := strings.Split(name, ".") 95 | return "_pb2." + parts[len(parts)-1] 96 | } 97 | 98 | func getSymbolName(name, localPackageName string) string { 99 | if strings.HasPrefix(name, "."+localPackageName) { 100 | return getLocalSymbolName(name) 101 | } 102 | 103 | return "_sym_db.GetSymbol(\"" + name[1:] + "\")" 104 | } 105 | 106 | func getSymbolNameForProtocol(name, localPackageName string) string { 107 | if strings.HasPrefix(name, "."+localPackageName) { 108 | return getLocalSymbolName(name) 109 | } 110 | 111 | return "Any" 112 | } 113 | 114 | func getFileDescriptor(files []*descriptor.FileDescriptorProto, name string) (*descriptor.FileDescriptorProto, error) { 115 | //Assumption: Number of files will not be large enough to justify making a map 116 | for _, f := range files { 117 | if f.GetName() == name { 118 | return f, nil 119 | } 120 | } 121 | return nil, errors.New("could not find descriptor") 122 | } 123 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/generator/generator_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/golang/protobuf/protoc-gen-go/descriptor" 8 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" 9 | "google.golang.org/protobuf/proto" 10 | ) 11 | 12 | func TestGenerateConnecpyFile(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | input *descriptor.FileDescriptorProto 16 | wantFile string 17 | wantErr bool 18 | }{ 19 | { 20 | name: "simple service", 21 | input: &descriptor.FileDescriptorProto{ 22 | Name: proto.String("test.proto"), 23 | Package: proto.String("test"), 24 | Service: []*descriptor.ServiceDescriptorProto{ 25 | { 26 | Name: proto.String("TestService"), 27 | Method: []*descriptor.MethodDescriptorProto{ 28 | { 29 | Name: proto.String("TestMethod"), 30 | InputType: proto.String(".test.TestRequest"), 31 | OutputType: proto.String(".test.TestResponse"), 32 | }, 33 | }, 34 | }, 35 | }, 36 | }, 37 | wantFile: "test_connecpy.py", 38 | wantErr: false, 39 | }, 40 | { 41 | name: "service with multiple methods", 42 | input: &descriptor.FileDescriptorProto{ 43 | Name: proto.String("multi.proto"), 44 | Package: proto.String("test"), 45 | Service: []*descriptor.ServiceDescriptorProto{ 46 | { 47 | Name: proto.String("MultiService"), 48 | Method: []*descriptor.MethodDescriptorProto{ 49 | { 50 | Name: proto.String("Method1"), 51 | InputType: proto.String(".test.Request1"), 52 | OutputType: proto.String(".test.Response1"), 53 | }, 54 | { 55 | Name: proto.String("Method2"), 56 | InputType: proto.String(".test.Request2"), 57 | OutputType: proto.String(".test.Response2"), 58 | }, 59 | }, 60 | }, 61 | }, 62 | }, 63 | wantFile: "multi_connecpy.py", 64 | wantErr: false, 65 | }, 66 | } 67 | 68 | for _, tt := range tests { 69 | t.Run(tt.name, func(t *testing.T) { 70 | got, err := GenerateConnecpyFile(tt.input) 71 | if (err != nil) != tt.wantErr { 72 | t.Errorf("GenerateConnecpyFile() error = %v, wantErr %v", err, tt.wantErr) 73 | return 74 | } 75 | if err == nil { 76 | if got.GetName() != tt.wantFile { 77 | t.Errorf("GenerateConnecpyFile() got filename = %v, want %v", got.GetName(), tt.wantFile) 78 | } 79 | 80 | content := got.GetContent() 81 | if !strings.Contains(content, "from typing import Any, Protocol, Union") { 82 | t.Error("Generated code missing required imports") 83 | } 84 | if !strings.Contains(content, "class "+strings.Split(tt.input.GetService()[0].GetName(), ".")[0]) { 85 | t.Error("Generated code missing service class") 86 | } 87 | } 88 | }) 89 | } 90 | } 91 | 92 | func TestGenerate(t *testing.T) { 93 | tests := []struct { 94 | name string 95 | req *plugin.CodeGeneratorRequest 96 | wantErr bool 97 | }{ 98 | { 99 | name: "empty request", 100 | req: &plugin.CodeGeneratorRequest{ 101 | FileToGenerate: []string{}, 102 | }, 103 | wantErr: true, 104 | }, 105 | { 106 | name: "valid request", 107 | req: &plugin.CodeGeneratorRequest{ 108 | FileToGenerate: []string{"test.proto"}, 109 | ProtoFile: []*descriptor.FileDescriptorProto{ 110 | { 111 | Name: proto.String("test.proto"), 112 | Package: proto.String("test"), 113 | Service: []*descriptor.ServiceDescriptorProto{ 114 | { 115 | Name: proto.String("TestService"), 116 | Method: []*descriptor.MethodDescriptorProto{ 117 | { 118 | Name: proto.String("TestMethod"), 119 | InputType: proto.String(".test.TestRequest"), 120 | OutputType: proto.String(".test.TestResponse"), 121 | }, 122 | }, 123 | }, 124 | }, 125 | }, 126 | }, 127 | }, 128 | wantErr: false, 129 | }, 130 | } 131 | 132 | for _, tt := range tests { 133 | t.Run(tt.name, func(t *testing.T) { 134 | resp := Generate(tt.req) 135 | if tt.wantErr { 136 | if resp.GetError() == "" { 137 | t.Error("Generate() expected error but got none") 138 | } 139 | } else { 140 | if resp.GetError() != "" { 141 | t.Errorf("Generate() unexpected error: %v", resp.GetError()) 142 | } 143 | if len(resp.GetFile()) == 0 { 144 | t.Error("Generate() returned no files") 145 | } 146 | } 147 | }) 148 | } 149 | } 150 | 151 | func TestGetSymbolName(t *testing.T) { 152 | tests := []struct { 153 | name string 154 | symbolName string 155 | localPackage string 156 | want string 157 | }{ 158 | { 159 | name: "local package type", 160 | symbolName: ".test.TestMessage", 161 | localPackage: "test", 162 | want: "_pb2.TestMessage", 163 | }, 164 | { 165 | name: "external package type", 166 | symbolName: ".other.OtherMessage", 167 | localPackage: "test", 168 | want: "_sym_db.GetSymbol(\"other.OtherMessage\")", 169 | }, 170 | } 171 | 172 | for _, tt := range tests { 173 | t.Run(tt.name, func(t *testing.T) { 174 | got := getSymbolName(tt.symbolName, tt.localPackage) 175 | if got != tt.want { 176 | t.Errorf("getSymbolName() = %v, want %v", got, tt.want) 177 | } 178 | }) 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/generator/template.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import "text/template" 4 | 5 | type ConnecpyTemplateVariables struct { 6 | FileName string 7 | ModuleName string 8 | Services []*ConnecpyService 9 | } 10 | 11 | type ConnecpyService struct { 12 | Package string 13 | Name string 14 | Comment string 15 | Methods []*ConnecpyMethod 16 | } 17 | 18 | type ConnecpyMethod struct { 19 | Package string 20 | ServiceName string 21 | Name string 22 | Comment string 23 | InputType string 24 | InputTypeForProtocol string 25 | OutputType string 26 | OutputTypeForProtocol string 27 | NoSideEffects bool 28 | } 29 | 30 | // ConnecpyTemplate - Template for connecpy server and client 31 | var ConnecpyTemplate = template.Must(template.New("ConnecpyTemplate").Parse(`# -*- coding: utf-8 -*- 32 | # Generated by https://github.com/i2y/connecpy/protoc-gen-connecpy. DO NOT EDIT! 33 | # source: {{.FileName}} 34 | 35 | from typing import Any, Protocol, Union 36 | 37 | import httpx 38 | 39 | from connecpy.async_client import AsyncConnecpyClient 40 | from connecpy.base import Endpoint 41 | from connecpy.server import ConnecpyServer 42 | from connecpy.client import ConnecpyClient 43 | from connecpy.context import ClientContext, ServiceContext 44 | 45 | import {{.ModuleName}}_pb2 as _pb2 46 | 47 | from google.protobuf import symbol_database 48 | 49 | _sym_db = symbol_database.Default() 50 | 51 | {{range .Services}} 52 | class {{.Name}}(Protocol):{{- range .Methods }} 53 | async def {{.Name}}(self, req: {{.InputTypeForProtocol}}, ctx: ServiceContext) -> {{.OutputTypeForProtocol}}: ... 54 | {{- end }} 55 | 56 | 57 | class {{.Name}}Server(ConnecpyServer): 58 | def __init__(self, *, service: {{.Name}}, server_path_prefix=""): 59 | super().__init__() 60 | self._prefix = f"{server_path_prefix}/{{.Package}}.{{.Name}}" 61 | self._endpoints = { {{- range .Methods }} 62 | "{{.Name}}": Endpoint[{{.InputType}}, {{.OutputType}}]( 63 | service_name="{{.ServiceName}}", 64 | name="{{.Name}}", 65 | function=getattr(service, "{{.Name}}"), 66 | input={{.InputType}}, 67 | output={{.OutputType}}, 68 | allowed_methods={{if .NoSideEffects}}("GET", "POST"){{else}}("POST",){{end}}, 69 | ),{{- end }} 70 | } 71 | 72 | def serviceName(self): 73 | return "{{.Package}}.{{.Name}}" 74 | {{- end }} 75 | 76 | {{range .Services}} 77 | class {{.Name}}Sync(Protocol):{{- range .Methods }} 78 | def {{.Name}}(self, req: {{.InputTypeForProtocol}}, ctx: ServiceContext) -> {{.OutputTypeForProtocol}}: ... 79 | {{- end }} 80 | 81 | 82 | class {{.Name}}ServerSync(ConnecpyServer): 83 | def __init__(self, *, service: {{.Name}}Sync, server_path_prefix=""): 84 | super().__init__() 85 | self._prefix = f"{server_path_prefix}/{{.Package}}.{{.Name}}" 86 | self._endpoints = { {{- range .Methods }} 87 | "{{.Name}}": Endpoint[{{.InputType}}, {{.OutputType}}]( 88 | service_name="{{.ServiceName}}", 89 | name="{{.Name}}", 90 | function=getattr(service, "{{.Name}}"), 91 | input={{.InputType}}, 92 | output={{.OutputType}}, 93 | allowed_methods={{if .NoSideEffects}}("GET", "POST"){{else}}("POST",){{end}}, 94 | ),{{- end }} 95 | } 96 | 97 | def serviceName(self): 98 | return "{{.Package}}.{{.Name}}" 99 | 100 | 101 | class {{.Name}}Client(ConnecpyClient):{{range .Methods}} 102 | def {{.Name}}( 103 | self, 104 | *, 105 | request: {{.InputTypeForProtocol}}, 106 | ctx: ClientContext, 107 | server_path_prefix: str = "", 108 | {{if .NoSideEffects}}use_get: bool = False, 109 | **kwargs, 110 | {{else}}**kwargs,{{end}} 111 | ) -> {{.OutputTypeForProtocol}}: 112 | {{if .NoSideEffects}}method = "GET" if use_get else "POST"{{else}}method = "POST"{{end}} 113 | return self._make_request( 114 | url=f"{server_path_prefix}/{{.Package}}.{{.ServiceName}}/{{.Name}}", 115 | ctx=ctx, 116 | request=request, 117 | response_obj={{.OutputType}}, 118 | method=method, 119 | **kwargs, 120 | ) 121 | {{end}} 122 | 123 | class Async{{.Name}}Client(AsyncConnecpyClient):{{range .Methods}} 124 | async def {{.Name}}( 125 | self, 126 | *, 127 | request: {{.InputTypeForProtocol}}, 128 | ctx: ClientContext, 129 | server_path_prefix: str = "", 130 | session: Union[httpx.AsyncClient, None] = None, 131 | {{if .NoSideEffects}}use_get: bool = False, 132 | **kwargs, 133 | {{else}}**kwargs,{{end}} 134 | ) -> {{.OutputTypeForProtocol}}: 135 | {{if .NoSideEffects}}method = "GET" if use_get else "POST"{{else}}method = "POST"{{end}} 136 | return await self._make_request( 137 | url=f"{server_path_prefix}/{{.Package}}.{{.ServiceName}}/{{.Name}}", 138 | ctx=ctx, 139 | request=request, 140 | response_obj={{.OutputType}}, 141 | method=method, 142 | session=session, 143 | **kwargs, 144 | ) 145 | {{end}}{{end}}`)) 146 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/generator/template_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/golang/protobuf/protoc-gen-go/descriptor" 9 | ) 10 | 11 | func TestConnecpyTemplate(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | vars ConnecpyTemplateVariables 15 | contains []string 16 | }{ 17 | { 18 | name: "simple service", 19 | vars: ConnecpyTemplateVariables{ 20 | FileName: "test.proto", 21 | ModuleName: "test", 22 | Services: []*ConnecpyService{ 23 | { 24 | Package: "test", 25 | Name: "TestService", 26 | Methods: []*ConnecpyMethod{ 27 | { 28 | Package: "test", 29 | ServiceName: "TestService", 30 | Name: "TestMethod", 31 | InputType: "_pb2.TestRequest", 32 | InputTypeForProtocol: "_pb2.TestRequest", 33 | OutputType: "_pb2.TestResponse", 34 | OutputTypeForProtocol: "_pb2.TestResponse", 35 | NoSideEffects: false, 36 | }, 37 | }, 38 | }, 39 | }, 40 | }, 41 | contains: []string{ 42 | "from typing import Any, Protocol, Union", 43 | "class TestService(Protocol):", 44 | "class TestServiceServer(ConnecpyServer):", 45 | "def TestMethod", 46 | `allowed_methods=("POST",)`, 47 | }, 48 | }, 49 | { 50 | name: "service with no side effects method", 51 | vars: ConnecpyTemplateVariables{ 52 | FileName: "test.proto", 53 | ModuleName: "test", 54 | Services: []*ConnecpyService{ 55 | { 56 | Package: "test", 57 | Name: "TestService", 58 | Methods: []*ConnecpyMethod{ 59 | { 60 | Package: "test", 61 | ServiceName: "TestService", 62 | Name: "GetData", 63 | InputType: "_pb2.GetRequest", 64 | InputTypeForProtocol: "_pb2.GetRequest", 65 | OutputType: "_pb2.GetResponse", 66 | OutputTypeForProtocol: "_pb2.GetResponse", 67 | NoSideEffects: true, 68 | }, 69 | }, 70 | }, 71 | }, 72 | }, 73 | contains: []string{ 74 | `allowed_methods=("GET", "POST")`, 75 | "use_get: bool = False", 76 | `method = "GET" if use_get else "POST"`, 77 | }, 78 | }, 79 | } 80 | 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | var buf bytes.Buffer 84 | err := ConnecpyTemplate.Execute(&buf, tt.vars) 85 | if err != nil { 86 | t.Fatalf("Template execution failed: %v", err) 87 | } 88 | 89 | result := buf.String() 90 | for _, want := range tt.contains { 91 | if !strings.Contains(result, want) { 92 | t.Errorf("Generated code missing expected content: %q, got: %q", want, result) 93 | } 94 | } 95 | }) 96 | } 97 | } 98 | 99 | func TestConnecpyTemplateWithMethodOptions(t *testing.T) { 100 | noSideEffects := descriptor.MethodOptions_NO_SIDE_EFFECTS 101 | 102 | tests := []struct { 103 | name string 104 | methodOptions *descriptor.MethodOptions 105 | wantAllowedMethods string 106 | }{ 107 | { 108 | name: "regular method", 109 | methodOptions: nil, 110 | wantAllowedMethods: `allowed_methods=("POST",)`, 111 | }, 112 | { 113 | name: "no side effects method", 114 | methodOptions: &descriptor.MethodOptions{ 115 | IdempotencyLevel: &noSideEffects, 116 | }, 117 | wantAllowedMethods: `allowed_methods=("GET", "POST")`, 118 | }, 119 | } 120 | 121 | for _, tt := range tests { 122 | t.Run(tt.name, func(t *testing.T) { 123 | vars := ConnecpyTemplateVariables{ 124 | FileName: "test.proto", 125 | ModuleName: "test", 126 | Services: []*ConnecpyService{ 127 | { 128 | Package: "test", 129 | Name: "TestService", 130 | Methods: []*ConnecpyMethod{ 131 | { 132 | Package: "test", 133 | ServiceName: "TestService", 134 | Name: "TestMethod", 135 | NoSideEffects: tt.methodOptions != nil && tt.methodOptions.GetIdempotencyLevel() == descriptor.MethodOptions_NO_SIDE_EFFECTS, 136 | }, 137 | }, 138 | }, 139 | }, 140 | } 141 | 142 | var buf bytes.Buffer 143 | err := ConnecpyTemplate.Execute(&buf, vars) 144 | if err != nil { 145 | t.Fatalf("Template execution failed: %v", err) 146 | } 147 | 148 | result := buf.String() 149 | if !strings.Contains(result, tt.wantAllowedMethods) { 150 | t.Errorf("Generated code missing expected allowed_methods: %q", tt.wantAllowedMethods) 151 | } 152 | }) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/i2y/connecpy/protoc-gen-connecpy 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/golang/protobuf v1.5.4 7 | google.golang.org/protobuf v1.34.2 8 | ) 9 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/go.sum: -------------------------------------------------------------------------------- 1 | github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= 2 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= 3 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 4 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 5 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 6 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 7 | google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= 8 | google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= 9 | -------------------------------------------------------------------------------- /protoc-gen-connecpy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "os" 7 | 8 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" 9 | "github.com/i2y/connecpy/protoc-gen-connecpy/generator" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | func main() { 14 | data, err := io.ReadAll(os.Stdin) 15 | if err != nil { 16 | log.Fatalln("could not read from stdin", err) 17 | return 18 | } 19 | var req = &plugin.CodeGeneratorRequest{} 20 | err = proto.Unmarshal(data, req) 21 | if err != nil { 22 | log.Fatalln("could not unmarshal proto", err) 23 | return 24 | } 25 | if len(req.GetFileToGenerate()) == 0 { 26 | log.Fatalln("no files to generate") 27 | return 28 | } 29 | resp := generator.Generate(req) 30 | 31 | if resp == nil { 32 | resp = &plugin.CodeGeneratorResponse{} 33 | } 34 | 35 | data, err = proto.Marshal(resp) 36 | if err != nil { 37 | log.Fatalln("could not unmarshal response proto", err) 38 | } 39 | _, err = os.Stdout.Write(data) 40 | if err != nil { 41 | log.Fatalln("could not write response to stdout", err) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "connecpy" 3 | version = "1.4.2" 4 | description = "Code generator, server library and client library for the Connect Protocol" 5 | authors = [ 6 | { name = "Yasushi Itoh" } 7 | ] 8 | dependencies = [ 9 | "httpx", 10 | "protobuf", 11 | "starlette", 12 | "zstd-asgi>=0.2", 13 | "brotli-asgi>=1.4.0", 14 | "brotli>=1.1.0", 15 | "zstandard>=0.22.0", 16 | ] 17 | readme = "README.md" 18 | requires-python = ">= 3.9" 19 | 20 | [build-system] 21 | requires = ["hatchling"] 22 | build-backend = "hatchling.build" 23 | 24 | [tool.uv] 25 | managed = true 26 | dev-dependencies = [ 27 | "ruff>=0.9.4", 28 | "grpcio-tools", 29 | "uvicorn", 30 | "grpcio", 31 | "protobuf", 32 | "brotli_asgi", 33 | "hypercorn", 34 | "daphne", 35 | "pip>=24.0", 36 | "pytest>=8.3.4", 37 | "pytest-asyncio>=0.25.2", 38 | "zstd-asgi>=0.2", 39 | ] 40 | 41 | [tool.hatch.metadata] 42 | allow-direct-references = true 43 | 44 | [tool.hatch.build.targets.wheel] 45 | packages = ["src/connecpy"] 46 | -------------------------------------------------------------------------------- /src/connecpy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i2y/connecpy/6941d57ac24be366b19308a7dc223b239a4f06e7/src/connecpy/__init__.py -------------------------------------------------------------------------------- /src/connecpy/asgi.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Iterable, List, Mapping, Tuple 3 | from urllib.parse import parse_qs 4 | import base64 5 | from functools import partial 6 | 7 | from starlette.requests import Request 8 | 9 | from . import base 10 | from . import context 11 | from . import errors 12 | from . import exceptions 13 | from . import encoding 14 | from . import compression 15 | 16 | 17 | class ConnecpyASGIApp(base.ConnecpyBaseApp): 18 | """ASGI application for Connecpy.""" 19 | 20 | async def __call__(self, scope, receive, send): 21 | """ 22 | Handle incoming ASGI requests. 23 | 24 | Args: 25 | scope (dict): The ASGI scope. 26 | receive (callable): The ASGI receive function. 27 | send (callable): The ASGI send function. 28 | """ 29 | assert scope["type"] == "http" 30 | try: 31 | http_method = scope["method"] 32 | endpoint = self._get_endpoint(scope["path"]) 33 | if http_method not in endpoint.allowed_methods: 34 | raise exceptions.ConnecpyServerException( 35 | code=errors.Errors.BadRoute, 36 | message=f"unsupported method {http_method}", 37 | ) 38 | 39 | headers = scope.get("headers", []) 40 | accept_encoding = compression.extract_header_value( 41 | headers, b"accept-encoding" 42 | ) 43 | selected_encoding = compression.select_encoding(accept_encoding) 44 | 45 | ctx = context.ConnecpyServiceContext( 46 | scope["client"], convert_to_mapping(headers) 47 | ) 48 | 49 | if http_method == "GET": 50 | request = await self._handle_get_request(scope, ctx) 51 | else: 52 | request = await self._handle_post_request(scope, receive, ctx) 53 | 54 | proc = endpoint.make_async_proc(self._interceptors) 55 | response_data = await proc(request, ctx) 56 | 57 | encoder = encoding.get_encoder(endpoint, ctx.content_type()) 58 | res_bytes, headers = encoder(response_data) 59 | 60 | # Compress response if needed 61 | if selected_encoding != "identity": 62 | compressor = compression.get_compressor(selected_encoding) 63 | if compressor: 64 | res_bytes = compressor(res_bytes) 65 | headers["Content-Encoding"] = [selected_encoding] 66 | 67 | combined_headers = dict( 68 | add_trailer_prefix(ctx.trailing_metadata()), **headers 69 | ) 70 | final_headers = convert_to_single_string(combined_headers) 71 | 72 | await send( 73 | { 74 | "type": "http.response.start", 75 | "status": 200, 76 | "headers": [ 77 | (k.lower().encode(), v.encode()) 78 | for k, v in final_headers.items() 79 | ], 80 | } 81 | ) 82 | await send( 83 | { 84 | "type": "http.response.body", 85 | "body": res_bytes, 86 | } 87 | ) 88 | 89 | except Exception as e: 90 | await self.handle_error(e, scope, receive, send) 91 | 92 | async def _handle_get_request(self, scope, ctx): 93 | """Handle GET request with query parameters.""" 94 | query_string = scope.get("query_string", b"").decode("utf-8") 95 | params = parse_qs(query_string) 96 | 97 | # Validation 98 | if "message" not in params: 99 | raise exceptions.ConnecpyServerException( 100 | code=errors.Errors.InvalidArgument, 101 | message="'message' parameter is required for GET requests", 102 | ) 103 | 104 | # Get and decode message 105 | message = params["message"][0] 106 | is_base64 = "base64" in params and params["base64"][0] == "1" 107 | 108 | if is_base64: 109 | try: 110 | message = base64.urlsafe_b64decode(message) 111 | except Exception: 112 | raise exceptions.ConnecpyServerException( 113 | code=errors.Errors.InvalidArgument, 114 | message="Invalid base64 encoding", 115 | ) 116 | else: 117 | message = message.encode("utf-8") 118 | 119 | # Handle encoding 120 | encoding_name = params.get("encoding", ["json"])[0] 121 | decoder = encoding.get_decoder_by_name(encoding_name) 122 | if not decoder: 123 | raise exceptions.ConnecpyServerException( 124 | code=errors.Errors.Unimplemented, 125 | message=f"Unsupported encoding: {encoding_name}", 126 | ) 127 | 128 | # Handle compression 129 | compression_name = params.get("compression", ["identity"])[0] 130 | decompressor = compression.get_decompressor(compression_name) 131 | if not decompressor: 132 | raise exceptions.ConnecpyServerException( 133 | code=errors.Errors.Unimplemented, 134 | message=f"Unsupported compression: {compression_name}", 135 | ) 136 | 137 | # Decompress and decode message 138 | if message: # Don't decompress empty messages 139 | message = decompressor(message) 140 | 141 | # Get the appropriate decoder for the endpoint 142 | endpoint = self._get_endpoint(scope["path"]) 143 | decoder = partial(decoder, data_obj=endpoint.input) 144 | return decoder(message) 145 | 146 | async def _handle_post_request(self, scope, receive, ctx): 147 | """Handle POST request with body.""" 148 | # Get request body and endpoint 149 | request = Request(scope, receive) 150 | endpoint = self._get_endpoint(scope["path"]) 151 | req_body = await request.body() 152 | 153 | if len(req_body) > self._max_receive_message_length: 154 | raise exceptions.ConnecpyServerException( 155 | code=errors.Errors.InvalidArgument, 156 | message=f"Request body exceeds maximum size of {self._max_receive_message_length} bytes", 157 | ) 158 | 159 | # Handle compression if specified 160 | compression_header = ( 161 | dict(scope["headers"]).get(b"content-encoding", b"identity").decode("ascii") 162 | ) 163 | decompressor = compression.get_decompressor(compression_header) 164 | if not decompressor: 165 | raise exceptions.ConnecpyServerException( 166 | code=errors.Errors.Unimplemented, 167 | message=f"Unsupported compression: {compression_header}", 168 | ) 169 | 170 | if req_body: # Don't decompress empty body 171 | req_body = decompressor(req_body) 172 | 173 | # Get the decoder based on content type 174 | base_decoder = encoding.get_decoder_by_name( 175 | "proto" if ctx.content_type() == "application/proto" else "json" 176 | ) 177 | if not base_decoder: 178 | raise exceptions.ConnecpyServerException( 179 | code=errors.Errors.Unimplemented, 180 | message=f"Unsupported encoding: {ctx.content_type()}", 181 | ) 182 | 183 | decoder = partial(base_decoder, data_obj=endpoint.input) 184 | return decoder(req_body) 185 | 186 | async def handle_error(self, exc, scope, receive, send): 187 | """Handle errors that occur during request processing.""" 188 | if not isinstance(exc, exceptions.ConnecpyServerException): 189 | exc = exceptions.ConnecpyServerException( 190 | code=errors.Errors.Internal, 191 | message=str(exc), 192 | ) 193 | 194 | status = errors.Errors.get_status_code(exc.code) 195 | headers = [ 196 | (b"content-type", b"application/json"), 197 | ] 198 | 199 | await send( 200 | { 201 | "type": "http.response.start", 202 | "status": status, 203 | "headers": headers, 204 | } 205 | ) 206 | await send( 207 | { 208 | "type": "http.response.body", 209 | "body": exc.to_json_bytes(), 210 | } 211 | ) 212 | 213 | 214 | def convert_to_mapping( 215 | iterable: Iterable[Tuple[bytes, bytes]], 216 | ) -> Mapping[str, List[str]]: 217 | result = defaultdict(list) 218 | for key, value in iterable: 219 | result[key.decode("utf-8")].append(value.decode("utf-8")) 220 | return dict(result) 221 | 222 | 223 | def convert_to_single_string(mapping: Mapping[str, List[str]]) -> Mapping[str, str]: 224 | return {key: ", ".join(values) for key, values in mapping.items()} 225 | 226 | 227 | def add_trailer_prefix(trailers: Mapping[str, List[str]]) -> Mapping[str, List[str]]: 228 | return {f"trailer-{key}": values for key, values in trailers.items()} 229 | -------------------------------------------------------------------------------- /src/connecpy/async_client.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import httpx 3 | from . import shared_client 4 | from . import compression 5 | from . import context 6 | from . import exceptions 7 | from . import errors 8 | 9 | 10 | class AsyncConnecpyClient: 11 | """ 12 | Represents an asynchronous client for Connecpy using httpx. 13 | 14 | Args: 15 | address (str): The address of the Connecpy server. 16 | session (httpx.AsyncClient): The httpx client session to use for making requests. 17 | """ 18 | 19 | def __init__( 20 | self, address: str, timeout=5, session: Union[httpx.AsyncClient, None] = None 21 | ) -> None: 22 | self._address = address 23 | self._timeout = timeout 24 | self._session = session 25 | 26 | async def _make_request( 27 | self, 28 | *, 29 | url: str, 30 | request, 31 | ctx: context.ClientContext, 32 | response_obj, 33 | method="POST", 34 | session: Union[httpx.AsyncClient, None] = None, 35 | **kwargs, 36 | ): 37 | """ 38 | Makes a request to the Connecpy server. 39 | 40 | Args: 41 | url (str): The URL to send the request to. 42 | request: The request object to send. 43 | ctx (context.ClientContext): The client context. 44 | response_obj: The response object class to deserialize the response into. 45 | method (str): The HTTP method to use for the request. Defaults to "POST". 46 | session (httpx.AsyncClient, optional): The httpx client session to use for the request. 47 | If not provided, the session passed to the constructor will be used. 48 | **kwargs: Additional keyword arguments to pass to the request. 49 | 50 | Returns: 51 | The deserialized response object. 52 | 53 | Raises: 54 | exceptions.ConnecpyServerException: If an error occurs while making the request. 55 | """ 56 | # Prepare headers and kwargs using shared logic 57 | headers, kwargs = shared_client.prepare_headers(ctx, kwargs, self._timeout) 58 | 59 | try: 60 | if session or self._session: 61 | client = session or self._session 62 | close_client = False 63 | else: 64 | client = httpx.AsyncClient() 65 | close_client = True 66 | 67 | try: 68 | if "content-encoding" in headers: 69 | request_data, headers = shared_client.compress_request( 70 | request, headers, compression 71 | ) 72 | else: 73 | request_data = request.SerializeToString() 74 | 75 | if method == "GET": 76 | params = shared_client.prepare_get_params(request_data, headers) 77 | kwargs["params"] = params 78 | kwargs["headers"].pop("content-type", None) 79 | resp = await client.get(url=self._address + url, **kwargs) 80 | else: 81 | resp = await client.post( 82 | url=self._address + url, content=request_data, **kwargs 83 | ) 84 | 85 | resp.raise_for_status() 86 | 87 | if resp.status_code == 200: 88 | response = response_obj() 89 | response.ParseFromString(resp.content) 90 | return response 91 | else: 92 | raise exceptions.ConnecpyServerException.from_json( 93 | await resp.json() 94 | ) 95 | finally: 96 | if close_client: 97 | await client.aclose() 98 | except httpx.TimeoutException as e: 99 | raise exceptions.ConnecpyServerException( 100 | code=errors.Errors.DeadlineExceeded, 101 | message=str(e) or "request timeout", 102 | ) 103 | except httpx.HTTPStatusError as e: 104 | raise exceptions.ConnecpyServerException( 105 | code=errors.Errors.Unavailable, 106 | message=str(e), 107 | ) 108 | except Exception as e: 109 | raise exceptions.ConnecpyServerException( 110 | code=errors.Errors.Internal, 111 | message=str(e), 112 | ) 113 | -------------------------------------------------------------------------------- /src/connecpy/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from functools import reduce 4 | from typing import Callable, Generic, Tuple, TypeVar, Union 5 | 6 | from starlette import concurrency 7 | 8 | from . import context 9 | from . import exceptions 10 | from . import errors 11 | from . import interceptor 12 | from . import server 13 | from . import encoding 14 | 15 | 16 | T = TypeVar("T") 17 | U = TypeVar("U") 18 | 19 | 20 | @dataclass 21 | class Endpoint(Generic[T, U]): 22 | """ 23 | Represents an endpoint in a service. 24 | 25 | Attributes: 26 | service_name (str): The name of the service. 27 | name (str): The name of the endpoint. 28 | function (Callable[[T, context.ServiceContext], U]): The function that implements the endpoint. 29 | input (type): The type of the input parameter. 30 | output (type): The type of the output parameter. 31 | allowed_methods (list[str]): The allowed HTTP methods for the endpoint. 32 | _async_proc (Callable[[T, context.ServiceContext], U] | None): The asynchronous function that implements the endpoint. 33 | """ 34 | 35 | service_name: str 36 | name: str 37 | function: Callable[ 38 | [ 39 | T, 40 | context.ServiceContext, 41 | ], 42 | U, 43 | ] 44 | input: type 45 | output: type 46 | allowed_methods: tuple[str] = ("POST",) 47 | _async_proc: Union[ 48 | Callable[ 49 | [ 50 | T, 51 | context.ServiceContext, 52 | ], 53 | U, 54 | ], 55 | None, 56 | ] = None 57 | _proc: Union[ 58 | Callable[ 59 | [ 60 | T, 61 | context.ServiceContext, 62 | ], 63 | U, 64 | ], 65 | None, 66 | ] = None 67 | 68 | def make_async_proc( 69 | self, 70 | interceptors: Tuple[interceptor.AsyncConnecpyServerInterceptor, ...], 71 | ) -> Callable[[T, context.ServiceContext], U]: 72 | """ 73 | Creates an asynchronous function that implements the endpoint. 74 | 75 | Args: 76 | interceptors (Tuple[interceptor.AsyncConnecpyServerInterceptor, ...]): The interceptors to apply to the endpoint. 77 | 78 | Returns: 79 | Callable[[T, context.ServiceContext], U]: The asynchronous function that implements the endpoint. 80 | """ 81 | if self._async_proc is not None: 82 | return self._async_proc 83 | 84 | method_name = self.name 85 | reversed_interceptors = reversed(interceptors) 86 | self._async_proc = reduce( # type: ignore 87 | lambda acc, interceptor: interceptor.make_interceptor(acc, method_name), 88 | reversed_interceptors, 89 | asynchronize(self.function), 90 | ) # type: ignore 91 | 92 | return self._async_proc # type: ignore 93 | 94 | def make_proc( 95 | self, 96 | ) -> Callable[[T, context.ServiceContext], U]: 97 | """ 98 | Creates an asynchronous function that implements the endpoint. 99 | 100 | Args: 101 | interceptors (Tuple[interceptor.AsyncConnecpyServerInterceptor, ...]): The interceptors to apply to the endpoint. 102 | 103 | Returns: 104 | Callable[[T, context.ServiceContext], U]: The asynchronous function that implements the endpoint. 105 | """ 106 | if self._proc is not None: 107 | return self._proc 108 | 109 | self._proc = self.function 110 | 111 | return self._proc # type: ignore 112 | 113 | 114 | def thread_pool_runner(func): 115 | async def run(request, ctx: context.ConnecpyServiceContext): 116 | return await concurrency.run_in_threadpool(func, request, ctx) 117 | 118 | return run 119 | 120 | 121 | def asynchronize(func) -> Callable: 122 | """ 123 | Decorator that converts a synchronous function into an asynchronous function. 124 | 125 | If the input function is already a coroutine function, it is returned as is. 126 | Otherwise, it is wrapped in a thread pool runner to execute it asynchronously. 127 | 128 | Args: 129 | func: The synchronous function to be converted. 130 | 131 | Returns: 132 | The converted asynchronous function. 133 | 134 | """ 135 | if asyncio.iscoroutinefunction(func): 136 | return func 137 | else: 138 | return thread_pool_runner(func) 139 | 140 | 141 | class ConnecpyBaseApp(object): 142 | """ 143 | Represents the base application class for Connecpy servers. 144 | 145 | Args: 146 | interceptors (Tuple[interceptor.AsyncConnecpyServerInterceptor, ...]): A tuple of interceptors to be applied to the server. 147 | prefix (str): The prefix to be added to the service endpoints. 148 | max_receive_message_length (int): The maximum length of the received messages. 149 | 150 | Attributes: 151 | _interceptors (Tuple[interceptor.AsyncConnecpyServerInterceptor, ...]): The interceptors applied to the server. 152 | _prefix (str): The prefix added to the service endpoints. 153 | _services (dict): A dictionary of services registered with the server. 154 | _max_receive_message_length (int): The maximum length of the received messages. 155 | 156 | Methods: 157 | add_service: Adds a service to the server. 158 | _get_endpoint: Retrieves the endpoint for a given path. 159 | json_decoder: Decodes a JSON request. 160 | json_encoder: Encodes a service response to JSON. 161 | proto_decoder: Decodes a protobuf request. 162 | proto_encoder: Encodes a service response to protobuf. 163 | _get_encoder_decoder: Retrieves the appropriate encoder and decoder for a given endpoint and content type. 164 | """ 165 | 166 | def __init__( 167 | self, 168 | interceptors: Tuple[interceptor.AsyncConnecpyServerInterceptor, ...] = (), 169 | prefix="", 170 | max_receive_message_length=1024 * 100 * 100, 171 | ): 172 | self._interceptors = interceptors 173 | self._prefix = prefix 174 | self._services = {} 175 | self._max_receive_message_length = max_receive_message_length 176 | 177 | def add_service(self, svc: server.ConnecpyServer): 178 | """ 179 | Adds a service to the server. 180 | 181 | Args: 182 | svc (server.ConnecpyServer): The service to be added. 183 | """ 184 | self._services[self._prefix + svc.prefix] = svc 185 | 186 | def _get_endpoint(self, path): 187 | """ 188 | Retrieves the endpoint for a given path. 189 | 190 | Args: 191 | path (str): The path of the endpoint. 192 | 193 | Returns: 194 | The endpoint for the given path. 195 | 196 | Raises: 197 | exceptions.ConnecpyServerException: If the endpoint is not found. 198 | """ 199 | svc = self._services.get(path.rsplit("/", 1)[0], None) 200 | if svc is None: 201 | raise exceptions.ConnecpyServerException( 202 | code=errors.Errors.NotFound, message="not found" 203 | ) 204 | 205 | return svc.get_endpoint(path[len(self._prefix) :]) 206 | 207 | def _get_encoder_decoder(self, endpoint, ctype: str): 208 | """ 209 | Retrieves the appropriate encoder and decoder for a given endpoint and content type. 210 | 211 | Args: 212 | endpoint: The endpoint to retrieve the encoder and decoder for. 213 | ctype (str): The content type. 214 | 215 | Returns: 216 | The encoder and decoder functions. 217 | """ 218 | return encoding.get_encoder_decoder_pair(endpoint, ctype) 219 | -------------------------------------------------------------------------------- /src/connecpy/client.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | 3 | from . import exceptions 4 | from . import errors 5 | from . import compression 6 | from . import shared_client 7 | 8 | 9 | class ConnecpyClient: 10 | def __init__(self, address, timeout=5): 11 | self._address = address 12 | self._timeout = timeout 13 | 14 | def _make_request( 15 | self, *, url, request, ctx, response_obj, method="POST", **kwargs 16 | ): 17 | """Make an HTTP request to the server.""" 18 | # Prepare headers and kwargs using shared logic 19 | headers, kwargs = shared_client.prepare_headers(ctx, kwargs, self._timeout) 20 | 21 | try: 22 | with httpx.Client() as client: 23 | if "content-encoding" in headers: 24 | request_data, headers = shared_client.compress_request( 25 | request, headers, compression 26 | ) 27 | else: 28 | request_data = request.SerializeToString() 29 | 30 | if method == "GET": 31 | params = shared_client.prepare_get_params(request_data, headers) 32 | kwargs["params"] = params 33 | kwargs["headers"].pop("content-type", None) 34 | resp = client.get(url=self._address + url, **kwargs) 35 | else: 36 | resp = client.post( 37 | url=self._address + url, content=request_data, **kwargs 38 | ) 39 | 40 | resp.raise_for_status() 41 | 42 | if resp.status_code == 200: 43 | response = response_obj() 44 | try: 45 | response.ParseFromString(resp.content) 46 | return response 47 | except Exception as e: 48 | raise exceptions.ConnecpyException( 49 | f"Failed to parse response message: {str(e)}" 50 | ) 51 | else: 52 | raise exceptions.ConnecpyServerException.from_json(resp.json()) 53 | 54 | except httpx.TimeoutException as e: 55 | raise exceptions.ConnecpyServerException( 56 | code=errors.Errors.DeadlineExceeded, 57 | message=str(e) or "request timeout", 58 | ) 59 | except httpx.HTTPStatusError as e: 60 | raise exceptions.ConnecpyServerException( 61 | code=errors.Errors.Unavailable, 62 | message=str(e), 63 | ) 64 | except exceptions.ConnecpyException: 65 | raise 66 | except Exception as e: 67 | raise exceptions.ConnecpyException(str(e)) 68 | -------------------------------------------------------------------------------- /src/connecpy/compression.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import gzip 3 | import brotli 4 | import zstandard 5 | 6 | 7 | 8 | def gzip_decompress(data: bytes) -> bytes: 9 | """Decompress data using gzip.""" 10 | try: 11 | return gzip.decompress(data) 12 | except gzip.BadGzipFile: 13 | raise 14 | 15 | 16 | def brotli_decompress(data: bytes) -> bytes: 17 | """Decompress data using brotli.""" 18 | try: 19 | return brotli.decompress(data) 20 | except brotli.error: 21 | raise 22 | 23 | 24 | def zstd_decompress(data: bytes) -> bytes: 25 | """Decompress data using zstandard.""" 26 | try: 27 | dctx = zstandard.ZstdDecompressor() 28 | return dctx.decompress(data) 29 | except zstandard.ZstdError: 30 | raise 31 | 32 | 33 | def gzip_compress(data: bytes) -> bytes: 34 | """Compress data using gzip.""" 35 | try: 36 | return gzip.compress(data) 37 | except Exception: 38 | raise 39 | 40 | 41 | def brotli_compress(data: bytes) -> bytes: 42 | """Compress data using brotli.""" 43 | try: 44 | return brotli.compress(data) 45 | except Exception: 46 | raise 47 | 48 | 49 | def zstd_compress(data: bytes) -> bytes: 50 | """Compress data using zstandard.""" 51 | try: 52 | cctx = zstandard.ZstdCompressor() 53 | return cctx.compress(data) 54 | except Exception: 55 | raise 56 | 57 | 58 | def identity(data: bytes) -> bytes: 59 | """Return data as-is without compression.""" 60 | return data 61 | 62 | 63 | _decompressors = { 64 | "identity": identity, 65 | "gzip": gzip_decompress, 66 | "br": brotli_decompress, 67 | "zstd": zstd_decompress, 68 | } 69 | 70 | 71 | def get_decompressor(compression_name: str) -> Callable[[bytes], bytes] | None: 72 | """Get decompressor function by compression name. 73 | 74 | Args: 75 | compression_name (str): The name of the compression. Can be "identity", "gzip", "br", or "zstd". 76 | 77 | Returns: 78 | Callable[[bytes], bytes]: The decompressor function for the specified compression. 79 | """ 80 | cmp_lower = compression_name.lower() 81 | decompressor = _decompressors.get(cmp_lower) 82 | if decompressor: 83 | return decompressor 84 | 85 | return None 86 | 87 | 88 | def get_compressor(compression_name: str) -> Callable[[bytes], bytes]: 89 | """Get compressor function by compression name. 90 | 91 | Args: 92 | compression_name (str): The name of the compression. Can be "identity", "gzip", "br", or "zstd". 93 | 94 | Returns: 95 | Callable[[bytes], bytes]: The compressor function for the specified compression. 96 | """ 97 | compressors = { 98 | "identity": identity, 99 | "gzip": gzip_compress, 100 | "br": brotli_compress, 101 | "zstd": zstd_compress, 102 | } 103 | return compressors.get(compression_name) 104 | 105 | 106 | def extract_header_value( 107 | headers: list[tuple[bytes, bytes]] | dict[str, str], name: bytes | str 108 | ) -> bytes | str: 109 | """Get a header value from a list of headers or a headers dictionary. 110 | 111 | Args: 112 | headers: Either a list of (name, value) tuples with bytes, or a dictionary with string keys/values 113 | name: Header name to look for (either bytes or str) 114 | 115 | Returns: 116 | The header value if found, empty bytes or string depending on input type 117 | """ 118 | if isinstance(headers, dict): 119 | # Dictionary case - string keys 120 | name = name.decode("ascii") if isinstance(name, bytes) else name 121 | name = name.lower() 122 | return headers.get(name, "") 123 | else: 124 | # List of tuples case - bytes 125 | name = name.encode("ascii") if isinstance(name, str) else name 126 | name_lower = name.lower() 127 | for key, value in headers: 128 | if key.lower() == name_lower: 129 | return value 130 | return b"" 131 | 132 | 133 | def parse_accept_encoding(accept_encoding: str | bytes) -> list[tuple[str, float]]: 134 | """Parse Accept-Encoding header value with quality values. 135 | 136 | Args: 137 | accept_encoding: The Accept-Encoding header value (str or bytes) 138 | 139 | Returns: 140 | list[tuple[str, float]]: List of (encoding, q-value) pairs, sorted by q-value 141 | """ 142 | if not accept_encoding: 143 | return [("identity", 1.0)] 144 | 145 | # Convert bytes to string if needed 146 | if isinstance(accept_encoding, bytes): 147 | accept_encoding = accept_encoding.decode("ascii") 148 | 149 | encodings = [] 150 | seen = set() # Track seen encodings to avoid duplicates 151 | 152 | # First, handle special case of "identity;q=0,*;q=0" which means "no encoding allowed" 153 | if accept_encoding.replace(" ", "") == "identity;q=0,*;q=0": 154 | return [("identity", 0.0), ("*", 0.0)] 155 | 156 | for part in accept_encoding.split(","): 157 | part = part.strip() 158 | if not part: 159 | continue 160 | 161 | # Split encoding and q-value 162 | if ";" in part: 163 | encoding, q_part = part.split(";", 1) 164 | encoding = encoding.strip().lower() 165 | try: 166 | if not q_part.strip().lower().startswith("q="): 167 | continue 168 | q = float(q_part.strip().lower().replace("q=", "")) 169 | q = max(0.0, min(1.0, q)) # Clamp between 0 and 1 170 | except (ValueError, AttributeError): 171 | q = 1.0 172 | else: 173 | encoding = part.strip().lower() 174 | q = 1.0 175 | 176 | if encoding and encoding not in seen: 177 | seen.add(encoding) 178 | encodings.append((encoding, q)) 179 | 180 | # Sort by q-value in descending order while preserving the original accept-encoding order for equal q-values 181 | result = sorted(encodings, key=lambda x: -x[1]) 182 | return result 183 | 184 | 185 | # TODO: wrong sorting order, use preference order instead of available order 186 | def select_encoding( 187 | accept_encoding: str | bytes, 188 | available_encodings: tuple[str] = ("br", "gzip", "zstd", "identity"), 189 | ) -> str: 190 | """Select the best compression encoding based on Accept-Encoding header. 191 | 192 | Args: 193 | accept_encoding: The Accept-Encoding header value (str or bytes) 194 | available_encodings: Tuple of available encodings. 195 | Defaults to ("br", "gzip", "zstd", "identity") 196 | 197 | Returns: 198 | str: The selected encoding name 199 | """ 200 | # Parse Accept-Encoding header with q-values (already sorted by q descending) 201 | encodings = parse_accept_encoding(accept_encoding) 202 | 203 | # Check for "no encoding allowed" case 204 | if len(encodings) == 2 and all(q == 0.0 for _, q in encodings): 205 | if {"identity", "*"} == {enc for enc, _ in encodings}: 206 | return "identity" 207 | 208 | # Iterate over client-preferred encodings (sorted by q-value) 209 | for client_encoding, q in encodings: 210 | if q <= 0: 211 | continue 212 | if client_encoding == "*": 213 | # For wildcard, choose any available encoding not explicitly defined by the client. 214 | excluded = {enc for enc, _ in encodings if enc != "*"} 215 | candidates = [enc for enc in available_encodings if enc not in excluded] 216 | if candidates: 217 | return candidates[0] 218 | else: 219 | # If all available encodings were explicitly mentioned, return the first available. 220 | return available_encodings[0] 221 | elif client_encoding in available_encodings: 222 | return client_encoding 223 | 224 | # If no match found, fallback to identity 225 | return "identity" 226 | -------------------------------------------------------------------------------- /src/connecpy/context.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol, Mapping, MutableMapping, List, Union 2 | import time 3 | 4 | from . import errors 5 | from . import exceptions 6 | 7 | 8 | class ClientContext: 9 | """Context object for storing context information of 10 | request currently being processed. 11 | 12 | Attributes: 13 | _values (Dict[str, Any]): Dictionary to store context key-value pairs. 14 | _headers (MutableMapping[str, str]): Request headers. 15 | _response_headers (MutableMapping[str, str]): Response headers. 16 | """ 17 | 18 | def __init__(self, *, headers: Union[MutableMapping[str, str], None] = None): 19 | """Create a new Context object 20 | 21 | Keyword arguments: 22 | headers (MutableMapping[str, str] | None): Headers for the request. 23 | """ 24 | 25 | self._values = {} 26 | if headers is None: 27 | headers = {} 28 | self._headers = headers 29 | self._response_headers: MutableMapping[str, str] = {} 30 | 31 | def set(self, key: str, value: Any) -> None: 32 | """Set a Context value 33 | 34 | Args: 35 | key (str): Key for the context key-value pair. 36 | value (Any): Value to be stored. 37 | """ 38 | 39 | self._values[key] = value 40 | 41 | def get(self, key: str) -> Any: 42 | """Get a Context value 43 | 44 | Args: 45 | key (str): Key for the context key-value pair. 46 | 47 | Returns: 48 | Any: The value associated with the key. 49 | """ 50 | 51 | return self._values[key] 52 | 53 | def get_headers(self) -> MutableMapping[str, str]: 54 | """Get request headers that are currently stored. 55 | 56 | Returns: 57 | MutableMapping[str, str]: The request headers. 58 | """ 59 | 60 | return self._headers 61 | 62 | def set_header(self, key: str, value: str) -> None: 63 | """Set a request header 64 | 65 | Args: 66 | key (str): Key for the header. 67 | value (str): Value for the header. 68 | """ 69 | 70 | self._headers[key] = value 71 | 72 | def get_response_headers(self) -> MutableMapping[str, str]: 73 | """Get response headers that are currently stored. 74 | 75 | Returns: 76 | MutableMapping[str, str]: The response headers. 77 | """ 78 | 79 | return self._response_headers 80 | 81 | def set_response_header(self, key: str, value: str) -> None: 82 | """Set a response header 83 | 84 | Args: 85 | key (str): Key for the header. 86 | value (str): Value for the header. 87 | """ 88 | 89 | self._response_headers[key] = value 90 | 91 | 92 | class ServiceContext(Protocol): 93 | """Represents the context of a service.""" 94 | 95 | async def abort(self, code, message): 96 | """Abort the service with the given code and message. 97 | 98 | Args: 99 | code (int): The error code. 100 | message (str): The error message. 101 | """ 102 | ... 103 | 104 | def code(self) -> int: 105 | """Get the error code. 106 | 107 | Returns: 108 | int: The error code. 109 | """ 110 | ... 111 | 112 | def details(self) -> str: 113 | """Get the error details. 114 | 115 | Returns: 116 | str: The error details. 117 | """ 118 | ... 119 | 120 | def invocation_metadata(self) -> Mapping[str, List[str]]: 121 | """Get the invocation metadata. 122 | 123 | Returns: 124 | Mapping[str, List[str]]: The invocation metadata. 125 | """ 126 | ... 127 | 128 | def peer(self) -> str: 129 | """Get the peer information. 130 | 131 | Returns: 132 | str: The peer information. 133 | """ 134 | ... 135 | 136 | def set_code(self, code: int) -> None: 137 | """Set the error code. 138 | 139 | Args: 140 | code (int): The error code. 141 | """ 142 | ... 143 | 144 | def set_details(self, details: str) -> None: 145 | """Set the error details. 146 | 147 | Args: 148 | details (str): The error details. 149 | """ 150 | ... 151 | 152 | def set_trailing_metadata(self, metadata: Mapping[str, List[str]]) -> None: 153 | """Set the trailing metadata. 154 | 155 | Args: 156 | metadata (Mapping[str, List[str]]): The trailing metadata. 157 | """ 158 | ... 159 | 160 | def time_remaining(self) -> Union[float, None]: 161 | """Get the remaining time. 162 | 163 | Returns: 164 | float | None: The remaining time in seconds, or None if not applicable. 165 | """ 166 | ... 167 | 168 | def trailing_metadata(self) -> Mapping[str, List[str]]: 169 | """Get the trailing metadata. 170 | 171 | Returns: 172 | Mapping[str, List[str]]: The trailing metadata. 173 | """ 174 | ... 175 | 176 | 177 | class ConnecpyServiceContext: 178 | """ 179 | Represents the context of a Connecpy service. 180 | 181 | Attributes: 182 | _peer: The peer information of the service. 183 | _invocation_metadata: The invocation metadata of the service. 184 | _code: The response code of the service. 185 | _details: The response details of the service. 186 | _trailing_metadata: The trailing metadata of the service. 187 | _timeout_sec: The timeout duration in seconds. 188 | _start_time: The start time of the service. 189 | """ 190 | 191 | def __init__(self, peer: str, invocation_metadata: Mapping[str, List[str]]): 192 | """ 193 | Initialize a Context object. 194 | 195 | Args: 196 | peer (str): The peer information. 197 | invocation_metadata (Mapping[str, List[str]]): The invocation metadata. 198 | 199 | Returns: 200 | None 201 | """ 202 | self._peer = peer 203 | self._invocation_metadata = invocation_metadata 204 | self._code = 200 205 | self._details = "" 206 | self._trailing_metadata = {} 207 | 208 | connect_protocol_version = self._invocation_metadata.get( 209 | "connect-protocol-version", ["1"] 210 | )[0] 211 | if connect_protocol_version != "1": 212 | raise exceptions.ConnecpyServerException( 213 | code=errors.Errors.BadRoute, 214 | message="Invalid connect-protocol-version", 215 | ) 216 | self._connect_protocol_version = connect_protocol_version 217 | 218 | ctype = self._invocation_metadata.get("content-type", ["application/proto"])[0] 219 | if ctype not in ("application/proto", "application/json"): 220 | raise exceptions.ConnecpyServerException( 221 | code=errors.Errors.BadRoute, 222 | message=f"Unsupported content type: {ctype}", 223 | ) 224 | self._content_type = ctype 225 | 226 | timeout_ms: Union[str, None] = invocation_metadata.get( 227 | "connect-timeout-ms", [None] 228 | )[0] 229 | if timeout_ms is None: 230 | self._timeout_sec = None 231 | else: 232 | self._timeout_sec = float(timeout_ms) / 1000.0 233 | self._start_time = time.time() 234 | 235 | async def abort(self, code, message): 236 | """ 237 | Abort the current request with the given code and message. 238 | 239 | :param code: The HTTP status code to return. 240 | :param message: The error message to include in the response. 241 | :raises: exceptions.ConnecpyServerException 242 | """ 243 | raise exceptions.ConnecpyServerException(code=code, message=message) 244 | 245 | def code(self) -> int: 246 | """ 247 | Get the status code associated with the context. 248 | 249 | Returns: 250 | int: The status code associated with the context. 251 | """ 252 | return self._code 253 | 254 | def details(self) -> str: 255 | """ 256 | Returns the details of the context. 257 | 258 | :return: The details of the context. 259 | :rtype: str 260 | """ 261 | return self._details 262 | 263 | def invocation_metadata(self) -> Mapping[str, List[str]]: 264 | """ 265 | Returns the invocation metadata associated with the context. 266 | 267 | :return: A mapping of metadata keys to lists of metadata values. 268 | """ 269 | return self._invocation_metadata 270 | 271 | def content_type(self) -> str: 272 | """ 273 | Returns the content type associated with the context. 274 | 275 | :return: The content type associated with the context. 276 | :rtype: str 277 | """ 278 | return self._content_type 279 | 280 | def peer(self): 281 | """ 282 | Returns the peer associated with the context. 283 | """ 284 | return self._peer 285 | 286 | def set_code(self, code: int) -> None: 287 | """ 288 | Set the status code for the context. 289 | 290 | Args: 291 | code (int): The code to set. 292 | 293 | Returns: 294 | None 295 | """ 296 | self._code = code 297 | 298 | def set_details(self, details: str) -> None: 299 | """ 300 | Set the details of the context. 301 | 302 | Args: 303 | details (str): The details to be set. 304 | 305 | Returns: 306 | None 307 | """ 308 | self._details = details 309 | 310 | def set_trailing_metadata(self, metadata: Mapping[str, List[str]]) -> None: 311 | """ 312 | Sets the trailing metadata for the context. 313 | 314 | Args: 315 | metadata (Mapping[str, List[str]]): A mapping of metadata keys to lists of values. 316 | 317 | Returns: 318 | None 319 | """ 320 | self._trailing_metadata = metadata 321 | 322 | def time_remaining(self) -> Union[float, None]: 323 | """ 324 | Calculate the remaining time until the timeout. 325 | 326 | Returns: 327 | float | None: The remaining time in seconds, or None if no timeout is set. 328 | """ 329 | if self._timeout_sec is None: 330 | return None 331 | return self._timeout_sec - (time.time() - self._start_time) 332 | 333 | def trailing_metadata(self) -> Mapping[str, List[str]]: 334 | """ 335 | Returns the trailing metadata associated with the context. 336 | 337 | :return: A mapping of metadata keys to lists of metadata values. 338 | :rtype: Mapping[str, List[str]] 339 | """ 340 | return self._trailing_metadata 341 | -------------------------------------------------------------------------------- /src/connecpy/cors.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, dataclass 2 | 3 | 4 | @dataclass 5 | class CORSConfig: 6 | """ 7 | Represents the configuration of the CORS policy. 8 | Attributes: 9 | allow_origin (str): The allowed origin. Defaults to "*". 10 | allow_methods (tuple[str, ...]): The allowed HTTP methods. Defaults to ("POST", "GET"). 11 | allow_headers (tuple[str, ...]): The allowed HTTP headers. Defaults to ( 12 | "Content-Type", 13 | "Connect-Protocol-Version", 14 | "Connect-Timeout-Ms", 15 | "X-User-Agent", 16 | ). 17 | access_control_max_age (int): The maximum age of the preflight request. Defaults to 86400. 18 | """ 19 | 20 | allow_origin: str = field(default="*") 21 | allow_methods: tuple[str, ...] = field(default=("POST", "GET")) 22 | allow_headers: tuple[str, ...] = field( 23 | default=( 24 | "Content-Type", 25 | "Connect-Protocol-Version", 26 | "Connect-Timeout-Ms", 27 | "X-User-Agent", 28 | ) 29 | ) 30 | access_control_max_age: int = field(default=86400) 31 | 32 | 33 | # ASGI Miidleware for ConnectRPC CORS 34 | def middleware(app, config: CORSConfig): 35 | """ 36 | Middleware for ConnectRPC CORS. 37 | Args: 38 | app (Callable): The ASGI application. 39 | config (CORSConfig): The CORS configuration. 40 | Returns: 41 | Callable: The ASGI application. 42 | """ 43 | 44 | async def cors_middleware(scope, receive, send): 45 | if scope["type"] != "http": 46 | # Pass through for non-HTTP scopes (e.g. websocket) 47 | await app(scope, receive, send) 48 | return 49 | 50 | # Handle preflight requests 51 | if scope["method"] == "OPTIONS": 52 | headers = [ 53 | (b"access-control-allow-origin", config.allow_origin.encode()), 54 | ( 55 | b"access-control-allow-methods", 56 | b", ".join(m.encode() for m in config.allow_methods), 57 | ), 58 | ( 59 | b"access-control-allow-headers", 60 | b", ".join(h.encode() for h in config.allow_headers), 61 | ), 62 | ( 63 | b"access-control-max-age", 64 | str(config.access_control_max_age).encode(), 65 | ), 66 | ] 67 | 68 | await send( 69 | { 70 | "type": "http.response.start", 71 | "status": 204, 72 | "headers": headers, 73 | } 74 | ) 75 | 76 | await send( 77 | { 78 | "type": "http.response.body", 79 | "body": b"", 80 | } 81 | ) 82 | 83 | # Handle normal requests with CORS headers 84 | if scope["method"] != "OPTIONS": 85 | 86 | async def send_wrapper(message): 87 | if message["type"] == "http.response.start": 88 | message["headers"].append( 89 | (b"access-control-allow-origin", config.allow_origin.encode()) 90 | ) 91 | 92 | await send(message) 93 | 94 | await app(scope, receive, send_wrapper) 95 | 96 | return cors_middleware 97 | 98 | 99 | class CORSMiddleware: 100 | """ 101 | Middleware for ConnectRPC CORS. 102 | Args: 103 | app (Callable): The ASGI application. 104 | config (CORSConfig): The CORS configuration. 105 | """ 106 | 107 | def __init__(self, app, config: CORSConfig = CORSConfig()): 108 | self._app = app 109 | self._config = config 110 | 111 | async def __call__(self, scope, receive, send): 112 | return await middleware(self._app, self._config)(scope, receive, send) 113 | -------------------------------------------------------------------------------- /src/connecpy/encoding.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | from typing import Any, Callable, Dict, List, Tuple, Union 4 | 5 | from google.protobuf import json_format, message 6 | 7 | from . import errors 8 | from . import exceptions 9 | 10 | 11 | def json_decoder(body: bytes, data_obj: Any) -> Any: 12 | """Decode JSON data.""" 13 | try: 14 | data = json.loads(body.decode("utf-8")) 15 | if issubclass(data_obj, message.Message): 16 | return json_format.ParseDict(data, data_obj()) 17 | return data 18 | except Exception as e: 19 | raise exceptions.ConnecpyServerException( 20 | code=errors.Errors.InvalidArgument, 21 | message=f"Failed to decode JSON message: {str(e)}", 22 | ) 23 | 24 | 25 | def proto_decoder(body: bytes, data_obj: Any) -> Any: 26 | """Decode Protocol Buffer data.""" 27 | try: 28 | msg = data_obj() 29 | msg.ParseFromString(body) 30 | return msg 31 | except Exception as e: 32 | raise exceptions.ConnecpyServerException( 33 | code=errors.Errors.InvalidArgument, 34 | message=f"Failed to decode protobuf message: {str(e)}", 35 | ) 36 | 37 | 38 | def json_encoder(value: Any, data_obj: Any) -> Tuple[bytes, Dict[str, List[str]]]: 39 | """Encode data as JSON.""" 40 | try: 41 | if isinstance(value, message.Message): 42 | data = json_format.MessageToDict(value) 43 | else: 44 | data = value 45 | return ( 46 | json.dumps(data).encode("utf-8"), 47 | {"Content-Type": ["application/json"]}, 48 | ) 49 | except Exception as e: 50 | raise exceptions.ConnecpyServerException( 51 | code=errors.Errors.Internal, 52 | message=f"Failed to encode JSON message: {str(e)}", 53 | ) 54 | 55 | 56 | def proto_encoder( 57 | value: message.Message, data_obj: Any 58 | ) -> Tuple[bytes, Dict[str, List[str]]]: 59 | """Encode data as Protocol Buffer.""" 60 | try: 61 | return ( 62 | value.SerializeToString(), 63 | {"Content-Type": ["application/proto"]}, 64 | ) 65 | except Exception as e: 66 | raise exceptions.ConnecpyServerException( 67 | code=errors.Errors.Internal, 68 | message=f"Failed to encode protobuf message: {str(e)}", 69 | ) 70 | 71 | 72 | def get_decoder_by_name(encoding_name: str) -> Union[Callable[[bytes, Any], Any], None]: 73 | """Get decoder function by encoding name.""" 74 | decoders = { 75 | "json": json_decoder, 76 | "proto": proto_decoder, 77 | } 78 | return decoders.get(encoding_name) 79 | 80 | 81 | def get_encoder( 82 | endpoint: Any, ctype: str 83 | ) -> Callable[[Any], Tuple[bytes, Dict[str, List[str]]]]: 84 | """Get encoder function by content type.""" 85 | if ctype == "application/json": 86 | return partial(json_encoder, data_obj=endpoint.output) 87 | elif ctype == "application/proto": 88 | return partial(proto_encoder, data_obj=endpoint.output) 89 | else: 90 | raise exceptions.ConnecpyServerException( 91 | code=errors.Errors.BadRoute, 92 | message=f"unexpected Content-Type: {ctype}", 93 | ) 94 | 95 | 96 | def get_encoder_decoder_pair( 97 | endpoint: Any, ctype: str 98 | ) -> Tuple[ 99 | Callable[[Any], Tuple[bytes, Dict[str, List[str]]]], Callable[[bytes, Any], Any] 100 | ]: 101 | """Get encoder and decoder functions for an endpoint and content type.""" 102 | encoder = get_encoder(endpoint, ctype) 103 | decoder = get_decoder_by_name("proto" if ctype == "application/proto" else "json") 104 | if not decoder: 105 | raise exceptions.ConnecpyServerException( 106 | code=errors.Errors.Unimplemented, 107 | message=f"Unsupported encoding: {ctype}", 108 | ) 109 | return encoder, partial(decoder, data_obj=endpoint.input) 110 | -------------------------------------------------------------------------------- /src/connecpy/errors.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Errors(Enum): 5 | """ 6 | Enum class representing different error codes and their corresponding status codes. 7 | """ 8 | 9 | Canceled = "canceled" 10 | Unknown = "unknown" 11 | InvalidArgument = "invalid_argument" 12 | DeadlineExceeded = "deadline_exceeded" 13 | NotFound = "not_found" 14 | BadRoute = "bad_route" 15 | AlreadyExists = "already_exists" 16 | PermissionDenied = "permission_denied" 17 | Unauthenticated = "unauthenticated" 18 | ResourceExhausted = "resource_exhausted" 19 | FailedPrecondition = "failed_precondition" 20 | Aborted = "aborted" 21 | OutOfRange = "out_of_range" 22 | Unimplemented = "unimplemented" 23 | Internal = "internal" 24 | Unavailable = "unavailable" 25 | DataLoss = "data_loss" 26 | Malformed = "malformed" 27 | NoError = "" 28 | 29 | @staticmethod 30 | def get_status_code(code: "Errors") -> int: 31 | """ 32 | Returns the corresponding HTTP status code for the given error code. 33 | 34 | Args: 35 | code (Errors): The error code. 36 | 37 | Returns: 38 | int: The corresponding HTTP status code. 39 | """ 40 | return { 41 | Errors.Canceled: 408, 42 | Errors.Unknown: 500, 43 | Errors.InvalidArgument: 400, 44 | Errors.Malformed: 400, 45 | Errors.DeadlineExceeded: 408, 46 | Errors.NotFound: 404, 47 | Errors.BadRoute: 404, 48 | Errors.AlreadyExists: 409, 49 | Errors.PermissionDenied: 403, 50 | Errors.Unauthenticated: 401, 51 | Errors.ResourceExhausted: 429, 52 | Errors.FailedPrecondition: 412, 53 | Errors.Aborted: 409, 54 | Errors.OutOfRange: 400, 55 | Errors.Unimplemented: 501, 56 | Errors.Internal: 500, 57 | Errors.Unavailable: 503, 58 | Errors.DataLoss: 500, 59 | Errors.NoError: 200, 60 | }.get(code, 500) 61 | -------------------------------------------------------------------------------- /src/connecpy/exceptions.py: -------------------------------------------------------------------------------- 1 | import http.client as httplib 2 | import json 3 | 4 | from . import errors 5 | 6 | 7 | class ConnecpyException(Exception): 8 | """Base exception class for Connecpy.""" 9 | 10 | pass 11 | 12 | 13 | class ConnecpyServerException(httplib.HTTPException): 14 | """ 15 | Exception class for Connecpy server errors. 16 | 17 | Attributes: 18 | code (errors.Errors): The error code associated with the exception. 19 | message (str): The error message associated with the exception. 20 | """ 21 | 22 | def __init__(self, *, code, message): 23 | """ 24 | Initializes a new instance of the ConnecpyServerException class. 25 | 26 | Args: 27 | code (int): The error code. 28 | message (str): The error message. 29 | """ 30 | try: 31 | self._code = errors.Errors(code) 32 | except ValueError: 33 | self._code = errors.Errors.Unknown 34 | self._message = message 35 | super(ConnecpyServerException, self).__init__(message) 36 | 37 | @property 38 | def code(self): 39 | if isinstance(self._code, errors.Errors): 40 | return self._code 41 | return errors.Errors.Unknown 42 | 43 | @property 44 | def message(self): 45 | return self._message 46 | 47 | def to_dict(self): 48 | return {"code": self._code.value, "msg": self._message} 49 | 50 | def to_json_bytes(self): 51 | return json.dumps(self.to_dict()).encode("utf-8") 52 | 53 | @staticmethod 54 | def from_json(err_dict): 55 | return ConnecpyServerException( 56 | code=err_dict.get("code", errors.Errors.Unknown), 57 | message=err_dict.get("msg", ""), 58 | ) 59 | 60 | 61 | def InvalidArgument(*args, argument, error): 62 | return ConnecpyServerException( 63 | code=errors.Errors.InvalidArgument, 64 | message="{} {}".format(argument, error), 65 | ) 66 | 67 | 68 | def RequiredArgument(*args, argument): 69 | return InvalidArgument(argument=argument, error="is required") 70 | 71 | 72 | def connecpy_error_from_intermediary(status, reason, headers, body): 73 | if 300 <= status < 400: 74 | # connecpy uses POST which should not redirect 75 | code = errors.Errors.Internal 76 | location = headers.get("location") 77 | message = f'unexpected HTTP status code {status} "{reason}" received, Location="{location}"' 78 | 79 | else: 80 | code = { 81 | 400: errors.Errors.Internal, # JSON response should have been returned 82 | 401: errors.Errors.Unauthenticated, 83 | 403: errors.Errors.PermissionDenied, 84 | 404: errors.Errors.BadRoute, 85 | 429: errors.Errors.ResourceExhausted, 86 | 502: errors.Errors.Unavailable, 87 | 503: errors.Errors.Unavailable, 88 | 504: errors.Errors.Unavailable, 89 | }.get(status, errors.Errors.Unknown) 90 | 91 | message = f'Error from intermediary with HTTP status code {status} "{reason}"' 92 | 93 | return ConnecpyServerException(code=code, message=message) 94 | -------------------------------------------------------------------------------- /src/connecpy/interceptor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Callable, Protocol 3 | 4 | from . import context 5 | 6 | 7 | class AsyncServerInterceptor(Protocol): 8 | """Interceptor for asynchronous Connecpy server.""" 9 | 10 | async def intercept( 11 | self, 12 | method: Callable, 13 | request: Any, 14 | ctx: context.ServiceContext, 15 | method_name: str, 16 | ) -> Any: ... 17 | 18 | 19 | class AsyncConnecpyServerInterceptor(ABC): 20 | """ 21 | Base class for asynchronous Connecpy server interceptors. 22 | """ 23 | 24 | def make_interceptor(self, method: Callable, method_name: str): 25 | async def interceptor(request: Any, ctx: context.ServiceContext) -> Any: 26 | return await self.intercept(method, request, ctx, method_name) 27 | 28 | return interceptor 29 | 30 | @abstractmethod 31 | async def intercept( 32 | self, 33 | method: Callable, 34 | request: Any, 35 | ctx: context.ServiceContext, 36 | method_name: str, 37 | ) -> Any: 38 | pass 39 | -------------------------------------------------------------------------------- /src/connecpy/server.py: -------------------------------------------------------------------------------- 1 | from . import exceptions 2 | from . import errors 3 | 4 | 5 | class ConnecpyServer: 6 | """ 7 | Represents a Connecpy server that handles incoming requests and dispatches them to the appropriate endpoints. 8 | """ 9 | 10 | def __init__(self): 11 | self._endpoints = {} 12 | self._prefix = "" 13 | 14 | @property 15 | def prefix(self): 16 | """ 17 | Represents the prefix used for routing requests to endpoints. 18 | """ 19 | return self._prefix 20 | 21 | def get_endpoint(self, path): 22 | """ 23 | Get the endpoint associated with the given path. 24 | 25 | Args: 26 | path (str): The path of the request. 27 | 28 | Returns: 29 | object: The endpoint associated with the given path. 30 | 31 | Raises: 32 | ConnecpyServerException: If no handler is found for the path or if the service has no endpoint for the given method. 33 | """ 34 | (_, url_pre, rpc_method) = path.rpartition(f"{self._prefix}/") 35 | 36 | if not url_pre or not rpc_method: 37 | raise exceptions.ConnecpyServerException( 38 | code=errors.Errors.BadRoute, 39 | message=f"no handler for path {path}", 40 | ) 41 | 42 | endpoint = self._endpoints.get(rpc_method, None) 43 | if not endpoint: 44 | raise exceptions.ConnecpyServerException( 45 | code=errors.Errors.Unimplemented, 46 | message=f"service has no endpoint {rpc_method}", 47 | ) 48 | 49 | return endpoint 50 | -------------------------------------------------------------------------------- /src/connecpy/shared_client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | 4 | def prepare_headers(ctx, kwargs, timeout): 5 | headers = {k.lower(): v for k, v in ctx.get_headers().items()} 6 | if "headers" in kwargs: 7 | headers.update({k.lower(): v for k, v in kwargs.pop("headers").items()}) 8 | # Ensure consistent header casing 9 | if "content-type" in headers: 10 | headers["content-type"] = headers.pop("content-type") 11 | if "content-encoding" in headers: 12 | headers["content-encoding"] = headers.pop("content-encoding") 13 | if "accept-encoding" in headers: 14 | headers["accept-encoding"] = headers.pop("accept-encoding") 15 | # Set default headers 16 | if "content-type" not in headers: 17 | headers["content-type"] = "application/proto" 18 | if "accept-encoding" not in headers: 19 | headers["accept-encoding"] = "gzip, br, zstd" 20 | if "timeout" not in kwargs: 21 | kwargs["timeout"] = timeout 22 | headers["connect-timeout-ms"] = str(timeout * 1000) 23 | kwargs["headers"] = headers 24 | return headers, kwargs 25 | 26 | 27 | def compress_request(request, headers, compression): 28 | request_data = request.SerializeToString() 29 | # If compression is requested 30 | if "content-encoding" in headers: 31 | compression_name = headers["content-encoding"].lower() 32 | compressor = compression.get_compressor(compression_name) 33 | if not compressor: 34 | raise Exception(f"Unsupported compression method: {compression_name}") 35 | try: 36 | compressed = compressor(request_data) 37 | if len(compressed) < len(request_data): 38 | # Optionally, log compression details 39 | request_data = compressed 40 | else: 41 | headers.pop("content-encoding", None) 42 | except Exception as e: 43 | raise Exception( 44 | f"Failed to compress request with {compression_name}: {str(e)}" 45 | ) 46 | return request_data, headers 47 | 48 | 49 | def prepare_get_params(request_data, headers): 50 | params = {} 51 | if request_data: 52 | params["message"] = base64.urlsafe_b64encode(request_data).decode("ascii") 53 | params["base64"] = "1" 54 | params["encoding"] = ( 55 | "proto" if headers.get("content-type") == "application/proto" else "json" 56 | ) 57 | if "content-encoding" in headers: 58 | params["compression"] = headers.pop("content-encoding") 59 | return params 60 | -------------------------------------------------------------------------------- /src/connecpy/wsgi.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Mapping, Union 3 | import base64 4 | from urllib.parse import parse_qs 5 | from functools import partial 6 | 7 | from . import base 8 | from . import context 9 | from . import errors 10 | from . import exceptions 11 | from . import encoding 12 | from . import compression 13 | 14 | 15 | def normalize_wsgi_headers(environ) -> dict: 16 | """Extract and normalize HTTP headers from WSGI environment.""" 17 | headers = {} 18 | if "CONTENT_TYPE" in environ: 19 | headers["content-type"] = environ["CONTENT_TYPE"].lower() 20 | if "CONTENT_LENGTH" in environ: 21 | headers["content-length"] = environ["CONTENT_LENGTH"].lower() 22 | 23 | for key, value in environ.items(): 24 | if key.startswith("HTTP_"): 25 | header = key[5:].replace("_", "-").lower() 26 | headers[header] = value 27 | return headers 28 | 29 | 30 | def convert_to_mapping(headers: dict) -> Mapping[str, List[str]]: 31 | """Convert headers dictionary to the expected mapping format.""" 32 | result = defaultdict(list) 33 | for key, value in headers.items(): 34 | key = key.lower() 35 | if isinstance(value, (list, tuple)): 36 | result[key].extend(str(v) for v in value) 37 | else: 38 | result[key] = [str(value)] 39 | return result 40 | 41 | 42 | def extract_metadata_from_query_params(query_string: str) -> dict: 43 | """Extract metadata from query parameters into a dictionary.""" 44 | return parse_qs(query_string) if query_string else {} 45 | 46 | 47 | def reflect_header(key: str, value: Union[str, List[str]], headers: list) -> None: 48 | """Add a header to the WSGI response headers list. 49 | 50 | Args: 51 | key: Header name 52 | value: Header value or list of values 53 | headers: List of header tuples to append to 54 | """ 55 | if isinstance(value, list): 56 | if value: # Only add if there are values 57 | headers.append((key, str(value[0]))) 58 | else: 59 | headers.append((key, str(value))) 60 | 61 | 62 | def format_response_headers( 63 | base_headers: dict, compression_info: dict, trailers: dict 64 | ) -> list[tuple[str, str]]: 65 | """Format response headers into WSGI compatible format. 66 | 67 | Args: 68 | base_headers (dict): Base headers from encoder 69 | compression_info (dict): Compression related headers 70 | trailers (dict): Trailer headers 71 | 72 | Returns: 73 | list[tuple[str, str]]: List of header tuples in WSGI format 74 | """ 75 | # Combine all headers 76 | headers = {} 77 | 78 | # Start with base headers 79 | for key, value in base_headers.items(): 80 | if isinstance(value, list): 81 | headers[key] = value[0] if value else "" 82 | else: 83 | headers[key] = str(value) 84 | 85 | # Add compression headers 86 | headers.update(compression_info) 87 | 88 | # Add trailers with prefix 89 | for key, value in trailers.items(): 90 | if isinstance(value, list): 91 | headers[f"trailer-{key.lower()}"] = value[0] if value else "" 92 | else: 93 | headers[f"trailer-{key.lower()}"] = str(value) 94 | 95 | # Convert to WSGI format 96 | return [(str(k).lower(), str(v)) for k, v in headers.items()] 97 | 98 | 99 | def validate_request_headers(headers: dict) -> tuple[str, str]: 100 | """Validate and normalize request headers. 101 | 102 | Args: 103 | headers: Dictionary of request headers 104 | 105 | Returns: 106 | tuple[str, str]: Normalized content type and content encoding 107 | """ 108 | # Get content type 109 | content_type = headers.get("content-type", "application/json").lower() 110 | if content_type not in ["application/json", "application/proto"]: 111 | raise exceptions.ConnecpyServerException( 112 | code=errors.Errors.InvalidArgument, 113 | message=f"Unsupported Content-Type: {content_type}", 114 | ) 115 | 116 | # Get content encoding 117 | content_encoding = headers.get("content-encoding", "identity").lower() 118 | if content_encoding not in ["identity", "gzip", "br", "zstd"]: 119 | raise exceptions.ConnecpyServerException( 120 | code=errors.Errors.Unimplemented, 121 | message=f"Unsupported Content-Encoding: {content_encoding}", 122 | ) 123 | 124 | return content_type, content_encoding 125 | 126 | 127 | def prepare_response_headers( 128 | base_headers: dict, 129 | selected_encoding: str, 130 | compressed_size: int = None, 131 | ) -> tuple[dict, bool]: 132 | """Prepare response headers and determine if compression should be used. 133 | 134 | Args: 135 | base_headers: Base response headers 136 | selected_encoding: Selected compression encoding 137 | compressed_size: Size of compressed content (if compression was attempted) 138 | 139 | Returns: 140 | tuple[dict, bool]: Final headers and whether to use compression 141 | """ 142 | headers = base_headers.copy() 143 | use_compression = False 144 | 145 | if "content-type" not in headers: 146 | headers["content-type"] = "application/proto" 147 | 148 | if selected_encoding != "identity" and compressed_size is not None: 149 | headers["content-encoding"] = selected_encoding 150 | use_compression = True 151 | 152 | headers["vary"] = "Accept-Encoding" 153 | return headers, use_compression 154 | 155 | 156 | def read_chunked(input_stream): 157 | body = b"" 158 | while True: 159 | line = input_stream.readline() 160 | if not line: 161 | break 162 | 163 | chunk_size = int(line.strip(), 16) 164 | if chunk_size == 0: 165 | # Zero-sized chunk indicates the end 166 | break 167 | 168 | chunk = input_stream.read(chunk_size) 169 | body += chunk 170 | input_stream.read(2) # CRLF 171 | return body 172 | 173 | 174 | class ConnecpyWSGIApp(base.ConnecpyBaseApp): 175 | """WSGI application for Connecpy.""" 176 | 177 | def __init__(self, interceptors=None): 178 | """Initialize the WSGI application.""" 179 | super().__init__(interceptors=interceptors or ()) 180 | 181 | def add_service(self, svc): 182 | """Add a service to the application. 183 | 184 | Args: 185 | svc: Service instance to add 186 | """ 187 | # Store the service with its full path prefix 188 | self._services[svc._prefix] = svc 189 | 190 | def _get_endpoint(self, path_info): 191 | """Find endpoint for given path. 192 | 193 | Args: 194 | path_info: The request path 195 | 196 | Returns: 197 | Endpoint instance matching the path 198 | 199 | Raises: 200 | ConnecpyServerException: If endpoint not found or path invalid 201 | """ 202 | if not path_info: 203 | raise exceptions.ConnecpyServerException( 204 | code=errors.Errors.BadRoute, 205 | message="Empty path", 206 | ) 207 | 208 | if path_info.startswith("/"): 209 | path_info = path_info[1:] 210 | 211 | # Split path into service path and method 212 | try: 213 | service_path, method_name = path_info.rsplit("/", 1) 214 | except ValueError: 215 | raise exceptions.ConnecpyServerException( 216 | code=errors.Errors.BadRoute, 217 | message=f"Invalid path format: {path_info}", 218 | ) 219 | 220 | # Look for service 221 | service = self._services.get(f"/{service_path}") 222 | if service is None: 223 | raise exceptions.ConnecpyServerException( 224 | code=errors.Errors.BadRoute, 225 | message=f"No service found for path: {service_path}", 226 | ) 227 | 228 | # Get endpoint from service 229 | endpoint = service._endpoints.get(method_name) 230 | if endpoint is None: 231 | raise exceptions.ConnecpyServerException( 232 | code=errors.Errors.BadRoute, 233 | message=f"Method not found: {method_name}", 234 | ) 235 | 236 | return endpoint 237 | 238 | def handle_error(self, exc, environ, start_response): 239 | """Handle and log errors with detailed information.""" 240 | if isinstance(exc, exceptions.ConnecpyServerException): 241 | status = { 242 | errors.Errors.InvalidArgument: "400 Bad Request", 243 | errors.Errors.BadRoute: "404 Not Found", 244 | errors.Errors.Unimplemented: "501 Not Implemented", 245 | }.get(exc.code, "500 Internal Server Error") 246 | 247 | headers = [("Content-Type", "application/json")] 248 | start_response(status, headers) 249 | return [exc.to_json_bytes()] 250 | else: 251 | 252 | headers = [("Content-Type", "application/json")] 253 | error = exceptions.ConnecpyServerException( 254 | code=errors.Errors.Internal, 255 | message=str(exc), 256 | ) 257 | start_response("500 Internal Server Error", headers) 258 | return [error.to_json_bytes()] 259 | 260 | def _handle_post_request(self, environ, endpoint, ctx): 261 | """Handle POST request with body.""" 262 | try: 263 | content_length = environ.get("CONTENT_LENGTH") 264 | if not content_length: 265 | content_length = 0 266 | else: 267 | content_length = int(content_length) 268 | if content_length > 0: 269 | req_body = environ["wsgi.input"].read(content_length) 270 | else: 271 | input_stream = environ["wsgi.input"] 272 | req_body = read_chunked(input_stream) 273 | 274 | # Handle compression if specified 275 | content_encoding = environ.get("HTTP_CONTENT_ENCODING", "identity").lower() 276 | if content_encoding != "identity": 277 | decompressor = compression.get_decompressor(content_encoding) 278 | if not decompressor: 279 | raise exceptions.ConnecpyServerException( 280 | code=errors.Errors.Unimplemented, 281 | message=f"Unsupported compression: {content_encoding}", 282 | ) 283 | try: 284 | req_body = decompressor(req_body) 285 | except Exception as e: 286 | raise exceptions.ConnecpyServerException( 287 | code=errors.Errors.InvalidArgument, 288 | message=f"Failed to decompress request body: {str(e)}", 289 | ) 290 | 291 | # Get decoder based on content type 292 | content_type = ctx.content_type() 293 | 294 | # Default to proto if not specified 295 | if content_type not in ["application/json", "application/proto"]: 296 | content_type = "application/proto" 297 | ctx = context.ConnecpyServiceContext( 298 | environ.get("REMOTE_ADDR"), 299 | convert_to_mapping({"content-type": ["application/proto"]}), 300 | ) 301 | 302 | decoder = encoding.get_decoder_by_name( 303 | "proto" if content_type == "application/proto" else "json" 304 | ) 305 | if not decoder: 306 | raise exceptions.ConnecpyServerException( 307 | code=errors.Errors.Unimplemented, 308 | message=f"Unsupported encoding: {content_type}", 309 | ) 310 | 311 | decoder = partial(decoder, data_obj=endpoint.input) 312 | try: 313 | request = decoder(req_body) 314 | return request 315 | except Exception as e: 316 | raise exceptions.ConnecpyServerException( 317 | code=errors.Errors.InvalidArgument, 318 | message=f"Failed to decode request body: {str(e)}", 319 | ) 320 | 321 | except Exception as e: 322 | if not isinstance(e, exceptions.ConnecpyServerException): 323 | raise exceptions.ConnecpyServerException( 324 | code=errors.Errors.Internal, 325 | message=str(e), # TODO 326 | ) 327 | raise 328 | 329 | def _handle_get_request(self, environ, endpoint, ctx): 330 | """Handle GET request with query parameters.""" 331 | try: 332 | query_string = environ.get("QUERY_STRING", "") 333 | params = parse_qs(query_string) 334 | 335 | if "message" not in params: 336 | raise exceptions.ConnecpyServerException( 337 | code=errors.Errors.InvalidArgument, 338 | message="'message' parameter is required for GET requests", 339 | ) 340 | 341 | message = params["message"][0] 342 | 343 | if "base64" in params and params["base64"][0] == "1": 344 | try: 345 | message = base64.urlsafe_b64decode(message.encode("ascii")) 346 | except Exception as e: 347 | raise exceptions.ConnecpyServerException( 348 | code=errors.Errors.InvalidArgument, 349 | message=f"Invalid base64 encoding: {str(e)}", 350 | ) 351 | else: 352 | message = message.encode("utf-8") 353 | 354 | # Handle compression if specified 355 | if "compression" in params: 356 | decompressor = compression.get_decompressor(params["compression"][0]) 357 | if decompressor: 358 | try: 359 | message = decompressor(message) 360 | except Exception as e: 361 | raise exceptions.ConnecpyServerException( 362 | code=errors.Errors.InvalidArgument, 363 | message=f"Failed to decompress message: {str(e)}", 364 | ) 365 | 366 | # Handle GET request with proto decoder 367 | try: 368 | # TODO - Use content type from queryparam 369 | request = encoding.get_decoder_by_name("proto")( 370 | message, data_obj=endpoint.input 371 | ) 372 | return request 373 | except Exception as e: 374 | raise exceptions.ConnecpyServerException( 375 | code=errors.Errors.InvalidArgument, 376 | message=f"Failed to decode proto message: {str(e)}", 377 | ) 378 | 379 | except Exception as e: 380 | if not isinstance(e, exceptions.ConnecpyServerException): 381 | raise exceptions.ConnecpyServerException( 382 | code=errors.Errors.Internal, 383 | message=str(e), 384 | ) 385 | raise 386 | 387 | def __call__(self, environ, start_response): 388 | """Handle incoming WSGI requests.""" 389 | try: 390 | request_headers = normalize_wsgi_headers(environ) 391 | request_method = environ.get("REQUEST_METHOD") 392 | if request_method == "POST": 393 | ctx = context.ConnecpyServiceContext( 394 | environ.get("REMOTE_ADDR"), convert_to_mapping(request_headers) 395 | ) 396 | else: 397 | metadata = {} 398 | metadata.update( 399 | extract_metadata_from_query_params(environ.get("QUERY_STRING")) 400 | ) 401 | ctx = context.ConnecpyServiceContext( 402 | environ.get("REMOTE_ADDR"), convert_to_mapping(metadata) 403 | ) 404 | endpoint = self._get_endpoint(environ.get("PATH_INFO")) 405 | request_method = environ.get("REQUEST_METHOD") 406 | if request_method not in endpoint.allowed_methods: 407 | raise exceptions.ConnecpyServerException( 408 | code=errors.Errors.BadRoute, 409 | message=f"unsupported method {request_method}", 410 | ) 411 | # Handle request based on method 412 | if request_method == "GET": 413 | request = self._handle_get_request(environ, endpoint, ctx) 414 | else: 415 | request = self._handle_post_request(environ, endpoint, ctx) 416 | 417 | # Process request 418 | proc = endpoint.make_proc() 419 | response = proc(request, ctx) 420 | 421 | # Encode response 422 | encoder = encoding.get_encoder(endpoint, ctx.content_type()) 423 | res_bytes, base_headers = encoder(response) 424 | 425 | # Handle compression if accepted 426 | accept_encoding = request_headers.get("accept-encoding", "identity") 427 | selected_encoding = compression.select_encoding(accept_encoding) 428 | compressed_bytes = None 429 | compressor = None 430 | if selected_encoding != "identity": 431 | compressor = compression.get_compressor(selected_encoding) 432 | if compressor: 433 | compressed_bytes = compressor(res_bytes) 434 | response_headers, use_compression = prepare_response_headers( 435 | base_headers, 436 | selected_encoding, 437 | len(compressed_bytes) if compressor is not None else None, 438 | ) 439 | 440 | # Convert headers to WSGI format 441 | wsgi_headers = [] 442 | for key, value in response_headers.items(): 443 | if isinstance(value, list): 444 | if value: # Only add if there are values 445 | wsgi_headers.append((key, str(value[0]))) 446 | else: 447 | wsgi_headers.append((key, str(value))) 448 | 449 | start_response("200 OK", wsgi_headers) 450 | final_response = compressed_bytes if use_compression else res_bytes 451 | return [final_response] 452 | except Exception as e: 453 | return self.handle_error(e, environ, start_response) 454 | -------------------------------------------------------------------------------- /test/test_cors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from starlette.responses import Response 3 | from connecpy.cors import middleware, CORSConfig, CORSMiddleware 4 | from starlette.testclient import TestClient 5 | import asyncio 6 | 7 | 8 | @pytest.fixture 9 | def config(): 10 | return CORSConfig() 11 | 12 | 13 | @pytest.fixture 14 | def app(config): 15 | async def app_scope(scope, receive, send): 16 | if scope["type"] != "http": 17 | return 18 | response = Response("OK") 19 | await response(scope, receive, send) 20 | 21 | return middleware(app_scope, config) 22 | 23 | 24 | def test_preflight_request(app): 25 | client = TestClient(app) 26 | response = client.options( 27 | "/", 28 | headers={ 29 | "Origin": "http://example.com", 30 | "Access-Control-Request-Method": "GET", 31 | }, 32 | ) 33 | assert response.status_code == 204 34 | assert response.headers["access-control-allow-origin"] == "*" 35 | assert "POST, GET" in response.headers["access-control-allow-methods"] 36 | assert "Content-Type" in response.headers["access-control-allow-headers"] 37 | assert ( 38 | "Connect-Protocol-Version" in response.headers["access-control-allow-headers"] 39 | ) 40 | assert "Connect-Timeout-Ms" in response.headers["access-control-allow-headers"] 41 | assert "X-User-Agent" in response.headers["access-control-allow-headers"] 42 | assert response.headers["access-control-max-age"] == "86400" 43 | 44 | 45 | def test_preflight_request_custom_config(): 46 | config = CORSConfig( 47 | allow_origin="http://example.com", 48 | allow_methods=("POST", "GET", "PUT"), 49 | allow_headers=("Content-Type", "X-User-Agent"), 50 | access_control_max_age=3600, 51 | ) 52 | 53 | async def app_scope(scope, receive, send): 54 | if scope["type"] != "http": 55 | return 56 | response = Response("OK") 57 | await response(scope, receive, send) 58 | 59 | app = middleware(app_scope, config) 60 | 61 | client = TestClient(app) 62 | response = client.options( 63 | "/", 64 | headers={ 65 | "Origin": "http://example.com", 66 | "Access-Control-Request-Method": "GET", 67 | }, 68 | ) 69 | 70 | assert response.status_code == 204 71 | assert response.headers["access-control-allow-origin"] == "http://example.com" 72 | assert "POST, GET, PUT" in response.headers["access-control-allow-methods"] 73 | assert "Content-Type" in response.headers["access-control-allow-headers"] 74 | assert "X-User-Agent" in response.headers["access-control-allow-headers"] 75 | assert ( 76 | "Connect-Protocol-Version" 77 | not in response.headers["access-control-allow-headers"] 78 | ) 79 | assert "Connect-Timeout-Ms" not in response.headers["access-control-allow-headers"] 80 | assert response.headers["access-control-max-age"] == "3600" 81 | 82 | 83 | def test_simple_request(app): 84 | client = TestClient(app) 85 | response = client.get("/", headers={"Origin": "http://example.com"}) 86 | assert response.status_code == 200 87 | assert response.headers["access-control-allow-origin"] == "*" 88 | assert response.text == "OK" 89 | 90 | 91 | def test_simple_request_custom_config(): 92 | config = CORSConfig(allow_origin="http://example.com") 93 | 94 | async def app_scope(scope, receive, send): 95 | if scope["type"] != "http": 96 | return 97 | response = Response("OK") 98 | await response(scope, receive, send) 99 | 100 | app = middleware(app_scope, config) 101 | 102 | client = TestClient(app) 103 | response = client.get("/", headers={"Origin": "http://example.com"}) 104 | assert response.status_code == 200 105 | assert response.headers["access-control-allow-origin"] == "http://example.com" 106 | assert response.text == "OK" 107 | 108 | 109 | def test_non_http_scope(app): 110 | async def non_http_scope(): 111 | scope = {"type": "websocket"} 112 | receive = lambda: None 113 | send = lambda message: None 114 | await app(scope, receive, send) 115 | 116 | asyncio.run(non_http_scope()) 117 | 118 | 119 | def test_cors_middleware_class(): 120 | async def app_scope(scope, receive, send): 121 | if scope["type"] != "http": 122 | return 123 | response = Response("OK") 124 | await response(scope, receive, send) 125 | 126 | client = TestClient(CORSMiddleware(app_scope)) 127 | response = client.options( 128 | "/", 129 | headers={ 130 | "Origin": "http://example.com", 131 | "Access-Control-Request-Method": "GET", 132 | }, 133 | ) 134 | assert response.status_code == 204 135 | assert response.headers["access-control-allow-origin"] == "*" 136 | assert "POST, GET" in response.headers["access-control-allow-methods"] 137 | assert "Content-Type" in response.headers["access-control-allow-headers"] 138 | assert ( 139 | "Connect-Protocol-Version" in response.headers["access-control-allow-headers"] 140 | ) 141 | assert "Connect-Timeout-Ms" in response.headers["access-control-allow-headers"] 142 | assert "X-User-Agent" in response.headers["access-control-allow-headers"] 143 | assert response.headers["access-control-max-age"] == "86400" 144 | 145 | 146 | def test_cors_middleware_class_custom_config(): 147 | config = CORSConfig( 148 | allow_origin="http://example.com", 149 | allow_methods=("POST", "GET", "PUT"), 150 | allow_headers=("Content-Type", "X-User-Agent"), 151 | access_control_max_age=3600, 152 | ) 153 | 154 | async def app_scope(scope, receive, send): 155 | if scope["type"] != "http": 156 | return 157 | response = Response("OK") 158 | await response(scope, receive, send) 159 | 160 | client = TestClient(CORSMiddleware(app_scope, config=config)) 161 | response = client.options( 162 | "/", 163 | headers={ 164 | "Origin": "http://example.com", 165 | "Access-Control-Request-Method": "GET", 166 | }, 167 | ) 168 | 169 | assert response.status_code == 204 170 | assert response.headers["access-control-allow-origin"] == "http://example.com" 171 | assert "POST, GET, PUT" in response.headers["access-control-allow-methods"] 172 | assert "Content-Type" in response.headers["access-control-allow-headers"] 173 | assert "X-User-Agent" in response.headers["access-control-allow-headers"] 174 | assert ( 175 | "Connect-Protocol-Version" 176 | not in response.headers["access-control-allow-headers"] 177 | ) 178 | assert "Connect-Timeout-Ms" not in response.headers["access-control-allow-headers"] 179 | assert response.headers["access-control-max-age"] == "3600" 180 | --------------------------------------------------------------------------------