├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── api └── transform.py ├── index.html ├── mypy.ini ├── pyproject.toml ├── requirements.txt ├── script.js ├── setup.cfg ├── setup.py ├── style.css ├── tests ├── __init__.py ├── collector.py ├── exception_cases │ ├── coroutine_gen_task.py │ ├── yield_dict_call.py │ ├── yield_dict_comprehension.py │ └── yield_dict_literal.py ├── test_cases │ ├── coroutine_multiple_generators │ │ ├── after.py │ │ └── before.py │ ├── coroutine_with_nested_generator_function │ │ ├── after.py │ │ └── before.py │ ├── gen_return_call_no_args │ │ ├── after.py │ │ └── before.py │ ├── gen_return_call_none │ │ ├── after.py │ │ └── before.py │ ├── gen_return_dict_multi_line │ │ ├── after.py │ │ └── before.py │ ├── gen_return_statement │ │ ├── after.py │ │ └── before.py │ ├── gen_sleep │ │ ├── after.py │ │ └── before.py │ ├── gen_test │ │ ├── after.py │ │ └── before.py │ ├── module_level_raise │ │ ├── after.py │ │ └── before.py │ ├── nested_coroutine │ │ ├── after.py │ │ └── before.py │ ├── non_coroutine_returns_coroutine │ │ ├── after.py │ │ └── before.py │ ├── simple_coroutine_from_tornado_import_gen │ │ ├── after.py │ │ └── before.py │ ├── simple_coroutine_import_tornado │ │ ├── after.py │ │ └── before.py │ ├── testing_gen_test │ │ ├── after.py │ │ └── before.py │ ├── tornado_testing_gen_test │ │ ├── after.py │ │ └── before.py │ ├── tornado_testing_gen_test_already_async │ │ ├── after.py │ │ └── before.py │ ├── yield_list_comprehension │ │ ├── after.py │ │ └── before.py │ └── yield_list_of_futures │ │ ├── after.py │ │ └── before.py └── test_tornado_async_transformer.py └── tornado_async_transformer ├── __init__.py ├── helpers.py ├── py.typed ├── tool.py └── tornado_async_transformer.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push] 4 | 5 | jobs: 6 | 7 | test: 8 | runs-on: ubuntu-latest 9 | 10 | strategy: 11 | matrix: 12 | python-version: [ '3.6', '3.7' ] 13 | 14 | name: pytest (${{ matrix.python-version }}) 15 | steps: 16 | - uses: actions/checkout@v1 17 | - uses: actions/setup-python@v1 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - run: pip3 install -r requirements.txt 21 | - run: pytest -vv 22 | 23 | lint: 24 | 25 | runs-on: ubuntu-latest 26 | 27 | steps: 28 | - uses: actions/checkout@v1 29 | - uses: actions/setup-python@v1 30 | with: 31 | python-version: '3.7' 32 | - run: pip3 install -r requirements.txt 33 | - run: black --check . 34 | # skipping until i can sort out some issues, example: 35 | # error: Signature of "leave_Module" incompatible with supertype "CSTTypedTransformerFunctions" 36 | # - run: mypy tornado_async_transformer 37 | 38 | deploy: 39 | 40 | if: github.ref == 'refs/heads/master' 41 | runs-on: ubuntu-latest 42 | needs: 43 | - lint 44 | - test 45 | 46 | steps: 47 | - uses: actions/checkout@v1 48 | - uses: actions/setup-python@v1 49 | with: 50 | python-version: '3.7' 51 | - run: pip3 install -r requirements.txt 52 | - name: bundle 53 | run: python setup.py sdist bdist_wheel 54 | - name: deploy to pypi 55 | env: 56 | TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} 57 | TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} 58 | run: twine upload --skip-existing dist/* 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # VS Code config 127 | .vscode/ 128 | 129 | # Zeit now 130 | .now/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, SeatGeek 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tornado Async Transformer 2 | 3 | ![](https://github.com/zhammer/tornado-async-transformer/workflows/CI/badge.svg) 4 | 5 | A [libcst](https://github.com/Instagram/LibCST) transformer for updating tornado @gen.coroutine syntax to python3.5+ native async/await. 6 | 7 | [Check out the demo.](https://tornado-async-transformer.zhammer.now.sh/) 8 | 9 | ### Usage 10 | You can either: 11 | - Add `tornado_async_transformer.TornadoAsyncTransformer` to your existing libcst codemod. 12 | - Or run `python -m tornado_async_transformer.tool my_project/` from the commandline. 13 | 14 | #### Example 15 | ```diff 16 | """ 17 | A simple coroutine. 18 | """ 19 | from tornado import gen 20 | 21 | 22 | -@gen.coroutine 23 | -def call_api(): 24 | - response = yield fetch() 25 | +async def call_api(): 26 | + response = await fetch() 27 | if response.status != 200: 28 | raise BadStatusError() 29 | - raise gen.Return(response.data) 30 | + return response.data 31 | ``` 32 | -------------------------------------------------------------------------------- /api/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the REST api route for the zeit now demo site, as I 3 | haven't been able to figure out how to nest in within the demo_site/ 4 | directory. 5 | """ 6 | 7 | from http.server import BaseHTTPRequestHandler 8 | from tornado_async_transformer import TornadoAsyncTransformer, TransformError 9 | import json 10 | import libcst 11 | 12 | 13 | def transform(source: str) -> str: 14 | source_tree = libcst.parse_module(source) 15 | visited_tree = source_tree.visit(TornadoAsyncTransformer()) 16 | return visited_tree.code 17 | 18 | 19 | class handler(BaseHTTPRequestHandler): 20 | def do_POST(self): 21 | request_raw = self.rfile.read(int(self.headers.get("Content-Length"))).decode() 22 | request_body = json.loads(request_raw) 23 | source = request_body["source"] 24 | 25 | try: 26 | transformed = transform(source) 27 | except Exception as e: 28 | transformed = repr(e) 29 | 30 | self.send_response(200) 31 | self.send_header("Content-type", "application/json") 32 | self.end_headers() 33 | self.wfile.write(json.dumps({"source": transformed}).encode()) 34 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Tornado Async Transformer 5 | 6 | 7 | 8 |
9 |
10 |

11 | Tornado Async Transformer - 12 | github 15 |

16 |
17 |
18 |
19 |
20 |
21 |
22 | 23 | 24 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.7 3 | check_untyped_defs = True 4 | disallow_incomplete_defs = True 5 | disallow_untyped_calls = True 6 | disallow_untyped_decorators = True 7 | disallow_untyped_defs = True 8 | ignore_missing_imports = True 9 | no_implicit_optional = True 10 | warn_redundant_casts = True 11 | warn_return_any = True 12 | warn_unused_configs = True 13 | warn_unused_ignores = True 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # unfortunately 'black' doesn't support setup.cfg, and it's important 2 | # not to auto format test_cases/ files for testing codemods preserving 3 | # whitespace, indentation, etc. 4 | [tool.black] 5 | exclude = "(test_cases|.now)" 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | libcst==0.2.4 2 | black==19.10b0 3 | mypy==0.750 4 | pytest==5.3.1 5 | twine==3.1.1 6 | wheel==0.33.6 7 | 8 | # for zeit site, until python3.7 is supported 9 | dataclasses==0.6; python_version < '3.6' 10 | -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | // globals ace 2 | 3 | const editorSource = ace.edit("editor-source"); 4 | editorSource.setTheme("ace/theme/tomorrow_night"); 5 | editorSource.session.setMode("ace/mode/python"); 6 | 7 | const editorTransformed = ace.edit("editor-transformed"); 8 | editorTransformed.setReadOnly(true); 9 | editorTransformed.setTheme("ace/theme/tomorrow_night"); 10 | editorTransformed.session.setMode("ace/mode/python"); 11 | 12 | async function fetchTransformedSource(source) { 13 | const response = await fetch("/api/transform", { 14 | method: "POST", 15 | body: JSON.stringify({ source }) 16 | }); 17 | const body = await response.json(); 18 | return body.source; 19 | } 20 | 21 | let calls = 0; 22 | async function onEditorUpdated() { 23 | const call = ++calls; 24 | const transformed = await fetchTransformedSource(editorSource.getValue()); 25 | if (call === calls) { 26 | editorTransformed.setValue(transformed); 27 | editorSource.clearSelection(); 28 | editorTransformed.clearSelection(); 29 | } 30 | } 31 | 32 | editorSource.setValue(`""" 33 | A simple coroutine. 34 | """ 35 | from tornado import gen 36 | 37 | @gen.coroutine 38 | def call_api(): 39 | response = yield fetch() 40 | if response.status != 200: 41 | raise BadStatusError() 42 | raise gen.Return(response.data) 43 | `); 44 | onEditorUpdated(); 45 | editorSource.on("change", onEditorUpdated); 46 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | addopts = tests tornado_async_transformer --ignore tests/test_cases --ignore tests/exception_cases --doctest-modules 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from distutils.core import setup 3 | 4 | with open("README.md", "r") as readme: 5 | long_description = readme.read() 6 | 7 | setup( 8 | name="tornado-async-transformer", 9 | version="0.2.0", 10 | description="libcst transformer and codemod for updating tornado @gen.coroutine syntax to python3.5+ native async/await", 11 | url="https://github.com/zhammer/tornado-async-transformer", 12 | packages=find_packages(exclude=["tests", "demo_site"]), 13 | package_data={"tornado_async_transformer": ["py.typed"]}, 14 | install_requires=["libcst == 0.2.4"], 15 | author="Zach Hammer", 16 | author_email="zachary_hammer@alumni.brown.edu", 17 | license="MIT License", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | classifiers=[ 21 | "License :: OSI Approved :: MIT License", 22 | "Topic :: Software Development :: Libraries", 23 | "Programming Language :: Python :: 3.6", 24 | "Programming Language :: Python :: 3.7", 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: fantasy; 3 | margin: 0; 4 | padding: 0; 5 | } 6 | 7 | .header { 8 | padding: 1em; 9 | } 10 | 11 | h1 { 12 | margin: 0; 13 | } 14 | 15 | .editor { 16 | display: block; 17 | height: 70vh; 18 | width: 50vw; 19 | } 20 | 21 | .editors { 22 | display: flex; 23 | justify-content: space-between; 24 | } 25 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seatgeek/tornado-async-transformer/c9353151bd1ffb9e532f68d929cf68e68799eb41/tests/__init__.py -------------------------------------------------------------------------------- /tests/collector.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | from typing import Any, List, NamedTuple, Tuple 4 | 5 | import pytest 6 | 7 | 8 | class TestCase(NamedTuple): 9 | before: str 10 | after: str 11 | 12 | 13 | def collect_test_cases() -> Tuple[Any, ...]: 14 | root_test_cases_directory = os.path.join(os.path.dirname(__file__), "test_cases") 15 | 16 | test_cases: List = [] 17 | for root, _, files in os.walk(root_test_cases_directory): 18 | if not {"before.py", "after.py"} <= set(files): 19 | continue 20 | 21 | test_case_name = os.path.basename(root).replace("_", " ") 22 | 23 | with open(os.path.join(root, "before.py")) as before_file: 24 | before = before_file.read() 25 | 26 | with open(os.path.join(root, "after.py")) as after_file: 27 | after = after_file.read() 28 | 29 | test_cases.append( 30 | pytest.param(TestCase(before=before, after=after), id=test_case_name) 31 | ) 32 | 33 | return tuple(test_cases) 34 | 35 | 36 | class ExceptionCase(NamedTuple): 37 | source: str 38 | expected_error_message: str 39 | 40 | 41 | def collect_exception_cases() -> Tuple[Any, ...]: 42 | root_exception_cases_directory = os.path.join( 43 | os.path.dirname(__file__), "exception_cases" 44 | ) 45 | 46 | # all .py files in the top-levl of exception cases directory 47 | python_files = [ 48 | os.path.join(root_exception_cases_directory, file) 49 | for file in os.listdir(root_exception_cases_directory) 50 | if file[-3:] == ".py" 51 | ] 52 | 53 | exception_cases: List = [] 54 | for python_filename in python_files: 55 | with open(python_filename) as python_file: 56 | source = python_file.read() 57 | 58 | # the module's docstring is the expected error message 59 | ast_tree = ast.parse(source) 60 | docstring = ast.get_docstring(ast_tree) 61 | expected_error_message = docstring.replace("\n", " ") 62 | 63 | test_case_name = os.path.basename(python_filename).replace("_", " ") 64 | 65 | exception_cases.append( 66 | pytest.param( 67 | ExceptionCase( 68 | source=source, expected_error_message=expected_error_message 69 | ), 70 | id=test_case_name, 71 | ) 72 | ) 73 | 74 | return tuple(exception_cases) 75 | -------------------------------------------------------------------------------- /tests/exception_cases/coroutine_gen_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | gen.Task (https://www.tornadoweb.org/en/branch2.4/gen.html#tornado.gen.Task) 3 | from tornado 2.4.1 is unsupported by this codemod. This file has not been modified. 4 | Manually update to supported syntax before running again. 5 | """ 6 | import time 7 | from tornado import gen 8 | from tornado.ioloop import IOLoop 9 | 10 | 11 | @gen.coroutine 12 | def ping(): 13 | yield gen.Task(IOLoop.instance().add_timeout, time.time() + 1.5) 14 | raise gen.Return("pong") 15 | -------------------------------------------------------------------------------- /tests/exception_cases/yield_dict_call.py: -------------------------------------------------------------------------------- 1 | """ 2 | Yielding a dict of futures 3 | (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) 4 | added in tornado 3.2 is unsupported by the codemod. This file has not been 5 | modified. Manually update to supported syntax before running again. 6 | """ 7 | from tornado import gen 8 | 9 | 10 | @gen.coroutine 11 | def get_user_friends_and_relatives(user_id): 12 | users = yield dict( 13 | friends=fetch("/friends", user_id), relatives=fetch("/relatives", user_id) 14 | ) 15 | raise gen.Return(users) 16 | -------------------------------------------------------------------------------- /tests/exception_cases/yield_dict_comprehension.py: -------------------------------------------------------------------------------- 1 | """ 2 | Yielding a dict of futures 3 | (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) 4 | added in tornado 3.2 is unsupported by the codemod. This file has not been 5 | modified. Manually update to supported syntax before running again. 6 | """ 7 | from tornado import gen 8 | 9 | 10 | @gen.coroutine 11 | def get_users_by_id(user_ids): 12 | users = yield {user_id: fetch(user_ids) for user_id in user_ids} 13 | raise gen.Return(users) 14 | -------------------------------------------------------------------------------- /tests/exception_cases/yield_dict_literal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Yielding a dict of futures 3 | (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) 4 | added in tornado 3.2 is unsupported by the codemod. This file has not been 5 | modified. Manually update to supported syntax before running again. 6 | """ 7 | from tornado import gen 8 | 9 | 10 | @gen.coroutine 11 | def get_two_users_by_id(user_id_1, user_id_2): 12 | users = yield {user_id_1: fetch(user_id_1), user_id_2: fetch(user_id_2)} 13 | raise gen.Return(users) 14 | -------------------------------------------------------------------------------- /tests/test_cases/coroutine_multiple_generators/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A few coroutines that have multiple decorators in addition to @gen.coroutine. 3 | """ 4 | from tornado import gen 5 | from util.decorators import route, deprecated, log_args 6 | 7 | @route("/user/:id") 8 | async def user_page(id): 9 | user = await get_user(id) 10 | user_page = "
{}
".format(user.name) 11 | return user_page 12 | 13 | 14 | @deprecated 15 | @log_args 16 | async def get_user(id): 17 | response = await fetch(id) 18 | return response.user 19 | -------------------------------------------------------------------------------- /tests/test_cases/coroutine_multiple_generators/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A few coroutines that have multiple decorators in addition to @gen.coroutine. 3 | """ 4 | from tornado import gen 5 | from util.decorators import route, deprecated, log_args 6 | 7 | @gen.coroutine 8 | @route("/user/:id") 9 | def user_page(id): 10 | user = yield get_user(id) 11 | user_page = "
{}
".format(user.name) 12 | raise gen.Return(user_page) 13 | 14 | 15 | @deprecated 16 | @gen.coroutine 17 | @log_args 18 | def get_user(id): 19 | response = yield fetch(id) 20 | raise gen.Return(response.user) 21 | -------------------------------------------------------------------------------- /tests/test_cases/coroutine_with_nested_generator_function/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine with a nested, non-coroutine generator function. 3 | """ 4 | from typing import List 5 | 6 | from tornado import gen 7 | 8 | 9 | async def save_users(users): 10 | def build_user_ids(users): 11 | for user in users: 12 | yield "{}-{}".format(user.first_name, user.last_name) 13 | 14 | for user_id in build_user_ids(users): 15 | await fetch("POST", user_id) 16 | -------------------------------------------------------------------------------- /tests/test_cases/coroutine_with_nested_generator_function/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine with a nested, non-coroutine generator function. 3 | """ 4 | from typing import List 5 | 6 | from tornado import gen 7 | 8 | 9 | @gen.coroutine 10 | def save_users(users): 11 | def build_user_ids(users): 12 | for user in users: 13 | yield "{}-{}".format(user.first_name, user.last_name) 14 | 15 | for user_id in build_user_ids(users): 16 | yield fetch("POST", user_id) 17 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_call_no_args/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that raises gen.Return() with no args. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | async def check_id_valid(id: str): 8 | response = await fetch(id) 9 | if response.status != 200: 10 | raise InvalidID() 11 | 12 | return 13 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_call_no_args/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that raises gen.Return() with no args. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | @gen.coroutine 8 | def check_id_valid(id: str): 9 | response = yield fetch(id) 10 | if response.status != 200: 11 | raise InvalidID() 12 | 13 | raise gen.Return() 14 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_call_none/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that raises gen.Return(None). 3 | """ 4 | from tornado import gen 5 | 6 | 7 | async def check_id_valid(id: str): 8 | response = await fetch(id) 9 | if response.status != 200: 10 | raise InvalidID() 11 | 12 | return None 13 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_call_none/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that raises gen.Return(None). 3 | """ 4 | from tornado import gen 5 | 6 | 7 | @gen.coroutine 8 | def check_id_valid(id: str): 9 | response = yield fetch(id) 10 | if response.status != 200: 11 | raise InvalidID() 12 | 13 | raise gen.Return(None) 14 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_dict_multi_line/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that returns a dict that spanning multiple lines. 3 | """ 4 | from tornado import gen 5 | 6 | async def fetch_user(id): 7 | response = await fetch(id) 8 | return { 9 | 'user': response.user, 10 | 'source': 'user-api' 11 | } 12 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_dict_multi_line/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that returns a dict that spanning multiple lines. 3 | """ 4 | from tornado import gen 5 | 6 | @gen.coroutine 7 | def fetch_user(id): 8 | response = yield fetch(id) 9 | raise gen.Return({ 10 | 'user': response.user, 11 | 'source': 'user-api' 12 | }) 13 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_statement/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that raises gen.Return directly instead of gen.Return(...). 3 | """ 4 | from tornado import gen 5 | 6 | 7 | async def check_id_valid(id: str): 8 | response = await fetch(id) 9 | if response.status != 200: 10 | raise InvalidID() 11 | 12 | return 13 | -------------------------------------------------------------------------------- /tests/test_cases/gen_return_statement/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that raises gen.Return directly instead of gen.Return(...). 3 | """ 4 | from tornado import gen 5 | 6 | 7 | @gen.coroutine 8 | def check_id_valid(id: str): 9 | response = yield fetch(id) 10 | if response.status != 200: 11 | raise InvalidID() 12 | 13 | raise gen.Return 14 | -------------------------------------------------------------------------------- /tests/test_cases/gen_sleep/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that calls gen.sleep. 3 | """ 4 | from tornado import gen 5 | import asyncio 6 | 7 | 8 | async def ping(): 9 | await asyncio.sleep(10) 10 | return "pong" 11 | -------------------------------------------------------------------------------- /tests/test_cases/gen_sleep/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that calls gen.sleep. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | @gen.coroutine 8 | def ping(): 9 | yield gen.sleep(10) 10 | raise gen.Return("pong") 11 | -------------------------------------------------------------------------------- /tests/test_cases/gen_test/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @gen_test decorator. 3 | """ 4 | 5 | from tornado.testing import gen_test, AsyncHTTPTestCase 6 | 7 | from my_app import make_application 8 | 9 | 10 | class TestMyTornadoApp(AsyncHTTPTestCase): 11 | def get_app(self): 12 | return make_application() 13 | 14 | @gen_test 15 | async def test_ping_route(self): 16 | response = await self.fetch("/ping") 17 | assert response.body == b"pong" 18 | -------------------------------------------------------------------------------- /tests/test_cases/gen_test/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @gen_test decorator. 3 | """ 4 | 5 | from tornado.testing import gen_test, AsyncHTTPTestCase 6 | 7 | from my_app import make_application 8 | 9 | 10 | class TestMyTornadoApp(AsyncHTTPTestCase): 11 | def get_app(self): 12 | return make_application() 13 | 14 | @gen_test 15 | def test_ping_route(self): 16 | response = yield self.fetch("/ping") 17 | assert response.body == b"pong" 18 | -------------------------------------------------------------------------------- /tests/test_cases/module_level_raise/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module that raises an exception outside of a function. 3 | """ 4 | raise NotImplementedError("This module isn't ready!") 5 | -------------------------------------------------------------------------------- /tests/test_cases/module_level_raise/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module that raises an exception outside of a function. 3 | """ 4 | raise NotImplementedError("This module isn't ready!") 5 | -------------------------------------------------------------------------------- /tests/test_cases/nested_coroutine/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple coroutine that contains a nested coroutine. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | async def call_api(): 8 | async def nested_callback(response): 9 | if response.status != 200: 10 | return response 11 | 12 | body = await response.json() 13 | if body["api-update-available"]: 14 | print("note: update api") 15 | return response 16 | 17 | response = await fetch(middleware=nested_callback) 18 | if response.status != 200: 19 | raise BadStatusError() 20 | return response.data 21 | -------------------------------------------------------------------------------- /tests/test_cases/nested_coroutine/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple coroutine that contains a nested coroutine. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | @gen.coroutine 8 | def call_api(): 9 | @gen.coroutine 10 | def nested_callback(response): 11 | if response.status != 200: 12 | raise gen.Return(response) 13 | 14 | body = yield response.json() 15 | if body["api-update-available"]: 16 | print("note: update api") 17 | raise gen.Return(response) 18 | 19 | response = yield fetch(middleware=nested_callback) 20 | if response.status != 200: 21 | raise BadStatusError() 22 | raise gen.Return(response.data) 23 | -------------------------------------------------------------------------------- /tests/test_cases/non_coroutine_returns_coroutine/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A non-coroutine function that returns a coroutine. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | def make_simple_fetch(url: str): 8 | async def my_simple_fetch(body): 9 | response = await fetch(url, body) 10 | return response.body 11 | 12 | return my_simple_fetch 13 | 14 | -------------------------------------------------------------------------------- /tests/test_cases/non_coroutine_returns_coroutine/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A non-coroutine function that returns a coroutine. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | def make_simple_fetch(url: str): 8 | @gen.coroutine 9 | def my_simple_fetch(body): 10 | response = yield fetch(url, body) 11 | raise gen.Return(response.body) 12 | 13 | return my_simple_fetch 14 | 15 | -------------------------------------------------------------------------------- /tests/test_cases/simple_coroutine_from_tornado_import_gen/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple coroutine. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | async def call_api(): 8 | response = await fetch() 9 | if response.status != 200: 10 | raise BadStatusError() 11 | return response.data 12 | -------------------------------------------------------------------------------- /tests/test_cases/simple_coroutine_from_tornado_import_gen/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple coroutine. 3 | """ 4 | from tornado import gen 5 | 6 | 7 | @gen.coroutine 8 | def call_api(): 9 | response = yield fetch() 10 | if response.status != 200: 11 | raise BadStatusError() 12 | raise gen.Return(response.data) 13 | -------------------------------------------------------------------------------- /tests/test_cases/simple_coroutine_import_tornado/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple coroutine in a module that imports the tornado package. 3 | """ 4 | import tornado 5 | 6 | 7 | async def call_api(): 8 | response = await fetch() 9 | if response.status != 200: 10 | raise BadStatusError() 11 | if response.status == 204: 12 | return 13 | return response.data 14 | -------------------------------------------------------------------------------- /tests/test_cases/simple_coroutine_import_tornado/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple coroutine in a module that imports the tornado package. 3 | """ 4 | import tornado 5 | 6 | 7 | @tornado.gen.coroutine 8 | def call_api(): 9 | response = yield fetch() 10 | if response.status != 200: 11 | raise BadStatusError() 12 | if response.status == 204: 13 | raise tornado.gen.Return 14 | raise tornado.gen.Return(response.data) 15 | -------------------------------------------------------------------------------- /tests/test_cases/testing_gen_test/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @testing.gen_test decorator. 3 | """ 4 | 5 | from tornado import testing 6 | 7 | from my_app import make_application 8 | 9 | 10 | class TestMyTornadoApp(testing.AsyncHTTPTestCase): 11 | def get_app(self): 12 | return make_application() 13 | 14 | @testing.gen_test 15 | async def test_ping_route(self): 16 | response = await self.fetch("/ping") 17 | assert response.body == b"pong" 18 | -------------------------------------------------------------------------------- /tests/test_cases/testing_gen_test/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @testing.gen_test decorator. 3 | """ 4 | 5 | from tornado import testing 6 | 7 | from my_app import make_application 8 | 9 | 10 | class TestMyTornadoApp(testing.AsyncHTTPTestCase): 11 | def get_app(self): 12 | return make_application() 13 | 14 | @testing.gen_test 15 | def test_ping_route(self): 16 | response = yield self.fetch("/ping") 17 | assert response.body == b"pong" 18 | -------------------------------------------------------------------------------- /tests/test_cases/tornado_testing_gen_test/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @tornado.testing.gen_test decorator. 3 | """ 4 | 5 | import tornado 6 | 7 | from my_app import make_application 8 | 9 | 10 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase): 11 | def get_app(self): 12 | return make_application() 13 | 14 | @tornado.testing.gen_test 15 | async def test_ping_route(self): 16 | response = await self.fetch("/ping") 17 | assert response.body == b"pong" 18 | -------------------------------------------------------------------------------- /tests/test_cases/tornado_testing_gen_test/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @tornado.testing.gen_test decorator. 3 | """ 4 | 5 | import tornado 6 | 7 | from my_app import make_application 8 | 9 | 10 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase): 11 | def get_app(self): 12 | return make_application() 13 | 14 | @tornado.testing.gen_test 15 | def test_ping_route(self): 16 | response = yield self.fetch("/ping") 17 | assert response.body == b"pong" 18 | -------------------------------------------------------------------------------- /tests/test_cases/tornado_testing_gen_test_already_async/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @tornado.testing.gen_test decorator that is already 3 | an async function. Function should not be modified. 4 | """ 5 | 6 | # This is a pretty contrived example to confirm these tests won't have yields changed to awaits. 7 | import some_experimental_pytest_decorator_that_allows_yielding_in_test_flow 8 | import tornado 9 | 10 | from my_app import make_application 11 | 12 | 13 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase): 14 | def get_app(self): 15 | return make_application() 16 | 17 | @tornado.testing.gen_test 18 | @some_experimental_pytest_decorator_that_allows_yielding_in_test_flow(b"pong") 19 | async def test_ping_route(self, yielder): 20 | response = await self.fetch("/ping") 21 | expected_response = yield yielder() 22 | assert response.body == expected_response 23 | -------------------------------------------------------------------------------- /tests/test_cases/tornado_testing_gen_test_already_async/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tornado gen_test using @tornado.testing.gen_test decorator that is already 3 | an async function. Function should not be modified. 4 | """ 5 | 6 | # This is a pretty contrived example to confirm these tests won't have yields changed to awaits. 7 | import some_experimental_pytest_decorator_that_allows_yielding_in_test_flow 8 | import tornado 9 | 10 | from my_app import make_application 11 | 12 | 13 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase): 14 | def get_app(self): 15 | return make_application() 16 | 17 | @tornado.testing.gen_test 18 | @some_experimental_pytest_decorator_that_allows_yielding_in_test_flow(b"pong") 19 | async def test_ping_route(self, yielder): 20 | response = await self.fetch("/ping") 21 | expected_response = yield yielder() 22 | assert response.body == expected_response 23 | -------------------------------------------------------------------------------- /tests/test_cases/yield_list_comprehension/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that yields a list comprehension that creates a list of yieldable objects. 3 | See: https://www.tornadoweb.org/en/stable/gen.html. 4 | """ 5 | from tornado import gen 6 | import asyncio 7 | 8 | 9 | async def get_users(user_ids): 10 | users = await asyncio.gather(*[fetch(user_id) for user_id in user_ids]) 11 | return users 12 | -------------------------------------------------------------------------------- /tests/test_cases/yield_list_comprehension/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that yields a list comprehension that creates a list of yieldable objects. 3 | See: https://www.tornadoweb.org/en/stable/gen.html. 4 | """ 5 | from tornado import gen 6 | 7 | 8 | @gen.coroutine 9 | def get_users(user_ids): 10 | users = yield [fetch(user_id) for user_id in user_ids] 11 | raise gen.Return(users) 12 | -------------------------------------------------------------------------------- /tests/test_cases/yield_list_of_futures/after.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that yields a list of yieldable objects. 3 | See: https://www.tornadoweb.org/en/stable/gen.html. 4 | """ 5 | from tornado import gen 6 | import asyncio 7 | 8 | 9 | async def get_two_users(user_id_1, user_id_2): 10 | response_1, reponse_2 = await asyncio.gather(*[fetch(user_id_1), fetch(user_id_2)]) 11 | return (response_1.user, response_2.user) 12 | -------------------------------------------------------------------------------- /tests/test_cases/yield_list_of_futures/before.py: -------------------------------------------------------------------------------- 1 | """ 2 | A coroutine that yields a list of yieldable objects. 3 | See: https://www.tornadoweb.org/en/stable/gen.html. 4 | """ 5 | from tornado import gen 6 | 7 | 8 | @gen.coroutine 9 | def get_two_users(user_id_1, user_id_2): 10 | response_1, reponse_2 = yield [fetch(user_id_1), fetch(user_id_2)] 11 | raise gen.Return((response_1.user, response_2.user)) 12 | -------------------------------------------------------------------------------- /tests/test_tornado_async_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple 2 | 3 | import libcst 4 | import pytest 5 | 6 | from tornado_async_transformer import TornadoAsyncTransformer, TransformError 7 | 8 | from tests.collector import ( 9 | ExceptionCase, 10 | collect_exception_cases, 11 | TestCase, 12 | collect_test_cases, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize("test_case", collect_test_cases()) 17 | def test_python_module(test_case: TestCase) -> None: 18 | source_tree = libcst.parse_module(test_case.before) 19 | visited_tree = source_tree.visit(TornadoAsyncTransformer()) 20 | assert visited_tree.code == test_case.after 21 | 22 | 23 | @pytest.mark.parametrize("exception_case", collect_exception_cases()) 24 | def test_unsupported_python_module(exception_case: ExceptionCase) -> None: 25 | source_tree = libcst.parse_module(exception_case.source) 26 | 27 | with pytest.raises(TransformError) as exception: 28 | visited_tree = source_tree.visit(TornadoAsyncTransformer()) 29 | 30 | assert exception_case.expected_error_message in str(exception.value) 31 | -------------------------------------------------------------------------------- /tornado_async_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tornado_async_transformer import TornadoAsyncTransformer, TransformError 2 | -------------------------------------------------------------------------------- /tornado_async_transformer/helpers.py: -------------------------------------------------------------------------------- 1 | from functools import singledispatch 2 | from typing import List, Sequence, Union 3 | 4 | import libcst as cst 5 | from libcst import matchers as m 6 | 7 | 8 | def with_added_imports( 9 | module_node: cst.Module, import_nodes: Sequence[Union[cst.Import, cst.ImportFrom]] 10 | ) -> cst.Module: 11 | """ 12 | Adds new import `import_node` after the first import in the module `module_node`. 13 | """ 14 | updated_body: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [] 15 | added_import = False 16 | for line in module_node.body: 17 | updated_body.append(line) 18 | if not added_import and _is_import_line(line): 19 | for import_node in import_nodes: 20 | updated_body.append(cst.SimpleStatementLine(body=tuple([import_node]))) 21 | added_import = True 22 | 23 | if not added_import: 24 | raise RuntimeError("Failed to add imports") 25 | 26 | return module_node.with_changes(body=tuple(updated_body)) 27 | 28 | 29 | def _is_import_line( 30 | line: Union[cst.SimpleStatementLine, cst.BaseCompoundStatement] 31 | ) -> bool: 32 | return m.matches(line, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])) 33 | 34 | 35 | def name_attr_possibilities(tag: str) -> List[Union[m.Name, m.Attribute]]: 36 | """ 37 | Let's say we want to find all instances of coroutine decorators in our code. The torando coroutine 38 | decorator can be imported and used in the following ways: 39 | 40 | ``` 41 | import tornado; @tornado.gen.coroutine 42 | from tornado import gen; @gen.coroutine 43 | from tornado.gen import coroutine; @coroutine 44 | ``` 45 | 46 | If we want to see if a decorator is a coroutine decorator, and we don't want the overhead of knowing 47 | all of the module's imports, we can check if a decorator matches one of the following options to be 48 | fairly confident it's a coroutine: 49 | - tornado.gen.coroutine 50 | - gen.coroutine 51 | - coroutine 52 | 53 | This doesn't account for renamed imports (since we're not import aware) but does a decent enough job. 54 | Another option is to only match against the final Name in a Name, Attribute or nested Attribute, but 55 | there doesn't seem to be a majorly simpler way to do that atm with libcst and this way we get some 56 | extra protection from considering @this.is.not.a.coroutine a match for @tornado.gen.coroutine. 57 | 58 | # We run this function on "tornado.gen.coroutine" 59 | >>> tornado_gen_coroutine, gen_coroutine, coroutine = name_attr_possibilities("tornado.gen.coroutine") 60 | 61 | # We have just the matcher Name "coroutine" 62 | >>> coroutine 63 | Name(value='coroutine',...) 64 | 65 | # We have an attribute "gen.coroutine" 66 | >>> gen_coroutine 67 | Attribute(value=Name(value='gen',...), attr=Name(value='coroutine',...),...) 68 | 69 | # We have a nested attribute "tornado.gen.coroutine" 70 | >>> tornado_gen_coroutine 71 | Attribute(value=Attribute(value=Name(value='tornado',...), attr=Name(value='gen',...),...), attr=Name(value='coroutine',...),...) 72 | """ 73 | 74 | def _make_name_or_attribute(parts: List[str]) -> Union[m.Name, m.Attribute]: 75 | if not parts: 76 | raise RuntimeError("Expected a non empty list of strings") 77 | 78 | # just a name, e.g. `coroutine` 79 | if len(parts) == 1: 80 | return m.Name(parts[0]) 81 | 82 | # a name and attribute, e.g. `gen.coroutine` 83 | if len(parts) == 2: 84 | return m.Attribute(value=m.Name(parts[0]), attr=m.Name(parts[1])) 85 | 86 | # a complex attribute, e.g. `tornado.gen.coroutine`, we want to make 87 | # the attribute with value `tornado.gen` and attr `coroutine` 88 | value = _make_name_or_attribute(parts[:-1]) 89 | attr = _make_name_or_attribute(parts[-1:]) 90 | return m.Attribute(value=value, attr=attr) 91 | 92 | parts = tag.split(".") 93 | return [_make_name_or_attribute(parts[start:]) for start in range(len(parts))] 94 | 95 | 96 | def some_version_of(tag: str) -> m.OneOf[m.Union[m.Name, m.Attribute]]: 97 | """ 98 | Poorly named wrapper around name_attr_possibilities. 99 | """ 100 | return m.OneOf(*name_attr_possibilities(tag)) 101 | -------------------------------------------------------------------------------- /tornado_async_transformer/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seatgeek/tornado-async-transformer/c9353151bd1ffb9e532f68d929cf68e68799eb41/tornado_async_transformer/py.typed -------------------------------------------------------------------------------- /tornado_async_transformer/tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from pathlib import Path 5 | from typing import List, Tuple 6 | import argparse 7 | 8 | import libcst as cst 9 | from libcst import CSTVisitorT 10 | 11 | from tornado_async_transformer import TornadoAsyncTransformer, TransformError 12 | 13 | 14 | def transform_file(visitor: CSTVisitorT, filename: str) -> None: 15 | with open(filename, "r") as python_file: 16 | python_source = python_file.read() 17 | 18 | try: 19 | source_tree = cst.parse_module(python_source) 20 | except Exception as e: 21 | print("{} failed parse: {}".format(filename, str(e))) 22 | return 23 | 24 | try: 25 | visited_tree = source_tree.visit(visitor) 26 | except TransformError as e: 27 | print("{} failed transform: {}".format(filename, str(e))) 28 | return 29 | 30 | if not visited_tree.deep_equals(source_tree): 31 | with open(filename, "w") as python_file: 32 | python_file.write(visited_tree.code) 33 | 34 | 35 | def collect_files(base: str) -> Tuple[str, ...]: 36 | """ 37 | Collect all python files under a base directory. 38 | """ 39 | 40 | def is_python_file(path: str) -> bool: 41 | return bool(os.path.isfile(path) and re.search(r"\.pyi?$", path)) 42 | 43 | if is_python_file(base): 44 | return (base,) 45 | 46 | if os.path.isdir(base): 47 | python_files: List[str] = [] 48 | for root, dirs, filenames in os.walk(base): 49 | full_filenames = (f"{root}/{filename}" for filename in filenames) 50 | python_files += [ 51 | full_filename 52 | for full_filename in full_filenames 53 | if is_python_file(full_filename) 54 | ] 55 | return tuple(python_files) 56 | 57 | return tuple() 58 | 59 | 60 | def parse_args() -> argparse.Namespace: 61 | parser = argparse.ArgumentParser( 62 | description="Codemod for converting legacy tornado @gen.coroutine syntax to python3.5+ native async/await" 63 | ) 64 | parser.add_argument( 65 | "bases", 66 | type=str, 67 | nargs="+", 68 | help="Files and directories (recursive) including python files to be modified.", 69 | ) 70 | return parser.parse_args() 71 | 72 | 73 | def main() -> None: 74 | args = parse_args() 75 | 76 | python_files: List[str] = [] 77 | for base in args.bases: 78 | python_files += collect_files(base) 79 | 80 | for python_file in python_files: 81 | transform_file(TornadoAsyncTransformer(), python_file) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /tornado_async_transformer/tornado_async_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Set, Tuple, Union 2 | 3 | import libcst as cst 4 | from libcst import matchers as m 5 | 6 | from tornado_async_transformer.helpers import ( 7 | name_attr_possibilities, 8 | some_version_of, 9 | with_added_imports, 10 | ) 11 | 12 | 13 | # matchers 14 | gen_return_statement_matcher = m.Raise(exc=some_version_of("tornado.gen.Return")) 15 | gen_return_call_with_args_matcher = m.Raise( 16 | exc=m.Call(func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)]) 17 | ) 18 | gen_return_call_matcher = m.Raise( 19 | exc=m.Call(func=some_version_of("tornado.gen.Return")) 20 | ) 21 | gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher 22 | gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep")) 23 | gen_task_matcher = m.Call(func=some_version_of("gen.Task")) 24 | gen_coroutine_decorator_matcher = m.Decorator( 25 | decorator=some_version_of("tornado.gen.coroutine") 26 | ) 27 | gen_test_coroutine_decorator = m.Decorator( 28 | decorator=some_version_of("tornado.testing.gen_test") 29 | ) 30 | coroutine_decorator_matcher = ( 31 | gen_coroutine_decorator_matcher | gen_test_coroutine_decorator 32 | ) 33 | coroutine_matcher = m.FunctionDef( 34 | asynchronous=None, 35 | decorators=[m.ZeroOrMore(), coroutine_decorator_matcher, m.ZeroOrMore()], 36 | ) 37 | 38 | 39 | class TransformError(Exception): 40 | """ 41 | Error raised upon encountering a known error while attempting to transform 42 | the tree. 43 | """ 44 | 45 | 46 | class TornadoAsyncTransformer(cst.CSTTransformer): 47 | """ 48 | A libcst transformer that replaces the legacy @gen.coroutine/yield 49 | async syntax with the python3.7 native async/await syntax. 50 | 51 | This transformer doesn't remove any tornado imports from modified 52 | files. 53 | """ 54 | 55 | def __init__(self) -> None: 56 | self.coroutine_stack: List[bool] = [] 57 | self.required_imports: Set[str] = set() 58 | 59 | def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.Module: 60 | if not self.required_imports: 61 | return updated_node 62 | 63 | imports = [ 64 | self.make_simple_package_import(required_import) 65 | for required_import in self.required_imports 66 | ] 67 | 68 | return with_added_imports(updated_node, imports) 69 | 70 | def visit_Call(self, node: cst.Call) -> Optional[bool]: 71 | if m.matches(node, gen_task_matcher): 72 | raise TransformError( 73 | "gen.Task (https://www.tornadoweb.org/en/branch2.4/gen.html#tornado.gen.Task) from tornado 2.4.1 is unsupported by this codemod. This file has not been modified. Manually update to supported syntax before running again." 74 | ) 75 | 76 | return True 77 | 78 | def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call: 79 | if not self.in_coroutine(self.coroutine_stack): 80 | return updated_node 81 | 82 | if m.matches(updated_node, gen_sleep_matcher): 83 | self.required_imports.add("asyncio") 84 | return updated_node.with_changes( 85 | func=cst.Attribute(value=cst.Name("asyncio"), attr=cst.Name("sleep")) 86 | ) 87 | 88 | return updated_node 89 | 90 | def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: 91 | self.coroutine_stack.append(m.matches(node, coroutine_matcher)) 92 | # always continue to visit function 93 | return True 94 | 95 | def leave_FunctionDef( 96 | self, node: cst.FunctionDef, updated_node: cst.FunctionDef 97 | ) -> cst.FunctionDef: 98 | leaving_coroutine = self.coroutine_stack.pop() 99 | if not leaving_coroutine: 100 | return updated_node 101 | 102 | return updated_node.with_changes( 103 | decorators=[ 104 | decorator 105 | for decorator in updated_node.decorators 106 | if not m.matches(decorator, gen_coroutine_decorator_matcher) 107 | ], 108 | asynchronous=cst.Asynchronous(), 109 | ) 110 | 111 | def leave_Raise( 112 | self, node: cst.Raise, updated_node: cst.Raise 113 | ) -> Union[cst.Return, cst.Raise]: 114 | if not self.in_coroutine(self.coroutine_stack): 115 | return updated_node 116 | 117 | if not m.matches(node, gen_return_matcher): 118 | return updated_node 119 | 120 | return_value, whitespace_after = self.pluck_gen_return_value(updated_node) 121 | return cst.Return( 122 | value=return_value, 123 | whitespace_after_return=whitespace_after, 124 | semicolon=updated_node.semicolon, 125 | ) 126 | 127 | def leave_Yield( 128 | self, node: cst.Yield, updated_node: cst.Yield 129 | ) -> Union[cst.Await, cst.Yield]: 130 | if not self.in_coroutine(self.coroutine_stack): 131 | return updated_node 132 | 133 | if not isinstance(updated_node.value, cst.BaseExpression): 134 | return updated_node 135 | 136 | if isinstance(updated_node.value, (cst.List, cst.ListComp)): 137 | self.required_imports.add("asyncio") 138 | expression = self.pluck_asyncio_gather_expression_from_yield_list_or_list_comp( 139 | updated_node 140 | ) 141 | 142 | elif m.matches( 143 | updated_node, 144 | m.Yield(value=((m.Dict() | m.DictComp())) | m.Call(func=m.Name("dict"))), 145 | ): 146 | raise TransformError( 147 | "Yielding a dict of futures (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) added in tornado 3.2 is unsupported by the codemod. This file has not been modified. Manually update to supported syntax before running again." 148 | ) 149 | 150 | else: 151 | expression = updated_node.value 152 | 153 | return cst.Await( 154 | expression=expression, 155 | whitespace_after_await=updated_node.whitespace_after_yield, 156 | lpar=updated_node.lpar, 157 | rpar=updated_node.rpar, 158 | ) 159 | 160 | @staticmethod 161 | def pluck_asyncio_gather_expression_from_yield_list_or_list_comp( 162 | node: cst.Yield, 163 | ) -> cst.BaseExpression: 164 | return cst.Call( 165 | func=cst.Attribute(value=cst.Name("asyncio"), attr=cst.Name("gather")), 166 | args=[cst.Arg(value=node.value, star="*")], 167 | ) 168 | 169 | @staticmethod 170 | def in_coroutine(coroutine_stack: List[bool]) -> bool: 171 | if not coroutine_stack: 172 | return False 173 | 174 | return coroutine_stack[-1] 175 | 176 | @staticmethod 177 | def pluck_gen_return_value( 178 | node: cst.Raise, 179 | ) -> Tuple[Optional[cst.BaseExpression], cst.SimpleWhitespace]: 180 | if m.matches(node, gen_return_call_with_args_matcher): 181 | return node.exc.args[0].value, node.whitespace_after_raise 182 | 183 | # if there's no return value, we don't preserve whitespace after 'raise' 184 | return None, cst.SimpleWhitespace("") 185 | 186 | @staticmethod 187 | def make_simple_package_import(package: str) -> cst.Import: 188 | assert not "." in package, "this only supports a root package, e.g. 'import os'" 189 | return cst.Import(names=[cst.ImportAlias(name=cst.Name(package))]) 190 | --------------------------------------------------------------------------------