├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── LICENSE
├── README.md
├── api
└── transform.py
├── index.html
├── mypy.ini
├── pyproject.toml
├── requirements.txt
├── script.js
├── setup.cfg
├── setup.py
├── style.css
├── tests
├── __init__.py
├── collector.py
├── exception_cases
│ ├── coroutine_gen_task.py
│ ├── yield_dict_call.py
│ ├── yield_dict_comprehension.py
│ └── yield_dict_literal.py
├── test_cases
│ ├── coroutine_multiple_generators
│ │ ├── after.py
│ │ └── before.py
│ ├── coroutine_with_nested_generator_function
│ │ ├── after.py
│ │ └── before.py
│ ├── gen_return_call_no_args
│ │ ├── after.py
│ │ └── before.py
│ ├── gen_return_call_none
│ │ ├── after.py
│ │ └── before.py
│ ├── gen_return_dict_multi_line
│ │ ├── after.py
│ │ └── before.py
│ ├── gen_return_statement
│ │ ├── after.py
│ │ └── before.py
│ ├── gen_sleep
│ │ ├── after.py
│ │ └── before.py
│ ├── gen_test
│ │ ├── after.py
│ │ └── before.py
│ ├── module_level_raise
│ │ ├── after.py
│ │ └── before.py
│ ├── nested_coroutine
│ │ ├── after.py
│ │ └── before.py
│ ├── non_coroutine_returns_coroutine
│ │ ├── after.py
│ │ └── before.py
│ ├── simple_coroutine_from_tornado_import_gen
│ │ ├── after.py
│ │ └── before.py
│ ├── simple_coroutine_import_tornado
│ │ ├── after.py
│ │ └── before.py
│ ├── testing_gen_test
│ │ ├── after.py
│ │ └── before.py
│ ├── tornado_testing_gen_test
│ │ ├── after.py
│ │ └── before.py
│ ├── tornado_testing_gen_test_already_async
│ │ ├── after.py
│ │ └── before.py
│ ├── yield_list_comprehension
│ │ ├── after.py
│ │ └── before.py
│ └── yield_list_of_futures
│ │ ├── after.py
│ │ └── before.py
└── test_tornado_async_transformer.py
└── tornado_async_transformer
├── __init__.py
├── helpers.py
├── py.typed
├── tool.py
└── tornado_async_transformer.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on: [push]
4 |
5 | jobs:
6 |
7 | test:
8 | runs-on: ubuntu-latest
9 |
10 | strategy:
11 | matrix:
12 | python-version: [ '3.6', '3.7' ]
13 |
14 | name: pytest (${{ matrix.python-version }})
15 | steps:
16 | - uses: actions/checkout@v1
17 | - uses: actions/setup-python@v1
18 | with:
19 | python-version: ${{ matrix.python-version }}
20 | - run: pip3 install -r requirements.txt
21 | - run: pytest -vv
22 |
23 | lint:
24 |
25 | runs-on: ubuntu-latest
26 |
27 | steps:
28 | - uses: actions/checkout@v1
29 | - uses: actions/setup-python@v1
30 | with:
31 | python-version: '3.7'
32 | - run: pip3 install -r requirements.txt
33 | - run: black --check .
34 | # skipping until i can sort out some issues, example:
35 | # error: Signature of "leave_Module" incompatible with supertype "CSTTypedTransformerFunctions"
36 | # - run: mypy tornado_async_transformer
37 |
38 | deploy:
39 |
40 | if: github.ref == 'refs/heads/master'
41 | runs-on: ubuntu-latest
42 | needs:
43 | - lint
44 | - test
45 |
46 | steps:
47 | - uses: actions/checkout@v1
48 | - uses: actions/setup-python@v1
49 | with:
50 | python-version: '3.7'
51 | - run: pip3 install -r requirements.txt
52 | - name: bundle
53 | run: python setup.py sdist bdist_wheel
54 | - name: deploy to pypi
55 | env:
56 | TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }}
57 | TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
58 | run: twine upload --skip-existing dist/*
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 |
118 | # mypy
119 | .mypy_cache/
120 | .dmypy.json
121 | dmypy.json
122 |
123 | # Pyre type checker
124 | .pyre/
125 |
126 | # VS Code config
127 | .vscode/
128 |
129 | # Zeit now
130 | .now/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, SeatGeek
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tornado Async Transformer
2 |
3 | 
4 |
5 | A [libcst](https://github.com/Instagram/LibCST) transformer for updating tornado @gen.coroutine syntax to python3.5+ native async/await.
6 |
7 | [Check out the demo.](https://tornado-async-transformer.zhammer.now.sh/)
8 |
9 | ### Usage
10 | You can either:
11 | - Add `tornado_async_transformer.TornadoAsyncTransformer` to your existing libcst codemod.
12 | - Or run `python -m tornado_async_transformer.tool my_project/` from the commandline.
13 |
14 | #### Example
15 | ```diff
16 | """
17 | A simple coroutine.
18 | """
19 | from tornado import gen
20 |
21 |
22 | -@gen.coroutine
23 | -def call_api():
24 | - response = yield fetch()
25 | +async def call_api():
26 | + response = await fetch()
27 | if response.status != 200:
28 | raise BadStatusError()
29 | - raise gen.Return(response.data)
30 | + return response.data
31 | ```
32 |
--------------------------------------------------------------------------------
/api/transform.py:
--------------------------------------------------------------------------------
1 | """
2 | This file defines the REST api route for the zeit now demo site, as I
3 | haven't been able to figure out how to nest in within the demo_site/
4 | directory.
5 | """
6 |
7 | from http.server import BaseHTTPRequestHandler
8 | from tornado_async_transformer import TornadoAsyncTransformer, TransformError
9 | import json
10 | import libcst
11 |
12 |
13 | def transform(source: str) -> str:
14 | source_tree = libcst.parse_module(source)
15 | visited_tree = source_tree.visit(TornadoAsyncTransformer())
16 | return visited_tree.code
17 |
18 |
19 | class handler(BaseHTTPRequestHandler):
20 | def do_POST(self):
21 | request_raw = self.rfile.read(int(self.headers.get("Content-Length"))).decode()
22 | request_body = json.loads(request_raw)
23 | source = request_body["source"]
24 |
25 | try:
26 | transformed = transform(source)
27 | except Exception as e:
28 | transformed = repr(e)
29 |
30 | self.send_response(200)
31 | self.send_header("Content-type", "application/json")
32 | self.end_headers()
33 | self.wfile.write(json.dumps({"source": transformed}).encode())
34 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Tornado Async Transformer
5 |
6 |
7 |
8 |
22 |
23 |
24 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.7
3 | check_untyped_defs = True
4 | disallow_incomplete_defs = True
5 | disallow_untyped_calls = True
6 | disallow_untyped_decorators = True
7 | disallow_untyped_defs = True
8 | ignore_missing_imports = True
9 | no_implicit_optional = True
10 | warn_redundant_casts = True
11 | warn_return_any = True
12 | warn_unused_configs = True
13 | warn_unused_ignores = True
14 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # unfortunately 'black' doesn't support setup.cfg, and it's important
2 | # not to auto format test_cases/ files for testing codemods preserving
3 | # whitespace, indentation, etc.
4 | [tool.black]
5 | exclude = "(test_cases|.now)"
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | libcst==0.2.4
2 | black==19.10b0
3 | mypy==0.750
4 | pytest==5.3.1
5 | twine==3.1.1
6 | wheel==0.33.6
7 |
8 | # for zeit site, until python3.7 is supported
9 | dataclasses==0.6; python_version < '3.6'
10 |
--------------------------------------------------------------------------------
/script.js:
--------------------------------------------------------------------------------
1 | // globals ace
2 |
3 | const editorSource = ace.edit("editor-source");
4 | editorSource.setTheme("ace/theme/tomorrow_night");
5 | editorSource.session.setMode("ace/mode/python");
6 |
7 | const editorTransformed = ace.edit("editor-transformed");
8 | editorTransformed.setReadOnly(true);
9 | editorTransformed.setTheme("ace/theme/tomorrow_night");
10 | editorTransformed.session.setMode("ace/mode/python");
11 |
12 | async function fetchTransformedSource(source) {
13 | const response = await fetch("/api/transform", {
14 | method: "POST",
15 | body: JSON.stringify({ source })
16 | });
17 | const body = await response.json();
18 | return body.source;
19 | }
20 |
21 | let calls = 0;
22 | async function onEditorUpdated() {
23 | const call = ++calls;
24 | const transformed = await fetchTransformedSource(editorSource.getValue());
25 | if (call === calls) {
26 | editorTransformed.setValue(transformed);
27 | editorSource.clearSelection();
28 | editorTransformed.clearSelection();
29 | }
30 | }
31 |
32 | editorSource.setValue(`"""
33 | A simple coroutine.
34 | """
35 | from tornado import gen
36 |
37 | @gen.coroutine
38 | def call_api():
39 | response = yield fetch()
40 | if response.status != 200:
41 | raise BadStatusError()
42 | raise gen.Return(response.data)
43 | `);
44 | onEditorUpdated();
45 | editorSource.on("change", onEditorUpdated);
46 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | addopts = tests tornado_async_transformer --ignore tests/test_cases --ignore tests/exception_cases --doctest-modules
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | from distutils.core import setup
3 |
4 | with open("README.md", "r") as readme:
5 | long_description = readme.read()
6 |
7 | setup(
8 | name="tornado-async-transformer",
9 | version="0.2.0",
10 | description="libcst transformer and codemod for updating tornado @gen.coroutine syntax to python3.5+ native async/await",
11 | url="https://github.com/zhammer/tornado-async-transformer",
12 | packages=find_packages(exclude=["tests", "demo_site"]),
13 | package_data={"tornado_async_transformer": ["py.typed"]},
14 | install_requires=["libcst == 0.2.4"],
15 | author="Zach Hammer",
16 | author_email="zachary_hammer@alumni.brown.edu",
17 | license="MIT License",
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | classifiers=[
21 | "License :: OSI Approved :: MIT License",
22 | "Topic :: Software Development :: Libraries",
23 | "Programming Language :: Python :: 3.6",
24 | "Programming Language :: Python :: 3.7",
25 | ],
26 | )
27 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: fantasy;
3 | margin: 0;
4 | padding: 0;
5 | }
6 |
7 | .header {
8 | padding: 1em;
9 | }
10 |
11 | h1 {
12 | margin: 0;
13 | }
14 |
15 | .editor {
16 | display: block;
17 | height: 70vh;
18 | width: 50vw;
19 | }
20 |
21 | .editors {
22 | display: flex;
23 | justify-content: space-between;
24 | }
25 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seatgeek/tornado-async-transformer/c9353151bd1ffb9e532f68d929cf68e68799eb41/tests/__init__.py
--------------------------------------------------------------------------------
/tests/collector.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | from typing import Any, List, NamedTuple, Tuple
4 |
5 | import pytest
6 |
7 |
8 | class TestCase(NamedTuple):
9 | before: str
10 | after: str
11 |
12 |
13 | def collect_test_cases() -> Tuple[Any, ...]:
14 | root_test_cases_directory = os.path.join(os.path.dirname(__file__), "test_cases")
15 |
16 | test_cases: List = []
17 | for root, _, files in os.walk(root_test_cases_directory):
18 | if not {"before.py", "after.py"} <= set(files):
19 | continue
20 |
21 | test_case_name = os.path.basename(root).replace("_", " ")
22 |
23 | with open(os.path.join(root, "before.py")) as before_file:
24 | before = before_file.read()
25 |
26 | with open(os.path.join(root, "after.py")) as after_file:
27 | after = after_file.read()
28 |
29 | test_cases.append(
30 | pytest.param(TestCase(before=before, after=after), id=test_case_name)
31 | )
32 |
33 | return tuple(test_cases)
34 |
35 |
36 | class ExceptionCase(NamedTuple):
37 | source: str
38 | expected_error_message: str
39 |
40 |
41 | def collect_exception_cases() -> Tuple[Any, ...]:
42 | root_exception_cases_directory = os.path.join(
43 | os.path.dirname(__file__), "exception_cases"
44 | )
45 |
46 | # all .py files in the top-levl of exception cases directory
47 | python_files = [
48 | os.path.join(root_exception_cases_directory, file)
49 | for file in os.listdir(root_exception_cases_directory)
50 | if file[-3:] == ".py"
51 | ]
52 |
53 | exception_cases: List = []
54 | for python_filename in python_files:
55 | with open(python_filename) as python_file:
56 | source = python_file.read()
57 |
58 | # the module's docstring is the expected error message
59 | ast_tree = ast.parse(source)
60 | docstring = ast.get_docstring(ast_tree)
61 | expected_error_message = docstring.replace("\n", " ")
62 |
63 | test_case_name = os.path.basename(python_filename).replace("_", " ")
64 |
65 | exception_cases.append(
66 | pytest.param(
67 | ExceptionCase(
68 | source=source, expected_error_message=expected_error_message
69 | ),
70 | id=test_case_name,
71 | )
72 | )
73 |
74 | return tuple(exception_cases)
75 |
--------------------------------------------------------------------------------
/tests/exception_cases/coroutine_gen_task.py:
--------------------------------------------------------------------------------
1 | """
2 | gen.Task (https://www.tornadoweb.org/en/branch2.4/gen.html#tornado.gen.Task)
3 | from tornado 2.4.1 is unsupported by this codemod. This file has not been modified.
4 | Manually update to supported syntax before running again.
5 | """
6 | import time
7 | from tornado import gen
8 | from tornado.ioloop import IOLoop
9 |
10 |
11 | @gen.coroutine
12 | def ping():
13 | yield gen.Task(IOLoop.instance().add_timeout, time.time() + 1.5)
14 | raise gen.Return("pong")
15 |
--------------------------------------------------------------------------------
/tests/exception_cases/yield_dict_call.py:
--------------------------------------------------------------------------------
1 | """
2 | Yielding a dict of futures
3 | (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen)
4 | added in tornado 3.2 is unsupported by the codemod. This file has not been
5 | modified. Manually update to supported syntax before running again.
6 | """
7 | from tornado import gen
8 |
9 |
10 | @gen.coroutine
11 | def get_user_friends_and_relatives(user_id):
12 | users = yield dict(
13 | friends=fetch("/friends", user_id), relatives=fetch("/relatives", user_id)
14 | )
15 | raise gen.Return(users)
16 |
--------------------------------------------------------------------------------
/tests/exception_cases/yield_dict_comprehension.py:
--------------------------------------------------------------------------------
1 | """
2 | Yielding a dict of futures
3 | (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen)
4 | added in tornado 3.2 is unsupported by the codemod. This file has not been
5 | modified. Manually update to supported syntax before running again.
6 | """
7 | from tornado import gen
8 |
9 |
10 | @gen.coroutine
11 | def get_users_by_id(user_ids):
12 | users = yield {user_id: fetch(user_ids) for user_id in user_ids}
13 | raise gen.Return(users)
14 |
--------------------------------------------------------------------------------
/tests/exception_cases/yield_dict_literal.py:
--------------------------------------------------------------------------------
1 | """
2 | Yielding a dict of futures
3 | (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen)
4 | added in tornado 3.2 is unsupported by the codemod. This file has not been
5 | modified. Manually update to supported syntax before running again.
6 | """
7 | from tornado import gen
8 |
9 |
10 | @gen.coroutine
11 | def get_two_users_by_id(user_id_1, user_id_2):
12 | users = yield {user_id_1: fetch(user_id_1), user_id_2: fetch(user_id_2)}
13 | raise gen.Return(users)
14 |
--------------------------------------------------------------------------------
/tests/test_cases/coroutine_multiple_generators/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A few coroutines that have multiple decorators in addition to @gen.coroutine.
3 | """
4 | from tornado import gen
5 | from util.decorators import route, deprecated, log_args
6 |
7 | @route("/user/:id")
8 | async def user_page(id):
9 | user = await get_user(id)
10 | user_page = "{}
".format(user.name)
11 | return user_page
12 |
13 |
14 | @deprecated
15 | @log_args
16 | async def get_user(id):
17 | response = await fetch(id)
18 | return response.user
19 |
--------------------------------------------------------------------------------
/tests/test_cases/coroutine_multiple_generators/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A few coroutines that have multiple decorators in addition to @gen.coroutine.
3 | """
4 | from tornado import gen
5 | from util.decorators import route, deprecated, log_args
6 |
7 | @gen.coroutine
8 | @route("/user/:id")
9 | def user_page(id):
10 | user = yield get_user(id)
11 | user_page = "{}
".format(user.name)
12 | raise gen.Return(user_page)
13 |
14 |
15 | @deprecated
16 | @gen.coroutine
17 | @log_args
18 | def get_user(id):
19 | response = yield fetch(id)
20 | raise gen.Return(response.user)
21 |
--------------------------------------------------------------------------------
/tests/test_cases/coroutine_with_nested_generator_function/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine with a nested, non-coroutine generator function.
3 | """
4 | from typing import List
5 |
6 | from tornado import gen
7 |
8 |
9 | async def save_users(users):
10 | def build_user_ids(users):
11 | for user in users:
12 | yield "{}-{}".format(user.first_name, user.last_name)
13 |
14 | for user_id in build_user_ids(users):
15 | await fetch("POST", user_id)
16 |
--------------------------------------------------------------------------------
/tests/test_cases/coroutine_with_nested_generator_function/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine with a nested, non-coroutine generator function.
3 | """
4 | from typing import List
5 |
6 | from tornado import gen
7 |
8 |
9 | @gen.coroutine
10 | def save_users(users):
11 | def build_user_ids(users):
12 | for user in users:
13 | yield "{}-{}".format(user.first_name, user.last_name)
14 |
15 | for user_id in build_user_ids(users):
16 | yield fetch("POST", user_id)
17 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_call_no_args/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that raises gen.Return() with no args.
3 | """
4 | from tornado import gen
5 |
6 |
7 | async def check_id_valid(id: str):
8 | response = await fetch(id)
9 | if response.status != 200:
10 | raise InvalidID()
11 |
12 | return
13 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_call_no_args/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that raises gen.Return() with no args.
3 | """
4 | from tornado import gen
5 |
6 |
7 | @gen.coroutine
8 | def check_id_valid(id: str):
9 | response = yield fetch(id)
10 | if response.status != 200:
11 | raise InvalidID()
12 |
13 | raise gen.Return()
14 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_call_none/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that raises gen.Return(None).
3 | """
4 | from tornado import gen
5 |
6 |
7 | async def check_id_valid(id: str):
8 | response = await fetch(id)
9 | if response.status != 200:
10 | raise InvalidID()
11 |
12 | return None
13 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_call_none/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that raises gen.Return(None).
3 | """
4 | from tornado import gen
5 |
6 |
7 | @gen.coroutine
8 | def check_id_valid(id: str):
9 | response = yield fetch(id)
10 | if response.status != 200:
11 | raise InvalidID()
12 |
13 | raise gen.Return(None)
14 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_dict_multi_line/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that returns a dict that spanning multiple lines.
3 | """
4 | from tornado import gen
5 |
6 | async def fetch_user(id):
7 | response = await fetch(id)
8 | return {
9 | 'user': response.user,
10 | 'source': 'user-api'
11 | }
12 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_dict_multi_line/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that returns a dict that spanning multiple lines.
3 | """
4 | from tornado import gen
5 |
6 | @gen.coroutine
7 | def fetch_user(id):
8 | response = yield fetch(id)
9 | raise gen.Return({
10 | 'user': response.user,
11 | 'source': 'user-api'
12 | })
13 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_statement/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that raises gen.Return directly instead of gen.Return(...).
3 | """
4 | from tornado import gen
5 |
6 |
7 | async def check_id_valid(id: str):
8 | response = await fetch(id)
9 | if response.status != 200:
10 | raise InvalidID()
11 |
12 | return
13 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_return_statement/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that raises gen.Return directly instead of gen.Return(...).
3 | """
4 | from tornado import gen
5 |
6 |
7 | @gen.coroutine
8 | def check_id_valid(id: str):
9 | response = yield fetch(id)
10 | if response.status != 200:
11 | raise InvalidID()
12 |
13 | raise gen.Return
14 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_sleep/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that calls gen.sleep.
3 | """
4 | from tornado import gen
5 | import asyncio
6 |
7 |
8 | async def ping():
9 | await asyncio.sleep(10)
10 | return "pong"
11 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_sleep/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that calls gen.sleep.
3 | """
4 | from tornado import gen
5 |
6 |
7 | @gen.coroutine
8 | def ping():
9 | yield gen.sleep(10)
10 | raise gen.Return("pong")
11 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_test/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @gen_test decorator.
3 | """
4 |
5 | from tornado.testing import gen_test, AsyncHTTPTestCase
6 |
7 | from my_app import make_application
8 |
9 |
10 | class TestMyTornadoApp(AsyncHTTPTestCase):
11 | def get_app(self):
12 | return make_application()
13 |
14 | @gen_test
15 | async def test_ping_route(self):
16 | response = await self.fetch("/ping")
17 | assert response.body == b"pong"
18 |
--------------------------------------------------------------------------------
/tests/test_cases/gen_test/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @gen_test decorator.
3 | """
4 |
5 | from tornado.testing import gen_test, AsyncHTTPTestCase
6 |
7 | from my_app import make_application
8 |
9 |
10 | class TestMyTornadoApp(AsyncHTTPTestCase):
11 | def get_app(self):
12 | return make_application()
13 |
14 | @gen_test
15 | def test_ping_route(self):
16 | response = yield self.fetch("/ping")
17 | assert response.body == b"pong"
18 |
--------------------------------------------------------------------------------
/tests/test_cases/module_level_raise/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A module that raises an exception outside of a function.
3 | """
4 | raise NotImplementedError("This module isn't ready!")
5 |
--------------------------------------------------------------------------------
/tests/test_cases/module_level_raise/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A module that raises an exception outside of a function.
3 | """
4 | raise NotImplementedError("This module isn't ready!")
5 |
--------------------------------------------------------------------------------
/tests/test_cases/nested_coroutine/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple coroutine that contains a nested coroutine.
3 | """
4 | from tornado import gen
5 |
6 |
7 | async def call_api():
8 | async def nested_callback(response):
9 | if response.status != 200:
10 | return response
11 |
12 | body = await response.json()
13 | if body["api-update-available"]:
14 | print("note: update api")
15 | return response
16 |
17 | response = await fetch(middleware=nested_callback)
18 | if response.status != 200:
19 | raise BadStatusError()
20 | return response.data
21 |
--------------------------------------------------------------------------------
/tests/test_cases/nested_coroutine/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple coroutine that contains a nested coroutine.
3 | """
4 | from tornado import gen
5 |
6 |
7 | @gen.coroutine
8 | def call_api():
9 | @gen.coroutine
10 | def nested_callback(response):
11 | if response.status != 200:
12 | raise gen.Return(response)
13 |
14 | body = yield response.json()
15 | if body["api-update-available"]:
16 | print("note: update api")
17 | raise gen.Return(response)
18 |
19 | response = yield fetch(middleware=nested_callback)
20 | if response.status != 200:
21 | raise BadStatusError()
22 | raise gen.Return(response.data)
23 |
--------------------------------------------------------------------------------
/tests/test_cases/non_coroutine_returns_coroutine/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A non-coroutine function that returns a coroutine.
3 | """
4 | from tornado import gen
5 |
6 |
7 | def make_simple_fetch(url: str):
8 | async def my_simple_fetch(body):
9 | response = await fetch(url, body)
10 | return response.body
11 |
12 | return my_simple_fetch
13 |
14 |
--------------------------------------------------------------------------------
/tests/test_cases/non_coroutine_returns_coroutine/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A non-coroutine function that returns a coroutine.
3 | """
4 | from tornado import gen
5 |
6 |
7 | def make_simple_fetch(url: str):
8 | @gen.coroutine
9 | def my_simple_fetch(body):
10 | response = yield fetch(url, body)
11 | raise gen.Return(response.body)
12 |
13 | return my_simple_fetch
14 |
15 |
--------------------------------------------------------------------------------
/tests/test_cases/simple_coroutine_from_tornado_import_gen/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple coroutine.
3 | """
4 | from tornado import gen
5 |
6 |
7 | async def call_api():
8 | response = await fetch()
9 | if response.status != 200:
10 | raise BadStatusError()
11 | return response.data
12 |
--------------------------------------------------------------------------------
/tests/test_cases/simple_coroutine_from_tornado_import_gen/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple coroutine.
3 | """
4 | from tornado import gen
5 |
6 |
7 | @gen.coroutine
8 | def call_api():
9 | response = yield fetch()
10 | if response.status != 200:
11 | raise BadStatusError()
12 | raise gen.Return(response.data)
13 |
--------------------------------------------------------------------------------
/tests/test_cases/simple_coroutine_import_tornado/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple coroutine in a module that imports the tornado package.
3 | """
4 | import tornado
5 |
6 |
7 | async def call_api():
8 | response = await fetch()
9 | if response.status != 200:
10 | raise BadStatusError()
11 | if response.status == 204:
12 | return
13 | return response.data
14 |
--------------------------------------------------------------------------------
/tests/test_cases/simple_coroutine_import_tornado/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple coroutine in a module that imports the tornado package.
3 | """
4 | import tornado
5 |
6 |
7 | @tornado.gen.coroutine
8 | def call_api():
9 | response = yield fetch()
10 | if response.status != 200:
11 | raise BadStatusError()
12 | if response.status == 204:
13 | raise tornado.gen.Return
14 | raise tornado.gen.Return(response.data)
15 |
--------------------------------------------------------------------------------
/tests/test_cases/testing_gen_test/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @testing.gen_test decorator.
3 | """
4 |
5 | from tornado import testing
6 |
7 | from my_app import make_application
8 |
9 |
10 | class TestMyTornadoApp(testing.AsyncHTTPTestCase):
11 | def get_app(self):
12 | return make_application()
13 |
14 | @testing.gen_test
15 | async def test_ping_route(self):
16 | response = await self.fetch("/ping")
17 | assert response.body == b"pong"
18 |
--------------------------------------------------------------------------------
/tests/test_cases/testing_gen_test/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @testing.gen_test decorator.
3 | """
4 |
5 | from tornado import testing
6 |
7 | from my_app import make_application
8 |
9 |
10 | class TestMyTornadoApp(testing.AsyncHTTPTestCase):
11 | def get_app(self):
12 | return make_application()
13 |
14 | @testing.gen_test
15 | def test_ping_route(self):
16 | response = yield self.fetch("/ping")
17 | assert response.body == b"pong"
18 |
--------------------------------------------------------------------------------
/tests/test_cases/tornado_testing_gen_test/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @tornado.testing.gen_test decorator.
3 | """
4 |
5 | import tornado
6 |
7 | from my_app import make_application
8 |
9 |
10 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase):
11 | def get_app(self):
12 | return make_application()
13 |
14 | @tornado.testing.gen_test
15 | async def test_ping_route(self):
16 | response = await self.fetch("/ping")
17 | assert response.body == b"pong"
18 |
--------------------------------------------------------------------------------
/tests/test_cases/tornado_testing_gen_test/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @tornado.testing.gen_test decorator.
3 | """
4 |
5 | import tornado
6 |
7 | from my_app import make_application
8 |
9 |
10 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase):
11 | def get_app(self):
12 | return make_application()
13 |
14 | @tornado.testing.gen_test
15 | def test_ping_route(self):
16 | response = yield self.fetch("/ping")
17 | assert response.body == b"pong"
18 |
--------------------------------------------------------------------------------
/tests/test_cases/tornado_testing_gen_test_already_async/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @tornado.testing.gen_test decorator that is already
3 | an async function. Function should not be modified.
4 | """
5 |
6 | # This is a pretty contrived example to confirm these tests won't have yields changed to awaits.
7 | import some_experimental_pytest_decorator_that_allows_yielding_in_test_flow
8 | import tornado
9 |
10 | from my_app import make_application
11 |
12 |
13 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase):
14 | def get_app(self):
15 | return make_application()
16 |
17 | @tornado.testing.gen_test
18 | @some_experimental_pytest_decorator_that_allows_yielding_in_test_flow(b"pong")
19 | async def test_ping_route(self, yielder):
20 | response = await self.fetch("/ping")
21 | expected_response = yield yielder()
22 | assert response.body == expected_response
23 |
--------------------------------------------------------------------------------
/tests/test_cases/tornado_testing_gen_test_already_async/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A tornado gen_test using @tornado.testing.gen_test decorator that is already
3 | an async function. Function should not be modified.
4 | """
5 |
6 | # This is a pretty contrived example to confirm these tests won't have yields changed to awaits.
7 | import some_experimental_pytest_decorator_that_allows_yielding_in_test_flow
8 | import tornado
9 |
10 | from my_app import make_application
11 |
12 |
13 | class TestMyTornadoApp(tornado.testing.AsyncHTTPTestCase):
14 | def get_app(self):
15 | return make_application()
16 |
17 | @tornado.testing.gen_test
18 | @some_experimental_pytest_decorator_that_allows_yielding_in_test_flow(b"pong")
19 | async def test_ping_route(self, yielder):
20 | response = await self.fetch("/ping")
21 | expected_response = yield yielder()
22 | assert response.body == expected_response
23 |
--------------------------------------------------------------------------------
/tests/test_cases/yield_list_comprehension/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that yields a list comprehension that creates a list of yieldable objects.
3 | See: https://www.tornadoweb.org/en/stable/gen.html.
4 | """
5 | from tornado import gen
6 | import asyncio
7 |
8 |
9 | async def get_users(user_ids):
10 | users = await asyncio.gather(*[fetch(user_id) for user_id in user_ids])
11 | return users
12 |
--------------------------------------------------------------------------------
/tests/test_cases/yield_list_comprehension/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that yields a list comprehension that creates a list of yieldable objects.
3 | See: https://www.tornadoweb.org/en/stable/gen.html.
4 | """
5 | from tornado import gen
6 |
7 |
8 | @gen.coroutine
9 | def get_users(user_ids):
10 | users = yield [fetch(user_id) for user_id in user_ids]
11 | raise gen.Return(users)
12 |
--------------------------------------------------------------------------------
/tests/test_cases/yield_list_of_futures/after.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that yields a list of yieldable objects.
3 | See: https://www.tornadoweb.org/en/stable/gen.html.
4 | """
5 | from tornado import gen
6 | import asyncio
7 |
8 |
9 | async def get_two_users(user_id_1, user_id_2):
10 | response_1, reponse_2 = await asyncio.gather(*[fetch(user_id_1), fetch(user_id_2)])
11 | return (response_1.user, response_2.user)
12 |
--------------------------------------------------------------------------------
/tests/test_cases/yield_list_of_futures/before.py:
--------------------------------------------------------------------------------
1 | """
2 | A coroutine that yields a list of yieldable objects.
3 | See: https://www.tornadoweb.org/en/stable/gen.html.
4 | """
5 | from tornado import gen
6 |
7 |
8 | @gen.coroutine
9 | def get_two_users(user_id_1, user_id_2):
10 | response_1, reponse_2 = yield [fetch(user_id_1), fetch(user_id_2)]
11 | raise gen.Return((response_1.user, response_2.user))
12 |
--------------------------------------------------------------------------------
/tests/test_tornado_async_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Tuple
2 |
3 | import libcst
4 | import pytest
5 |
6 | from tornado_async_transformer import TornadoAsyncTransformer, TransformError
7 |
8 | from tests.collector import (
9 | ExceptionCase,
10 | collect_exception_cases,
11 | TestCase,
12 | collect_test_cases,
13 | )
14 |
15 |
16 | @pytest.mark.parametrize("test_case", collect_test_cases())
17 | def test_python_module(test_case: TestCase) -> None:
18 | source_tree = libcst.parse_module(test_case.before)
19 | visited_tree = source_tree.visit(TornadoAsyncTransformer())
20 | assert visited_tree.code == test_case.after
21 |
22 |
23 | @pytest.mark.parametrize("exception_case", collect_exception_cases())
24 | def test_unsupported_python_module(exception_case: ExceptionCase) -> None:
25 | source_tree = libcst.parse_module(exception_case.source)
26 |
27 | with pytest.raises(TransformError) as exception:
28 | visited_tree = source_tree.visit(TornadoAsyncTransformer())
29 |
30 | assert exception_case.expected_error_message in str(exception.value)
31 |
--------------------------------------------------------------------------------
/tornado_async_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .tornado_async_transformer import TornadoAsyncTransformer, TransformError
2 |
--------------------------------------------------------------------------------
/tornado_async_transformer/helpers.py:
--------------------------------------------------------------------------------
1 | from functools import singledispatch
2 | from typing import List, Sequence, Union
3 |
4 | import libcst as cst
5 | from libcst import matchers as m
6 |
7 |
8 | def with_added_imports(
9 | module_node: cst.Module, import_nodes: Sequence[Union[cst.Import, cst.ImportFrom]]
10 | ) -> cst.Module:
11 | """
12 | Adds new import `import_node` after the first import in the module `module_node`.
13 | """
14 | updated_body: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]] = []
15 | added_import = False
16 | for line in module_node.body:
17 | updated_body.append(line)
18 | if not added_import and _is_import_line(line):
19 | for import_node in import_nodes:
20 | updated_body.append(cst.SimpleStatementLine(body=tuple([import_node])))
21 | added_import = True
22 |
23 | if not added_import:
24 | raise RuntimeError("Failed to add imports")
25 |
26 | return module_node.with_changes(body=tuple(updated_body))
27 |
28 |
29 | def _is_import_line(
30 | line: Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]
31 | ) -> bool:
32 | return m.matches(line, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()]))
33 |
34 |
35 | def name_attr_possibilities(tag: str) -> List[Union[m.Name, m.Attribute]]:
36 | """
37 | Let's say we want to find all instances of coroutine decorators in our code. The torando coroutine
38 | decorator can be imported and used in the following ways:
39 |
40 | ```
41 | import tornado; @tornado.gen.coroutine
42 | from tornado import gen; @gen.coroutine
43 | from tornado.gen import coroutine; @coroutine
44 | ```
45 |
46 | If we want to see if a decorator is a coroutine decorator, and we don't want the overhead of knowing
47 | all of the module's imports, we can check if a decorator matches one of the following options to be
48 | fairly confident it's a coroutine:
49 | - tornado.gen.coroutine
50 | - gen.coroutine
51 | - coroutine
52 |
53 | This doesn't account for renamed imports (since we're not import aware) but does a decent enough job.
54 | Another option is to only match against the final Name in a Name, Attribute or nested Attribute, but
55 | there doesn't seem to be a majorly simpler way to do that atm with libcst and this way we get some
56 | extra protection from considering @this.is.not.a.coroutine a match for @tornado.gen.coroutine.
57 |
58 | # We run this function on "tornado.gen.coroutine"
59 | >>> tornado_gen_coroutine, gen_coroutine, coroutine = name_attr_possibilities("tornado.gen.coroutine")
60 |
61 | # We have just the matcher Name "coroutine"
62 | >>> coroutine
63 | Name(value='coroutine',...)
64 |
65 | # We have an attribute "gen.coroutine"
66 | >>> gen_coroutine
67 | Attribute(value=Name(value='gen',...), attr=Name(value='coroutine',...),...)
68 |
69 | # We have a nested attribute "tornado.gen.coroutine"
70 | >>> tornado_gen_coroutine
71 | Attribute(value=Attribute(value=Name(value='tornado',...), attr=Name(value='gen',...),...), attr=Name(value='coroutine',...),...)
72 | """
73 |
74 | def _make_name_or_attribute(parts: List[str]) -> Union[m.Name, m.Attribute]:
75 | if not parts:
76 | raise RuntimeError("Expected a non empty list of strings")
77 |
78 | # just a name, e.g. `coroutine`
79 | if len(parts) == 1:
80 | return m.Name(parts[0])
81 |
82 | # a name and attribute, e.g. `gen.coroutine`
83 | if len(parts) == 2:
84 | return m.Attribute(value=m.Name(parts[0]), attr=m.Name(parts[1]))
85 |
86 | # a complex attribute, e.g. `tornado.gen.coroutine`, we want to make
87 | # the attribute with value `tornado.gen` and attr `coroutine`
88 | value = _make_name_or_attribute(parts[:-1])
89 | attr = _make_name_or_attribute(parts[-1:])
90 | return m.Attribute(value=value, attr=attr)
91 |
92 | parts = tag.split(".")
93 | return [_make_name_or_attribute(parts[start:]) for start in range(len(parts))]
94 |
95 |
96 | def some_version_of(tag: str) -> m.OneOf[m.Union[m.Name, m.Attribute]]:
97 | """
98 | Poorly named wrapper around name_attr_possibilities.
99 | """
100 | return m.OneOf(*name_attr_possibilities(tag))
101 |
--------------------------------------------------------------------------------
/tornado_async_transformer/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seatgeek/tornado-async-transformer/c9353151bd1ffb9e532f68d929cf68e68799eb41/tornado_async_transformer/py.typed
--------------------------------------------------------------------------------
/tornado_async_transformer/tool.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | from pathlib import Path
5 | from typing import List, Tuple
6 | import argparse
7 |
8 | import libcst as cst
9 | from libcst import CSTVisitorT
10 |
11 | from tornado_async_transformer import TornadoAsyncTransformer, TransformError
12 |
13 |
14 | def transform_file(visitor: CSTVisitorT, filename: str) -> None:
15 | with open(filename, "r") as python_file:
16 | python_source = python_file.read()
17 |
18 | try:
19 | source_tree = cst.parse_module(python_source)
20 | except Exception as e:
21 | print("{} failed parse: {}".format(filename, str(e)))
22 | return
23 |
24 | try:
25 | visited_tree = source_tree.visit(visitor)
26 | except TransformError as e:
27 | print("{} failed transform: {}".format(filename, str(e)))
28 | return
29 |
30 | if not visited_tree.deep_equals(source_tree):
31 | with open(filename, "w") as python_file:
32 | python_file.write(visited_tree.code)
33 |
34 |
35 | def collect_files(base: str) -> Tuple[str, ...]:
36 | """
37 | Collect all python files under a base directory.
38 | """
39 |
40 | def is_python_file(path: str) -> bool:
41 | return bool(os.path.isfile(path) and re.search(r"\.pyi?$", path))
42 |
43 | if is_python_file(base):
44 | return (base,)
45 |
46 | if os.path.isdir(base):
47 | python_files: List[str] = []
48 | for root, dirs, filenames in os.walk(base):
49 | full_filenames = (f"{root}/{filename}" for filename in filenames)
50 | python_files += [
51 | full_filename
52 | for full_filename in full_filenames
53 | if is_python_file(full_filename)
54 | ]
55 | return tuple(python_files)
56 |
57 | return tuple()
58 |
59 |
60 | def parse_args() -> argparse.Namespace:
61 | parser = argparse.ArgumentParser(
62 | description="Codemod for converting legacy tornado @gen.coroutine syntax to python3.5+ native async/await"
63 | )
64 | parser.add_argument(
65 | "bases",
66 | type=str,
67 | nargs="+",
68 | help="Files and directories (recursive) including python files to be modified.",
69 | )
70 | return parser.parse_args()
71 |
72 |
73 | def main() -> None:
74 | args = parse_args()
75 |
76 | python_files: List[str] = []
77 | for base in args.bases:
78 | python_files += collect_files(base)
79 |
80 | for python_file in python_files:
81 | transform_file(TornadoAsyncTransformer(), python_file)
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
86 |
--------------------------------------------------------------------------------
/tornado_async_transformer/tornado_async_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Set, Tuple, Union
2 |
3 | import libcst as cst
4 | from libcst import matchers as m
5 |
6 | from tornado_async_transformer.helpers import (
7 | name_attr_possibilities,
8 | some_version_of,
9 | with_added_imports,
10 | )
11 |
12 |
13 | # matchers
14 | gen_return_statement_matcher = m.Raise(exc=some_version_of("tornado.gen.Return"))
15 | gen_return_call_with_args_matcher = m.Raise(
16 | exc=m.Call(func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)])
17 | )
18 | gen_return_call_matcher = m.Raise(
19 | exc=m.Call(func=some_version_of("tornado.gen.Return"))
20 | )
21 | gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher
22 | gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep"))
23 | gen_task_matcher = m.Call(func=some_version_of("gen.Task"))
24 | gen_coroutine_decorator_matcher = m.Decorator(
25 | decorator=some_version_of("tornado.gen.coroutine")
26 | )
27 | gen_test_coroutine_decorator = m.Decorator(
28 | decorator=some_version_of("tornado.testing.gen_test")
29 | )
30 | coroutine_decorator_matcher = (
31 | gen_coroutine_decorator_matcher | gen_test_coroutine_decorator
32 | )
33 | coroutine_matcher = m.FunctionDef(
34 | asynchronous=None,
35 | decorators=[m.ZeroOrMore(), coroutine_decorator_matcher, m.ZeroOrMore()],
36 | )
37 |
38 |
39 | class TransformError(Exception):
40 | """
41 | Error raised upon encountering a known error while attempting to transform
42 | the tree.
43 | """
44 |
45 |
46 | class TornadoAsyncTransformer(cst.CSTTransformer):
47 | """
48 | A libcst transformer that replaces the legacy @gen.coroutine/yield
49 | async syntax with the python3.7 native async/await syntax.
50 |
51 | This transformer doesn't remove any tornado imports from modified
52 | files.
53 | """
54 |
55 | def __init__(self) -> None:
56 | self.coroutine_stack: List[bool] = []
57 | self.required_imports: Set[str] = set()
58 |
59 | def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.Module:
60 | if not self.required_imports:
61 | return updated_node
62 |
63 | imports = [
64 | self.make_simple_package_import(required_import)
65 | for required_import in self.required_imports
66 | ]
67 |
68 | return with_added_imports(updated_node, imports)
69 |
70 | def visit_Call(self, node: cst.Call) -> Optional[bool]:
71 | if m.matches(node, gen_task_matcher):
72 | raise TransformError(
73 | "gen.Task (https://www.tornadoweb.org/en/branch2.4/gen.html#tornado.gen.Task) from tornado 2.4.1 is unsupported by this codemod. This file has not been modified. Manually update to supported syntax before running again."
74 | )
75 |
76 | return True
77 |
78 | def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
79 | if not self.in_coroutine(self.coroutine_stack):
80 | return updated_node
81 |
82 | if m.matches(updated_node, gen_sleep_matcher):
83 | self.required_imports.add("asyncio")
84 | return updated_node.with_changes(
85 | func=cst.Attribute(value=cst.Name("asyncio"), attr=cst.Name("sleep"))
86 | )
87 |
88 | return updated_node
89 |
90 | def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
91 | self.coroutine_stack.append(m.matches(node, coroutine_matcher))
92 | # always continue to visit function
93 | return True
94 |
95 | def leave_FunctionDef(
96 | self, node: cst.FunctionDef, updated_node: cst.FunctionDef
97 | ) -> cst.FunctionDef:
98 | leaving_coroutine = self.coroutine_stack.pop()
99 | if not leaving_coroutine:
100 | return updated_node
101 |
102 | return updated_node.with_changes(
103 | decorators=[
104 | decorator
105 | for decorator in updated_node.decorators
106 | if not m.matches(decorator, gen_coroutine_decorator_matcher)
107 | ],
108 | asynchronous=cst.Asynchronous(),
109 | )
110 |
111 | def leave_Raise(
112 | self, node: cst.Raise, updated_node: cst.Raise
113 | ) -> Union[cst.Return, cst.Raise]:
114 | if not self.in_coroutine(self.coroutine_stack):
115 | return updated_node
116 |
117 | if not m.matches(node, gen_return_matcher):
118 | return updated_node
119 |
120 | return_value, whitespace_after = self.pluck_gen_return_value(updated_node)
121 | return cst.Return(
122 | value=return_value,
123 | whitespace_after_return=whitespace_after,
124 | semicolon=updated_node.semicolon,
125 | )
126 |
127 | def leave_Yield(
128 | self, node: cst.Yield, updated_node: cst.Yield
129 | ) -> Union[cst.Await, cst.Yield]:
130 | if not self.in_coroutine(self.coroutine_stack):
131 | return updated_node
132 |
133 | if not isinstance(updated_node.value, cst.BaseExpression):
134 | return updated_node
135 |
136 | if isinstance(updated_node.value, (cst.List, cst.ListComp)):
137 | self.required_imports.add("asyncio")
138 | expression = self.pluck_asyncio_gather_expression_from_yield_list_or_list_comp(
139 | updated_node
140 | )
141 |
142 | elif m.matches(
143 | updated_node,
144 | m.Yield(value=((m.Dict() | m.DictComp())) | m.Call(func=m.Name("dict"))),
145 | ):
146 | raise TransformError(
147 | "Yielding a dict of futures (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) added in tornado 3.2 is unsupported by the codemod. This file has not been modified. Manually update to supported syntax before running again."
148 | )
149 |
150 | else:
151 | expression = updated_node.value
152 |
153 | return cst.Await(
154 | expression=expression,
155 | whitespace_after_await=updated_node.whitespace_after_yield,
156 | lpar=updated_node.lpar,
157 | rpar=updated_node.rpar,
158 | )
159 |
160 | @staticmethod
161 | def pluck_asyncio_gather_expression_from_yield_list_or_list_comp(
162 | node: cst.Yield,
163 | ) -> cst.BaseExpression:
164 | return cst.Call(
165 | func=cst.Attribute(value=cst.Name("asyncio"), attr=cst.Name("gather")),
166 | args=[cst.Arg(value=node.value, star="*")],
167 | )
168 |
169 | @staticmethod
170 | def in_coroutine(coroutine_stack: List[bool]) -> bool:
171 | if not coroutine_stack:
172 | return False
173 |
174 | return coroutine_stack[-1]
175 |
176 | @staticmethod
177 | def pluck_gen_return_value(
178 | node: cst.Raise,
179 | ) -> Tuple[Optional[cst.BaseExpression], cst.SimpleWhitespace]:
180 | if m.matches(node, gen_return_call_with_args_matcher):
181 | return node.exc.args[0].value, node.whitespace_after_raise
182 |
183 | # if there's no return value, we don't preserve whitespace after 'raise'
184 | return None, cst.SimpleWhitespace("")
185 |
186 | @staticmethod
187 | def make_simple_package_import(package: str) -> cst.Import:
188 | assert not "." in package, "this only supports a root package, e.g. 'import os'"
189 | return cst.Import(names=[cst.ImportAlias(name=cst.Name(package))])
190 |
--------------------------------------------------------------------------------