├── .github ├── pull_request_template.md └── workflows │ ├── check.yml │ ├── ci.yml │ └── codeql.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── conftest.py ├── pyproject.toml ├── scripts └── benchmark.py ├── src └── synchronicity │ ├── __init__.py │ ├── annotations.py │ ├── async_utils.py │ ├── async_wrap.py │ ├── callback.py │ ├── combined_types.py │ ├── exceptions.py │ ├── interface.py │ ├── overload_tracking.py │ ├── py.typed │ ├── synchronizer.py │ └── type_stubs.py └── test ├── __init__.py ├── async_wrap_test.py ├── asynccontextmanager_test.py ├── callback_test.py ├── conftest.py ├── docstring_test.py ├── exception_test.py ├── fork_test.py ├── generators_test.py ├── getattr_test.py ├── gevent_test.py ├── helper_methods_test.py ├── inspect_test.py ├── nowrap_test.py ├── pickle_test.py ├── shutdown_test.py ├── support ├── _forker.py ├── _gevent.py ├── _shutdown.py ├── _shutdown_async_run.py └── _shutdown_ctx_mgr.py ├── synchronicity_test.py ├── threading_test.py ├── tracebacks_test.py ├── translate_test.py ├── type_stub_e2e_test.py ├── type_stub_helpers ├── .gitignore ├── __init__.py ├── e2e_example_export.py ├── e2e_example_impl.py ├── e2e_example_type_assertions.py └── some_mod.py ├── type_stub_test.py ├── type_stub_translation_test.py └── warnings_test.py /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 4 | 5 | **Issue:** https://github.com/modal-labs/synchronicity/issues/XX or N/A 6 | -------------------------------------------------------------------------------- /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | name: Check 2 | on: push 3 | 4 | jobs: 5 | ruff: 6 | name: Ruff 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 11 | - name: Install uv 12 | uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # v3 13 | - name: Install Python 14 | uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5 15 | with: 16 | python-version: 3.11 17 | - name: Install dependencies 18 | run: uv sync --only-group=lint 19 | - name: Check lint with Ruff 20 | run: uv run --only-group=lint ruff check . 21 | - name: Check formatting with Ruff 22 | run: uv run --only-group=lint ruff format --diff . 23 | 24 | mypy: 25 | name: MyPy 26 | runs-on: ubuntu-latest 27 | 28 | steps: 29 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 30 | - name: Install uv 31 | uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # v3 32 | - name: Install Python 33 | uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5 34 | with: 35 | python-version: 3.11 36 | - name: Install dependencies 37 | run: uv sync --only-group=lint 38 | - name: Run 39 | run: uv run --only-group=lint mypy src/synchronicity 40 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Run Python Tests 2 | on: push 3 | 4 | jobs: 5 | tests: 6 | strategy: 7 | fail-fast: false # run all variants across python versions/os to completion 8 | matrix: 9 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 10 | os: ["ubuntu-latest"] 11 | include: 12 | - os: "macos-13" # x86-64 13 | python-version: "3.10" 14 | - os: "macos-14" # ARM64 (M1) 15 | python-version: "3.10" 16 | - os: "windows-latest" 17 | python-version: "3.10" 18 | 19 | runs-on: ${{ matrix.os }} 20 | 21 | steps: 22 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 23 | - name: Install uv 24 | uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # v3 25 | - name: Install Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install dependencies 30 | run: uv sync --group=dev 31 | - name: Run tests with pytest 32 | run: uv run --group=dev pytest -s 33 | - name: Run README tests 34 | run: uv run --group=dev pytest --markdown-docs README.md 35 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | schedule: 19 | - cron: '19 13 * * 0' 20 | 21 | jobs: 22 | analyze: 23 | name: Analyze 24 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 25 | permissions: 26 | actions: read 27 | contents: read 28 | security-events: write 29 | 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | language: [ 'python' ] 34 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 35 | # Use only 'java' to analyze code written in Java, Kotlin or both 36 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@b8d3b6e8af63cde30bdc382c0bc28114f4346c88 # v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@b8d3b6e8af63cde30bdc382c0bc28114f4346c88 # v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@b8d3b6e8af63cde30bdc382c0bc28114f4346c88 # v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | venv* 3 | .venv 4 | *.egg-info 5 | build 6 | dist 7 | *.iml 8 | uv.lock 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: "v0.7.1" 4 | hooks: 5 | - id: ruff 6 | # Autofix, and respect `exclude` and `extend-exclude` settings. 7 | args: [--fix, --exit-non-zero-on-fix] 8 | - id: ruff-format 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | rm -rf build dist 3 | 4 | build: clean 5 | uv build 6 | 7 | publish: build 8 | uv publish 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![CI/CD badge](https://github.com/erikbern/synchronicity/actions/workflows/ci.yml/badge.svg) 2 | [![pypi badge](https://img.shields.io/pypi/v/synchronicity.svg?style=flat)](https://pypi.python.org/pypi/synchronicity) 3 | 4 | Python 3 has some amazing support for async programming but it's arguably made it a bit harder to develop libraries. Are you tired of implementing synchronous _and_ asynchronous methods doing basically the same thing? This might be a simple solution for you. 5 | 6 | Installing 7 | ========== 8 | 9 | ``` 10 | pip install synchronicity 11 | ``` 12 | 13 | 14 | Background: why is anything like this needed 15 | ============================================ 16 | 17 | Let's say you have an asynchronous function 18 | 19 | ```python fixture:quicksleep 20 | async def f(x): 21 | await asyncio.sleep(1.0) 22 | return x**2 23 | ``` 24 | 25 | And let's say (for whatever reason) you want to offer a synchronous API to users. For instance maybe you want to make it easy to run your code in a basic script, or a user is building something that's mostly CPU-bound, so they don't want to bother with asyncio. 26 | 27 | A "simple" way to create a synchronous equivalent would be to implement a set of synchronous functions where all they do is call [asyncio.run](https://docs.python.org/3/library/asyncio-task.html#asyncio.run) on an asynchronous function. But this isn't a great solution for more complex code: 28 | 29 | * It's kind of tedious grunt work to have to do this for every method/function 30 | * [asyncio.run](https://docs.python.org/3/library/asyncio-task.html#asyncio.run) doesn't work with generators 31 | * In many cases, you need to preserve an event loop running between calls 32 | 33 | The last case is particularly challenging. For instance, let's say you are implementing a client to a database that needs to have a persistent connection, and you want to build it in asyncio: 34 | 35 | ```python 36 | class DBConnection: 37 | def __init__(self, url): 38 | self._url = url 39 | 40 | async def connect(self): 41 | self._connection = await connect_to_database(self._url) 42 | 43 | async def query(self, q): 44 | return await self._connection.run_query(q) 45 | ``` 46 | 47 | How do you expose a synchronous interface to this code? The problem is that wrapping `connect` and `query` in [asyncio.run](https://docs.python.org/3/library/asyncio-task.html#asyncio.run) won't work since you need to _preserve the event loop across calls_. It's clear we need something slightly more advanced. 48 | 49 | How to use this library 50 | ======================= 51 | 52 | This library offers a simple `Synchronizer` class that creates an event loop on a separate thread, and wraps functions/generators/classes so that execution happens on that thread. 53 | 54 | Wrapped functions expose two interfaces: 55 | * For synchronous (non-async) use, the wrapper function itself will simply block until the result of the wrapped function is available (note that you can make it return a future as well, [see below](#returning-futures)) 56 | * For async use, you use a special `.aio` member *on the wrapper function itself* which works just like the usual business of calling asynchronous code (`await`, `async for` etc.) - except that async code is executed on the `Synchronizer`'s own event loop ([more on why this matters below](#using-synchronicity-with-other-asynchronous-code)). 57 | 58 | ```python fixture:quicksleep 59 | import asyncio 60 | from synchronicity import Synchronizer 61 | 62 | synchronizer = Synchronizer() 63 | 64 | @synchronizer.wrap 65 | async def f(x): 66 | await asyncio.sleep(1.0) 67 | return x**2 68 | 69 | 70 | # Running f in a synchronous context blocks until the result is available 71 | ret = f(42) # Blocks 72 | assert isinstance(ret, int) 73 | print('f(42) =', ret) 74 | ``` 75 | 76 | Async usage of the `f` wrapper, using the `f.aio` special coroutine function. This will execute `f` on `synchronizer`'s event loop - not the main event loop used by `asyncio.run()` here: 77 | ```python continuation fixture:quicksleep 78 | async def g(): 79 | # Running f in an asynchronous context works the normal way 80 | ret = await f.aio(42) # f.aio is roughly equivalent to the original `f` 81 | print('f(42) =', ret) 82 | 83 | asyncio.run(g()) 84 | ``` 85 | 86 | More advanced examples 87 | ====================== 88 | 89 | Generators 90 | ---------- 91 | 92 | The decorator also works on async generators, wrapping them as a regular (non-async) generator: 93 | 94 | ```python fixture:quicksleep 95 | @synchronizer.wrap 96 | async def f(n): 97 | for i in range(n): 98 | await asyncio.sleep(1.0) 99 | yield i 100 | 101 | # Note that the following runs in a synchronous context 102 | # Each number will take 1s to print 103 | for ret in f(3): 104 | print(ret) 105 | ``` 106 | 107 | The wrapped generators can also be called safely in an async context using the `.aio` property: 108 | 109 | ```py continuation fixture:quicksleep 110 | async def async_iteration(): 111 | async for ret in f.aio(3): 112 | pass 113 | 114 | asyncio.run(async_iteration()) 115 | ``` 116 | 117 | Synchronizing whole classes 118 | --------------------------- 119 | 120 | The `Synchronizer` wrapper operates on classes by creating a new class that wraps every method on the class: 121 | 122 | 123 | ```python 124 | @synchronizer.wrap 125 | class DBConnection: 126 | def __init__(self, url): 127 | self._url = url 128 | 129 | async def connect(self): 130 | self._connection = await connect_to_database(self._url) 131 | 132 | async def query(self, q): 133 | return await self._connection.run_query(q) 134 | 135 | 136 | # Now we can call it synchronously, if we want to 137 | db_conn = DBConnection('tcp://localhost:1234') 138 | db_conn.connect() 139 | data = db_conn.query('select * from foo') 140 | ``` 141 | *Or*, we could opt to use the wrapped class in an async context if we want to: 142 | ```python continuation 143 | async def async_main(): 144 | db_conn = DBConnection('tcp://localhost:1234') 145 | await db_conn.connect.aio() 146 | await db_conn.query.aio('select * from foo') # .aio works on methods too 147 | 148 | asyncio.run(async_main()) 149 | ``` 150 | 151 | Context managers 152 | ---------------- 153 | 154 | You can synchronize context manager classes just like any other class and the special methods will be handled properly. 155 | 156 | ```python fixture:quicksleep 157 | @synchronizer.wrap 158 | class CtxMgr: 159 | def __init__(self, exit_delay: float): 160 | self.exit_delay = exit_delay 161 | 162 | async def __aenter__(self): 163 | pass 164 | 165 | async def __aexit__(self, exc, exc_type, tb): 166 | await asyncio.sleep(self.exit_delay) 167 | 168 | with CtxMgr(exit_delay=1): 169 | print("sleeping 1 second") 170 | print("done") 171 | ``` 172 | 173 | 174 | Returning futures 175 | ----------------- 176 | 177 | You can also make functions return a `concurrent.futures.Future` object by adding `_future=True` to any call. This can be useful if you want to dispatch many calls from a blocking context, but you want to resolve them roughly in parallel: 178 | 179 | ```python fixture:quicksleep 180 | @synchronizer.wrap 181 | async def f(x): 182 | await asyncio.sleep(1.0) 183 | return x**2 184 | 185 | futures = [f(i, _future=True) for i in range(10)] # This returns immediately, but starts running all calls in the background 186 | rets = [fut.result() for fut in futures] # This should take ~1s to run, resolving all futures in parallel 187 | print('first ten squares:', rets) 188 | ``` 189 | 190 | 191 | Using synchronicity with other asynchronous code 192 | ------------------------------------------------ 193 | 194 | Why does synchronicity expose a separate async interface (`.aio`) when you could just use the original unwrapped function that is already async? It solves two issues: 195 | * Intercompatibility with the non-async interface - you can pass wrapped class instances to the wrapper and those will be "unwrapped" so that the implementation code only needs to deal with unwrapped objects. 196 | * Separate event loops of the library and the user of the library adds safeguards from event loop blockers for both 197 | 198 | A common pitfall in asynchronous programming is to accidentally lock up an event loop by making non-async long-running calls within the event loop. If your async library shares an event loop with a user's own async code, a synchronous call (typically a bug) in either the library or the user code would prevent the other from running concurrent tasks. Using synchronicity wrappers on your library functions, you avoid this pitfall by isolating the library execution to its own event loop and thread automatically. 199 | 200 | 201 | ```python 202 | import time 203 | 204 | @synchronizer.wrap 205 | async def buggy_library(): 206 | time.sleep(0.1) #non-async sleep, this locks the library's event loop for the duration 207 | 208 | async def async_user_code(): 209 | await buggy_library.aio() # this will not lock the "user's" event loop 210 | ``` 211 | 212 | This library can also be useful in purely asynchronous settings, if you have multiple event loops, if you have some section that is CPU-bound, or some critical code that you want to run on a separate thread for safety. All calls to synchronized functions/generators are thread-safe by design. This makes it a useful alternative to [loop.run_in_executor](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor) for simple things. Note however that each synchronizer only runs one thread. 213 | 214 | 215 | Static type support for wrappers 216 | ---------------------------------- 217 | 218 | One issue with the wrapper functions and classes is that they will have different argument and return value types than the wrapped originals (e.g. an AsyncGenerator becomes a Generator after being wrapped). This type transformation can't easily be expressed statically in Python's typing system. 219 | 220 | For this reason, synchronicity includes a basic type stub (.pyi) generation tool (`python -m synchronicity.type_stubs`) that takes Python modules names as inputs and emits a `.pyi` file for each module with static types translating any synchronicity-wrapped classes or functions. 221 | 222 | Since the `.pyi` file will sometimes "shadow" the original file, which you still might want to type check for issues in the implementation code, a good practice is to separate wrappers and wrapped implementation code into separate modules and only emit type stubs for the "wrapper modules". 223 | 224 | A recommended structure would be something like this: 225 | 226 | #### _my_library.py (private library implementation) 227 | ```py 228 | import typing 229 | 230 | async def foo() -> typing.AsyncGenerator[int, None]: 231 | yield 1 232 | ``` 233 | 234 | #### my_library.py (public library interface) 235 | ```py notest 236 | import _my_library 237 | from synchronicity import Synchronizer 238 | 239 | synchronizer = Synchronizer() 240 | 241 | foo = synchronizer.wrap(_my_library.foo, name="MyClass", target_module=__name__) 242 | ``` 243 | 244 | You can then emit type stubs for the public module, as part of your build process: 245 | ```shell 246 | python -m synchronicity.type_stubs my_module 247 | ``` 248 | The automatically generated type stub `my_library.pyi` would then look something like: 249 | ```py 250 | import typing 251 | import typing_extensions 252 | 253 | class __foo_spec(typing_extensions.Protocol): 254 | def __call__(self) -> typing.Generator[int, None, None]: 255 | ... 256 | 257 | def aio(self) -> typing.AsyncGenerator[int, None]: 258 | ... 259 | 260 | foo: __foo_spec 261 | ``` 262 | 263 | The special `*_spec` protocol types here make sure that both calling the wrapped `for x in foo()` method and `async for x in foo.aio()` will be statically valid operations, and their respective return values are typed correctly. 264 | 265 | 266 | Gotchas 267 | ======= 268 | 269 | * If you have a non-async function that *returns* an awaitable or other async entity, but isn't itself defined with the `async` keyword, you have to *type annotate* that function with the correct async return type - otherwise it will not get wrapped correctly by `synchronizer.wrap`: 270 | 271 | ```py 272 | @synchronizer.wrap 273 | def foo() -> typing.AsyncContextManager[str]: 274 | return make_context_manager() 275 | ``` 276 | * If a class is "synchronized", any instance of that class will be a proxy for an instance of the original class. Methods on the class will delegate to methods of the underlying class, but *attributes* of the original class aren't directly reachable and would need getter methods or @properties to be reachable on the wrapper. 277 | * Note that all synchronized code will run on a different thread, and a different event loop, so calling the code might have some minor extra overhead. 278 | * Since all arguments and return values of wrapped functions are recursively run-time inspected to "translate" them, large data structures that are passed in and out can incur extra overhead. This can be disabled using a `@synchronizer.no_io_translation` decorator on the original function. 279 | 280 | 281 | Future ideas 282 | ===== 283 | * Use type annotations instead of runtime type inspection to determine the wrapping logic. This would prevent overly zealous argument/return value inspection when it isn't needed. 284 | * Use (optional?) code generation (using type annotations) instead of runtime wrappers + type stub generation. This could make it easier to navigate exception tracebacks, and lead to simpler/better static types for wrappers. 285 | * Support the opposite case, i.e. you have a blocking function/generator/class/object, and you want to call it asynchronously (this is relatively simple to do for plain functions using [loop.run_in_executor](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), but Python has no built-in support for generators, and it would be nice to transform a whole class. 286 | * More/better documentation 287 | * A cleaner way to return futures from sync code (instead of the `_future=True` argument) 288 | 289 | 290 | Release process 291 | =============== 292 | TODO: We should automate this in CI/CD 293 | 294 | * Make a new branch `release-X.Y.Z` from main 295 | * Bump version in pyproject.toml to `X.Y.Z` 296 | * Commit that change and create a PR 297 | * Merge the PR once green 298 | * Checkout main 299 | * `git tag -a vX.Y.Z -m "* release bullets"` 300 | * git push --tags 301 | * `UV_PUBLISH_TOKEN="$PYPI_TOKEN_SYNCHRONICITY" make publish` 302 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | import typing 4 | 5 | from synchronicity import Synchronizer 6 | 7 | 8 | class DummyConnection: 9 | async def run_query(self, query): 10 | pass 11 | 12 | 13 | async def dummy_connect_to_db(url): 14 | return DummyConnection() 15 | 16 | 17 | @pytest.fixture() 18 | def quicksleep(monkeypatch): 19 | from asyncio import sleep as original_sleep 20 | 21 | monkeypatch.setattr("asyncio.sleep", lambda x: original_sleep(x / 1000.0)) 22 | 23 | 24 | def pytest_markdown_docs_globals(): 25 | synchronizer = Synchronizer() 26 | return { 27 | "typing": typing, 28 | "synchronizer": synchronizer, 29 | "asyncio": asyncio, 30 | "connect_to_database": dummy_connect_to_db, 31 | } 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "synchronicity" 3 | version = "0.10.2" 4 | description = "Export blocking and async library versions from a single async implementation" 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Modal Labs" } 8 | ] 9 | requires-python = ">=3.9" 10 | dependencies = [ 11 | "sigtools>=4.0.1", 12 | "typing-extensions>=4.12.2", 13 | ] 14 | classifiers = [ 15 | "Operating System :: OS Independent", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Programming Language :: Python :: 3", 18 | ] 19 | 20 | 21 | [build-system] 22 | requires = ["hatchling"] 23 | build-backend = "hatchling.build" 24 | 25 | 26 | [tool.ruff] 27 | line-length = 120 28 | exclude = ['.venv', '.git', '__pycache__', 'build', 'dist'] 29 | 30 | [tool.ruff.lint] 31 | select = ['E', 'F', 'W', 'I'] 32 | 33 | [tool.ruff.lint.isort] 34 | combine-as-imports = true 35 | known-first-party = ["synchronicity"] 36 | extra-standard-library = ["pytest"] 37 | 38 | [tool.hatch.build.targets.sdist] 39 | exclude = [ 40 | ".*", 41 | ] 42 | 43 | [dependency-groups] 44 | dev = [ 45 | "pre-commit>=3.5.0", 46 | {include-group = "lint"}, 47 | {include-group = "test"} 48 | ] 49 | lint = [ 50 | "mypy-extensions>=1.0.0", 51 | "mypy>=1.13.0", 52 | "ruff>=0.11.13", 53 | ] 54 | test = [ 55 | "console-ctrl>=0.1.0", 56 | "gevent>=24.2.1; python_version < '3.13'", 57 | "pytest>=8.3.3", 58 | "pytest-asyncio>=0.24.0", 59 | "pytest-markdown-docs>=0.7.1", 60 | ] 61 | 62 | 63 | [tool.pytest.ini_options] 64 | filterwarnings = [ 65 | "error", 66 | "ignore::DeprecationWarning", 67 | "ignore::PendingDeprecationWarning", 68 | ] 69 | -------------------------------------------------------------------------------- /scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import time 4 | 5 | from synchronicity import Synchronizer 6 | 7 | s = Synchronizer() 8 | 9 | 10 | async def _f(): 11 | pass 12 | 13 | 14 | f = s.wrap(_f) 15 | 16 | 17 | @contextlib.contextmanager 18 | def timer(test_str: str): 19 | t0 = time.monotonic() 20 | yield 21 | t1 = time.monotonic() 22 | print(f"Ran {test_str} in {t1 - t0} seconds") 23 | 24 | 25 | n = 10_000 26 | 27 | 28 | async def run_original(): 29 | with timer(f"original * {n}"): 30 | [(await _f()) for i in range(n)] 31 | 32 | 33 | asyncio.run(run_original()) 34 | 35 | with timer(f"sync * {n}"): 36 | [f() for i in range(n)] 37 | 38 | 39 | async def run_some_async(): 40 | with timer(f"async * {n}"): 41 | [(await f.aio()) for i in range(n)] 42 | 43 | 44 | asyncio.run(run_some_async()) 45 | -------------------------------------------------------------------------------- /src/synchronicity/__init__.py: -------------------------------------------------------------------------------- 1 | from .synchronizer import Synchronizer, classproperty 2 | 3 | __all__ = ["Synchronizer", "classproperty"] 4 | -------------------------------------------------------------------------------- /src/synchronicity/annotations.py: -------------------------------------------------------------------------------- 1 | # compatibility utilities/polyfills for supporting older python versions 2 | import importlib 3 | import logging 4 | import sys 5 | import typing 6 | 7 | logger = logging.getLogger("synchronicity") 8 | 9 | # Modules that cannot be evaluated at runtime, e.g., 10 | # only available under the TYPE_CHECKING guard, but can be used freely in stub files 11 | TYPE_CHECKING_OVERRIDES = {"_typeshed"} 12 | 13 | 14 | def evaluated_annotation(annotation, *, globals_=None, declaration_module=None): 15 | # evaluate string annotations... 16 | imported_declaration_module = None 17 | if globals_ is None and declaration_module is not None: 18 | if declaration_module in sys.modules: 19 | # already loaded module 20 | imported_declaration_module = sys.modules[declaration_module] 21 | else: 22 | imported_declaration_module = importlib.import_module(declaration_module) 23 | globals_ = imported_declaration_module.__dict__ 24 | 25 | try: 26 | return eval(annotation, globals_) 27 | except NameError: 28 | if "." in annotation: 29 | # in case of unimported modules referenced in the annotation itself 30 | # typically happens with TYPE_CHECKING guards etc. 31 | ref_module, _ = annotation.rsplit(".", 1) 32 | # for modules that can't be evaluated at runtime, 33 | # return a ForwardRef with __forward_module__ set 34 | # to the name of the module that we want to import in the stub file 35 | if ref_module in TYPE_CHECKING_OVERRIDES: 36 | ref = typing.ForwardRef(annotation) 37 | ref.__forward_module__ = ref_module 38 | return ref 39 | # hack: import the library *into* the namespace of the supplied globals 40 | exec(f"import {ref_module}", globals_) 41 | return eval(annotation, globals_) 42 | raise 43 | -------------------------------------------------------------------------------- /src/synchronicity/async_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import signal 3 | import sys 4 | import threading 5 | import typing 6 | 7 | from synchronicity.exceptions import NestedEventLoops 8 | 9 | T = typing.TypeVar("T") 10 | 11 | 12 | class Runner: 13 | """Simplified backport of asyncio.Runner from Python 3.11 14 | 15 | Like asyncio.run() but allows multiple calls to the same event loop 16 | before teardown, and is converts sigints into graceful cancellations 17 | similar to asyncio.run on Python 3.11+. 18 | """ 19 | 20 | def __enter__(self) -> "Runner": 21 | try: 22 | asyncio.get_running_loop() 23 | except RuntimeError: 24 | pass # no event loop - this is what we expect! 25 | else: 26 | raise NestedEventLoops() 27 | 28 | self._loop = asyncio.new_event_loop() 29 | return self 30 | 31 | def __exit__(self, exc_type, exc_value, traceback): 32 | self._loop.run_until_complete(self._loop.shutdown_asyncgens()) 33 | if sys.version_info[:2] >= (3, 9): 34 | # Introduced in Python 3.9 35 | self._loop.run_until_complete(self._loop.shutdown_default_executor()) 36 | 37 | self._loop.close() 38 | return False 39 | 40 | def run(self, coro: typing.Awaitable[T]) -> T: 41 | is_main_thread = threading.current_thread() == threading.main_thread() 42 | self._num_sigints = 0 43 | 44 | coro_task = asyncio.ensure_future(coro, loop=self._loop) 45 | 46 | async def wrapper_coro(): 47 | # this wrapper ensures that we won't reraise KeyboardInterrupt into 48 | # the calling scope until all async finalizers in coro_task have 49 | # finished executing. It even allows the coro to prevent cancellation 50 | # and thereby ignoring the first keyboardinterrupt 51 | return await coro_task 52 | 53 | def _sigint_handler(signum, frame): 54 | # cancel the task in order to have run_until_complete return soon and 55 | # prevent a bunch of unwanted tracebacks when shutting down the 56 | # event loop. 57 | 58 | # this basically replicates the sigint handler installed by asyncio.run() 59 | self._num_sigints += 1 60 | if self._num_sigints == 1: 61 | # first sigint is graceful 62 | self._loop.call_soon_threadsafe(coro_task.cancel) 63 | return 64 | 65 | # this should normally not happen, but the second sigint would "hard kill" the event loop 66 | # by raising KeyboardInterrupt inside of it 67 | raise KeyboardInterrupt() 68 | 69 | original_sigint_handler = None 70 | try: 71 | # only install signal handler if running from main thread and we haven't disabled sigint 72 | handle_sigint = is_main_thread and signal.getsignal(signal.SIGINT) == signal.default_int_handler 73 | 74 | if handle_sigint: 75 | # intentionally not using _loop.add_signal_handler since it's slow (?) 76 | # and not available on Windows. We just don't want the sigint to 77 | # mess with the event loop anyways 78 | original_sigint_handler = signal.signal(signal.SIGINT, _sigint_handler) 79 | except KeyboardInterrupt: 80 | # this is quite unlikely, but with bad timing we could get interrupted before 81 | # installing the sigint handler and this has happened repeatedly in unit tests 82 | _sigint_handler(signal.SIGINT, None) 83 | 84 | try: 85 | return self._loop.run_until_complete(wrapper_coro()) 86 | except asyncio.CancelledError: 87 | if self._num_sigints > 0: 88 | raise KeyboardInterrupt() # might want to use original_sigint_handler here instead? 89 | raise # "internal" cancellations, not triggered by KeyboardInterrupt 90 | finally: 91 | if original_sigint_handler: 92 | # reset signal handler 93 | signal.signal(signal.SIGINT, original_sigint_handler) 94 | -------------------------------------------------------------------------------- /src/synchronicity/async_wrap.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import contextlib 3 | import functools 4 | import inspect 5 | import typing 6 | from contextlib import asynccontextmanager as _asynccontextmanager 7 | 8 | import typing_extensions 9 | 10 | from .exceptions import UserCodeException, suppress_synchronicity_tb_frames 11 | from .interface import Interface 12 | 13 | 14 | def wraps_by_interface(interface: Interface, func): 15 | """Like functools.wraps but maintains `inspect.iscoroutinefunction` and allows custom type annotations overrides 16 | 17 | Use this when the wrapper function is non-async but returns the coroutine resulting 18 | from calling the underlying wrapped `func`. This will make sure that the wrapper 19 | is still an async function in that case, and can be inspected as such. 20 | 21 | Note: Does not forward async generator information other than explicit annotations 22 | """ 23 | if is_coroutine_function_follow_wrapped(func) and interface == Interface._ASYNC_WITH_BLOCKING_TYPES: 24 | 25 | def asyncfunc_deco(user_wrapper): 26 | @functools.wraps(func) 27 | async def wrapper(*args, **kwargs): 28 | with suppress_synchronicity_tb_frames(): 29 | try: 30 | return await user_wrapper(*args, **kwargs) 31 | except UserCodeException as uc_exc: 32 | uc_exc.exc.__suppress_context__ = True 33 | raise uc_exc.exc 34 | 35 | return wrapper 36 | 37 | return asyncfunc_deco 38 | else: 39 | return functools.wraps(func) 40 | 41 | 42 | def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: 43 | """Determine if func returns a coroutine, unwrapping decorators, but not the async synchronicity interace.""" 44 | from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import 45 | 46 | if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: 47 | return is_coroutine_function_follow_wrapped(func.__wrapped__) 48 | return inspect.iscoroutinefunction(func) 49 | 50 | 51 | def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool: 52 | """Determine if func returns an async generator, unwrapping decorators, but not the async synchronicity interace.""" 53 | from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import 54 | 55 | if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: 56 | return is_async_gen_function_follow_wrapped(func.__wrapped__) 57 | return inspect.isasyncgenfunction(func) 58 | 59 | 60 | YIELD_TYPE = typing.TypeVar("YIELD_TYPE") 61 | SEND_TYPE = typing.TypeVar("SEND_TYPE") 62 | 63 | 64 | P = typing_extensions.ParamSpec("P") 65 | 66 | 67 | def asynccontextmanager( 68 | f: typing.Callable[P, typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE]], 69 | ) -> typing.Callable[P, typing.AsyncContextManager[YIELD_TYPE]]: 70 | """Wrapper around contextlib.asynccontextmanager that sets correct type annotations 71 | 72 | The standard library one doesn't 73 | """ 74 | acm_factory: typing.Callable[..., typing.AsyncContextManager[YIELD_TYPE]] = _asynccontextmanager(f) 75 | 76 | old_ret = acm_factory.__annotations__.pop("return", None) 77 | if old_ret is not None: 78 | if old_ret.__origin__ in [ 79 | collections.abc.AsyncGenerator, 80 | collections.abc.AsyncIterator, 81 | collections.abc.AsyncIterable, 82 | ]: 83 | acm_factory.__annotations__["return"] = typing.AsyncContextManager[old_ret.__args__[0]] # type: ignore 84 | elif old_ret.__origin__ == contextlib.AbstractAsyncContextManager: 85 | # if the standard lib fixes the annotations in the future, lets not break it... 86 | return acm_factory 87 | else: 88 | raise ValueError( 89 | "To use the fixed @asynccontextmanager, make sure to properly" 90 | " annotate your wrapped function as an AsyncGenerator" 91 | ) 92 | 93 | return acm_factory 94 | -------------------------------------------------------------------------------- /src/synchronicity/callback.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import inspect 3 | 4 | 5 | class Callback: 6 | """A callback is when synchronized call needs to call outside functions passed into it. 7 | 8 | Currently only supports non-generator functions.""" 9 | 10 | def __init__(self, synchronizer, f): 11 | self._synchronizer = synchronizer 12 | self._f = f 13 | 14 | def _invoke(self, args, kwargs): 15 | # This runs on a separate thread 16 | res = self._f(*args, **kwargs) 17 | if inspect.iscoroutine(res): 18 | try: 19 | loop = asyncio.new_event_loop() 20 | return loop.run_until_complete(res) 21 | finally: 22 | loop.close() 23 | elif inspect.isasyncgen(res): 24 | raise RuntimeError("Async generators are not supported") 25 | elif inspect.isgenerator(res): 26 | raise RuntimeError("Generators are not supported") 27 | else: 28 | return res 29 | 30 | async def __call__(self, *args, **kwargs): 31 | # This translates the opposite way from the code in the synchronizer 32 | args = self._synchronizer._translate_out(args) 33 | kwargs = self._synchronizer._translate_out(kwargs) 34 | 35 | # This function may be blocking, so we need to run it on a thread 36 | loop = asyncio.get_event_loop() 37 | res = await loop.run_in_executor(None, self._invoke, args, kwargs) 38 | 39 | # Now, we need to translate the result _in_ 40 | return self._synchronizer._translate_in(res) 41 | -------------------------------------------------------------------------------- /src/synchronicity/combined_types.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import typing 3 | 4 | import typing_extensions 5 | 6 | from synchronicity.async_wrap import wraps_by_interface 7 | from synchronicity.exceptions import UserCodeException, suppress_synchronicity_tb_frames 8 | from synchronicity.interface import Interface 9 | 10 | if typing.TYPE_CHECKING: 11 | from synchronicity.synchronizer import Synchronizer 12 | 13 | 14 | class FunctionWithAio: 15 | def __init__(self, func, aio_func, synchronizer): 16 | self._func = func 17 | self.aio = self._aio_func = aio_func 18 | self._synchronizer = synchronizer 19 | 20 | def __call__(self, *args, **kwargs): 21 | # .__call__ is special - it's being looked up on the class instead of the instance when calling something, 22 | # so setting the magic method from the constructor is not possible 23 | # https://stackoverflow.com/questions/22390532/object-is-not-callable-after-adding-call-method-to-instance 24 | # so we need to use an explicit wrapper function here 25 | with suppress_synchronicity_tb_frames(): 26 | try: 27 | return self._func(*args, **kwargs) 28 | except UserCodeException as uc_exc: 29 | # For Python < 3.11 we use UserCodeException as an exception wrapper 30 | # to remove some internal frames from tracebacks, but it can't remove 31 | # all frames 32 | uc_exc.exc.__suppress_context__ = True 33 | raise uc_exc.exc 34 | 35 | 36 | class MethodWithAio: 37 | """Creates a bound method that can have callable child-properties on the method itself. 38 | 39 | Child-properties are also bound to the parent instance. 40 | """ 41 | 42 | def __init__(self, func, aio_func, synchronizer: "Synchronizer", is_classmethod=False): 43 | self._func = func 44 | self._aio_func = aio_func 45 | self._synchronizer = synchronizer 46 | self._is_classmethod = is_classmethod 47 | 48 | def __get__(self, instance, owner=None): 49 | bind_var = instance if instance is not None and not self._is_classmethod else owner 50 | 51 | bound_func = functools.wraps(self._func)(functools.partial(self._func, bind_var)) # bound blocking function 52 | self._synchronizer._update_wrapper(bound_func, self._func, interface=Interface.BLOCKING) 53 | 54 | bound_aio_func = wraps_by_interface(Interface._ASYNC_WITH_BLOCKING_TYPES, self._aio_func)( 55 | functools.partial(self._aio_func, bind_var) 56 | ) # bound async function 57 | self._synchronizer._update_wrapper(bound_func, self._func, interface=Interface._ASYNC_WITH_BLOCKING_TYPES) 58 | bound_func.aio = bound_aio_func 59 | return bound_func 60 | 61 | 62 | CTX = typing.TypeVar("CTX", covariant=True) 63 | 64 | 65 | class AsyncAndBlockingContextManager(typing_extensions.Protocol[CTX]): 66 | def __enter__(self) -> CTX: ... 67 | 68 | async def __aenter__(self) -> CTX: ... 69 | 70 | def __exit__(self, typ, value, tb): ... 71 | 72 | async def __aexit__(self, typ, value, tb): ... 73 | -------------------------------------------------------------------------------- /src/synchronicity/exceptions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures 3 | import os 4 | import sys 5 | from pathlib import Path 6 | from types import TracebackType 7 | from typing import Literal, Optional 8 | 9 | import synchronicity 10 | 11 | SYNCHRONICITY_TRACEBACK = os.getenv("SYNCHRONICITY_TRACEBACK", "0") == "1" 12 | # note to insert into exception.__notes__ if a traceback frame is hidden 13 | SYNCHRONICITY_TRACEBACK_NOTE = None 14 | 15 | 16 | class UserCodeException(Exception): 17 | """This is used to wrap and unwrap exceptions in "user code". 18 | 19 | This lets us have cleaner tracebacks without all the internal synchronicity stuff.""" 20 | 21 | def __init__(self, exc): 22 | # There's always going to be one place inside synchronicity where we 23 | # catch the exception. We can always safely remove that frame from the 24 | # traceback. 25 | self.exc = exc 26 | 27 | 28 | def wrap_coro_exception(coro): 29 | async def coro_wrapped(): 30 | try: 31 | return await coro 32 | except StopAsyncIteration: 33 | raise 34 | except asyncio.CancelledError: 35 | # we don't want to wrap these since cancelled Task's are otherwise 36 | # not properly marked as cancelled, and then not treated correctly 37 | # during event loop shutdown (perhaps in other places too) 38 | raise 39 | except UserCodeException: 40 | raise # Pass-through in case it got double-wrapped 41 | except TimeoutError as exc: 42 | # user-raised TimeoutError always needs to be wrapped, or they would interact 43 | # with synchronicity's own timeout handling 44 | # TODO: if we want to get rid of UserCodeException at some point 45 | # we could use a custom version of `asyncio.wait_for` to get around this 46 | raise UserCodeException(exc) 47 | except Exception as exc: 48 | if sys.version_info < (3, 11) and not SYNCHRONICITY_TRACEBACK: 49 | exc.with_traceback(exc.__traceback__.tb_next) # skip the `await coro` frame from above 50 | raise UserCodeException(exc) 51 | raise # raise as is on Python 3.11 - we hide things later 52 | except BaseException as exc: 53 | # special case if a coroutine raises a KeyboardInterrupt or similar 54 | # exception that would otherwise kill the event loop. 55 | # Not sure if this is wise tbh, but there is a unit test that checks 56 | # for KeyboardInterrupt getting propagated, which would require this 57 | raise UserCodeException(exc) 58 | 59 | return coro_wrapped() 60 | 61 | 62 | async def unwrap_coro_exception(coro): 63 | try: 64 | return await coro 65 | except UserCodeException as uc_exc: 66 | uc_exc.exc.__suppress_context__ = True 67 | raise uc_exc.exc 68 | 69 | 70 | class NestedEventLoops(Exception): 71 | pass 72 | 73 | 74 | _skip_modules = [synchronicity, concurrent.futures, asyncio] 75 | _skip_module_roots = [Path(mod.__file__).parent for mod in _skip_modules if mod.__file__] 76 | 77 | 78 | class suppress_synchronicity_tb_frames: 79 | def __enter__(self): 80 | pass 81 | 82 | def __exit__( 83 | self, exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType] 84 | ) -> Literal[False]: 85 | if tb is None or exc_type is None or exc is None or SYNCHRONICITY_TRACEBACK: 86 | # no exception, or enabled full tracebacks - don't do anything 87 | return False 88 | 89 | def should_hide_file(fn: str): 90 | return any(Path(fn).is_relative_to(modroot) for modroot in _skip_module_roots) 91 | 92 | def get_next_valid(tb: TracebackType) -> Optional[TracebackType]: 93 | next_valid: Optional[TracebackType] = tb 94 | while next_valid is not None and should_hide_file(next_valid.tb_frame.f_code.co_filename or ""): 95 | next_valid = next_valid.tb_next 96 | return next_valid 97 | 98 | cleaned_root = get_next_valid(tb) 99 | if cleaned_root is None: 100 | # no frames outside of skip_modules - return original error 101 | return False 102 | 103 | exc.with_traceback(cleaned_root) # side effect modification of exception object 104 | exc_notes = getattr(exc, "__notes__", []) 105 | if SYNCHRONICITY_TRACEBACK_NOTE is not None and SYNCHRONICITY_TRACEBACK_NOTE not in exc_notes: 106 | exc_notes.append(exc_notes) 107 | 108 | return False 109 | -------------------------------------------------------------------------------- /src/synchronicity/interface.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | class Interface(enum.Enum): 5 | BLOCKING = enum.auto() 6 | _ASYNC_WITH_BLOCKING_TYPES = enum.auto() # this is *only* used for functions, since all types are blocking 7 | 8 | 9 | # Default names for classes 10 | DEFAULT_CLASS_PREFIX = "Blocking" 11 | 12 | # Default names for functions 13 | DEFAULT_FUNCTION_PREFIXES = { 14 | Interface.BLOCKING: "blocking_", 15 | # this is only used internally - usage will be via `.aio` on the blocking function: 16 | Interface._ASYNC_WITH_BLOCKING_TYPES: "aio_", 17 | } 18 | -------------------------------------------------------------------------------- /src/synchronicity/overload_tracking.py: -------------------------------------------------------------------------------- 1 | """Utility for monkey patching typing.overload to allow run time retrieval overloads 2 | 3 | Requires any @typing.overload to happen within the patched_overload contextmanager, e.g.: 4 | 5 | ```python 6 | with patched_overload(): 7 | # the following could be imported from some other module (as long as it wasn't already loaded), or inlined: 8 | 9 | @typing.overload 10 | def foo(a: int) -> float: 11 | ... 12 | 13 | def foo(a: typing.Union[bool, int]) -> typing.Union[bool, float]: 14 | if isinstance(a, bool): 15 | return a 16 | return float(a) 17 | 18 | # returns reference to the overloads of foo (the int -> float one in this case) 19 | # in the order they are declared 20 | foo_overloads = get_overloads(foo) 21 | """ 22 | 23 | import contextlib 24 | import typing 25 | from unittest import mock 26 | 27 | overloads: typing.Dict[typing.Tuple[str, str], typing.List] = {} 28 | original_overload = typing.overload 29 | 30 | 31 | class Untrackable(Exception): 32 | pass 33 | 34 | 35 | def _function_locator(f): 36 | if isinstance(f, (staticmethod, classmethod)): 37 | return _function_locator(f.__func__) 38 | 39 | try: 40 | return (f.__module__, f.__qualname__) 41 | except AttributeError: 42 | raise Untrackable() # TODO(elias): handle descriptors like classmethod 43 | 44 | 45 | def _tracking_overload(f): 46 | # hacky thing to track all typing.overload declarations 47 | global overloads, original_overload 48 | try: 49 | locator = _function_locator(f) 50 | overloads.setdefault(locator, []).append(f) 51 | except Untrackable: 52 | print(f"WARNING: can't track overloads for {f}") 53 | 54 | return original_overload(f) 55 | 56 | 57 | @contextlib.contextmanager 58 | def patched_overload(): 59 | with mock.patch("typing.overload", _tracking_overload): 60 | yield 61 | 62 | 63 | def get_overloads(f) -> typing.List: 64 | try: 65 | return overloads.get(_function_locator(f), []) 66 | except Untrackable: 67 | return [] 68 | -------------------------------------------------------------------------------- /src/synchronicity/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modal-labs/synchronicity/731a20bbe9bff1d973eb8bbb87f73f2e9d4bad1a/src/synchronicity/py.typed -------------------------------------------------------------------------------- /src/synchronicity/synchronizer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import asyncio.futures 3 | import atexit 4 | import collections.abc 5 | import concurrent 6 | import concurrent.futures 7 | import contextlib 8 | import functools 9 | import inspect 10 | import os 11 | import threading 12 | import types 13 | import typing 14 | import warnings 15 | from typing import ForwardRef, Optional 16 | 17 | import typing_extensions 18 | 19 | from synchronicity.annotations import evaluated_annotation 20 | from synchronicity.combined_types import FunctionWithAio, MethodWithAio 21 | 22 | from .async_wrap import is_async_gen_function_follow_wrapped, is_coroutine_function_follow_wrapped, wraps_by_interface 23 | from .callback import Callback 24 | from .exceptions import UserCodeException, suppress_synchronicity_tb_frames, unwrap_coro_exception, wrap_coro_exception 25 | from .interface import DEFAULT_CLASS_PREFIX, DEFAULT_FUNCTION_PREFIXES, Interface 26 | 27 | _BUILTIN_ASYNC_METHODS = { 28 | "__aiter__": "__iter__", 29 | "__aenter__": "__enter__", 30 | "__aexit__": "__exit__", 31 | "__anext__": "__next__", 32 | "aclose": "close", 33 | } 34 | 35 | IGNORED_ATTRIBUTES = ( 36 | # the "zope" lib monkey patches in some non-introspectable stuff on stdlib abc.ABC. 37 | # Ignoring __provides__ fixes an incompatibility with `channels[daphne]`, 38 | # where Synchronizer creation fails when wrapping contextlib._AsyncGeneratorContextManager 39 | "__provides__", 40 | ) 41 | 42 | _RETURN_FUTURE_KWARG = "_future" 43 | 44 | TARGET_INTERFACE_ATTR = "_sync_target_interface" 45 | SYNCHRONIZER_ATTR = "_sync_synchronizer" 46 | 47 | 48 | ASYNC_GENERIC_ORIGINS = ( 49 | collections.abc.Awaitable, 50 | collections.abc.Coroutine, 51 | collections.abc.AsyncIterator, 52 | collections.abc.AsyncIterable, 53 | collections.abc.AsyncGenerator, 54 | contextlib.AbstractAsyncContextManager, 55 | ) 56 | 57 | 58 | class classproperty: 59 | """Read-only class property recognized by Synchronizer's wrap method.""" 60 | 61 | def __init__(self, fget): 62 | self.fget = fget 63 | 64 | def __get__(self, obj, owner): 65 | return self.fget(owner) 66 | 67 | 68 | def _type_requires_aio_usage(annotation, declaration_module): 69 | if isinstance(annotation, ForwardRef): 70 | annotation = annotation.__forward_arg__ 71 | if isinstance(annotation, str): 72 | try: 73 | annotation = evaluated_annotation(annotation, declaration_module=declaration_module) 74 | except Exception: 75 | # TODO: this will be incorrect in special case of `arg: "Awaitable[some_forward_ref_type]"`, 76 | # but its a hard problem to solve without passing around globals everywhere 77 | return False 78 | 79 | if hasattr(annotation, "__origin__"): 80 | if annotation.__origin__ in ASYNC_GENERIC_ORIGINS: # type: ignore 81 | return True 82 | # recurse for generic subtypes 83 | for a in getattr(annotation, "__args__", ()): 84 | if _type_requires_aio_usage(a, declaration_module): 85 | return True 86 | return False 87 | 88 | 89 | def should_have_aio_interface(func): 90 | # determines if a blocking function gets an .aio attribute with an async interface to the function or not 91 | if is_coroutine_function_follow_wrapped(func) or is_async_gen_function_follow_wrapped(func): 92 | return True 93 | # check annotations if they contain any async entities that would need an event loop to be translated: 94 | # This catches things like vanilla functions returning Coroutines 95 | annos = getattr(func, "__annotations__", {}) 96 | for anno in annos.values(): 97 | if _type_requires_aio_usage(anno, func.__module__): 98 | return True 99 | return False 100 | 101 | 102 | class Synchronizer: 103 | """Helps you offer a blocking (synchronous) interface to asynchronous code.""" 104 | 105 | def __init__( 106 | self, 107 | multiwrap_warning=False, 108 | async_leakage_warning=True, 109 | ): 110 | self._future_poll_interval = 0.1 111 | self._multiwrap_warning = multiwrap_warning 112 | self._async_leakage_warning = async_leakage_warning 113 | self._loop = None 114 | self._loop_creation_lock = threading.Lock() 115 | self._thread = None 116 | self._owner_pid = None 117 | self._stopping = None 118 | 119 | # Special attribute we use to go from wrapped <-> original 120 | self._wrapped_attr = "_sync_wrapped_%d" % id(self) 121 | self._original_attr = "_sync_original_%d" % id(self) 122 | 123 | # Special attribute to mark something as non-wrappable 124 | self._nowrap_attr = "_sync_nonwrap_%d" % id(self) 125 | self._input_translation_attr = "_sync_input_translation_%d" % id(self) 126 | self._output_translation_attr = "_sync_output_translation_%d" % id(self) 127 | 128 | # Prep a synchronized context manager in case one is returned and needs translation 129 | self._ctx_mgr_cls = contextlib._AsyncGeneratorContextManager 130 | self.create_blocking(self._ctx_mgr_cls) 131 | atexit.register(self._close_loop) 132 | 133 | _PICKLE_ATTRS = [ 134 | "_multiwrap_warning", 135 | "_async_leakage_warning", 136 | ] 137 | 138 | def __getstate__(self): 139 | return dict([(attr, getattr(self, attr)) for attr in self._PICKLE_ATTRS]) 140 | 141 | def __setstate__(self, d): 142 | for attr in self._PICKLE_ATTRS: 143 | setattr(self, attr, d[attr]) 144 | 145 | def _start_loop(self): 146 | with self._loop_creation_lock: 147 | if self._loop and self._loop.is_running(): 148 | # in case of a race between two _start_loop, the loop might already 149 | # be created here by another thread 150 | return self._loop 151 | 152 | is_ready = threading.Event() 153 | 154 | def thread_inner(): 155 | async def loop_inner(): 156 | self._loop = asyncio.get_running_loop() 157 | self._stopping = asyncio.Event() 158 | is_ready.set() 159 | await self._stopping.wait() # wait until told to stop 160 | 161 | try: 162 | asyncio.run(loop_inner()) 163 | except RuntimeError as exc: 164 | # Python 3.12 raises a RuntimeError when new threads are created at shutdown. 165 | # Swallowing it here is innocuous, but ideally we will revisit this after 166 | # refactoring the shutdown handlers that modal uses to avoid triggering it. 167 | if "can't create new thread at interpreter shutdown" not in str(exc): 168 | raise exc 169 | 170 | self._owner_pid = os.getpid() 171 | thread = threading.Thread(target=thread_inner, daemon=True) 172 | thread.start() 173 | is_ready.wait() # TODO: this might block for a very short time 174 | self._thread = thread 175 | return self._loop 176 | 177 | def _close_loop(self): 178 | # Use getattr to protect against weird gc races when we get here via __del__ 179 | if getattr(self, "_thread", None) is not None: 180 | if not self._loop.is_closed(): 181 | # This also serves the purpose of waking up an idle loop 182 | self._loop.call_soon_threadsafe(self._stopping.set) 183 | self._thread.join() 184 | self._thread = None 185 | self._loop = None 186 | self._owner_pid = None 187 | 188 | def __del__(self): 189 | # TODO: this isn't reliably called, because self.create_blocking(self._ctx_mgr_cls) 190 | # creates a global reference to this Synchronizer which makes it never get gced 191 | self._close_loop() 192 | 193 | def _get_loop(self, start=False) -> asyncio.AbstractEventLoop: 194 | if self._thread and not self._thread.is_alive(): 195 | if self._owner_pid == os.getpid(): 196 | # warn - thread died without us forking 197 | raise RuntimeError("Synchronizer thread unexpectedly died") 198 | 199 | self._thread = None 200 | self._loop = None 201 | 202 | if self._loop is None and start: 203 | return self._start_loop() 204 | return self._loop 205 | 206 | def _get_running_loop(self): 207 | # TODO: delete this method 208 | try: 209 | return asyncio.get_running_loop() 210 | except RuntimeError: 211 | return 212 | 213 | def _is_inside_loop(self): 214 | loop = self._get_loop() 215 | if loop is None: 216 | return False 217 | if threading.current_thread() != self._thread: 218 | # gevent does something bad that causes asyncio.get_running_loop() to return self._loop 219 | return False 220 | current_loop = self._get_running_loop() 221 | return loop == current_loop 222 | 223 | def _wrap_check_async_leakage(self, coro): 224 | """Check if a coroutine returns another coroutine (or an async generator) and warn. 225 | 226 | The reason this is important to catch is that otherwise even synchronized code might end up 227 | "leaking" async code into the caller. 228 | """ 229 | if not self._async_leakage_warning: 230 | return coro 231 | 232 | async def coro_wrapped(): 233 | value = await coro 234 | # TODO: we should include the name of the original function here 235 | if inspect.iscoroutine(value): 236 | warnings.warn(f"Potential async leakage: coroutine returned a coroutine {value}.") 237 | elif inspect.isasyncgen(value): 238 | warnings.warn(f"Potential async leakage: Coroutine returned an async generator {value}.") 239 | return value 240 | 241 | return coro_wrapped() 242 | 243 | def _wrap_instance(self, obj): 244 | # Takes an object and creates a new proxy object for it 245 | cls = obj.__class__ 246 | cls_dct = cls.__dict__ 247 | wrapper_cls = cls_dct[self._wrapped_attr][Interface.BLOCKING] 248 | new_obj = wrapper_cls.__new__(wrapper_cls) 249 | # Store a reference to the original object 250 | new_obj.__dict__[self._original_attr] = obj 251 | new_obj.__dict__[SYNCHRONIZER_ATTR] = self 252 | return new_obj 253 | 254 | def _translate_scalar_in(self, obj): 255 | # If it's an external object, translate it to the internal type 256 | if hasattr(obj, "__dict__"): 257 | if inspect.isclass(obj): # TODO: functions? 258 | return obj.__dict__.get(self._original_attr, obj) 259 | else: 260 | return obj.__dict__.get(self._original_attr, obj) 261 | else: 262 | return obj 263 | 264 | def _translate_scalar_out(self, obj): 265 | # If it's an internal object, translate it to the external interface 266 | if inspect.isclass(obj): # TODO: functions? 267 | cls_dct = obj.__dict__ 268 | if self._wrapped_attr in cls_dct: 269 | return cls_dct[self._wrapped_attr][Interface.BLOCKING] 270 | else: 271 | return obj 272 | elif isinstance(obj, (typing.TypeVar, typing_extensions.ParamSpec)): 273 | if hasattr(obj, self._wrapped_attr): 274 | return getattr(obj, self._wrapped_attr)[Interface.BLOCKING] 275 | else: 276 | return obj 277 | else: 278 | cls_dct = obj.__class__.__dict__ 279 | if self._wrapped_attr in cls_dct: 280 | # This is an *instance* of a synchronized class, translate its type 281 | return self._wrap(obj, interface=Interface.BLOCKING) 282 | else: 283 | return obj 284 | 285 | def _recurse_map(self, mapper, obj): 286 | if type(obj) == list: # noqa: E721 287 | return list(self._recurse_map(mapper, item) for item in obj) 288 | elif type(obj) == tuple: # noqa: E721 289 | return tuple(self._recurse_map(mapper, item) for item in obj) 290 | elif type(obj) == dict: # noqa: E721 291 | return dict((key, self._recurse_map(mapper, item)) for key, item in obj.items()) 292 | else: 293 | return mapper(obj) 294 | 295 | def _translate_in(self, obj): 296 | return self._recurse_map(self._translate_scalar_in, obj) 297 | 298 | def _translate_out(self, obj, interface=None): 299 | # TODO: remove deprecated interface arg - not used but needs deprecation path in case of external usage 300 | return self._recurse_map(lambda scalar: self._translate_scalar_out(scalar), obj) 301 | 302 | def _translate_coro_out(self, coro, original_func): 303 | async def unwrap_coro(): 304 | res = await coro 305 | if getattr(original_func, self._output_translation_attr, True): 306 | return self._translate_out(res) 307 | return res 308 | 309 | return unwrap_coro() 310 | 311 | def _run_function_sync(self, coro, original_func): 312 | if self._is_inside_loop(): 313 | raise Exception("Deadlock detected: calling a sync function from the synchronizer loop") 314 | 315 | coro = wrap_coro_exception(coro) 316 | coro = self._wrap_check_async_leakage(coro) 317 | loop = self._get_loop(start=True) 318 | 319 | inner_task_fut = concurrent.futures.Future() 320 | 321 | async def wrapper_coro(): 322 | # this wrapper is needed since run_coroutine_threadsafe *only* accepts coroutines 323 | inner_task = loop.create_task(coro) 324 | inner_task_fut.set_result(inner_task) # sends the task itself to the origin thread 325 | return await inner_task 326 | 327 | fut = asyncio.run_coroutine_threadsafe(wrapper_coro(), loop) 328 | try: 329 | while 1: 330 | try: 331 | # repeated poll to give Windows a chance to abort on Ctrl-C 332 | value = fut.result(timeout=self._future_poll_interval) 333 | break 334 | except concurrent.futures.TimeoutError: 335 | pass 336 | except KeyboardInterrupt as exc: 337 | # in case there is a keyboard interrupt while we are waiting 338 | # we cancel the *underlying* coro_task (unlike what fut.cancel() would do) 339 | # and then wait for the *wrapper* coroutine to get a result back, which 340 | # happens after the cancellation resolves 341 | if inner_task_fut.done(): 342 | inner_task: asyncio.Task = inner_task_fut.result() 343 | loop.call_soon_threadsafe(inner_task.cancel) 344 | try: 345 | value = fut.result() 346 | except concurrent.futures.CancelledError as expected_cancellation: 347 | # we *expect* this cancellation, but defer to the passed coro to potentially 348 | # intercept and treat the cancellation some other way 349 | expected_cancellation.__suppress_context__ = True 350 | raise exc # if cancel - re-raise the original KeyboardInterrupt again 351 | 352 | if getattr(original_func, self._output_translation_attr, True): 353 | return self._translate_out(value) 354 | return value 355 | 356 | def _run_function_sync_future(self, coro, original_func): 357 | coro = wrap_coro_exception(coro) 358 | coro = self._wrap_check_async_leakage(coro) 359 | loop = self._get_loop(start=True) 360 | # For futures, we unwrap the result at this point, not in f_wrapped 361 | coro = unwrap_coro_exception(coro) 362 | coro = self._translate_coro_out(coro, original_func=original_func) 363 | return asyncio.run_coroutine_threadsafe(coro, loop) 364 | 365 | async def _run_function_async(self, coro, original_func): 366 | coro = wrap_coro_exception(coro) 367 | coro = self._wrap_check_async_leakage(coro) 368 | loop = self._get_loop(start=True) 369 | if self._is_inside_loop(): 370 | value = await coro 371 | else: 372 | inner_task_fut = concurrent.futures.Future() 373 | 374 | async def wrapper_coro(): 375 | inner_task = loop.create_task(coro) 376 | inner_task_fut.set_result(inner_task) # sends the task itself to the origin thread 377 | return await inner_task 378 | 379 | c_fut = asyncio.run_coroutine_threadsafe(wrapper_coro(), loop) 380 | a_fut = asyncio.wrap_future(c_fut) 381 | 382 | shielded_task = None 383 | try: 384 | while 1: 385 | # the loop + wait_for timeout is for windows ctrl-C compatibility since 386 | # windows doesn't truly interrupt the event loop on sigint 387 | try: 388 | # We create a task here to prevent an anonymous task inside asyncio.wait_for that could 389 | # get an unresolved timeout during cancellation handling below, resulting in a warning 390 | # traceback. 391 | shielded_task = asyncio.create_task( 392 | asyncio.wait_for( 393 | # inner shield prevents wait_for from cancelling a_fut on timeout 394 | asyncio.shield(a_fut), 395 | timeout=self._future_poll_interval, 396 | ) 397 | ) 398 | # The outer shield prevents a cancelled caller from cancelling a_fut directly 399 | # so that we can instead cancel the underlying coro_task and wait for it 400 | # to bubble back up as a CancelledError gracefully between threads 401 | # in order to run any cancellation logic in the coroutine 402 | value = await asyncio.shield(shielded_task) 403 | break 404 | except asyncio.TimeoutError: 405 | continue 406 | 407 | except asyncio.CancelledError: 408 | try: 409 | if a_fut.cancelled(): 410 | raise # cancellation came from within c_fut 411 | if inner_task_fut.done(): 412 | inner_task: asyncio.Task = inner_task_fut.result() 413 | loop.call_soon_threadsafe(inner_task.cancel) # cancel task on synchronizer event loop 414 | # wait for cancellation logic in the underlying coro to complete 415 | # this should typically raise CancelledError, but in case of either: 416 | # * cancellation prevention in the coro (catching the CancelledError) 417 | # * coro_task resolves before the call_soon_threadsafe above is scheduled 418 | # the cancellation in a_fut would be cancelled 419 | 420 | await a_fut # wait for cancellation logic to complete - this *normally* raises CancelledError 421 | raise # re-raise the CancelledError regardless - preventing unintended cancellation aborts 422 | finally: 423 | if shielded_task: 424 | shielded_task.cancel() # cancel the shielded task, preventing timeouts 425 | 426 | if getattr(original_func, self._output_translation_attr, True): 427 | return self._translate_out(value) 428 | return value 429 | 430 | def _run_generator_sync(self, gen, original_func): 431 | value, is_exc = None, False 432 | with suppress_synchronicity_tb_frames(): 433 | while True: 434 | try: 435 | if is_exc: 436 | value = self._run_function_sync(gen.athrow(value), original_func) 437 | else: 438 | value = self._run_function_sync(gen.asend(value), original_func) 439 | except UserCodeException as uc_exc: 440 | uc_exc.exc.__suppress_context__ = True 441 | raise uc_exc.exc 442 | except StopAsyncIteration: 443 | break 444 | 445 | try: 446 | value = yield value 447 | is_exc = False 448 | except BaseException as exc: 449 | value = exc 450 | is_exc = True 451 | 452 | async def _run_generator_async(self, gen, original_func): 453 | value, is_exc = None, False 454 | with suppress_synchronicity_tb_frames(): 455 | while True: 456 | try: 457 | if is_exc: 458 | value = await self._run_function_async(gen.athrow(value), original_func) 459 | else: 460 | value = await self._run_function_async(gen.asend(value), original_func) 461 | except UserCodeException as uc_exc: 462 | uc_exc.exc.__suppress_context__ = True 463 | raise uc_exc.exc 464 | except StopAsyncIteration: 465 | break 466 | 467 | try: 468 | value = yield value 469 | is_exc = False 470 | except BaseException as exc: 471 | value = exc 472 | is_exc = True 473 | 474 | def create_callback(self, f): 475 | return Callback(self, f) 476 | 477 | def _update_wrapper(self, f_wrapped, f, name=None, interface=None, target_module=None): 478 | """Very similar to functools.update_wrapper""" 479 | functools.update_wrapper(f_wrapped, f) 480 | if name is not None: 481 | f_wrapped.__name__ = name 482 | f_wrapped.__qualname__ = name 483 | if target_module is not None: 484 | f_wrapped.__module__ = target_module 485 | setattr(f_wrapped, SYNCHRONIZER_ATTR, self) 486 | setattr(f_wrapped, TARGET_INTERFACE_ATTR, interface) 487 | 488 | def _wrap_callable( 489 | self, 490 | f, 491 | interface, 492 | name=None, 493 | allow_futures=True, 494 | unwrap_user_excs=True, 495 | target_module=None, 496 | include_aio_interface=True, 497 | ): 498 | if hasattr(f, self._original_attr): 499 | if self._multiwrap_warning: 500 | warnings.warn(f"Function {f} is already wrapped, but getting wrapped again") 501 | return f 502 | 503 | if name is None: 504 | _name = DEFAULT_FUNCTION_PREFIXES[interface] + f.__name__ 505 | else: 506 | _name = name 507 | 508 | @wraps_by_interface(interface, f) 509 | def f_wrapped(*args, **kwargs): 510 | return_future = kwargs.pop(_RETURN_FUTURE_KWARG, False) 511 | 512 | # If this gets called with an argument that represents an external type, 513 | # translate it into an internal type 514 | if getattr(f, self._input_translation_attr, True): 515 | args = self._translate_in(args) 516 | kwargs = self._translate_in(kwargs) 517 | 518 | # Call the function 519 | res = f(*args, **kwargs) 520 | 521 | # Figure out if this is a coroutine or something 522 | is_coroutine = inspect.iscoroutine(res) 523 | is_asyncgen = inspect.isasyncgen(res) 524 | 525 | if return_future: 526 | if not allow_futures: 527 | raise Exception("Can not return future for this function") 528 | elif is_coroutine: 529 | return self._run_function_sync_future(res, f) 530 | elif is_asyncgen: 531 | raise Exception("Can not return futures for generators") 532 | else: 533 | return res 534 | elif is_coroutine: 535 | if interface == Interface._ASYNC_WITH_BLOCKING_TYPES: 536 | coro = self._run_function_async(res, f) 537 | if not is_coroutine_function_follow_wrapped(f): 538 | # If this is a non-async function that returns a coroutine, 539 | # then this is the exit point, and we need to unwrap any 540 | # wrapped exception here. Otherwise, the exit point is 541 | # in async_wrap.py 542 | coro = unwrap_coro_exception(coro) 543 | return coro 544 | elif interface == Interface.BLOCKING: 545 | # This is the exit point, so we need to unwrap the exception here 546 | try: 547 | return self._run_function_sync(res, f) 548 | except StopAsyncIteration as exc: 549 | # this is a special case for handling __next__ wrappers around 550 | # __anext__ that raises StopAsyncIteration 551 | raise StopIteration().with_traceback(exc.__traceback__) 552 | except UserCodeException as uc_exc: 553 | # Used to skip a frame when called from `proxy_method`. 554 | if unwrap_user_excs and not (Interface.BLOCKING and include_aio_interface): 555 | uc_exc.exc.__suppress_context__ = True 556 | raise uc_exc.exc 557 | else: 558 | raise uc_exc 559 | elif is_asyncgen: 560 | # Note that the _run_generator_* functions handle their own 561 | # unwrapping of exceptions (this happens during yielding) 562 | if interface == Interface._ASYNC_WITH_BLOCKING_TYPES: 563 | return self._run_generator_async(res, f) 564 | elif interface == Interface.BLOCKING: 565 | return self._run_generator_sync(res, f) 566 | else: 567 | if inspect.isfunction(res) or isinstance(res, functools.partial): # TODO: HACKY HACK 568 | # TODO: this is needed for decorator wrappers that returns functions 569 | # Maybe a bit of a hacky special case that deserves its own decorator 570 | @wraps_by_interface(interface, res) 571 | def f_wrapped(*args, **kwargs): 572 | args = self._translate_in(args) 573 | kwargs = self._translate_in(kwargs) 574 | f_res = res(*args, **kwargs) 575 | if getattr(f, self._output_translation_attr, True): 576 | return self._translate_out(f_res) 577 | else: 578 | return f_res 579 | 580 | return f_wrapped 581 | 582 | if getattr(f, self._output_translation_attr, True): 583 | return self._translate_out(res, interface) 584 | else: 585 | return res 586 | 587 | self._update_wrapper(f_wrapped, f, _name, interface, target_module=target_module) 588 | setattr(f_wrapped, self._original_attr, f) 589 | 590 | if interface == Interface.BLOCKING and include_aio_interface and should_have_aio_interface(f): 591 | # special async interface 592 | # this async interface returns *blocking* instances of wrapped objects, not async ones: 593 | async_interface = self._wrap_callable( 594 | f, 595 | interface=Interface._ASYNC_WITH_BLOCKING_TYPES, 596 | name=name, 597 | allow_futures=allow_futures, 598 | unwrap_user_excs=unwrap_user_excs, 599 | target_module=target_module, 600 | ) 601 | f_wrapped = FunctionWithAio(f_wrapped, async_interface, self) 602 | self._update_wrapper(f_wrapped, f, _name, interface, target_module=target_module) 603 | setattr(f_wrapped, self._original_attr, f) 604 | 605 | return f_wrapped 606 | 607 | def _wrap_proxy_method( 608 | synchronizer_self, 609 | method, 610 | interface, 611 | allow_futures=True, 612 | include_aio_interface=True, 613 | ): 614 | if getattr(method, synchronizer_self._nowrap_attr, None): 615 | # This method is marked as non-wrappable 616 | return method 617 | 618 | wrapped_method = synchronizer_self._wrap_callable( 619 | method, 620 | interface, 621 | allow_futures=allow_futures, 622 | unwrap_user_excs=False, 623 | ) 624 | 625 | @wraps_by_interface(interface, wrapped_method) 626 | def proxy_method(self, *args, **kwargs): 627 | instance = self.__dict__[synchronizer_self._original_attr] 628 | with suppress_synchronicity_tb_frames(): 629 | try: 630 | return wrapped_method(instance, *args, **kwargs) 631 | except UserCodeException as uc_exc: 632 | uc_exc.exc.__suppress_context__ = True 633 | raise uc_exc.exc 634 | 635 | if interface == Interface.BLOCKING and include_aio_interface and should_have_aio_interface(method): 636 | async_proxy_method = synchronizer_self._wrap_proxy_method( 637 | method, Interface._ASYNC_WITH_BLOCKING_TYPES, allow_futures 638 | ) 639 | return MethodWithAio(proxy_method, async_proxy_method, synchronizer_self) 640 | 641 | return proxy_method 642 | 643 | def _wrap_proxy_staticmethod(self, method, interface): 644 | orig_function = method.__func__ 645 | method = self._wrap_callable(orig_function, interface) 646 | if isinstance(method, FunctionWithAio): 647 | return method # no need to wrap a FunctionWithAio in a staticmethod, as it won't get bound anyways 648 | return staticmethod(method) 649 | 650 | def _wrap_proxy_classmethod(self, orig_classmethod, interface): 651 | orig_func = orig_classmethod.__func__ 652 | method = self._wrap_callable(orig_func, interface, include_aio_interface=False) 653 | 654 | if interface == Interface.BLOCKING and should_have_aio_interface(orig_func): 655 | async_method = self._wrap_callable(orig_func, Interface._ASYNC_WITH_BLOCKING_TYPES) 656 | return MethodWithAio(method, async_method, self, is_classmethod=True) 657 | 658 | return classmethod(method) 659 | 660 | def _wrap_proxy_property(self, prop, interface): 661 | kwargs = {} 662 | for attr in ["fget", "fset", "fdel"]: 663 | if getattr(prop, attr): 664 | func = getattr(prop, attr) 665 | kwargs[attr] = self._wrap_proxy_method( 666 | func, interface, allow_futures=False, include_aio_interface=False 667 | ) 668 | return property(**kwargs) 669 | 670 | def _wrap_proxy_classproperty(self, prop, interface): 671 | wrapped_func = self._wrap_proxy_method(prop.fget, interface, allow_futures=False, include_aio_interface=False) 672 | return classproperty(fget=wrapped_func) 673 | 674 | def _wrap_proxy_constructor(synchronizer_self, cls, interface): 675 | """Returns a custom __init__ for the subclass.""" 676 | 677 | def my_init(self, *args, **kwargs): 678 | # Create base instance 679 | args = synchronizer_self._translate_in(args) 680 | kwargs = synchronizer_self._translate_in(kwargs) 681 | instance = cls(*args, **kwargs) 682 | 683 | # Register self as the wrapped one 684 | interface_instances = {interface: self} 685 | instance.__dict__[synchronizer_self._wrapped_attr] = interface_instances 686 | 687 | # Store a reference to the original object 688 | self.__dict__[synchronizer_self._original_attr] = instance 689 | 690 | synchronizer_self._update_wrapper(my_init, cls.__init__, interface=interface) 691 | setattr(my_init, synchronizer_self._original_attr, cls.__init__) 692 | return my_init 693 | 694 | def _wrap_class(self, cls, interface, name, target_module=None): 695 | new_bases = [] 696 | for base in cls.__dict__.get("__orig_bases__", cls.__bases__): 697 | base_is_generic = hasattr(base, "__origin__") 698 | if base is object or (base_is_generic and base.__origin__ == typing.Generic): 699 | new_bases.append(base) # no need to wrap these, just add them as base classes 700 | else: 701 | if base_is_generic: 702 | wrapped_generic = self._wrap(base.__origin__, interface, require_already_wrapped=(name is not None)) 703 | new_bases.append(wrapped_generic.__class_getitem__(base.__args__)) 704 | else: 705 | new_bases.append(self._wrap(base, interface, require_already_wrapped=(name is not None))) 706 | 707 | bases = tuple(new_bases) 708 | new_dict = {self._original_attr: cls} 709 | if cls is not None: 710 | new_dict["__init__"] = self._wrap_proxy_constructor(cls, interface) 711 | 712 | for k, v in cls.__dict__.items(): 713 | if k in _BUILTIN_ASYNC_METHODS: 714 | k_sync = _BUILTIN_ASYNC_METHODS[k] 715 | new_dict[k_sync] = self._wrap_proxy_method( 716 | v, 717 | Interface.BLOCKING, 718 | allow_futures=False, 719 | include_aio_interface=False, 720 | ) 721 | new_dict[k] = self._wrap_proxy_method( 722 | v, 723 | Interface._ASYNC_WITH_BLOCKING_TYPES, 724 | allow_futures=False, 725 | ) 726 | elif k in ("__new__", "__init__"): 727 | # Skip custom constructor in the wrapped class 728 | # Instead, delegate to the base class constructor and wrap it 729 | pass 730 | elif k in IGNORED_ATTRIBUTES: 731 | pass 732 | elif isinstance(v, staticmethod): 733 | # TODO(erikbern): this feels pretty hacky 734 | new_dict[k] = self._wrap_proxy_staticmethod(v, Interface.BLOCKING) 735 | elif isinstance(v, classmethod): 736 | new_dict[k] = self._wrap_proxy_classmethod(v, Interface.BLOCKING) 737 | elif isinstance(v, property): 738 | new_dict[k] = self._wrap_proxy_property(v, Interface.BLOCKING) 739 | elif isinstance(v, classproperty): 740 | new_dict[k] = self._wrap_proxy_classproperty(v, Interface.BLOCKING) 741 | elif isinstance(v, MethodWithAio): 742 | # if library defines its own MethodWithAio descriptor we transfer it "as is" to the wrapper 743 | # without wrapping it again 744 | new_dict[k] = v 745 | elif callable(v): 746 | new_dict[k] = self._wrap_proxy_method(v, Interface.BLOCKING) 747 | 748 | if name is None: 749 | name = DEFAULT_CLASS_PREFIX + cls.__name__ 750 | 751 | new_cls = types.new_class(name, bases, exec_body=lambda ns: ns.update(new_dict)) 752 | new_cls.__module__ = cls.__module__ if target_module is None else target_module 753 | new_cls.__doc__ = cls.__doc__ 754 | if "__annotations__" in cls.__dict__: 755 | new_cls.__annotations__ = cls.__annotations__ # transfer annotations 756 | 757 | setattr(new_cls, SYNCHRONIZER_ATTR, self) 758 | return new_cls 759 | 760 | def _wrap( 761 | self, 762 | obj, 763 | interface, 764 | name=None, 765 | require_already_wrapped=False, 766 | target_module=None, 767 | ): 768 | # This method works for classes, functions, and instances 769 | # It wraps the object, and caches the wrapped object 770 | 771 | # Get the list of existing interfaces 772 | if hasattr(obj, "__dict__"): 773 | if self._wrapped_attr not in obj.__dict__: 774 | if isinstance(obj.__dict__, dict): 775 | # This works for instances 776 | obj.__dict__.setdefault(self._wrapped_attr, {}) 777 | else: 778 | # This works for classes & functions 779 | setattr(obj, self._wrapped_attr, {}) 780 | interfaces = obj.__dict__[self._wrapped_attr] 781 | else: 782 | # e.g., TypeVar in Python>=3.12 783 | if not hasattr(obj, self._wrapped_attr): 784 | setattr(obj, self._wrapped_attr, {}) 785 | interfaces = getattr(obj, self._wrapped_attr) 786 | 787 | # If this is already wrapped, return the existing interface 788 | if interface in interfaces: 789 | if self._multiwrap_warning: 790 | warnings.warn(f"Object {obj} is already wrapped, but getting wrapped again") 791 | return interfaces[interface] 792 | 793 | if require_already_wrapped: 794 | # This happens if a class has a custom name but its base class doesn't 795 | raise RuntimeError(f"{obj} needs to be serialized explicitly with a custom name") 796 | 797 | # Wrap object (different cases based on the type) 798 | if inspect.isclass(obj): 799 | new_obj = self._wrap_class( 800 | obj, 801 | interface, 802 | name, 803 | target_module=target_module, 804 | ) 805 | elif inspect.isfunction(obj): 806 | new_obj = self._wrap_callable(obj, interface, name, target_module=target_module) 807 | elif isinstance(obj, typing_extensions.ParamSpec): 808 | new_obj = self._wrap_param_spec(obj, interface, name, target_module) 809 | elif isinstance(obj, typing.TypeVar): 810 | new_obj = self._wrap_type_var(obj, interface, name, target_module) 811 | elif self._wrapped_attr in obj.__class__.__dict__: 812 | new_obj = self._wrap_instance(obj) 813 | else: 814 | raise Exception("Argument %s is not a class or a callable" % obj) 815 | 816 | # Store the interface on the obj and return 817 | interfaces[interface] = new_obj 818 | return new_obj 819 | 820 | def _wrap_type_var(self, obj, interface, name, target_module): 821 | # TypeVar translation is needed only for type stub generation, in case the 822 | # "bound" attribute refers to a translatable type. 823 | 824 | # Creates a new identical TypeVar, marked with synchronicity's special attributes 825 | # This lets type stubs "translate" the `bounds` attribute on emitted type vars 826 | # if picked up from module scope and in generics using the base implementation type 827 | 828 | # TODO(elias): Refactor - since this isn't used for live apps, move type stub generation into genstub 829 | new_obj = typing.TypeVar(name, bound=obj.__bound__) # noqa 830 | setattr(new_obj, self._original_attr, obj) 831 | setattr(new_obj, SYNCHRONIZER_ATTR, self) 832 | setattr(new_obj, TARGET_INTERFACE_ATTR, interface) 833 | new_obj.__module__ = target_module 834 | if not hasattr(obj, self._wrapped_attr): 835 | setattr(obj, self._wrapped_attr, {}) 836 | getattr(obj, self._wrapped_attr)[interface] = new_obj 837 | return new_obj 838 | 839 | def _wrap_param_spec(self, obj, interface, name, target_module): 840 | # TODO(elias): Refactor - since this isn't used for live apps, move type stub generation into genstub 841 | new_obj = typing_extensions.ParamSpec(name) # noqa 842 | setattr(new_obj, self._original_attr, obj) 843 | setattr(new_obj, SYNCHRONIZER_ATTR, self) 844 | setattr(new_obj, TARGET_INTERFACE_ATTR, interface) 845 | new_obj.__module__ = target_module 846 | if not hasattr(obj, self._wrapped_attr): 847 | setattr(obj, self._wrapped_attr, {}) 848 | getattr(obj, self._wrapped_attr)[interface] = new_obj 849 | return new_obj 850 | 851 | def nowrap(self, obj): 852 | setattr(obj, self._nowrap_attr, True) 853 | return obj 854 | 855 | def no_input_translation(self, obj): 856 | setattr(obj, self._input_translation_attr, False) 857 | return obj 858 | 859 | def no_output_translation(self, obj): 860 | setattr(obj, self._output_translation_attr, False) 861 | return obj 862 | 863 | def no_io_translation(self, obj): 864 | return self.no_input_translation(self.no_output_translation(obj)) 865 | 866 | # New interface that (almost) doesn't mutate objects 867 | def create_blocking(self, obj, name: Optional[str] = None, target_module: Optional[str] = None): 868 | # TODO: deprecate this alias method 869 | return self.wrap(obj, name, target_module) 870 | 871 | def wrap(self, obj, name: Optional[str] = None, target_module: Optional[str] = None): 872 | wrapped = self._wrap(obj, Interface.BLOCKING, name, target_module=target_module) 873 | return wrapped 874 | 875 | def is_synchronized(self, obj): 876 | if inspect.isclass(obj) or inspect.isfunction(obj): 877 | return hasattr(obj, self._original_attr) 878 | else: 879 | return hasattr(obj.__class__, self._original_attr) 880 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modal-labs/synchronicity/731a20bbe9bff1d973eb8bbb87f73f2e9d4bad1a/test/__init__.py -------------------------------------------------------------------------------- /test/async_wrap_test.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | 4 | from synchronicity import async_wrap 5 | from synchronicity.async_wrap import wraps_by_interface 6 | from synchronicity.interface import Interface 7 | from synchronicity.synchronizer import FunctionWithAio 8 | 9 | 10 | def test_wrap_corofunc_using_async(): 11 | async def foo(): 12 | pass 13 | 14 | @wraps_by_interface(Interface._ASYNC_WITH_BLOCKING_TYPES, foo) 15 | async def bar(): 16 | pass 17 | 18 | assert inspect.iscoroutinefunction(bar) 19 | 20 | 21 | def test_wrap_corofunc_using_non_async(): 22 | async def foo(): 23 | pass 24 | 25 | @wraps_by_interface(Interface._ASYNC_WITH_BLOCKING_TYPES, foo) 26 | def bar(): 27 | pass 28 | 29 | assert inspect.iscoroutinefunction(bar) 30 | 31 | 32 | def test_wrap_asynccontextmanager_annotations(): 33 | @async_wrap.asynccontextmanager # this would not work with contextlib.asynccontextmanager 34 | async def foo() -> typing.AsyncGenerator[int, None]: ... 35 | 36 | assert foo.__annotations__["return"] == typing.AsyncContextManager[int] 37 | 38 | 39 | def test_wrap_staticmethod(synchronizer): 40 | class Foo: 41 | @staticmethod 42 | async def a_static_method() -> typing.Awaitable[str]: 43 | async def wrapped(): 44 | return "hello" 45 | 46 | return wrapped() 47 | 48 | BlockingFoo = synchronizer.create_blocking(Foo) 49 | 50 | assert isinstance(BlockingFoo.__dict__["a_static_method"], FunctionWithAio) 51 | assert not inspect.iscoroutinefunction(BlockingFoo.__dict__["a_static_method"]._func) 52 | assert inspect.iscoroutinefunction(BlockingFoo.__dict__["a_static_method"].aio) 53 | -------------------------------------------------------------------------------- /test/asynccontextmanager_test.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import pytest 3 | import typing 4 | 5 | from synchronicity.async_wrap import asynccontextmanager 6 | 7 | 8 | async def noop(): 9 | pass 10 | 11 | 12 | async def error(): 13 | raise Exception("problem") 14 | 15 | 16 | class Resource: 17 | def __init__(self): 18 | self.state = "none" 19 | 20 | def get_state(self): 21 | return self.state 22 | 23 | @asynccontextmanager 24 | async def wrap(self) -> typing.AsyncGenerator[None, None]: 25 | self.state = "entered" 26 | try: 27 | yield 28 | finally: 29 | self.state = "exited" 30 | 31 | @asynccontextmanager 32 | async def wrap_yield_twice(self) -> typing.AsyncGenerator[None, None]: 33 | yield 34 | yield 35 | 36 | @asynccontextmanager 37 | async def wrap_never_yield(self) -> typing.AsyncGenerator[None, None]: 38 | if False: 39 | yield 40 | 41 | 42 | def test_asynccontextmanager_sync(synchronizer): 43 | r = synchronizer.create_blocking(Resource)() 44 | assert r.get_state() == "none" 45 | with r.wrap(): 46 | assert r.get_state() == "entered" 47 | assert r.get_state() == "exited" 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_asynccontextmanager_async(synchronizer): 52 | r = synchronizer.create_blocking(Resource)() 53 | assert r.get_state() == "none" 54 | async with r.wrap(): 55 | assert r.get_state() == "entered" 56 | assert r.get_state() == "exited" 57 | 58 | 59 | @pytest.mark.asyncio 60 | async def test_asynccontextmanager_async_raise(synchronizer): 61 | r = synchronizer.create_blocking(Resource)() 62 | assert r.get_state() == "none" 63 | with pytest.raises(Exception): 64 | async with r.wrap(): 65 | assert r.get_state() == "entered" 66 | raise Exception("boom") 67 | assert r.get_state() == "exited" 68 | 69 | 70 | @pytest.mark.asyncio 71 | async def test_asynccontextmanager_yield_twice(synchronizer): 72 | r = synchronizer.create_blocking(Resource)() 73 | with pytest.raises(RuntimeError): 74 | async with r.wrap_yield_twice(): 75 | pass 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_asynccontextmanager_never_yield(synchronizer): 80 | r = synchronizer.create_blocking(Resource)() 81 | with pytest.raises(RuntimeError): 82 | async with r.wrap_never_yield(): 83 | pass 84 | 85 | 86 | @pytest.mark.asyncio 87 | async def test_asynccontextmanager_nested(synchronizer): 88 | finally_blocks = [] 89 | 90 | @synchronizer.create_blocking 91 | @asynccontextmanager 92 | async def a() -> typing.AsyncGenerator[str, None]: 93 | try: 94 | yield "foo" 95 | finally: 96 | finally_blocks.append("A") 97 | 98 | @synchronizer.create_blocking 99 | @asynccontextmanager 100 | async def b() -> typing.AsyncGenerator[str, None]: 101 | async with a() as it: 102 | try: 103 | yield it 104 | finally: 105 | finally_blocks.append("B") 106 | 107 | with pytest.raises(BaseException): 108 | async with b(): 109 | raise BaseException("boom!") 110 | 111 | assert finally_blocks == ["B", "A"] 112 | 113 | 114 | @pytest.mark.asyncio 115 | async def test_asynccontextmanager_with_in_async(synchronizer): 116 | r = synchronizer.create_blocking(Resource)() 117 | # err_cls = AttributeError if sys.version_info < (3, 11) else TypeError 118 | # with pytest.raises(err_cls): 119 | with r.wrap.aio(): # TODO: this *should* not be allowed, but works for stupid reasons 120 | pass 121 | 122 | 123 | @pytest.mark.asyncio 124 | async def test_returning_context_manager(synchronizer): 125 | @contextlib.asynccontextmanager 126 | async def foo(): 127 | yield "hello" 128 | 129 | @synchronizer.wrap 130 | def returner() -> typing.AsyncContextManager[str]: 131 | return foo() 132 | 133 | with returner(): 134 | pass 135 | 136 | async with returner.aio(): 137 | pass 138 | -------------------------------------------------------------------------------- /test/callback_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | import time 4 | 5 | 6 | def sleep(ms): 7 | time.sleep(ms / 1000) 8 | return ms 9 | 10 | 11 | async def sleep_async(ms): 12 | time.sleep(ms / 1000) 13 | return ms 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_blocking(synchronizer): 18 | sleep_cb = synchronizer.create_callback(sleep) 19 | t0 = time.time() 20 | coros = [sleep_cb(200), sleep_cb(300), sleep_cb(300), sleep_cb(300)] 21 | rets = await asyncio.gather(*coros) 22 | assert rets == [200, 300, 300, 300] 23 | assert 0.3 <= time.time() - t0 < 0.5 # make sure they run in parallel 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_async(synchronizer): 28 | sleep_cb = synchronizer.create_callback(sleep_async) 29 | t0 = time.time() 30 | coros = [sleep_cb(200), sleep_cb(300), sleep_cb(300), sleep_cb(300)] 31 | rets = await asyncio.gather(*coros) 32 | assert rets == [200, 300, 300, 300] 33 | assert 0.3 <= time.time() - t0 <= 0.5 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_translate(synchronizer): 38 | class Foo: 39 | def __init__(self, x): 40 | self.x = x 41 | 42 | def get(self): 43 | return self.x 44 | 45 | BlockingFoo = synchronizer.create_blocking(Foo) 46 | 47 | def f(foo): 48 | assert isinstance(foo, BlockingFoo) 49 | x = foo.get() 50 | return BlockingFoo(x + 1) 51 | 52 | f_cb = synchronizer.create_callback(f) 53 | 54 | foo1 = Foo(42) 55 | foo2 = await f_cb(foo1) 56 | assert foo2.x == 43 57 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from synchronicity import Synchronizer 4 | 5 | 6 | @pytest.fixture(autouse=True) 7 | def use_asyncio_debug(monkeypatch): 8 | monkeypatch.setenv("PYTHONASYNCIODEBUG", "1") 9 | 10 | 11 | @pytest.fixture() 12 | def synchronizer(use_asyncio_debug): 13 | s = Synchronizer() 14 | yield s 15 | s._close_loop() # avoid "unclosed event loop" warnings in tests when garbage collecting synchronizers 16 | -------------------------------------------------------------------------------- /test/docstring_test.py: -------------------------------------------------------------------------------- 1 | def test_docs(synchronizer): 2 | class Foo: 3 | def __init__(self): 4 | """init docs""" 5 | self._attrs = {} 6 | 7 | async def bar(self): 8 | """bar docs""" 9 | 10 | foo = Foo() 11 | assert foo.__init__.__doc__ == "init docs" 12 | assert foo.bar.__doc__ == "bar docs" 13 | 14 | BlockingFoo = synchronizer.create_blocking(Foo) 15 | blocking_foo = BlockingFoo() 16 | assert blocking_foo.__init__.__doc__ == "init docs" 17 | assert blocking_foo.bar.__doc__ == "bar docs" 18 | 19 | assert blocking_foo.bar.aio.__doc__ == "bar docs" 20 | -------------------------------------------------------------------------------- /test/exception_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests exceptions thrown from functions wrapped by Synchronicity. 3 | 4 | Currently, exceptions are thrown from Synchronicity like so: 5 | 6 | try: 7 | return self._func(*args, **kwargs) 8 | except UserCodeException as uc_exc: 9 | uc_exc.exc.__suppress_context__ = True 10 | raise uc_exc.exc 11 | 12 | When we raise an exception, the exception context is from Synchronicity, which 13 | may confuse users. Therefore, we set __suppress_context__ to True to avoid 14 | showing the user those error messages. This will preserve uc_exc.exc.__cause__, 15 | but will cause uc_exc.exc.__context__ to be lost. Unfortunately, I don't know 16 | how to avoid that. 17 | 18 | These tests ensure that the __cause__ of an user exception is not lost, and 19 | that either __suppress_context__ is True or __context__ is None so that users 20 | are not exposed to confusing Synchronicity error messages. 21 | 22 | See https://github.com/modal-labs/synchronicity/pull/165 for more details. 23 | """ 24 | 25 | import asyncio 26 | import concurrent 27 | import functools 28 | import inspect 29 | import pytest 30 | import sys 31 | import time 32 | import traceback 33 | import typing 34 | from pathlib import Path 35 | 36 | SLEEP_DELAY = 0.1 37 | WINDOWS_TIME_RESOLUTION_FIX = 0.01 if sys.platform == "win32" else 0.0 38 | 39 | 40 | class CustomExceptionCause(Exception): 41 | pass 42 | 43 | 44 | class CustomException(Exception): 45 | pass 46 | 47 | 48 | async def f_raises(exc): 49 | await asyncio.sleep(SLEEP_DELAY) 50 | raise exc 51 | 52 | 53 | async def f_raises_with_cause(): 54 | await asyncio.sleep(SLEEP_DELAY) 55 | raise CustomException("something failed") from CustomExceptionCause("exception cause") 56 | 57 | 58 | def test_function_raises_sync(synchronizer): 59 | t0 = time.monotonic() 60 | with pytest.raises(CustomException) as exc: 61 | f_raises_s = synchronizer.create_blocking(f_raises) 62 | f_raises_s(CustomException("something failed")) 63 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 64 | assert exc.value.__suppress_context__ or exc.value.__context__ is None 65 | 66 | 67 | def test_function_raises_with_cause_sync(synchronizer): 68 | t0 = time.monotonic() 69 | with pytest.raises(CustomException) as exc: 70 | f_raises_s = synchronizer.create_blocking(f_raises_with_cause) 71 | f_raises_s() 72 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 73 | assert isinstance(exc.value.__cause__, CustomExceptionCause) 74 | 75 | 76 | def test_function_raises_sync_futures(synchronizer): 77 | t0 = time.monotonic() 78 | f_raises_s = synchronizer.create_blocking(f_raises) 79 | fut = f_raises_s(CustomException("something failed"), _future=True) 80 | assert isinstance(fut, concurrent.futures.Future) 81 | assert time.monotonic() - t0 < SLEEP_DELAY 82 | with pytest.raises(CustomException) as exc: 83 | fut.result() 84 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 85 | assert exc.value.__suppress_context__ or exc.value.__context__ is None 86 | 87 | 88 | def test_function_raises_with_cause_sync_futures(synchronizer): 89 | t0 = time.monotonic() 90 | f_raises_s = synchronizer.create_blocking(f_raises_with_cause) 91 | fut = f_raises_s(_future=True) 92 | assert isinstance(fut, concurrent.futures.Future) 93 | assert time.monotonic() - t0 < SLEEP_DELAY 94 | with pytest.raises(CustomException) as exc: 95 | fut.result() 96 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 97 | assert isinstance(exc.value.__cause__, CustomExceptionCause) 98 | 99 | 100 | @pytest.mark.asyncio 101 | async def test_function_raises_async(synchronizer): 102 | t0 = time.monotonic() 103 | f_raises_s = synchronizer.create_blocking(f_raises) 104 | coro = f_raises_s.aio(CustomException("something failed")) 105 | assert inspect.iscoroutine(coro) 106 | assert time.monotonic() - t0 < SLEEP_DELAY 107 | with pytest.raises(CustomException) as exc: 108 | await coro 109 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 110 | assert exc.value.__suppress_context__ or exc.value.__context__ is None 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_function_raises_with_cause_async(synchronizer): 115 | t0 = time.monotonic() 116 | f_raises_s = synchronizer.create_blocking(f_raises_with_cause) 117 | coro = f_raises_s.aio() 118 | assert inspect.iscoroutine(coro) 119 | assert time.monotonic() - t0 < SLEEP_DELAY 120 | with pytest.raises(CustomException) as exc: 121 | await coro 122 | dur = time.monotonic() - t0 123 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= dur < 2 * SLEEP_DELAY 124 | assert isinstance(exc.value.__cause__, CustomExceptionCause) 125 | 126 | 127 | async def f_raises_baseexc(): 128 | await asyncio.sleep(SLEEP_DELAY) 129 | raise KeyboardInterrupt 130 | 131 | 132 | def test_function_raises_baseexc_sync(synchronizer): 133 | t0 = time.monotonic() 134 | with pytest.raises(BaseException) as exc: 135 | f_raises_baseexc_s = synchronizer.create_blocking(f_raises_baseexc) 136 | f_raises_baseexc_s() 137 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 138 | assert exc.value.__suppress_context__ or exc.value.__context__ is None 139 | 140 | 141 | def f_raises_syncwrap() -> typing.Coroutine[typing.Any, typing.Any, None]: 142 | return f_raises(CustomException("something failed")) # returns a coro 143 | 144 | 145 | @pytest.mark.asyncio 146 | async def test_function_raises_async_syncwrap(synchronizer): 147 | t0 = time.monotonic() 148 | f_raises_syncwrap_s = synchronizer.create_blocking(f_raises_syncwrap) 149 | coro = f_raises_syncwrap_s.aio() 150 | assert inspect.iscoroutine(coro) 151 | assert time.monotonic() - t0 < SLEEP_DELAY 152 | with pytest.raises(CustomException) as exc: 153 | await coro 154 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 155 | assert exc.value.__suppress_context__ or exc.value.__context__ is None 156 | 157 | 158 | def f_raises_with_cause_syncwrap() -> typing.Coroutine[typing.Any, typing.Any, None]: 159 | return f_raises_with_cause() # returns a coro 160 | 161 | 162 | @pytest.mark.asyncio 163 | async def test_function_raises_with_cause_async_syncwrap(synchronizer): 164 | t0 = time.monotonic() 165 | f_raises_syncwrap_s = synchronizer.create_blocking(f_raises_with_cause_syncwrap) 166 | coro = f_raises_syncwrap_s.aio() 167 | assert inspect.iscoroutine(coro) 168 | assert time.monotonic() - t0 < SLEEP_DELAY 169 | with pytest.raises(CustomException) as exc: 170 | await coro 171 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 172 | assert isinstance(exc.value.__cause__, CustomExceptionCause) 173 | 174 | 175 | def decorator(f): 176 | @functools.wraps(f) 177 | def wrapper(*args, **kwargs): 178 | return f(*args, **kwargs) 179 | 180 | return wrapper 181 | 182 | 183 | f_raises_wrapped = decorator(f_raises) 184 | 185 | 186 | @pytest.mark.asyncio 187 | async def test_wrapped_function_raises_async(synchronizer): 188 | t0 = time.monotonic() 189 | f_raises_s = synchronizer.create_blocking(f_raises_wrapped) 190 | coro = f_raises_s.aio(CustomException("something failed")) 191 | assert inspect.iscoroutine(coro) 192 | assert time.monotonic() - t0 < SLEEP_DELAY 193 | with pytest.raises(CustomException) as exc: 194 | await coro 195 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 196 | assert exc.value.__suppress_context__ or exc.value.__context__ is None 197 | 198 | 199 | class CustomBaseException(BaseException): 200 | pass 201 | 202 | 203 | @pytest.mark.parametrize( 204 | "exc", 205 | [ 206 | Exception("foo"), 207 | CustomException("bar"), 208 | KeyboardInterrupt(), 209 | SystemExit(), 210 | CustomBaseException(), 211 | TimeoutError(), 212 | ], 213 | ) 214 | def test_raising_various_exceptions(exc, synchronizer): 215 | f_raises_s = synchronizer.wrap(f_raises) 216 | with pytest.raises(type(exc)) as exc_info: 217 | f_raises_s(exc) 218 | full_tb = "\n".join(traceback.format_tb(exc_info.tb)) 219 | import synchronicity 220 | 221 | if sys.version_info >= (3, 11): 222 | # basic traceback improvement tests - there are more tests in traceback_test.py 223 | assert str(Path(synchronicity.__file__).parent) not in full_tb 224 | -------------------------------------------------------------------------------- /test/fork_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import subprocess 3 | import sys 4 | from pathlib import Path 5 | 6 | 7 | @pytest.mark.skipif(sys.platform == "win32", reason="Windows can't fork") 8 | def test_fork_restarts_loop(): 9 | with subprocess.Popen( 10 | [sys.executable, Path(__file__).parent / "support" / "_forker.py"], 11 | encoding="utf8", 12 | stdout=subprocess.PIPE, 13 | stderr=subprocess.PIPE, 14 | ) as p: 15 | try: 16 | stdout, stderr = p.communicate(timeout=2) 17 | except subprocess.TimeoutExpired: 18 | p.kill() 19 | assert False, "Fork process hanged" 20 | 21 | assert p.returncode == 0 22 | assert stdout == "done\ndone\n" 23 | -------------------------------------------------------------------------------- /test/generators_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | 4 | events = [] 5 | 6 | 7 | async def async_producer(): 8 | for i in range(10): 9 | events.append("producer") 10 | yield i 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_generator_order_async(synchronizer): 15 | events.clear() 16 | async_producer_synchronized = synchronizer.create_blocking(async_producer) 17 | async for i in async_producer_synchronized.aio(): 18 | events.append("consumer") 19 | assert events == ["producer", "consumer"] * 10 20 | 21 | 22 | @pytest.mark.asyncio 23 | async def test_generator_order_explicit_async(synchronizer): 24 | events.clear() 25 | async_producer_synchronized = synchronizer.create_blocking(async_producer) 26 | async for i in async_producer_synchronized.aio(): 27 | events.append("consumer") 28 | assert events == ["producer", "consumer"] * 10 29 | 30 | 31 | def test_generator_order_sync(synchronizer): 32 | events.clear() 33 | async_producer_synchronized = synchronizer.create_blocking(async_producer) 34 | for i in async_producer_synchronized(): 35 | events.append("consumer") 36 | assert events == ["producer", "consumer"] * 10 37 | 38 | 39 | async def async_bidirectional_producer(i): 40 | j = yield i 41 | assert j == i**2 42 | 43 | 44 | @pytest.mark.asyncio 45 | async def test_bidirectional_generator_async(synchronizer): 46 | f = synchronizer.create_blocking(async_bidirectional_producer) 47 | gen = f.aio(42) 48 | value = await gen.asend(None) 49 | assert value == 42 50 | with pytest.raises(StopAsyncIteration): 51 | await gen.asend(42 * 42) 52 | 53 | 54 | def test_bidirectional_generator_sync(synchronizer): 55 | f = synchronizer.create_blocking(async_bidirectional_producer) 56 | gen = f(42) 57 | value = gen.send(None) 58 | assert value == 42 59 | with pytest.raises(StopIteration): 60 | gen.send(42 * 42) 61 | 62 | 63 | async def athrow_example_gen(): 64 | try: 65 | await asyncio.sleep(0.1) 66 | yield "hello" 67 | except ZeroDivisionError: 68 | await asyncio.sleep(0.2) 69 | yield "world" 70 | except BaseException: 71 | yield "foobar" 72 | 73 | 74 | @pytest.mark.asyncio 75 | async def test_athrow_async(synchronizer): 76 | gen = synchronizer.create_blocking(athrow_example_gen).aio() 77 | v = await gen.asend(None) 78 | assert v == "hello" 79 | v = await gen.athrow(ZeroDivisionError) 80 | assert v == "world" 81 | await gen.aclose() 82 | 83 | 84 | def test_athrow_sync(synchronizer): 85 | gen = synchronizer.create_blocking(athrow_example_gen)() 86 | v = gen.send(None) 87 | assert v == "hello" 88 | v = gen.throw(ZeroDivisionError) 89 | assert v == "world" 90 | 91 | 92 | @pytest.mark.asyncio 93 | async def test_athrow_baseexc_async(synchronizer): 94 | gen = synchronizer.create_blocking(athrow_example_gen).aio() 95 | v = await gen.asend(None) 96 | assert v == "hello" 97 | v = await gen.athrow(KeyboardInterrupt) 98 | assert v == "foobar" 99 | 100 | 101 | def test_athrow_baseexc_sync(synchronizer): 102 | gen = synchronizer.create_blocking(athrow_example_gen)() 103 | v = gen.send(None) 104 | assert v == "hello" 105 | v = gen.throw(KeyboardInterrupt) 106 | assert v == "foobar" 107 | 108 | 109 | async def ensure_stop_async_iteration(): 110 | try: 111 | yield 42 112 | yield 43 113 | except BaseException as exc: 114 | events.append(exc) 115 | 116 | 117 | def test_ensure_stop_async_iteration(synchronizer): 118 | events.clear() 119 | 120 | def create_generator(): 121 | gen_f = synchronizer.create_blocking(ensure_stop_async_iteration) 122 | for x in gen_f(): 123 | break 124 | 125 | create_generator() 126 | assert len(events) == 1 127 | assert isinstance(events[0], GeneratorExit) 128 | 129 | 130 | class MyGenerator: 131 | def __aiter__(self): 132 | return async_producer() 133 | 134 | 135 | def test_custom_generator(synchronizer): 136 | events.clear() 137 | BlockingMyGenerator = synchronizer.create_blocking(MyGenerator) 138 | blocking_my_generator = BlockingMyGenerator() 139 | for x in blocking_my_generator: 140 | events.append("consumer") 141 | assert events == ["producer", "consumer"] * 10 142 | -------------------------------------------------------------------------------- /test/getattr_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | from typing import Any, Dict 4 | 5 | from synchronicity import classproperty 6 | 7 | 8 | @pytest.mark.asyncio 9 | async def test_getattr(synchronizer): 10 | class Foo: 11 | _attrs: Dict[str, Any] 12 | 13 | def __init__(self): 14 | self._attrs = {} 15 | 16 | async def __getattr__(self, k): 17 | await asyncio.sleep(0.01) 18 | return self._attrs[k] 19 | 20 | def __setattr__(self, k, v): 21 | if k in self.__annotations__: 22 | # Only needed because the constructor sets _attrs 23 | self.__dict__[k] = v 24 | else: 25 | self._attrs[k] = v 26 | 27 | @property 28 | def z(self): 29 | return self._attrs["x"] 30 | 31 | @staticmethod 32 | def make_foo(): 33 | return Foo() 34 | 35 | @classproperty 36 | def my_cls_prop(cls): 37 | return "abc" 38 | 39 | @classproperty 40 | async def another_cls_prop(cls): 41 | await asyncio.sleep(0.01) 42 | return "another-cls-prop" 43 | 44 | foo = Foo() 45 | foo.x = 42 46 | assert await foo.x == 42 47 | with pytest.raises(KeyError): 48 | await foo.y 49 | assert foo.z == 42 50 | assert Foo.my_cls_prop == "abc" 51 | assert await Foo.another_cls_prop == "another-cls-prop" 52 | 53 | BlockingFoo = synchronizer.create_blocking(Foo) 54 | 55 | blocking_foo = BlockingFoo() 56 | blocking_foo.x = 43 57 | assert blocking_foo.x == 43 58 | with pytest.raises(KeyError): 59 | blocking_foo.y 60 | assert blocking_foo.z == 43 61 | assert BlockingFoo.my_cls_prop == "abc" 62 | assert BlockingFoo.another_cls_prop == "another-cls-prop" 63 | 64 | blocking_foo = BlockingFoo.make_foo() 65 | blocking_foo.x = 44 66 | assert isinstance(blocking_foo, BlockingFoo) 67 | 68 | # TODO: there is no longer a way to make async properties, but there is this w/ async __getattr__: 69 | assert await blocking_foo.__getattr__.aio("x") == 44 70 | -------------------------------------------------------------------------------- /test/gevent_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import subprocess 3 | import sys 4 | from pathlib import Path 5 | 6 | 7 | @pytest.mark.skipif(sys.version_info >= (3, 13), reason="gevent seems broken on Python 3.13") 8 | @pytest.mark.skipif( 9 | sys.platform == "win32", reason="gevent support broken on Windows, probably due to event loop patching" 10 | ) 11 | def test_gevent(): 12 | # Run it in a separate process because gevent modifies a lot of modules 13 | fn = Path(__file__).parent / "support" / "_gevent.py" 14 | ret = subprocess.run([sys.executable, fn], stdout=sys.stdout, stderr=sys.stderr, timeout=5) 15 | assert ret.returncode == 0 16 | -------------------------------------------------------------------------------- /test/helper_methods_test.py: -------------------------------------------------------------------------------- 1 | def test_is_synchronized(synchronizer): 2 | class Foo: 3 | pass 4 | 5 | BlockingFoo = synchronizer.create_blocking(Foo) 6 | assert synchronizer.is_synchronized(Foo) is False 7 | assert synchronizer.is_synchronized(BlockingFoo) is True 8 | -------------------------------------------------------------------------------- /test/inspect_test.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | class _Api: 5 | def blocking_func(self): 6 | pass 7 | 8 | async def async_func(self): 9 | pass 10 | 11 | 12 | def test_inspect_coroutinefunction(synchronizer): 13 | BlockingApi = synchronizer.create_blocking(_Api) 14 | 15 | assert inspect.iscoroutinefunction(BlockingApi.blocking_func) is False 16 | assert inspect.iscoroutinefunction(BlockingApi.async_func) is False 17 | assert hasattr(BlockingApi.blocking_func, "aio") is False 18 | assert inspect.iscoroutinefunction(BlockingApi.async_func.aio) is True 19 | -------------------------------------------------------------------------------- /test/nowrap_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | 4 | 5 | def test_nowrap(synchronizer): 6 | @synchronizer.create_blocking 7 | class MyClass: 8 | async def f(self, x): 9 | await asyncio.sleep(0.2) 10 | return x**2 11 | 12 | @synchronizer.nowrap 13 | def g(self, x): 14 | # This runs on the wrapped class 15 | return self.f(x) * x # calls the blocking function 16 | 17 | my_obj = MyClass() 18 | 19 | t0 = time.time() 20 | assert my_obj.f(111) == 12321 21 | assert 0.15 < time.time() - t0 < 0.25 22 | 23 | t0 = time.time() 24 | assert my_obj.g(111) == 1367631 25 | assert 0.15 < time.time() - t0 < 0.25 26 | -------------------------------------------------------------------------------- /test/pickle_test.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pytest 3 | 4 | from synchronicity.interface import Interface 5 | 6 | 7 | class PicklableClass: 8 | async def f(self, x): 9 | return x**2 10 | 11 | 12 | @pytest.mark.skip(reason="Let's revisit this in 0.2.0") 13 | def test_pickle(synchronizer): 14 | BlockingPicklableClass = synchronizer.create(PicklableClass, Interface.BLOCKING) 15 | obj = BlockingPicklableClass() 16 | assert obj.f(42) == 1764 17 | data = pickle.dumps(obj) 18 | obj2 = pickle.loads(data) 19 | assert obj2.f(43) == 1849 20 | 21 | 22 | def test_pickle_synchronizer(synchronizer): 23 | pickle.dumps(synchronizer) 24 | -------------------------------------------------------------------------------- /test/shutdown_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import signal 3 | import subprocess 4 | import sys 5 | from pathlib import Path 6 | 7 | 8 | class PopenWithCtrlC(subprocess.Popen): 9 | def __init__(self, *args, creationflags=0, **kwargs): 10 | if sys.platform == "win32": 11 | # needed on windows to separate ctrl-c lifecycle of subprocess from parent: 12 | creationflags = creationflags | subprocess.CREATE_NEW_CONSOLE # type: ignore 13 | 14 | super().__init__(*args, **kwargs, creationflags=creationflags) 15 | 16 | def send_ctrl_c(self): 17 | # platform independent way to replicate the behavior of Ctrl-C:ing a cli app 18 | if sys.platform == "win32": 19 | # windows doesn't support sigint, and subprocess.CTRL_C_EVENT has a bunch 20 | # of gotchas since it's bound to a console which is the same for the parent 21 | # process by default, and can't be sent using the python standard library 22 | # to a separate process's console 23 | import console_ctrl 24 | 25 | console_ctrl.send_ctrl_c(self.pid) # noqa [E731] 26 | else: 27 | self.send_signal(signal.SIGINT) 28 | 29 | 30 | def test_shutdown(): 31 | # We run it in a separate process so we can simulate interrupting it 32 | fn = Path(__file__).parent / "support" / "_shutdown.py" 33 | with PopenWithCtrlC( 34 | [sys.executable, "-u", fn], stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8" 35 | ) as p: 36 | for i in range(2): # this number doesn't matter, it's a while loop 37 | assert p.stdout.readline() == "running\n" 38 | p.send_ctrl_c() 39 | for i in range(2): 40 | # in some extreme cases there is a risk of a race where the "running" still appears here 41 | if p.stdout.readline() == "cancelled\n": 42 | break 43 | else: 44 | assert False 45 | 46 | assert p.stdout.readline() == "handled cancellation\n" 47 | assert p.stdout.readline() == "exit async\n" 48 | assert ( 49 | p.stdout.readline() == "keyboard interrupt\n" 50 | ) # we want the keyboard interrupt to come *after* the running function has been cancelled! 51 | 52 | assert p.stderr.read().strip() == "" 53 | 54 | 55 | def test_keyboard_interrupt_reraised_as_is(synchronizer): 56 | @synchronizer.create_blocking 57 | async def a(): 58 | raise KeyboardInterrupt() 59 | 60 | with pytest.raises(KeyboardInterrupt): 61 | a() 62 | 63 | 64 | def test_shutdown_during_ctx_mgr_setup(): 65 | # We run it in a separate process so we can simulate interrupting it 66 | fn = Path(__file__).parent / "support" / "_shutdown_ctx_mgr.py" 67 | with PopenWithCtrlC( 68 | [sys.executable, "-u", fn, "enter"], 69 | stdout=subprocess.PIPE, 70 | stderr=subprocess.PIPE, 71 | encoding="utf8", 72 | ) as p: 73 | for i in range(2): # this number doesn't matter, it's a while loop 74 | assert p.stdout.readline() == "enter\n" 75 | p.send_ctrl_c() 76 | assert p.stdout.readline() == "exit\n" 77 | assert p.stdout.readline() == "keyboard interrupt\n" 78 | assert p.stderr.read() == "" 79 | 80 | 81 | def test_shutdown_during_ctx_mgr_yield(): 82 | # We run it in a separate process so we can simulate interrupting it 83 | fn = Path(__file__).parent / "support" / "_shutdown_ctx_mgr.py" 84 | with PopenWithCtrlC( 85 | [sys.executable, "-u", fn, "yield"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8" 86 | ) as p: 87 | for i in range(2): # this number doesn't matter, it's a while loop 88 | assert p.stdout.readline() == "in ctx\n" 89 | p.send_ctrl_c() 90 | assert p.stdout.readline() == "exit\n" 91 | assert p.stdout.readline() == "keyboard interrupt\n" 92 | assert p.stderr.read() == "" 93 | 94 | 95 | @pytest.mark.parametrize("run_number", range(10)) # don't allow this to flake! 96 | def test_shutdown_during_async_run(run_number): 97 | fn = Path(__file__).parent / "support" / "_shutdown_async_run.py" 98 | with PopenWithCtrlC( 99 | [sys.executable, "-u", fn], 100 | stdout=subprocess.PIPE, 101 | stderr=subprocess.PIPE, 102 | encoding="utf-8", 103 | ) as p: 104 | 105 | def line(): 106 | # debugging help 107 | line_data = p.stdout.readline() 108 | print(line_data) 109 | return line_data 110 | 111 | assert line() == "running\n" 112 | p.send_ctrl_c() 113 | print("sigint sent") 114 | while (next_line := line()) == "running\n": 115 | pass 116 | assert next_line == "cancelled\n" 117 | stdout, stderr = p.communicate(timeout=5) 118 | print(stderr) 119 | assert stdout == ("handled cancellation\nexit async\nkeyboard interrupt\n") 120 | assert stderr == "" 121 | -------------------------------------------------------------------------------- /test/support/_forker.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import synchronicity 4 | 5 | synchronizer = synchronicity.Synchronizer() 6 | 7 | 8 | @synchronizer.create_blocking 9 | async def dummy(): 10 | print("done", flush=True) 11 | 12 | 13 | if __name__ == "__main__": 14 | dummy() # this starts a synchronizer loop/thread 15 | if not os.fork(): 16 | assert not synchronizer._thread.is_alive() # threads don't come along in forks 17 | dummy() # this should still work 18 | -------------------------------------------------------------------------------- /test/support/_gevent.py: -------------------------------------------------------------------------------- 1 | from gevent import monkey 2 | 3 | monkey.patch_all() 4 | 5 | import asyncio # noqa: E402 6 | 7 | from synchronicity import Synchronizer # noqa: E402 8 | 9 | 10 | async def f(x): 11 | await asyncio.sleep(0.1) 12 | return x**2 13 | 14 | 15 | s = Synchronizer() 16 | f_s = s.create_blocking(f) 17 | for i in range(3): 18 | assert f_s(42) == 1764 19 | -------------------------------------------------------------------------------- /test/support/_shutdown.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from synchronicity import Synchronizer 4 | 5 | 6 | async def run(): 7 | try: 8 | while True: 9 | print("running") 10 | await asyncio.sleep(0.3) 11 | except asyncio.CancelledError: 12 | print("cancelled") 13 | await asyncio.sleep(0.1) 14 | print("handled cancellation") 15 | raise 16 | finally: 17 | await asyncio.sleep(0.1) 18 | print("exit async") 19 | 20 | 21 | s = Synchronizer() 22 | blocking_run = s.create_blocking(run) 23 | try: 24 | blocking_run() 25 | except KeyboardInterrupt: 26 | print("keyboard interrupt") 27 | -------------------------------------------------------------------------------- /test/support/_shutdown_async_run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from synchronicity import Synchronizer 4 | from synchronicity.async_utils import Runner 5 | 6 | 7 | async def run(): 8 | try: 9 | while True: 10 | await asyncio.sleep(0.2) 11 | print("running") 12 | 13 | except asyncio.CancelledError: 14 | print("cancelled") 15 | await asyncio.sleep(0.1) 16 | print("handled cancellation") 17 | raise 18 | finally: 19 | await asyncio.sleep(0.1) 20 | print("exit async") 21 | 22 | 23 | s = Synchronizer() 24 | 25 | blocking_run = s.create_blocking(run) 26 | 27 | 28 | try: 29 | with Runner() as runner: 30 | runner.run(blocking_run.aio()) 31 | except KeyboardInterrupt: 32 | print("keyboard interrupt") 33 | -------------------------------------------------------------------------------- /test/support/_shutdown_ctx_mgr.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | import time 4 | from contextlib import asynccontextmanager 5 | 6 | from synchronicity import Synchronizer 7 | 8 | 9 | @asynccontextmanager 10 | async def ctx_mgr(): 11 | try: 12 | if sys.argv[1] == "enter": 13 | while True: 14 | print("enter") 15 | await asyncio.sleep(0.3) 16 | elif sys.argv[1] == "yield": 17 | yield 18 | else: 19 | print("this should not happen") 20 | finally: 21 | print("exit") 22 | 23 | 24 | s = Synchronizer() 25 | blocking_ctx_mgr = s.create_blocking(ctx_mgr) 26 | try: 27 | with blocking_ctx_mgr(): 28 | while True: 29 | print("in ctx") 30 | time.sleep(0.3) 31 | except KeyboardInterrupt: 32 | print("keyboard interrupt") 33 | -------------------------------------------------------------------------------- /test/synchronicity_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures 3 | import inspect 4 | import pytest 5 | import sys 6 | import threading 7 | import time 8 | import typing 9 | from typing import Coroutine 10 | from unittest.mock import MagicMock 11 | 12 | from synchronicity import Synchronizer 13 | 14 | SLEEP_DELAY = 0.5 15 | WINDOWS_TIME_RESOLUTION_FIX = 0.01 if sys.platform == "win32" else 0.0 16 | 17 | 18 | async def f(x): 19 | await asyncio.sleep(SLEEP_DELAY) 20 | return x**2 21 | 22 | 23 | async def f2(fn, x): 24 | return await fn(x) 25 | 26 | 27 | def test_function_sync(synchronizer): 28 | s = synchronizer 29 | t0 = time.monotonic() 30 | f_s = s.create_blocking(f) 31 | assert f_s.__name__ == "blocking_f" 32 | ret = f_s(42) 33 | assert ret == 1764 34 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 35 | 36 | 37 | @pytest.mark.asyncio 38 | async def test_function_async(synchronizer): 39 | s = synchronizer 40 | f_s = s.wrap(f) 41 | t0 = time.monotonic() 42 | ret = await f_s.aio(42) 43 | assert ret == 1764 44 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 45 | 46 | 47 | def test_function_sync_future(synchronizer): 48 | t0 = time.monotonic() 49 | f_s = synchronizer.create_blocking(f) 50 | assert f_s.__name__ == "blocking_f" 51 | fut = f_s(42, _future=True) 52 | assert isinstance(fut, concurrent.futures.Future) 53 | assert time.monotonic() - t0 < SLEEP_DELAY 54 | assert fut.result() == 1764 55 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 56 | 57 | 58 | @pytest.mark.asyncio 59 | async def test_function_async_as_function_attribute(synchronizer): 60 | s = synchronizer 61 | t0 = time.monotonic() 62 | f_s = s.create_blocking(f).aio 63 | assert f_s.__name__ == "aio_f" 64 | coro = f_s(42) 65 | assert inspect.iscoroutine(coro) 66 | assert time.monotonic() - t0 < SLEEP_DELAY 67 | assert await coro == 1764 68 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 69 | 70 | # Make sure the same-loop calls work 71 | f2_s = s.create_blocking(f2).aio 72 | assert f2_s.__name__ == "aio_f2" 73 | coro = f2_s(f_s, 42) 74 | assert await coro == 1764 75 | 76 | # Make sure cross-loop calls work 77 | s2 = Synchronizer() 78 | f2_s2 = s2.create_blocking(f2).aio 79 | assert f2_s2.__name__ == "aio_f2" 80 | coro = f2_s2(f_s, 42) 81 | assert await coro == 1764 82 | s2._close_loop() 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_function_async_block_event_loop(synchronizer): 87 | async def spinlock(): 88 | # This blocks the event loop, but not the main event loop 89 | time.sleep(SLEEP_DELAY) 90 | 91 | spinlock_s = synchronizer.create_blocking(spinlock) 92 | spinlock_coro = spinlock_s.aio() 93 | sleep_coro = asyncio.sleep(SLEEP_DELAY) 94 | 95 | t0 = time.monotonic() 96 | await asyncio.gather(spinlock_coro, sleep_coro) 97 | duration = time.monotonic() - t0 98 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= duration < 1.5 * SLEEP_DELAY 99 | 100 | 101 | def test_function_many_parallel_sync(synchronizer): 102 | g = synchronizer.create_blocking(f) 103 | t0 = time.monotonic() 104 | rets = [g(i) for i in range(10)] # Will resolve serially 105 | assert ( 106 | len(rets) * SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < (len(rets) + 1) * SLEEP_DELAY 107 | ) 108 | 109 | 110 | def test_function_many_parallel_sync_futures(synchronizer): 111 | g = synchronizer.create_blocking(f) 112 | t0 = time.monotonic() 113 | futs = [g(i, _future=True) for i in range(100)] 114 | assert isinstance(futs[0], concurrent.futures.Future) 115 | assert time.monotonic() - t0 < SLEEP_DELAY 116 | assert [fut.result() for fut in futs] == [z**2 for z in range(100)] 117 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 118 | 119 | 120 | @pytest.mark.asyncio 121 | async def test_function_many_parallel_async(synchronizer): 122 | g = synchronizer.create_blocking(f) 123 | t0 = time.monotonic() 124 | coros = [g.aio(i) for i in range(20)] 125 | assert inspect.iscoroutine(coros[0]) 126 | assert time.monotonic() - t0 < 0.01 # invoking coroutine functions should be cheap 127 | assert await asyncio.gather(*coros) == [z**2 for z in range(20)] 128 | dur = time.monotonic() - t0 129 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= dur < 2 * SLEEP_DELAY 130 | 131 | 132 | async def gen(n): 133 | for i in range(n): 134 | await asyncio.sleep(SLEEP_DELAY) 135 | yield i 136 | 137 | 138 | async def gen2(generator, n): 139 | async for ret in generator(n): 140 | yield ret 141 | 142 | 143 | def test_generator_sync(synchronizer): 144 | synchronizer = synchronizer 145 | t0 = time.monotonic() 146 | gen_s = synchronizer.create_blocking(gen) 147 | it = gen_s(3) 148 | assert inspect.isgenerator(it) 149 | assert time.monotonic() - t0 < SLEEP_DELAY 150 | lst = list(it) 151 | assert lst == [0, 1, 2] 152 | assert time.monotonic() - t0 + WINDOWS_TIME_RESOLUTION_FIX >= len(lst) * SLEEP_DELAY 153 | 154 | 155 | @pytest.mark.asyncio 156 | async def test_generator_async(synchronizer): 157 | t0 = time.monotonic() 158 | gen_s = synchronizer.create_blocking(gen).aio 159 | 160 | asyncgen = gen_s(3) 161 | assert inspect.isasyncgen(asyncgen) 162 | assert time.monotonic() - t0 < SLEEP_DELAY 163 | lst = [z async for z in asyncgen] 164 | assert lst == [0, 1, 2] 165 | assert time.monotonic() - t0 + WINDOWS_TIME_RESOLUTION_FIX >= len(lst) * SLEEP_DELAY 166 | 167 | # Make sure same-loop calls work 168 | gen2_s = synchronizer.create_blocking(gen2).aio 169 | asyncgen = gen2_s(gen_s, 3) 170 | lst = [z async for z in asyncgen] 171 | assert lst == [0, 1, 2] 172 | 173 | # Make sure cross-loop calls work 174 | s2 = synchronizer 175 | gen2_s2 = s2.create_blocking(gen2).aio 176 | asyncgen = gen2_s2(gen_s, 3) 177 | lst = [z async for z in asyncgen] 178 | assert lst == [0, 1, 2] 179 | 180 | 181 | @pytest.mark.asyncio 182 | async def test_function_returning_coroutine(synchronizer): 183 | def func() -> Coroutine: 184 | async def inner(): 185 | return 10 186 | 187 | return inner() 188 | 189 | blocking_func = synchronizer.create_blocking(func) 190 | assert blocking_func() == 10 191 | coro = blocking_func.aio() 192 | assert inspect.iscoroutine(coro) 193 | assert await coro == 10 194 | 195 | 196 | def test_sync_lambda_returning_coroutine_sync(synchronizer): 197 | t0 = time.monotonic() 198 | g = synchronizer.create_blocking(lambda z: f(z + 1)) 199 | ret = g(42) 200 | assert ret == 1849 201 | assert time.monotonic() - t0 >= SLEEP_DELAY 202 | 203 | 204 | def test_sync_lambda_returning_coroutine_sync_futures(synchronizer): 205 | t0 = time.monotonic() 206 | g = synchronizer.create_blocking(lambda z: f(z + 1)) 207 | fut = g(42, _future=True) 208 | assert isinstance(fut, concurrent.futures.Future) 209 | assert time.monotonic() - t0 < SLEEP_DELAY 210 | assert fut.result() == 1849 211 | assert time.monotonic() - t0 >= SLEEP_DELAY 212 | 213 | 214 | @pytest.mark.asyncio 215 | async def test_sync_inline_func_returning_coroutine_async(synchronizer): 216 | t0 = time.monotonic() 217 | 218 | # NOTE: we don't create the async variant unless we know the function returns a coroutine 219 | def func(z) -> Coroutine: 220 | return f(z + 1) 221 | 222 | g = synchronizer.create_blocking(func) 223 | coro = g.aio(42) 224 | assert inspect.iscoroutine(coro) 225 | assert time.monotonic() - t0 < SLEEP_DELAY 226 | assert await coro == 1849 227 | assert time.monotonic() - t0 >= SLEEP_DELAY 228 | 229 | 230 | class Base: 231 | def __init__(self, x): 232 | self._x = x 233 | 234 | 235 | class MyClass(Base): 236 | def __init__(self, x): 237 | super().__init__(x) 238 | 239 | async def start(self): 240 | async def task(): 241 | await asyncio.sleep(SLEEP_DELAY) 242 | return self._x 243 | 244 | loop = asyncio.get_event_loop() 245 | self._task = loop.create_task(task()) 246 | 247 | async def get_result(self): 248 | ret = await self._task 249 | return ret**2 250 | 251 | async def __aenter__(self): 252 | await asyncio.sleep(SLEEP_DELAY) 253 | return 42 254 | 255 | async def __aexit__(self, exc_type, exc, tb): 256 | await asyncio.sleep(SLEEP_DELAY) 257 | 258 | @staticmethod 259 | async def my_static_method(): 260 | await asyncio.sleep(SLEEP_DELAY) 261 | return 43 262 | 263 | @classmethod 264 | async def my_class_method(cls): 265 | await asyncio.sleep(SLEEP_DELAY) 266 | return 44 267 | 268 | async def __aiter__(self): 269 | for i in range(self._x): 270 | yield i 271 | 272 | 273 | def test_class_sync(synchronizer): 274 | BlockingBase = synchronizer.create_blocking(Base, name="BlockingBase") 275 | BlockingMyClass = synchronizer.create_blocking(MyClass, name="BlockingMyClass") 276 | 277 | assert BlockingMyClass.__name__ == "BlockingMyClass" 278 | obj = BlockingMyClass(x=42) 279 | assert isinstance(obj, BlockingMyClass) 280 | assert isinstance(obj, BlockingBase) 281 | obj.start() 282 | ret = obj.get_result() 283 | assert ret == 1764 284 | 285 | t0 = time.monotonic() 286 | with obj as z: 287 | assert z == 42 288 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 289 | assert time.monotonic() - t0 + WINDOWS_TIME_RESOLUTION_FIX >= 2 * SLEEP_DELAY 290 | 291 | t0 = time.monotonic() 292 | assert BlockingMyClass.my_static_method() == 43 293 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 294 | 295 | t0 = time.monotonic() 296 | assert BlockingMyClass.my_class_method() == 44 297 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 298 | 299 | assert list(z for z in obj) == list(range(42)) 300 | 301 | 302 | def test_class_sync_futures(synchronizer): 303 | BlockingMyClass = synchronizer.create_blocking(MyClass) 304 | BlockingBase = synchronizer.create_blocking(Base) 305 | assert BlockingMyClass.__name__ == "BlockingMyClass" 306 | obj = BlockingMyClass(x=42) 307 | assert isinstance(obj, BlockingMyClass) 308 | assert isinstance(obj, BlockingBase) 309 | obj.start() 310 | fut = obj.get_result(_future=True) 311 | assert isinstance(fut, concurrent.futures.Future) 312 | assert fut.result() == 1764 313 | 314 | t0 = time.monotonic() 315 | with obj as z: 316 | assert z == 42 317 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 318 | 319 | assert time.monotonic() - t0 + WINDOWS_TIME_RESOLUTION_FIX >= 2 * SLEEP_DELAY 320 | 321 | 322 | @pytest.mark.asyncio 323 | async def test_class_async_as_method_attribute(synchronizer): 324 | BlockingMyClass = synchronizer.create_blocking(MyClass) 325 | BlockingBase = synchronizer.create_blocking(Base) 326 | assert BlockingMyClass.__name__ == "BlockingMyClass" 327 | obj = BlockingMyClass(x=42) 328 | assert isinstance(obj, BlockingMyClass) 329 | assert isinstance(obj, BlockingBase) 330 | await obj.start.aio() 331 | coro = obj.get_result.aio() 332 | assert inspect.iscoroutine(coro) 333 | assert await coro == 1764 334 | 335 | t0 = time.monotonic() 336 | async with obj as z: 337 | assert z == 42 338 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 339 | 340 | assert time.monotonic() - t0 + WINDOWS_TIME_RESOLUTION_FIX >= 2 * SLEEP_DELAY 341 | 342 | lst = [] 343 | 344 | async for z in obj: 345 | lst.append(z) 346 | 347 | assert lst == list(range(42)) 348 | 349 | assert await obj.my_static_method.aio() == 43 350 | assert await obj.my_class_method.aio() == 44 351 | 352 | 353 | @pytest.mark.skip(reason="Skip this until we've made it impossible to re-synchronize objects") 354 | def test_event_loop(synchronizer): 355 | t0 = time.monotonic() 356 | f_s = synchronizer.create_blocking(f) 357 | assert f_s(42) == 42 * 42 358 | assert SLEEP_DELAY - WINDOWS_TIME_RESOLUTION_FIX <= time.monotonic() - t0 < 2 * SLEEP_DELAY 359 | assert synchronizer._thread.is_alive() 360 | assert synchronizer._loop.is_running() 361 | synchronizer._close_loop() 362 | assert not synchronizer._loop.is_running() 363 | assert not synchronizer._thread.is_alive() 364 | 365 | new_loop = asyncio.new_event_loop() 366 | synchronizer._start_loop(new_loop) 367 | assert synchronizer._loop == new_loop 368 | assert synchronizer._loop.is_running() 369 | assert synchronizer._thread.is_alive() 370 | 371 | # Starting a loop again before closing throws. 372 | with pytest.raises(Exception): 373 | synchronizer._start_loop(new_loop) 374 | 375 | 376 | def test_doc_transfer(synchronizer): 377 | class Foo: 378 | """Hello""" 379 | 380 | async def foo(self): 381 | """hello""" 382 | 383 | output_class = synchronizer.create_blocking(Foo) 384 | 385 | assert output_class.__doc__ == "Hello" 386 | assert output_class.foo.__doc__ == "hello" 387 | assert output_class.foo.aio.__doc__ == "hello" 388 | 389 | 390 | def test_set_function_name(synchronizer): 391 | f_s = synchronizer.create_blocking(f, "xyz") 392 | assert f_s(42) == 1764 393 | assert f_s.__name__ == "xyz" 394 | 395 | 396 | def test_set_class_name(synchronizer): 397 | BlockingBase = synchronizer.create_blocking(Base, "XYZBase") 398 | assert BlockingBase.__name__ == "XYZBase" 399 | BlockingMyClass = synchronizer.create_blocking(MyClass, "XYZMyClass") 400 | assert BlockingMyClass.__name__ == "XYZMyClass" 401 | 402 | 403 | @pytest.mark.asyncio 404 | async def test_blocking_nested_aio_returns_blocking_obj(synchronizer): 405 | class Foo: 406 | async def get_self(self): 407 | return self 408 | 409 | BlockingFoo = synchronizer.create_blocking(Foo) 410 | 411 | original = BlockingFoo() 412 | assert original.get_self() == original 413 | 414 | self_from_aio_interface = await original.get_self.aio() 415 | assert self_from_aio_interface == original 416 | assert isinstance(self_from_aio_interface, BlockingFoo) 417 | 418 | 419 | def test_no_input_translation(monkeypatch, synchronizer): 420 | @synchronizer.create_blocking 421 | def does_input_translation(arg: float) -> str: 422 | return str(arg) 423 | 424 | @synchronizer.create_blocking 425 | @synchronizer.no_input_translation 426 | async def without_input_translation(arg: float) -> str: 427 | return str(arg) 428 | 429 | in_translate_spy = MagicMock(wraps=synchronizer._translate_scalar_in) 430 | monkeypatch.setattr(synchronizer, "_translate_scalar_in", in_translate_spy) 431 | does_input_translation(3.14) # test without decorator, this *should* do input translation 432 | in_translate_spy.assert_called_once_with(3.14) 433 | 434 | in_translate_spy.reset_mock() 435 | without_input_translation(3.14) 436 | in_translate_spy.assert_not_called() 437 | 438 | 439 | def test_no_output_translation(monkeypatch, synchronizer): 440 | @synchronizer.create_blocking 441 | def does_output_translation(arg: float) -> str: 442 | return str(arg) 443 | 444 | @synchronizer.create_blocking 445 | @synchronizer.no_output_translation 446 | async def without_output_translation(arg: float) -> str: 447 | return str(arg) 448 | 449 | out_translate_spy = MagicMock(wraps=synchronizer._translate_scalar_out) 450 | monkeypatch.setattr(synchronizer, "_translate_scalar_out", out_translate_spy) 451 | does_output_translation(3.14) # test without decorator, this *should* do input translation 452 | out_translate_spy.assert_called_once_with("3.14") 453 | 454 | out_translate_spy.reset_mock() 455 | without_output_translation(3.14) 456 | out_translate_spy.assert_not_called() 457 | 458 | 459 | @pytest.mark.asyncio 460 | async def test_non_async_aiter(synchronizer): 461 | async def some_async_gen(): 462 | yield "foo" 463 | yield "bar" 464 | 465 | class It: 466 | def __aiter__(self): 467 | self._gen = some_async_gen() 468 | return self 469 | 470 | async def __anext__(self): 471 | value = await self._gen.__anext__() 472 | return value 473 | 474 | async def aclose(self): 475 | await self._gen.aclose() 476 | 477 | WrappedIt = synchronizer.create_blocking(It, name="WrappedIt") 478 | 479 | # just a sanity check of the original iterable: 480 | orig_async_it = It() 481 | assert [v async for v in orig_async_it] == ["foo", "bar"] 482 | await orig_async_it.aclose() 483 | 484 | # check async iteration on the wrapped iterator 485 | it = WrappedIt() 486 | assert [v async for v in it] == ["foo", "bar"] 487 | await it.aclose() 488 | 489 | # check sync iteration on the wrapped iterator 490 | it = WrappedIt() 491 | assert list(it) == ["foo", "bar"] 492 | it.close() 493 | 494 | 495 | def test_generic_baseclass(synchronizer): 496 | T = typing.TypeVar("T") 497 | V = typing.TypeVar("V") 498 | 499 | class GenericClass(typing.Generic[T, V]): 500 | async def do_something(self): 501 | return 1 502 | 503 | WrappedGenericClass = synchronizer.create_blocking(GenericClass, name="BlockingGenericClass") 504 | 505 | assert WrappedGenericClass[str, float].__args__ == (str, float) 506 | 507 | instance: WrappedGenericClass[str, float] = WrappedGenericClass() # should be allowed 508 | assert isinstance(instance, WrappedGenericClass) 509 | assert instance.do_something() == 1 510 | 511 | Q = typing.TypeVar("Q") 512 | Y = typing.TypeVar("Y") 513 | 514 | class GenericSubclass(GenericClass[Q, Y]): 515 | pass 516 | 517 | WrappedGenericSubclass = synchronizer.create_blocking(GenericSubclass, name="BlockingGenericSubclass") 518 | assert WrappedGenericSubclass[bool, int].__args__ == (bool, int) 519 | instance_2 = WrappedGenericSubclass() 520 | assert isinstance(instance_2, WrappedGenericSubclass) 521 | assert isinstance(instance_2, WrappedGenericClass) # still instance of parent 522 | assert instance.do_something() == 1 # has base methods 523 | 524 | 525 | @pytest.mark.asyncio 526 | async def test_async_cancellation(synchronizer): 527 | states = [] 528 | 529 | async def foo(abort_cancellation: bool, cancel_self: bool = False): 530 | states.append("ready") 531 | if cancel_self: 532 | asyncio.tasks.current_task().cancel() 533 | try: 534 | await asyncio.sleep(10) 535 | except asyncio.CancelledError: 536 | states.append("cancelled") 537 | await asyncio.sleep(0.1) 538 | states.append("handled cancellation") 539 | if not abort_cancellation: 540 | raise 541 | return "done" 542 | 543 | wrapped_foo = synchronizer.create_blocking(foo) 544 | 545 | async def start_task(abort_cancellation: bool, cancel_self: bool = False): 546 | states.clear() 547 | calling_task = asyncio.create_task( 548 | wrapped_foo.aio(abort_cancellation=abort_cancellation, cancel_self=cancel_self) 549 | ) 550 | while "ready" not in states: 551 | await asyncio.sleep(0.01) # don't cancel before the task even starts 552 | return calling_task 553 | 554 | # Case 1: cancel in parent goes into the coroutine and comes back out: 555 | calling_task = await start_task(abort_cancellation=False) 556 | calling_task.cancel() 557 | with pytest.raises(asyncio.CancelledError): 558 | await calling_task 559 | assert states == ["ready", "cancelled", "handled cancellation"] 560 | 561 | # Case 2: cancel in parent goes into the coroutine and is "aborted" by the coroutine: 562 | # Note: This is explicitly not allowed anymore, since we can't distinguish it from the task 563 | # finishing successfully before a cancellation takes place, and no cancellation 564 | # getting raised - causing unintended aborted cancellations in the calling event loop 565 | # calling_task = await start_task(abort_cancellation=True) 566 | # calling_task.cancel() 567 | # assert await calling_task == "done" 568 | # assert states == ["ready", "cancelled", "handled cancellation"] 569 | 570 | # Case 3: cancellation from within the coroutine itself comes back out: 571 | calling_task = await start_task(abort_cancellation=False, cancel_self=True) 572 | with pytest.raises(asyncio.CancelledError): 573 | await calling_task 574 | assert states == ["ready", "cancelled", "handled cancellation"] 575 | 576 | # Case 4: cancellation of the synchronicity task containing the coroutine itself 577 | # but it's caught and should not be propagated to the caller: 578 | calling_task = await start_task(abort_cancellation=True, cancel_self=True) 579 | assert await calling_task == "done" 580 | assert "ready" in states 581 | assert states == ["ready", "cancelled", "handled cancellation"] 582 | 583 | 584 | @pytest.mark.asyncio 585 | async def test_async_cancel_completes_successfully_still_cancels(synchronizer): 586 | # Reproduces a race where the synchronizer event loop finishes a task 587 | # before a cancellation has a chance to get scheduled, and as such 588 | # never bubbles up the cancellation, even though it was never 589 | # caught 590 | e = threading.Event() 591 | 592 | @synchronizer.wrap 593 | async def well_behaved_coro(): 594 | e.wait() 595 | return 1 596 | 597 | local_task = asyncio.create_task(well_behaved_coro.aio()) 598 | await asyncio.sleep(0.1) # let other event loop block at e2.wait above 599 | # well_behaved_coro has not exited at this point, and local_task is not resolved 600 | assert not local_task.done() 601 | local_task.cancel() # this schedules cancellation on other thread 602 | await asyncio.sleep(0.1) 603 | assert not local_task.cancelled() # not yet fully cancelled! 604 | e.set() # release other event loop at resolve point, simulating a race 605 | # users would typically assume that a cancellation of a non-done well behaved 606 | # task would race a cancellation error in the next await of that task: 607 | with pytest.raises(asyncio.CancelledError): 608 | await local_task 609 | 610 | 611 | def test_async_inner_still_translates(synchronizer): 612 | class _V: 613 | pass 614 | 615 | V = synchronizer.wrap(_V) 616 | 617 | @synchronizer.wrap 618 | async def inner(): 619 | return _V() 620 | 621 | @synchronizer.wrap 622 | async def outer(): 623 | v = await inner.aio() 624 | assert isinstance(v, V) 625 | 626 | outer() 627 | -------------------------------------------------------------------------------- /test/threading_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures 3 | import time 4 | 5 | 6 | def test_start_loop(synchronizer): 7 | # Make sure there's no race condition in _start_loop 8 | with concurrent.futures.ThreadPoolExecutor() as executor: 9 | ret = list(executor.map(lambda i: synchronizer._start_loop(), range(1000))) 10 | 11 | assert len(set(ret)) == 1 12 | assert isinstance(ret[0], asyncio.AbstractEventLoop) 13 | 14 | 15 | async def f(i): 16 | await asyncio.sleep(1.0) 17 | return i**2 18 | 19 | 20 | def test_multithreaded(synchronizer, n_threads=20): 21 | f_s = synchronizer.create_blocking(f) 22 | 23 | t0 = time.time() 24 | with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor: 25 | ret = list(executor.map(f_s, range(n_threads))) 26 | assert 1.0 <= time.time() - t0 < 1.2 27 | assert ret == [i**2 for i in range(n_threads)] 28 | -------------------------------------------------------------------------------- /test/tracebacks_test.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import pytest 3 | import sys 4 | import traceback 5 | from pathlib import Path 6 | from types import TracebackType 7 | 8 | 9 | class CustomException(Exception): 10 | pass 11 | 12 | 13 | async def raise_something(exc): 14 | raise exc 15 | 16 | 17 | async def gen(): 18 | raise CustomException("gen boom!") 19 | yield 20 | 21 | 22 | def check_traceback(tb: TracebackType, outside_frames=0, outside_frames_old_python=1): 23 | traceback_string = "\n".join(traceback.format_tb(tb)) 24 | assert str(Path(__file__)) in traceback_string # this file should be in traceback 25 | n_outside = 0 26 | for frame in traceback.extract_tb(tb): 27 | if frame.filename != __file__: 28 | n_outside += 1 29 | 30 | # don't allow more than allowed_outside_frames from outside of this file 31 | limit = outside_frames_old_python if sys.version_info < (3, 11) else outside_frames 32 | if n_outside != limit: 33 | print(traceback_string) 34 | raise Exception(f"Got {n_outside} frames outside user code, expected {limit}") 35 | 36 | 37 | def test_sync_to_async(synchronizer): 38 | raise_something_blocking = synchronizer.create_blocking(raise_something) 39 | with pytest.raises(CustomException) as exc_info: 40 | raise_something_blocking(CustomException("boom!")) 41 | 42 | check_traceback(exc_info.tb) 43 | traceback_string = "\n".join(traceback.format_tb(exc_info.tb)) 44 | assert 'raise_something_blocking(CustomException("boom!"))' in traceback_string 45 | assert "raise exc" in traceback_string 46 | 47 | 48 | def test_full_traceback_flag(synchronizer, monkeypatch): 49 | monkeypatch.setattr("synchronicity.exceptions.SYNCHRONICITY_TRACEBACK", True) 50 | raise_something_blocking = synchronizer.create_blocking(raise_something) 51 | with pytest.raises(CustomException) as exc_info: 52 | raise_something_blocking(CustomException("boom!")) 53 | 54 | check_traceback(exc_info.tb, outside_frames=8, outside_frames_old_python=8) 55 | traceback_string = "\n".join(traceback.format_tb(exc_info.tb)) 56 | 57 | assert 'raise_something_blocking(CustomException("boom!"))' in traceback_string 58 | assert "raise exc" in traceback_string 59 | 60 | 61 | @pytest.mark.asyncio 62 | async def test_async_to_async(synchronizer): 63 | raise_something_wrapped = synchronizer.create_blocking(raise_something) 64 | with pytest.raises(CustomException) as exc_info: 65 | await raise_something_wrapped.aio(CustomException("boom!")) 66 | 67 | check_traceback(exc_info.tb) 68 | 69 | 70 | def test_sync_to_async_gen(synchronizer): 71 | gen_s = synchronizer.create_blocking(gen) 72 | with pytest.raises(CustomException) as exc_info: 73 | for x in gen_s(): 74 | pass 75 | 76 | check_traceback(exc_info.tb) 77 | 78 | 79 | @pytest.mark.asyncio 80 | async def test_async_to_async_gen(synchronizer): 81 | gen_s = synchronizer.create_blocking(gen) 82 | with pytest.raises(CustomException) as exc_info: 83 | async for x in gen_s.aio(): 84 | pass 85 | 86 | check_traceback(exc_info.tb) 87 | 88 | 89 | def test_sync_to_async_ctx_mgr(synchronizer): 90 | ctx_mgr = synchronizer.create_blocking(contextlib.asynccontextmanager(gen)) 91 | with pytest.raises(CustomException) as exc_info: 92 | with ctx_mgr(): 93 | pass 94 | 95 | # we allow one frame from contextlib which would be expected in non-synchronicity code 96 | # in old pythons we have to live with more synchronicity frames here due to multi 97 | # wrapping 98 | check_traceback(exc_info.tb, outside_frames=1, outside_frames_old_python=3) 99 | 100 | 101 | @pytest.mark.asyncio 102 | async def test_async_to_async_ctx_mgr(synchronizer): 103 | ctx_mgr = synchronizer.create_blocking(contextlib.asynccontextmanager(gen)) 104 | with pytest.raises(CustomException) as exc_info: 105 | async with ctx_mgr(): 106 | pass 107 | 108 | # we allow one frame from contextlib which would be expected in non-synchronicity code 109 | # in old pythons we have to live with more synchronicity frames here due to multi 110 | # wrapping 111 | check_traceback(exc_info.tb, outside_frames=1, outside_frames_old_python=3) 112 | 113 | 114 | def test_recursive(synchronizer): 115 | async def f(n): 116 | if n == 0: 117 | raise CustomException("boom!") 118 | else: 119 | return await f(n - 1) 120 | 121 | f_blocking = synchronizer.create_blocking(f) 122 | 123 | with pytest.raises(CustomException) as exc_info: 124 | f_blocking(10) 125 | 126 | check_traceback(exc_info.tb) 127 | -------------------------------------------------------------------------------- /test/translate_test.py: -------------------------------------------------------------------------------- 1 | def test_translate(synchronizer): 2 | class Foo: 3 | pass 4 | 5 | class FooProvider: 6 | def __init__(self, foo=None): 7 | if foo is not None: 8 | assert type(foo) is Foo 9 | self.foo = foo 10 | else: 11 | self.foo = Foo() 12 | 13 | def get(self): 14 | return self.foo 15 | 16 | def get2(self): 17 | return [self.foo, self.foo] 18 | 19 | @property 20 | def pget(self): 21 | return self.foo 22 | 23 | def set(self, foo): 24 | assert type(foo) is Foo 25 | self.foo = foo 26 | 27 | @classmethod 28 | def cls_in(cls): 29 | assert cls == FooProvider 30 | 31 | @classmethod 32 | def cls_out(cls): 33 | return FooProvider 34 | 35 | BlockingFoo = synchronizer.create_blocking(Foo) 36 | assert BlockingFoo.__name__ == "BlockingFoo" 37 | BlockingFooProvider = synchronizer.create_blocking(FooProvider) 38 | assert BlockingFooProvider.__name__ == "BlockingFooProvider" 39 | foo_provider = BlockingFooProvider() 40 | 41 | # Make sure two instances translated out are the same 42 | foo1 = foo_provider.get() 43 | foo2 = foo_provider.get() 44 | assert foo1 == foo2 45 | 46 | # Make sure we can return a list 47 | foos = foo_provider.get2() 48 | assert foos == [foo1, foo2] 49 | 50 | # Make sure properties work 51 | foo = foo_provider.pget 52 | assert isinstance(foo, BlockingFoo) 53 | 54 | # Translate an object in and then back out, make sure it's the same 55 | foo = BlockingFoo() 56 | assert type(foo) is BlockingFoo 57 | foo_provider.set(foo) 58 | assert foo_provider.get() == foo 59 | 60 | # Make sure classes are translated properly too 61 | BlockingFooProvider.cls_in() 62 | assert BlockingFooProvider.cls_out() == BlockingFooProvider 63 | 64 | # Make sure the constructor works 65 | foo = BlockingFoo() 66 | foo_provider = BlockingFooProvider(foo) 67 | assert foo_provider.get() == foo 68 | -------------------------------------------------------------------------------- /test/type_stub_e2e_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import subprocess 3 | import sys 4 | import tempfile 5 | import textwrap 6 | from contextlib import contextmanager 7 | from pathlib import Path 8 | from traceback import print_exc 9 | 10 | from synchronicity.type_stubs import write_stub 11 | 12 | helpers_dir = Path(__file__).parent / "type_stub_helpers" 13 | assertion_file = helpers_dir / "e2e_example_type_assertions.py" 14 | 15 | 16 | class FailedMyPyCheck(Exception): 17 | def __init__(self, output): 18 | self.output = output 19 | 20 | 21 | def run_mypy(input_file, print_errors=True): 22 | with subprocess.Popen(["mypy", input_file], stderr=subprocess.STDOUT, stdout=subprocess.PIPE) as p: 23 | result_code = p.wait() 24 | if result_code != 0: 25 | mypy_report = p.stdout.read().decode("utf8") 26 | if print_errors: 27 | print(mypy_report, file=sys.stderr) 28 | raise FailedMyPyCheck(mypy_report) 29 | 30 | 31 | @contextmanager 32 | def temp_assertion_file(new_assertion): 33 | template = assertion_file.read_text() 34 | setup_code, default_assertions = template.split("# assert start") 35 | assertion_code = setup_code + new_assertion 36 | with tempfile.NamedTemporaryFile(dir=assertion_file.parent, suffix=".py") as new_file: 37 | new_file.write(assertion_code.encode("utf8")) 38 | new_file.flush() 39 | try: 40 | yield new_file.name 41 | except: 42 | print(f"Exception when running type assertions on:\n{assertion_code}") 43 | print_exc() 44 | raise 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def interface_file(): 49 | write_stub("test.type_stub_helpers.e2e_example_export") 50 | yield 51 | 52 | 53 | def test_mypy_assertions(interface_file): 54 | run_mypy(assertion_file) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "failing_assertion,error_matches", 59 | [ 60 | ( 61 | "e2e_example_export.BlockingFoo(1)", 62 | 'incompatible type "int"; expected "str"', 63 | ), 64 | ( 65 | "blocking_foo.some_static()", 66 | 'Missing positional argument "arg" in call to "some_static"', 67 | ), # missing argument 68 | ( 69 | "blocking_foo.some_static(True)", 70 | 'Argument 1 to "some_static" of "BlockingFoo" has incompatible type "bool"', 71 | ), # bool instead of str 72 | ( 73 | "e2e_example_export.listify(123)", 74 | 'Value of type variable "_T_Blocking" of "__call__" of "__listify_spec" cannot be "int"', 75 | ), # int does not satisfy the type bound of the typevar (!) 76 | ( 77 | textwrap.dedent( 78 | """ 79 | async def a() -> None: 80 | aio_res = await e2e_example_export.returns_foo.aio("hello") 81 | """ 82 | ), 83 | 'Too many arguments for "aio" of "__returns_foo_spec"', 84 | ), 85 | ], 86 | ) 87 | @pytest.mark.skipif( 88 | sys.platform == "win32", reason="temp_assertion_file permissions issues on github actions (windows)" 89 | ) 90 | def test_failing_assertion(interface_file, failing_assertion, error_matches): 91 | # since there appears to be no good way of asserting failing type checks (and skipping to the next assertion) 92 | # we use the assertion file as a template to insert statements that should fail type checking 93 | with temp_assertion_file(failing_assertion) as custom_file: # we pass int instead of str 94 | with pytest.raises(FailedMyPyCheck, match=error_matches): 95 | run_mypy(custom_file, print_errors=False) 96 | -------------------------------------------------------------------------------- /test/type_stub_helpers/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyi 2 | -------------------------------------------------------------------------------- /test/type_stub_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modal-labs/synchronicity/731a20bbe9bff1d973eb8bbb87f73f2e9d4bad1a/test/type_stub_helpers/__init__.py -------------------------------------------------------------------------------- /test/type_stub_helpers/e2e_example_export.py: -------------------------------------------------------------------------------- 1 | # This file creates wrapped entities for async implementation in e2e_example_impl.py 2 | # This file is then used as input for generating type stubs 3 | from typing import Optional 4 | 5 | import synchronicity 6 | 7 | from . import e2e_example_impl 8 | 9 | synchronizer = synchronicity.Synchronizer() 10 | BlockingFoo = synchronizer.create_blocking(e2e_example_impl._Foo, "BlockingFoo", __name__) 11 | 12 | some_instance: Optional[BlockingFoo] = None 13 | 14 | _T_Blocking = synchronizer.create_blocking( 15 | e2e_example_impl._T, "_T_Blocking", __name__ 16 | ) # synchronize the TypeVar to support translation of bounds 17 | listify = synchronizer.create_blocking(e2e_example_impl._listify, "listify", __name__) 18 | 19 | overloaded = synchronizer.create_blocking(e2e_example_impl._overloaded, "overloaded", __name__) 20 | 21 | returns_foo = synchronizer.create_blocking(e2e_example_impl._returns_foo, "returns_foo", __name__) 22 | 23 | wrapped_make_context = synchronizer.create_blocking(e2e_example_impl.make_context, "make_context", __name__) 24 | 25 | # TODO: we shouldn't need to wrap typevars unless they have wrapped `bounds` 26 | P = synchronizer.create_blocking(e2e_example_impl.P, "P", __name__) 27 | R = synchronizer.create_blocking(e2e_example_impl.R, "R", __name__) 28 | 29 | 30 | CallableWrapper = synchronizer.create_blocking(e2e_example_impl.CallableWrapper, "CallableWrapper", __name__) 31 | 32 | wrap_callable = synchronizer.create_blocking(e2e_example_impl.wrap_callable, "wrap_callable", __name__) 33 | -------------------------------------------------------------------------------- /test/type_stub_helpers/e2e_example_impl.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import AsyncGenerator, List, TypeVar, Union, overload 3 | 4 | import typing_extensions 5 | 6 | from synchronicity.async_wrap import asynccontextmanager 7 | 8 | 9 | class _Foo: 10 | singleton: "_Foo" 11 | 12 | def __init__(self, arg: str): 13 | self.arg = arg 14 | 15 | async def getarg(self) -> str: 16 | return self.arg 17 | 18 | async def gen(self) -> AsyncGenerator[int, None]: 19 | yield 1 20 | 21 | @staticmethod 22 | def some_static(arg: str) -> float: 23 | return # type: ignore 24 | 25 | @classmethod 26 | def clone(cls, foo: "_Foo") -> "_Foo": # self ref 27 | return # type: ignore 28 | 29 | 30 | _T = TypeVar("_T", bound=_Foo) 31 | 32 | 33 | async def _listify(t: _T) -> List[_T]: 34 | return [t] 35 | 36 | 37 | @overload 38 | def _overloaded(arg: str) -> float: 39 | pass 40 | 41 | 42 | @overload 43 | def _overloaded(arg: int) -> int: 44 | pass 45 | 46 | 47 | def _overloaded(arg: Union[str, int]): 48 | if isinstance(arg, str): 49 | return float(arg) 50 | return arg 51 | 52 | 53 | async def _returns_foo() -> _Foo: 54 | return _Foo("hello") 55 | 56 | 57 | @asynccontextmanager 58 | async def make_context(a: float) -> typing.AsyncGenerator[str, None]: 59 | yield "hello" 60 | 61 | 62 | P = typing_extensions.ParamSpec("P") 63 | R = typing.TypeVar("R") 64 | 65 | 66 | class CallableWrapper(typing.Generic[P, R]): 67 | async def func(self, *args: P.args, **kwargs: P.kwargs) -> R: 68 | return # type: ignore 69 | 70 | 71 | def wrap_callable(c: typing.Callable[P, R]) -> CallableWrapper[P, R]: 72 | return # type: ignore 73 | -------------------------------------------------------------------------------- /test/type_stub_helpers/e2e_example_type_assertions.py: -------------------------------------------------------------------------------- 1 | # this code is only meant to be "running" through mypy and not an actual python interpreter! 2 | import typing 3 | 4 | from typing_extensions import assert_type 5 | 6 | from test.type_stub_helpers import e2e_example_export 7 | 8 | blocking_foo = e2e_example_export.BlockingFoo("hello") 9 | 10 | # assert start 11 | assert_type(blocking_foo, e2e_example_export.BlockingFoo) 12 | 13 | assert_type(blocking_foo.getarg(), str) 14 | assert_type(blocking_foo.gen(), typing.Generator[int, None, None]) 15 | 16 | assert_type(e2e_example_export.some_instance, typing.Optional[e2e_example_export.BlockingFoo]) 17 | 18 | assert_type(blocking_foo.some_static("foo"), float) 19 | 20 | assert_type(e2e_example_export.BlockingFoo.clone(blocking_foo), e2e_example_export.BlockingFoo) 21 | 22 | assert_type(blocking_foo.singleton, e2e_example_export.BlockingFoo) 23 | 24 | 25 | assert_type( 26 | e2e_example_export.listify(blocking_foo), 27 | typing.List[e2e_example_export.BlockingFoo], 28 | ) 29 | 30 | 31 | assert_type(e2e_example_export.overloaded("12"), float) 32 | 33 | assert_type(e2e_example_export.overloaded(12), int) 34 | 35 | 36 | with e2e_example_export.wrapped_make_context(10.0) as c: 37 | assert_type(c, str) 38 | 39 | 40 | async def async_block() -> None: 41 | res = await e2e_example_export.returns_foo.aio() 42 | assert_type(res, e2e_example_export.BlockingFoo) 43 | 44 | async with e2e_example_export.wrapped_make_context(10.0) as c: 45 | assert_type(c, str) 46 | 47 | # not sure if this should actually be supported, but it is, for completeness: 48 | async with e2e_example_export.wrapped_make_context.aio(10.0) as c: 49 | assert_type(c, str) 50 | 51 | 52 | def f(a: str) -> float: 53 | return 0.1 54 | 55 | 56 | res = e2e_example_export.wrap_callable(f).func(a="q") 57 | assert_type(res, float) 58 | -------------------------------------------------------------------------------- /test/type_stub_helpers/some_mod.py: -------------------------------------------------------------------------------- 1 | import typing_extensions 2 | 3 | 4 | class Foo: 5 | pass 6 | 7 | 8 | cool: str 9 | 10 | P = typing_extensions.ParamSpec("P") 11 | -------------------------------------------------------------------------------- /test/type_stub_test.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import importlib 4 | import pathlib 5 | import pytest 6 | import sys 7 | import typing 8 | from textwrap import dedent 9 | 10 | import typing_extensions 11 | 12 | import synchronicity 13 | from synchronicity import classproperty, overload_tracking 14 | from synchronicity.async_wrap import asynccontextmanager 15 | from synchronicity.type_stubs import StubEmitter 16 | 17 | from .type_stub_helpers import some_mod 18 | 19 | 20 | def noop(): ... 21 | 22 | 23 | def arg_no_anno(arg1): ... 24 | 25 | 26 | def scalar_args(arg1: str, arg2: int) -> float: 27 | return 0 28 | 29 | 30 | def generic_other_module_arg(arg: typing.List[some_mod.Foo]): ... 31 | 32 | 33 | async def async_func() -> str: 34 | return "hello" 35 | 36 | 37 | def single_line_docstring_func(): 38 | """I have a single line docstring""" 39 | 40 | 41 | def multi_line_docstring_func(): 42 | """I have a docstring 43 | 44 | with multiple lines 45 | """ 46 | 47 | 48 | def nested_docstring_func(): 49 | """I have a docstring 50 | 51 | ``` 52 | def example(): 53 | \"""SUPRISE! SO DO I!\""" 54 | ``` 55 | """ 56 | 57 | 58 | def deranged_docstring_func(): 59 | """I have \""" and also ''' for some reason""" 60 | 61 | 62 | class SingleLineDocstringClass: 63 | """I have a single line docstring""" 64 | 65 | 66 | class MultiLineDocstringClass: 67 | """I have a docstring 68 | 69 | with multiple lines 70 | """ 71 | 72 | 73 | class ClassWithMethodsWithDocstrings: 74 | def method_with_single_line_docstring(self): 75 | """I have a docstring""" 76 | 77 | def method_with_multi_line_docstring(self): 78 | """I have a docstring 79 | 80 | with multiple lines 81 | """ 82 | 83 | 84 | def _function_source(func, target_module=__name__): 85 | stub_emitter = StubEmitter(target_module) 86 | stub_emitter.add_function(func, func.__name__) 87 | return stub_emitter.get_source() 88 | 89 | 90 | def _class_source(cls, target_module=__name__): 91 | stub_emitter = StubEmitter(target_module) 92 | stub_emitter.add_class(cls, cls.__name__) 93 | return stub_emitter.get_source() 94 | 95 | 96 | def test_function_basics(): 97 | assert _function_source(noop) == "def noop():\n ...\n" 98 | assert _function_source(arg_no_anno) == "def arg_no_anno(arg1):\n ...\n" 99 | assert _function_source(scalar_args) == "def scalar_args(arg1: str, arg2: int) -> float:\n ...\n" 100 | 101 | 102 | def test_function_with_imports(): 103 | assert ( 104 | _function_source(generic_other_module_arg, target_module="dummy") 105 | == """import test.type_stub_helpers.some_mod 106 | import typing 107 | 108 | def generic_other_module_arg(arg: typing.List[test.type_stub_helpers.some_mod.Foo]): 109 | ... 110 | """ 111 | ) 112 | 113 | 114 | def test_async_func(): 115 | assert _function_source(async_func) == "async def async_func() -> str:\n ...\n" 116 | 117 | 118 | def test_async_gen(): 119 | async def async_gen() -> typing.AsyncGenerator[int, None]: 120 | yield 0 121 | 122 | assert ( 123 | _function_source(async_gen) 124 | == "import typing\n\ndef async_gen() -> typing.AsyncGenerator[int, None]:\n ...\n" 125 | ) 126 | 127 | def weird_async_gen() -> typing.AsyncGenerator[int, None]: 128 | # non-async function that returns an async generator 129 | async def gen(): 130 | yield 0 131 | 132 | return gen() 133 | 134 | assert ( 135 | _function_source(weird_async_gen) 136 | == "import typing\n\ndef weird_async_gen() -> typing.AsyncGenerator[int, None]:\n ...\n" 137 | ) 138 | 139 | async def it() -> typing.AsyncIterator[str]: # this is the/a correct annotation 140 | yield "hello" 141 | 142 | src = _function_source(it) 143 | assert "yield" not in src 144 | # since the yield keyword is removed in a type stub, the async prefix needs to be removed as well 145 | # to avoid "double asyncness" (while keeping the remaining annotation) 146 | assert "async" not in src 147 | assert "def it() -> typing.AsyncIterator[str]:" in src 148 | 149 | 150 | class MixedClass: 151 | class_var: str 152 | 153 | def some_method(self) -> bool: 154 | return False 155 | 156 | @classmethod 157 | def some_class_method(cls) -> int: 158 | return 1 159 | 160 | @staticmethod 161 | def some_staticmethod() -> float: 162 | return 0.0 163 | 164 | @property 165 | def some_property(self) -> str: 166 | return "" 167 | 168 | @some_property.setter 169 | def some_property(self, val): 170 | print(val) 171 | 172 | @some_property.deleter 173 | def some_property(self, val): 174 | print(val) 175 | 176 | @classproperty 177 | def class_property(cls): 178 | return 1 179 | 180 | 181 | def test_class_generation(): 182 | emitter = StubEmitter(__name__) 183 | emitter.add_class(MixedClass, "MixedClass") 184 | source = emitter.get_source() 185 | last_assertion_location = None 186 | 187 | def assert_in_after_last(search_string): 188 | nonlocal last_assertion_location 189 | assert search_string in source 190 | if last_assertion_location is not None: 191 | new_location = source.find(search_string) 192 | assert new_location > last_assertion_location 193 | last_assertion_location = new_location 194 | 195 | indent = " " 196 | assert_in_after_last("import synchronicity") 197 | assert_in_after_last("class MixedClass:") 198 | assert_in_after_last(f"{indent}class_var: str") 199 | assert_in_after_last(f"{indent}class_var: str") 200 | assert_in_after_last(f"{indent}def some_method(self) -> bool:\n{indent * 2}...") 201 | assert_in_after_last(f"{indent}@classmethod\n{indent}def some_class_method(cls) -> int:\n{indent * 2}...") 202 | assert_in_after_last(f"{indent}@staticmethod\n{indent}def some_staticmethod() -> float:") 203 | assert_in_after_last(f"{indent}@property\n{indent}def some_property(self) -> str:") 204 | assert_in_after_last(f"{indent}@some_property.setter\n{indent}def some_property(self, val):") 205 | assert_in_after_last(f"{indent}@some_property.deleter\n{indent}def some_property(self, val):") 206 | assert_in_after_last(f"{indent}@synchronicity.classproperty\n{indent}def class_property(cls):\n{indent * 2}...") 207 | 208 | 209 | def merged_signature(*sigs): 210 | sig = sigs[0].copy() 211 | return sig 212 | 213 | 214 | def test_wrapped_function_with_new_annotations(): 215 | """A wrapped function (in general, using functools.wraps/partial) would 216 | have an inspect.signature from the wrapped function by default 217 | and from the wrapper function is inspect.signature gets the follow_wrapped=True 218 | option. However, for the best type stub usability, the best would be to combine 219 | all layers of wrapping, adding any additional arguments or annotations as updates 220 | to the underlying wrapped function signature. 221 | 222 | This test makes sure we do just that. 223 | """ 224 | 225 | def orig(arg: str): ... 226 | 227 | @functools.wraps(orig) 228 | def wrapper(extra_arg: int, *args, **kwargs): 229 | orig(*args, **kwargs) 230 | 231 | wrapper.__annotations__.update({"extra_arg": int, "arg": float}) 232 | assert _function_source(wrapper) == "def orig(extra_arg: int, arg: float):\n ...\n" 233 | 234 | 235 | def test_wrapped_async_func_remains_async(): 236 | async def orig(arg: str): ... 237 | 238 | @functools.wraps(orig) 239 | def wrapper(*args, **kwargs): 240 | return orig(*args, **kwargs) 241 | 242 | assert _function_source(wrapper) == "async def orig(arg: str):\n ...\n" 243 | 244 | 245 | class Base: 246 | def base_method(self) -> str: 247 | return "" 248 | 249 | 250 | Base.__module__ = "basemod" 251 | Base.__qualname__ = "Base" 252 | 253 | 254 | class Based(Base): 255 | def sub(self) -> float: 256 | return 0 257 | 258 | 259 | def test_base_class_included_and_imported(): 260 | src = _class_source(Based) 261 | assert "import basemod" in src 262 | assert "class Based(basemod.Base):" in src 263 | assert "base_method" not in src # base class method should not be in emitted stub 264 | 265 | 266 | def test_typevar(): 267 | T = typing.TypeVar("T") 268 | T.__module__ = "source_mod" 269 | 270 | def foo(arg: T) -> T: 271 | return arg 272 | 273 | src = _function_source(foo) 274 | assert "import source_mod" in src 275 | assert "def foo(arg: source_mod.T) -> source_mod.T" in src 276 | 277 | 278 | def test_string_annotation(): 279 | stub_emitter = StubEmitter("dummy") 280 | stub_emitter.add_variable(annotation="Foo", name="some_foo") # string annotation 281 | src = stub_emitter.get_source() 282 | assert 'some_foo: "Foo"' in src or "some_foo: 'Foo'" in src 283 | 284 | 285 | class Forwarder: 286 | def foo(self) -> typing.List["Forwardee"]: ... 287 | 288 | 289 | class Forwardee: ... 290 | 291 | 292 | def test_forward_ref(): 293 | # add in the same order here: 294 | stub = StubEmitter(__name__) 295 | stub.add_class(Forwarder, "Forwarder") 296 | stub.add_class(Forwardee, "Forwardee") 297 | src = stub.get_source() 298 | assert "class Forwarder:" in src 299 | assert ( 300 | "def foo(self) -> typing.List[Forwardee]:" in src 301 | ) # should technically be quoted 'Forwardee', but non-strings seem ok in pure type stubs 302 | 303 | 304 | def test_optional(): 305 | # Not super important, but try to preserve typing.Optional as typing.Optional instead of typing.Union[None, ...] 306 | # This only works on Python 3.10+, since 3.9 and earlier do "eager" conversion when creating the type 307 | def f() -> typing.Optional[str]: ... 308 | 309 | wrapped_f = synchronizer.create_blocking(f, "wrapped_f", __name__) 310 | 311 | src = _function_source(wrapped_f) 312 | if sys.version_info[:2] >= (3, 10): 313 | assert "typing.Optional[str]" in src 314 | else: 315 | assert "typing.Union[str, None]" in src 316 | 317 | 318 | class SelfRefFoo: 319 | def foo(self) -> "SelfRefFoo": 320 | return self 321 | 322 | 323 | def test_self_ref(): 324 | src = _class_source(SelfRefFoo) 325 | assert ( 326 | "def foo(self) -> SelfRefFoo" in src 327 | ) # should technically be 'Foo' but non-strings seem ok in pure type stubs 328 | 329 | 330 | class _Foo: 331 | @staticmethod 332 | async def clone(foo: "_Foo") -> "_Foo": 333 | return foo 334 | 335 | 336 | synchronizer = synchronicity.Synchronizer() 337 | 338 | 339 | @pytest.fixture(autouse=True, scope="module") 340 | def synchronizer_teardown(): 341 | yield 342 | synchronizer._close_loop() # prevent "unclosed event loop" warnings 343 | 344 | 345 | Foo = synchronizer.create_blocking(_Foo, "Foo", __name__) 346 | 347 | 348 | def test_synchronicity_type_translation(): 349 | async def _get_foo(foo: _Foo) -> typing.AsyncContextManager[_Foo]: 350 | return foo 351 | 352 | get_foo = synchronizer.create_blocking(_get_foo, "get_foo", __name__) 353 | src = _function_source(get_foo) 354 | 355 | print(src) 356 | assert "class __get_foo_spec(typing_extensions.Protocol):" in src 357 | assert ( 358 | " def __call__(self, /, foo: Foo) -> synchronicity.combined_types.AsyncAndBlockingContextManager[Foo]" in src 359 | ) 360 | # python 3.13 has an exit type generic argument, e.g. typing.AsyncContextManager[Foo, bool | None] 361 | # but we want the type stubs to work on older versions of python too (without conditionals everywhere): 362 | assert " async def aio(self, /, foo: Foo) -> typing.AsyncContextManager[Foo]" in src 363 | assert "get_foo: __get_foo_spec" 364 | 365 | 366 | def test_synchronicity_wrapped_class(): 367 | src = _class_source(Foo) 368 | print(src) 369 | # assert "__init__" not in Foo 370 | assert "class __clone_spec(typing_extensions.Protocol):" in src 371 | assert " def __call__(self, /, foo: Foo) -> Foo" in src 372 | assert " async def aio(self, /, foo: Foo) -> Foo" in src 373 | assert "clone: __clone_spec" in src 374 | 375 | 376 | class _WithClassMethod: 377 | @classmethod 378 | def classy(cls): ... 379 | 380 | async def meth(self, arg: bool) -> int: 381 | return 0 382 | 383 | 384 | WithClassMethod = synchronizer.create_blocking(_WithClassMethod, "WithClassMethod", __name__) 385 | 386 | 387 | def test_synchronicity_class(): 388 | src = _class_source(WithClassMethod) 389 | assert " @classmethod" in src 390 | assert " def classy(cls):" in src 391 | 392 | assert "__meth_spec" in src 393 | 394 | assert ( 395 | """ 396 | class __meth_spec(typing_extensions.Protocol[SUPERSELF]): 397 | def __call__(self, /, arg: bool) -> int: 398 | ... 399 | 400 | async def aio(self, /, arg: bool) -> int: 401 | ... 402 | 403 | meth: __meth_spec[typing_extensions.Self] 404 | """ 405 | in src 406 | ) 407 | 408 | 409 | T = typing.TypeVar("T") 410 | P = typing_extensions.ParamSpec("P") 411 | 412 | 413 | Translated_T = synchronizer.create_blocking(T, "Translated_T", __name__) 414 | Translated_P = synchronizer.create_blocking(P, "Translated_P", __name__) 415 | 416 | 417 | class MyGeneric(typing.Generic[T]): ... 418 | 419 | 420 | BlockingMyGeneric = synchronizer.create_blocking( 421 | MyGeneric, 422 | "BlockingMyGeneric", 423 | __name__, 424 | ) 425 | 426 | 427 | def test_custom_generic(): 428 | # TODO: build out this test a bit, as it currently creates an invalid stub (missing base types) 429 | src = _class_source(BlockingMyGeneric) 430 | 431 | class Specific(MyGeneric[str]): ... 432 | 433 | src = _class_source(Specific) 434 | assert "class Specific(MyGeneric[str]):" in src 435 | 436 | 437 | class ParamSpecGeneric(typing.Generic[P, T]): 438 | async def meth(self, *args: P.args, **kwargs: P.kwargs) -> typing_extensions.Self: ... 439 | 440 | def syncfunc(self) -> T: ... 441 | 442 | 443 | BlockingParamSpecGeneric = synchronizer.create_blocking(ParamSpecGeneric, "BlockingParamSpecGeneric", __name__) 444 | 445 | 446 | def test_paramspec_generic(): 447 | src = _class_source(BlockingParamSpecGeneric) 448 | assert "class BlockingParamSpecGeneric(typing.Generic[Translated_P, Translated_T])" in src 449 | 450 | assert "class __meth_spec(typing_extensions.Protocol[Translated_P_INNER, SUPERSELF]):" in src 451 | assert ( 452 | "def __call__(self, /, *args: Translated_P_INNER.args, **kwargs: Translated_P_INNER.kwargs) -> SUPERSELF" in src 453 | ) 454 | assert "def aio(self, /, *args: Translated_P_INNER.args, **kwargs: Translated_P_INNER.kwargs) -> SUPERSELF" in src 455 | assert "meth: __meth_spec[Translated_P, typing_extensions.Self]" in src 456 | assert "def syncfunc(self) -> Translated_T:" in src 457 | 458 | 459 | def test_synchronicity_generic_subclass(): 460 | class Specific(MyGeneric[str]): ... 461 | 462 | assert Specific.__bases__ == (MyGeneric,) 463 | assert Specific.__orig_bases__ == (MyGeneric[str],) 464 | 465 | BlockingSpecific = synchronizer.create_blocking(Specific, "BlockingSpecific", __name__) 466 | assert BlockingSpecific.__bases__ == (BlockingMyGeneric,) 467 | assert BlockingSpecific.__orig_bases__ == (BlockingMyGeneric[str],) 468 | 469 | src = _class_source(BlockingSpecific) 470 | assert "class BlockingSpecific(BlockingMyGeneric[str]):" in src 471 | 472 | async def foo_impl(bar: MyGeneric[str]): ... 473 | 474 | foo = synchronizer.create_blocking(foo_impl, "foo") 475 | src = _function_source(foo) 476 | assert "def __call__(self, /, bar: BlockingMyGeneric[str]):" in src 477 | assert "async def aio(self, /, bar: BlockingMyGeneric[str]):" in src 478 | 479 | 480 | _B = typing.TypeVar("_B", bound="str") 481 | 482 | B = synchronizer.create_blocking( 483 | _B, "B", __name__ 484 | ) # only strictly needed if the bound is a synchronicity implementation type 485 | 486 | 487 | def _ident(b: _B) -> _B: 488 | return b 489 | 490 | 491 | ident = synchronizer.create_blocking(_ident, "ident", __name__) 492 | 493 | 494 | def test_translated_bound_type_vars(): 495 | emitter = StubEmitter(__name__) 496 | emitter.add_type_var(B, "B") 497 | emitter.add_function(ident, "ident") 498 | src = emitter.get_source() 499 | assert 'B = typing.TypeVar("B", bound="str")' in src 500 | assert "def ident(b: B) -> B" in src 501 | 502 | 503 | def test_literal_alias(tmp_path): 504 | contents = dedent( 505 | """ 506 | import typing 507 | from typing import Literal 508 | foo = typing.Literal["foo"] 509 | bar = Literal["bar"] 510 | """ 511 | ) 512 | with open(fname := (tmp_path / "foo.py"), "w") as f: 513 | f.write(contents) 514 | 515 | spec = importlib.util.spec_from_file_location("foo", fname) 516 | mod = importlib.util.module_from_spec(spec) 517 | spec.loader.exec_module(mod) 518 | 519 | emitter = StubEmitter.from_module(mod) 520 | src = emitter.get_source() 521 | assert "foo = typing.Literal['foo']" in src 522 | assert "bar = typing.Literal['bar']" in src 523 | 524 | 525 | def test_callable(): 526 | def foo() -> collections.abc.Callable[[str], float]: 527 | return lambda x: 0.0 528 | 529 | src = _function_source(foo) 530 | assert "-> collections.abc.Callable[[str], float]" in src 531 | 532 | 533 | def test_ellipsis(): 534 | def foo() -> collections.abc.Callable[..., typing.Any]: 535 | return lambda x: 0 536 | 537 | src = _function_source(foo) 538 | assert "-> collections.abc.Callable[..., typing.Any]" in src 539 | 540 | 541 | def test_param_spec(): 542 | P = typing_extensions.ParamSpec("P") 543 | 544 | def foo() -> collections.abc.Callable[P, typing.Any]: 545 | return lambda x: 0 546 | 547 | src = _function_source(foo) 548 | assert "-> collections.abc.Callable[P, typing.Any]" in src 549 | 550 | 551 | def test_typing_literal(): 552 | def foo() -> typing.Literal["three", "str"]: 553 | return "str" 554 | 555 | src = _function_source(foo) 556 | assert "-> typing.Literal['three', 'str']" in src # "str" should not be eval:ed in a Literal! 557 | 558 | 559 | def test_overloads_unwrapped_functions(): 560 | with overload_tracking.patched_overload(): 561 | 562 | @typing.overload 563 | def _overloaded(arg: str) -> float: ... 564 | 565 | @typing.overload 566 | def _overloaded(arg: int) -> int: ... 567 | 568 | def _overloaded(arg: typing.Union[str, int]): 569 | if isinstance(arg, str): 570 | return float(arg) 571 | return arg 572 | 573 | overloaded = synchronizer.create_blocking(_overloaded, "overloaded") 574 | 575 | src = _function_source(overloaded) 576 | assert src.count("@typing.overload") == 2 577 | assert src.count("def overloaded") == 2 # original should be omitted 578 | assert "def overloaded(arg: str) -> float" in src 579 | assert "def overloaded(arg: int) -> int:" in src 580 | 581 | 582 | def test_wrapped_context_manager_is_both_blocking_and_async(): 583 | @asynccontextmanager 584 | async def foo(arg: int) -> typing.AsyncGenerator[str, None]: 585 | yield "hello" 586 | 587 | wrapped_foo = synchronizer.create_blocking(foo, name="wrapped_foo") 588 | assert wrapped_foo.__annotations__["return"] == typing.AsyncContextManager[str] 589 | wrapped_foo_src = _function_source(wrapped_foo) 590 | 591 | assert ( 592 | "def __call__(self, /, arg: int) -> synchronicity.combined_types.AsyncAndBlockingContextManager[str]:" 593 | in wrapped_foo_src 594 | ) 595 | assert "AbstractAsyncContextManager" not in wrapped_foo_src 596 | 597 | 598 | @pytest.mark.skipif(sys.version_info < (3, 9), reason="collections.abc.Iterator isn't a generic type before Python 3.9") 599 | def test_collections_iterator(): 600 | def foo() -> collections.abc.Iterator[int]: 601 | class MyIterator(collections.abc.Iterator): 602 | def __iter__(self) -> collections.abc.Iterator[int]: 603 | return self 604 | 605 | def __next__(self) -> int: 606 | return 1 607 | 608 | return MyIterator() 609 | 610 | src = _function_source(foo) 611 | assert "-> collections.abc.Iterator[int]" in src 612 | 613 | 614 | U = typing.TypeVar("U") 615 | 616 | 617 | class _ReturnVal(typing.Generic[U]): 618 | pass 619 | 620 | 621 | ReturnVal = synchronizer.create_blocking(_ReturnVal, "ReturnVal", __name__) 622 | 623 | 624 | def test_returns_forward_wrapped_generic(): 625 | # forward reference of a wrapped generic as a string is one of the trickier cases to handle 626 | # as the string needs to be evaluated, the generics need to be recursively expanded and 627 | # type vars need to be replaced with "inner" generated versions 628 | 629 | class _Container(typing.Generic[T]): 630 | async def fun(self) -> "ReturnVal[T]": 631 | return ReturnVal() 632 | 633 | Container = synchronizer.create_blocking(_Container, "Container") 634 | 635 | src = _class_source(Container) 636 | 637 | # base class should be generic in the (potentially) translated type var (could have wrapped bounds spec) 638 | assert "class Container(typing.Generic[Translated_T]):" in src 639 | assert "Translated_T_INNER = typing.TypeVar" in src # distinct "inner copy" of Translated_T needs to be declared 640 | assert "typing_extensions.Protocol[Translated_T_INNER, SUPERSELF]" in src 641 | assert "def __call__(self, /) -> ReturnVal[Translated_T_INNER]:" in src 642 | assert "fun: __fun_spec[Translated_T, typing_extensions.Self]" in src 643 | 644 | 645 | def custom_field(): # needs to be in global scope 646 | pass 647 | 648 | 649 | def test_dataclass_transform(): 650 | @typing_extensions.dataclass_transform(field_specifiers=(custom_field,), kw_only_default=True) 651 | def decorator(): 652 | pass 653 | 654 | src = _function_source(decorator) 655 | assert "@typing_extensions.dataclass_transform(field_specifiers=(custom_field, ), kw_only_default=True, )\n" in src 656 | 657 | src = _function_source(decorator, target_module="other_module") 658 | assert "import test.type_stub_test" in src 659 | assert "import typing_extensions" in src 660 | assert "field_specifiers=(test.type_stub_test.custom_field, )" in src 661 | 662 | 663 | def test_contextvar(): 664 | import contextvars 665 | 666 | s = StubEmitter("blah") 667 | s.add_variable(contextvars.ContextVar, "c") 668 | src = s.get_source() 669 | assert "import contextvars" in src 670 | assert "c: contextvars.ContextVar" in src 671 | 672 | 673 | @pytest.mark.skipif( 674 | sys.version_info < (3, 10), 675 | reason="collections.abc.Callable strips Concatenate wrappers at runtime before Python 3.10 :(", 676 | ) 677 | def test_concatenate_origin_module(): 678 | s = StubEmitter(__name__) 679 | P = typing_extensions.ParamSpec("P") 680 | R = typing.TypeVar("R") 681 | s.add_variable(collections.abc.Callable[typing_extensions.Concatenate[typing.Any, P], R], "f") 682 | src = s.get_source() 683 | print(src) 684 | assert "f: collections.abc.Callable[typing_extensions.Concatenate[typing.Any, P], R]" in src 685 | 686 | 687 | def test_paramspec_args(): 688 | from .type_stub_helpers.some_mod import P 689 | 690 | def foo(fn: collections.abc.Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> str: 691 | return "Hello World!" 692 | 693 | src = _function_source(foo) 694 | assert "import test.type_stub_helpers.some_mod" in src 695 | assert ( 696 | "def foo(fn: collections.abc.Callable[test.type_stub_helpers.some_mod.P, None], *args: test.type_stub_helpers.some_mod.P.args, **kwargs: test.type_stub_helpers.some_mod.P.kwargs) -> str:" # noqa 697 | in src 698 | ) # noqa: E501 699 | 700 | 701 | if typing.TYPE_CHECKING: 702 | import _typeshed 703 | 704 | 705 | def test_typeshed(): 706 | """Test that _typeshed annotations are preserved in stubs.""" 707 | 708 | def foo() -> "_typeshed.OpenTextMode": 709 | return "r" 710 | 711 | src = _function_source(foo) 712 | assert "import _typeshed" in src 713 | assert "def foo() -> _typeshed.OpenTextMode:" in src 714 | 715 | 716 | def test_positional_only_wrapped_function(synchronizer): 717 | @synchronizer.wrap 718 | async def f(pos_only=None, /, **kwargs): ... 719 | 720 | # The following used to crash because the injected `self` in the generated Protocol 721 | # didn't use the positional-only qualifier 722 | src = _function_source(f) 723 | assert "def __call__(self, pos_only=None, /, **kwargs):" in src 724 | 725 | 726 | def test_docstrings(): 727 | src = _function_source(single_line_docstring_func) 728 | assert ' """I have a single line docstring"""' in src 729 | 730 | src = _function_source(multi_line_docstring_func) 731 | assert ' """I have a docstring\n\n with multiple lines\n """\n' in src 732 | 733 | src = _function_source(nested_docstring_func) 734 | assert "'''I have a docstring" in src 735 | assert '"""SUPRISE! SO DO I!"""' in src 736 | 737 | src = _class_source(SingleLineDocstringClass) 738 | assert ' """I have a single line docstring"""\n' in src 739 | 740 | src = _class_source(MultiLineDocstringClass) 741 | assert ' """I have a docstring\n\n with multiple lines\n """\n' in src 742 | 743 | src = _class_source(ClassWithMethodsWithDocstrings) 744 | assert ' """I have a docstring"""\n' in src 745 | assert ' """I have a docstring\n\n with multiple lines\n """\n' in src 746 | 747 | with pytest.warns(UserWarning, match="both \"\"\" and ''' quote blocks"): 748 | src = _function_source(deranged_docstring_func) 749 | assert '"""' not in src 750 | 751 | 752 | def test_pathlib(): 753 | def test_path() -> pathlib.Path: ... 754 | 755 | src = _function_source(test_path) 756 | assert "import pathlib\n" in src 757 | assert "pathlib.Path" in src 758 | -------------------------------------------------------------------------------- /test/type_stub_translation_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import typing 3 | 4 | from synchronicity import Synchronizer, combined_types 5 | from synchronicity.interface import Interface 6 | from synchronicity.type_stubs import StubEmitter 7 | 8 | 9 | class ImplType: 10 | attr: str 11 | 12 | 13 | synchronizer = Synchronizer() 14 | 15 | 16 | @pytest.fixture(autouse=True, scope="module") 17 | def synchronizer_teardown(): 18 | yield 19 | synchronizer._close_loop() # prevent "unclosed event loop" warnings 20 | 21 | 22 | BlockingType = synchronizer.create_blocking(ImplType, "BlockingType", __name__) 23 | 24 | 25 | def test_wrapped_class_keeps_class_annotations(): 26 | assert BlockingType.__annotations__ == ImplType.__annotations__ 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "t,interface,expected", 31 | [ 32 | ( 33 | typing.AsyncGenerator[int, str], 34 | Interface.BLOCKING, 35 | typing.Generator[int, str, None], 36 | ), 37 | ( 38 | typing.AsyncContextManager[ImplType], 39 | Interface.BLOCKING, 40 | combined_types.AsyncAndBlockingContextManager[BlockingType], 41 | ), 42 | ( 43 | typing.AsyncContextManager[ImplType], 44 | Interface._ASYNC_WITH_BLOCKING_TYPES, 45 | typing.AsyncContextManager[BlockingType], 46 | ), 47 | ( 48 | typing.Awaitable[typing.Awaitable[str]], 49 | Interface._ASYNC_WITH_BLOCKING_TYPES, 50 | typing.Awaitable[typing.Awaitable[str]], 51 | ), 52 | (typing.Awaitable[typing.Awaitable[str]], Interface.BLOCKING, str), 53 | (typing.Coroutine[None, None, str], Interface.BLOCKING, str), 54 | (typing.AsyncIterable[str], Interface.BLOCKING, typing.Iterable[str]), 55 | (typing.AsyncIterator[str], Interface.BLOCKING, typing.Iterator[str]), 56 | ( 57 | typing.Optional[ImplType], 58 | Interface.BLOCKING, 59 | typing.Union[BlockingType, None], 60 | ), 61 | (typing.Optional[ImplType], Interface._ASYNC_WITH_BLOCKING_TYPES, typing.Union[BlockingType, None]), 62 | ], 63 | ) 64 | def test_annotation_mapping(t, interface, expected): 65 | stub_emitter = StubEmitter(__name__) 66 | assert stub_emitter._translate_annotation(t, synchronizer, interface, __name__) == expected 67 | -------------------------------------------------------------------------------- /test/warnings_test.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from synchronicity import Synchronizer 4 | 5 | 6 | def f(x): 7 | return x**2 8 | 9 | 10 | def test_multiwrap_warning(recwarn): 11 | s = Synchronizer(multiwrap_warning=True) 12 | try: 13 | f_s = s.create_blocking(f) 14 | assert f_s(42) == 1764 15 | assert len(recwarn) == 0 16 | f_s_s = s.create_blocking(f_s) 17 | assert f_s_s(42) == 1764 18 | assert len(recwarn) == 1 19 | finally: 20 | s._close_loop() # clean up 21 | 22 | 23 | def test_multiwrap_no_warning(recwarn, synchronizer): 24 | f_s = synchronizer.create_blocking(f) 25 | assert f_s(42) == 1764 26 | f_s_s = synchronizer.create_blocking(f_s) 27 | assert f_s_s(42) == 1764 28 | print("Recorded warnings 1:") 29 | for w in recwarn.list: 30 | print(str(w)) 31 | assert len(recwarn) == 0 32 | 33 | 34 | async def asyncgen(): 35 | yield 42 36 | 37 | 38 | async def returns_asyncgen(): 39 | return asyncgen() 40 | 41 | 42 | def test_check_double_wrapped(recwarn, synchronizer): 43 | assert len(recwarn) == 0 44 | ret = synchronizer.create_blocking(returns_asyncgen)() 45 | assert inspect.isasyncgen(ret) 46 | for w in recwarn.list: 47 | print("Recorded warning 2:", w) 48 | 49 | assert len(recwarn) == 1 50 | --------------------------------------------------------------------------------