├── .github └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .libcst.codemod.yaml ├── .pre-commit-hooks.yaml ├── README.md ├── autotyping ├── __init__.py ├── __main__.py ├── autotyping.py └── guess_type.py ├── pyproject.toml ├── tests └── test_codemod.py └── tox.ini /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # Based on 2 | # https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ 3 | 4 | name: Publish Python distributions to PyPI 5 | 6 | on: push 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python distributions to PyPI 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python 3.12 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: 3.12 18 | - name: Install pypa/build 19 | run: >- 20 | python -m 21 | pip install 22 | build 23 | --user 24 | - name: Build a binary wheel and a source tarball 25 | run: >- 26 | python -m 27 | build 28 | --sdist 29 | --wheel 30 | --outdir dist/ 31 | . 32 | - name: Publish distribution to PyPI 33 | if: startsWith(github.ref, 'refs/tags') 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | with: 36 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: autotyping 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13-dev"] 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install tox tox-gh-actions 24 | - name: Test with tox 25 | run: tox 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .tox/ 3 | *.egg-info/ 4 | dist/ 5 | -------------------------------------------------------------------------------- /.libcst.codemod.yaml: -------------------------------------------------------------------------------- 1 | # String that LibCST should look for in code which indicates that the 2 | # module is generated code. 3 | generated_code_marker: '@generated' 4 | # Command line and arguments for invoking a code formatter. Anything 5 | # specified here must be capable of taking code via stdin and returning 6 | # formatted code via stdout. 7 | formatter: ['black', '-'] 8 | # List of regex patterns which LibCST will evaluate against filenames to 9 | # determine if the module should be touched. 10 | blacklist_patterns: [] 11 | # List of modules that contain codemods inside of them. 12 | modules: 13 | - 'libcst.codemod.commands' 14 | - 'autotyping' 15 | # Absolute or relative path of the repository root, used for providing 16 | # full-repo metadata. Relative paths should be specified with this file 17 | # location as the base. 18 | repo_root: '.' 19 | -------------------------------------------------------------------------------- /.pre-commit-hooks.yaml: -------------------------------------------------------------------------------- 1 | - id: autotyping 2 | name: autotyping 3 | description: Automatically add simple type-annotations to your code. 4 | entry: autotyping 5 | language: python 6 | types_or: [python, pyi] 7 | args: [] 8 | additional_dependencies: [] 9 | minimum_pre_commit_version: 3.0.0 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | When I refactor code I often find myself tediously adding type 2 | annotations that are obvious from context: functions that don't 3 | return anything, boolean flags, etcetera. That's where autotyping 4 | comes in: it automatically adds those types and inserts the right 5 | annotations. 6 | 7 | # Usage 8 | 9 | Autotyping can be called directly from the CLI, be used as a [pre-commit hook](#pre-commit-hook) or run via the [`libcst` interface](#LibCST) as a codemod. 10 | Here's how to use it from the CLI: 11 | 12 | - `pip install autotyping` 13 | - `python -m autotyping /path/to/my/code` 14 | 15 | By default it does nothing; you have to add flags to make it do 16 | more transformations. The following are supported: 17 | 18 | - Annotating return types: 19 | - `--none-return`: add a `-> None` return type to functions without any 20 | return, yield, or raise in their body 21 | - `--scalar-return`: add a return annotation to functions that only return 22 | literal bool, str, bytes, int, or float objects. 23 | - Annotating parameter types: 24 | - `--bool-param`: add a `: bool` annotation to any function 25 | parameter with a default of `True` or `False` 26 | - `--int-param`, `--float-param`, `--str-param`, `--bytes-param`: add 27 | an annotation to any parameter for which the default is a literal int, 28 | float, str, or bytes object 29 | - `--annotate-optional foo:bar.Baz`: for any parameter of the form 30 | `foo=None`, add `Baz`, imported from `bar`, as the type. For example, 31 | use `--annotate-optional uid:my_types.Uid` to annotate any `uid` in your 32 | codebase with a `None` default as `Optional[my_types.Uid]`. 33 | - `--annotate-named-param foo:bar.Baz`: annotate any parameter with no 34 | default that is named `foo` with `bar.Baz`. For example, use 35 | `--annotate-named-param uid:my_types.Uid` to annotate any `uid` 36 | parameter in your codebase with no default as `my_types.Uid`. 37 | - `--guess-common-names`: infer certain parameter types from their names 38 | based on common patterns in open-source Python code. For example, infer 39 | that a `verbose` parameter is of type `bool`. 40 | - Annotating magical methods: 41 | - `--annotate-magics`: add type annotation to certain magic methods. 42 | Currently this does the following: 43 | - `__str__` returns `str` 44 | - `__repr__` returns `str` 45 | - `__len__` returns `int` 46 | - `__length_hint__` returns `int` 47 | - `__init__` returns `None` 48 | - `__del__` returns `None` 49 | - `__bool__` returns `bool` 50 | - `__bytes__` returns `bytes` 51 | - `__format__` returns `str` 52 | - `__contains__` returns `bool` 53 | - `__complex__` returns `complex` 54 | - `__int__` returns `int` 55 | - `__float__` returns `float` 56 | - `__index__` returns `int` 57 | - `__exit__`: the three parameters are `Optional[Type[BaseException]]`, 58 | `Optional[BaseException]`, and `Optional[TracebackType]` 59 | - `__aexit__`: same as `__exit__` 60 | - `--annotate-imprecise-magics`: add imprecise type annotations for 61 | some additional magic methods. Currently this adds `typing.Iterator` 62 | return annotations to `__iter__`, `__await__`, and `__reversed__`. 63 | These annotations should have a generic parameter to indicate what 64 | you're iterating over, but that's too hard for autotyping to figure 65 | out. 66 | - External integrations 67 | - `--pyanalyze-report`: takes types suggested by 68 | [pyanalyze](https://github.com/quora/pyanalyze)'s `suggested_parameter_type` 69 | and `suggested_return_type` codes and applies them. You can generate these 70 | with a command like: 71 | `pyanalyze --json-output failures.json -e suggested_return_type -e suggested_parameter_type -v .` 72 | - `--only-without-imports`: only apply pyanalyze suggestions that do not require 73 | new imports. This is useful because suggestions that require imports may need 74 | more manual work. 75 | 76 | There are two shortcut flags to enable multiple transformations at once: 77 | 78 | - `--safe` enables changes that should always be safe. This includes 79 | `--none-return`, `--scalar-return`, and `--annotate-magics`. 80 | - `--aggressive` enables riskier changes that are more likely to produce 81 | new type checker errors. It includes all of `--safe` as well as `--bool-param`, 82 | `--int-param`, `--float-param`, `--str-param`, `--bytes-param`, and 83 | `--annotate-imprecise-magics`. 84 | 85 | # LibCST 86 | 87 | Autotyping is built as a LibCST codemod; see the 88 | [LibCST documentation](https://libcst.readthedocs.io/en/latest/codemods_tutorial.html) 89 | for more information on how to use codemods. 90 | 91 | If you wish to run things through the `libcst.tool` interface, you can do this like so: 92 | - Make sure you have a `.libcst.codemod.yaml` with `'autotyping'` in the `modules` list. 93 | For an example, see the `.libcst.codemod.yaml` in this repo. 94 | - Run `python -m libcst.tool codemod autotyping.AutotypeCommand /path/to/my/code` 95 | 96 | 97 | # pre-commit hook 98 | 99 | Pre-commit hooks are scripts that runs automatically before a commit is made, 100 | which makes them really handy for checking and enforcing code-formatting 101 | (or in this case, typing) 102 | 103 | 1. To add `autotyping` as a [pre-commit](https://pre-commit.com/) hook, 104 | you will first need to install pre-commit if you haven't already: 105 | ``` 106 | pip install pre-commit 107 | ``` 108 | 109 | 2. After that, create or update the `.pre-commit-config.yaml` file at the root 110 | of your repository and add in: 111 | 112 | ```yaml 113 | - repos: 114 | - repo: https://github.com/JelleZijlstra/autotyping 115 | rev: 24.9.0 116 | hooks: 117 | - id: autotyping 118 | stages: [commit] 119 | types: [python] 120 | args: [--safe] # or alternatively, --aggressive, or any of the other flags mentioned above 121 | ``` 122 | 123 | 3. Finally, run the following command to install the pre-commit hook 124 | in your repository: 125 | 126 | ``` 127 | pre-commit install 128 | ``` 129 | 130 | Now whenever you commit changes, autotyping will automatically add 131 | type annotations to your code! 132 | 133 | 134 | # Limitations 135 | 136 | Autotyping is intended to be a simple tool that uses heuristics to find 137 | annotations that would be tedious to add by hand. The heuristics may fail, 138 | and after you run autotyping you should run a type checker to verify that 139 | the types it added are correct. 140 | 141 | Known limitations: 142 | 143 | - autotyping does not model code flow through a function, so it may miss 144 | implicit `None` returns 145 | 146 | # Changelog 147 | 148 | ## 24.9.0 (September 23, 2024) 149 | 150 | - Add pre-commit support. (Thanks to Akshit Tyagi and Matthew Akram.) 151 | - Add missing dependency. (Thanks to Stefane Fermigier.) 152 | 153 | ## 24.3.0 (March 25, 2024) 154 | 155 | - Add simpler ways to invoke autotyping. Now, it is possible to simply use 156 | `python3 -m autotyping` to invoke the tool. (Thanks to Shantanu Jain.) 157 | - Drop support for Python 3.7; add support for Python 3.12. (Thanks to Hugo 158 | van Kemenade.) 159 | - Infer return types for some more magic methods. (Thanks to Dhruv Manilawala.) 160 | 161 | ## 23.3.0 (March 3, 2023) 162 | 163 | - Fix crash on certain argument names like `iterables` (contributed by 164 | Marco Gorelli) 165 | 166 | ## 23.2.0 (February 3, 2023) 167 | 168 | - Add `--guess-common-names` (contributed by John Litborn) 169 | - Fix the `--safe` and `--aggressive` flags so they don't take 170 | ignored arguments 171 | - `--length-hint` should return `int` (contributed by Nikita Sobolev) 172 | - Fix bug in import adding (contributed by Shantanu) 173 | 174 | ## 22.9.0 (September 5, 2022) 175 | 176 | - Add `--safe` and `--aggressive` 177 | - Add `--pyanalyze-report` 178 | - Do not add `None` return types to methods marked with `@abstractmethod` and 179 | to methods in stub files 180 | - Improve type inference: 181 | - `"string" % ...` is always `str` 182 | - `b"bytes" % ...` is always `bytes` 183 | - An `and` or `or` operator where left and right sides are of the same type 184 | returns that type 185 | - `is`, `is not`, `in`, and `not in` always return `bool` 186 | 187 | ## 21.12.0 (December 21, 2021) 188 | 189 | - Initial PyPI release 190 | -------------------------------------------------------------------------------- /autotyping/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JelleZijlstra/autotyping/4dcb8d9dcb3dceedce8de4df074d18de0b7cfe3c/autotyping/__init__.py -------------------------------------------------------------------------------- /autotyping/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from autotyping.autotyping import AutotypeCommand 6 | 7 | from libcst.codemod import ( 8 | CodemodContext, 9 | gather_files, 10 | parallel_exec_transform_with_prettyprint, 11 | ) 12 | 13 | 14 | def main() -> int: 15 | parser = argparse.ArgumentParser() 16 | AutotypeCommand.add_args(parser) 17 | parser.add_argument("path", nargs="+") 18 | args = parser.parse_args() 19 | 20 | kwargs = vars(args) 21 | del args 22 | 23 | path = kwargs.pop("path") 24 | 25 | bases = list(map(os.path.abspath, path)) 26 | root = os.path.commonpath(bases) 27 | root = os.path.dirname(root) if os.path.isfile(root) else root 28 | 29 | # Based on: 30 | # https://github.com/Instagram/LibCST/blob/36e791ebe5f008af91a2ccc6be4900e69fad190d/libcst/tool.py#L593 31 | files = gather_files(bases, include_stubs=True) 32 | try: 33 | result = parallel_exec_transform_with_prettyprint( 34 | AutotypeCommand(CodemodContext(), **kwargs), files, repo_root=root 35 | ) 36 | except KeyboardInterrupt: 37 | print("Interrupted!", file=sys.stderr) 38 | return 2 39 | 40 | print( 41 | f"Finished codemodding {result.successes + result.skips + result.failures} files!", 42 | file=sys.stderr, 43 | ) 44 | print(f" - Transformed {result.successes} files successfully.", file=sys.stderr) 45 | print(f" - Skipped {result.skips} files.", file=sys.stderr) 46 | print(f" - Failed to codemod {result.failures} files.", file=sys.stderr) 47 | print(f" - {result.warnings} warnings were generated.", file=sys.stderr) 48 | return 1 if result.failures > 0 else 0 49 | 50 | 51 | if __name__ == "__main__": 52 | sys.exit(main()) 53 | -------------------------------------------------------------------------------- /autotyping/autotyping.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass, field 3 | import enum 4 | import json 5 | from typing import Dict, List, Optional, Sequence, Set, Tuple, Type 6 | from typing_extensions import TypedDict 7 | 8 | import libcst 9 | from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand 10 | from libcst.codemod.visitors import AddImportsVisitor 11 | from libcst.metadata import CodePosition, CodeRange, PositionProvider 12 | 13 | from autotyping.guess_type import guess_type_from_argname 14 | 15 | _DEFAULT_POSITION = CodePosition(0, 0) 16 | _DEFAULT_CODE_RANGE = CodeRange(_DEFAULT_POSITION, _DEFAULT_POSITION) 17 | 18 | 19 | @dataclass 20 | class NamedParam: 21 | name: str 22 | module: Optional[str] 23 | type_name: str 24 | 25 | @classmethod 26 | def make(cls, input: str) -> "NamedParam": 27 | name, type_path = input.split(":") 28 | if "." in type_path: 29 | module, type_name = type_path.rsplit(".", maxsplit=1) 30 | else: 31 | module = None 32 | type_name = type_path 33 | return NamedParam(name, module, type_name) 34 | 35 | 36 | class PyanalyzeSuggestion(TypedDict): 37 | suggested_type: str 38 | imports: List[str] 39 | 40 | 41 | class DecoratorKind(enum.Enum): 42 | asynq = 1 43 | abstractmethod = 2 44 | 45 | 46 | @dataclass 47 | class State: 48 | annotate_optionals: List[NamedParam] 49 | annotate_named_params: List[NamedParam] 50 | annotate_magics: bool 51 | annotate_imprecise_magics: bool 52 | none_return: bool 53 | scalar_return: bool 54 | param_types: Set[Type[object]] 55 | seen_return_statement: List[bool] = field(default_factory=lambda: [False]) 56 | seen_return_types: List[Set[Optional[Type[object]]]] = field( 57 | default_factory=lambda: [set()] 58 | ) 59 | seen_raise_statement: List[bool] = field(default_factory=lambda: [False]) 60 | seen_yield: List[bool] = field(default_factory=lambda: [False]) 61 | in_lambda: bool = False 62 | pyanalyze_suggestions: Dict[Tuple[str, int, int], PyanalyzeSuggestion] = field( 63 | default_factory=dict 64 | ) 65 | only_without_imports: bool = False 66 | guess_common_names: bool = False 67 | 68 | 69 | SIMPLE_MAGICS = { 70 | "__str__": "str", 71 | "__repr__": "str", 72 | "__len__": "int", 73 | "__length_hint__": "int", 74 | "__init__": "None", 75 | "__del__": "None", 76 | "__bool__": "bool", 77 | "__bytes__": "bytes", 78 | "__format__": "str", 79 | "__contains__": "bool", 80 | "__complex__": "complex", 81 | "__int__": "int", 82 | "__float__": "float", 83 | "__index__": "int", 84 | "__setattr__": "None", 85 | "__delattr__": "None", 86 | "__setitem__": "None", 87 | "__delitem__": "None", 88 | "__set__": "None", 89 | "__instancecheck__": "bool", 90 | "__subclasscheck__": "bool", 91 | } 92 | IMPRECISE_MAGICS = { 93 | "__iter__": ("typing", "Iterator"), 94 | "__reversed__": ("typing", "Iterator"), 95 | "__await__": ("typing", "Iterator"), 96 | } 97 | 98 | 99 | class AutotypeCommand(VisitorBasedCodemodCommand): 100 | # Add a description so that future codemodders can see what this does. 101 | DESCRIPTION: str = "Automatically adds simple type annotations." 102 | METADATA_DEPENDENCIES = (PositionProvider,) 103 | 104 | state: State 105 | 106 | @staticmethod 107 | def add_args(arg_parser: argparse.ArgumentParser) -> None: 108 | arg_parser.add_argument( 109 | "--annotate-optional", 110 | nargs="*", 111 | help=( 112 | "foo:bar.Baz annotates any argument named 'foo' with a default of None" 113 | " as 'bar.Baz'" 114 | ), 115 | ) 116 | arg_parser.add_argument( 117 | "--annotate-named-param", 118 | nargs="*", 119 | help=( 120 | "foo:bar.Baz annotates any argument named 'foo' with no default" 121 | " as 'bar.Baz'" 122 | ), 123 | ) 124 | arg_parser.add_argument( 125 | "--none-return", 126 | action="store_true", 127 | default=False, 128 | help="Automatically add None return types", 129 | ) 130 | arg_parser.add_argument( 131 | "--scalar-return", 132 | action="store_true", 133 | default=False, 134 | help="Automatically add int, str, bytes, float, and bool return types", 135 | ) 136 | arg_parser.add_argument( 137 | "--bool-param", 138 | action="store_true", 139 | default=False, 140 | help=( 141 | "Automatically add bool annotation to parameters with a default of True" 142 | " or False" 143 | ), 144 | ) 145 | arg_parser.add_argument( 146 | "--int-param", 147 | action="store_true", 148 | default=False, 149 | help="Automatically add int annotation to parameters with an int default", 150 | ) 151 | arg_parser.add_argument( 152 | "--float-param", 153 | action="store_true", 154 | default=False, 155 | help=( 156 | "Automatically add float annotation to parameters with a float default" 157 | ), 158 | ) 159 | arg_parser.add_argument( 160 | "--str-param", 161 | action="store_true", 162 | default=False, 163 | help="Automatically add str annotation to parameters with a str default", 164 | ) 165 | arg_parser.add_argument( 166 | "--bytes-param", 167 | action="store_true", 168 | default=False, 169 | help=( 170 | "Automatically add bytes annotation to parameters with a bytes default" 171 | ), 172 | ) 173 | arg_parser.add_argument( 174 | "--annotate-magics", 175 | action="store_true", 176 | default=False, 177 | help="Add annotations to certain magic methods (e.g., __str__)", 178 | ) 179 | arg_parser.add_argument( 180 | "--annotate-imprecise-magics", 181 | action="store_true", 182 | default=False, 183 | help=( 184 | "Add annotations to magic methods that are less precise (e.g., Iterable" 185 | " for __iter__)" 186 | ), 187 | ) 188 | arg_parser.add_argument( 189 | "--guess-common-names", 190 | action="store_true", 191 | default=False, 192 | help="Guess types from argument name", 193 | ) 194 | arg_parser.add_argument( 195 | "--pyanalyze-report", 196 | type=str, 197 | default=None, 198 | help="Path to a pyanalyze --json-report file to use for suggested types.", 199 | ) 200 | arg_parser.add_argument( 201 | "--only-without-imports", 202 | action="store_true", 203 | default=False, 204 | help="Only apply pyanalyze suggestions that do not require imports", 205 | ) 206 | arg_parser.add_argument( 207 | "--safe", 208 | action="store_true", 209 | default=False, 210 | help="Apply all safe transformations", 211 | ) 212 | arg_parser.add_argument( 213 | "--aggressive", 214 | action="store_true", 215 | default=False, 216 | help="Apply all transformations that do not require arguments", 217 | ) 218 | 219 | def __init__( 220 | self, 221 | context: CodemodContext, 222 | *, 223 | annotate_optional: Optional[Sequence[str]] = None, 224 | annotate_named_param: Optional[Sequence[str]] = None, 225 | annotate_magics: bool = False, 226 | annotate_imprecise_magics: bool = False, 227 | none_return: bool = False, 228 | scalar_return: bool = False, 229 | bool_param: bool = False, 230 | str_param: bool = False, 231 | bytes_param: bool = False, 232 | float_param: bool = False, 233 | int_param: bool = False, 234 | guess_common_names: bool = False, 235 | pyanalyze_report: Optional[str] = None, 236 | only_without_imports: bool = False, 237 | safe: bool = False, 238 | aggressive: bool = False, 239 | ) -> None: 240 | if safe or aggressive: 241 | none_return = True 242 | scalar_return = True 243 | annotate_magics = True 244 | if aggressive: 245 | bool_param = True 246 | int_param = True 247 | float_param = True 248 | str_param = True 249 | bytes_param = True 250 | annotate_imprecise_magics = True 251 | guess_common_names = True 252 | super().__init__(context) 253 | param_type_pairs = [ 254 | (bool_param, bool), 255 | (str_param, str), 256 | (bytes_param, bytes), 257 | (int_param, int), 258 | (float_param, float), 259 | ] 260 | pyanalyze_suggestions = {} 261 | if pyanalyze_report is not None: 262 | with open(pyanalyze_report) as f: 263 | data = json.load(f) 264 | for failure in data: 265 | if "lineno" not in failure or "col_offset" not in failure: 266 | continue 267 | metadata = failure.get("extra_metadata") 268 | if not metadata: 269 | continue 270 | if "suggested_type" not in metadata or "imports" not in metadata: 271 | continue 272 | if failure.get("code") not in ( 273 | "suggested_parameter_type", 274 | "suggested_return_type", 275 | ): 276 | continue 277 | pyanalyze_suggestions[ 278 | ( 279 | failure["absolute_filename"], 280 | failure["lineno"], 281 | failure["col_offset"], 282 | ) 283 | ] = metadata 284 | self.state = State( 285 | annotate_optionals=( 286 | [NamedParam.make(s) for s in annotate_optional] 287 | if annotate_optional 288 | else [] 289 | ), 290 | annotate_named_params=( 291 | [NamedParam.make(s) for s in annotate_named_param] 292 | if annotate_named_param 293 | else [] 294 | ), 295 | none_return=none_return, 296 | scalar_return=scalar_return, 297 | param_types={typ for param, typ in param_type_pairs if param}, 298 | annotate_magics=annotate_magics, 299 | annotate_imprecise_magics=annotate_imprecise_magics, 300 | pyanalyze_suggestions=pyanalyze_suggestions, 301 | only_without_imports=only_without_imports, 302 | guess_common_names=guess_common_names, 303 | ) 304 | 305 | def is_stub(self) -> bool: 306 | filename = self.context.filename 307 | return filename is not None and filename.endswith(".pyi") 308 | 309 | def visit_FunctionDef(self, node: libcst.FunctionDef) -> None: 310 | self.state.seen_return_statement.append(False) 311 | self.state.seen_raise_statement.append(False) 312 | self.state.seen_yield.append(False) 313 | self.state.seen_return_types.append(set()) 314 | 315 | def visit_Return(self, node: libcst.Return) -> None: 316 | if node.value is not None: 317 | self.state.seen_return_statement[-1] = True 318 | self.state.seen_return_types[-1].add(type_of_expression(node.value)) 319 | else: 320 | self.state.seen_return_types[-1].add(None) 321 | 322 | def visit_Raise(self, node: libcst.Raise) -> None: 323 | self.state.seen_raise_statement[-1] = True 324 | 325 | def visit_Yield(self, node: libcst.Yield) -> None: 326 | self.state.seen_yield[-1] = True 327 | 328 | def visit_Lambda(self, node: libcst.Lambda) -> None: 329 | self.state.in_lambda = True 330 | 331 | def leave_Lambda( 332 | self, original_node: libcst.Lambda, updated_node: libcst.Lambda 333 | ) -> libcst.CSTNode: 334 | self.state.in_lambda = False 335 | return updated_node 336 | 337 | def leave_FunctionDef( 338 | self, original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef 339 | ) -> libcst.CSTNode: 340 | kinds = {get_decorator_kind(decorator) for decorator in updated_node.decorators} 341 | is_asynq = DecoratorKind.asynq in kinds 342 | is_abstractmethod = DecoratorKind.abstractmethod in kinds 343 | seen_return = self.state.seen_return_statement.pop() 344 | seen_raise = self.state.seen_raise_statement.pop() 345 | seen_yield = self.state.seen_yield.pop() 346 | return_types = self.state.seen_return_types.pop() 347 | name = original_node.name.value 348 | if self.state.annotate_magics and name in ("__exit__", "__aexit__"): 349 | updated_node = self.annotate_exit(updated_node) 350 | 351 | if original_node.returns is not None: 352 | return updated_node 353 | 354 | if self.state.pyanalyze_suggestions and self.context.filename: 355 | # Currently pyanalyze gives the lineno of the first decorator 356 | # and libcst that of the def. 357 | # TODO I think the AST behavior changed in later Python versions. 358 | if original_node.decorators: 359 | lineno_node = original_node.decorators[0] 360 | else: 361 | lineno_node = original_node 362 | pos = self.get_metadata( 363 | PositionProvider, lineno_node, _DEFAULT_CODE_RANGE 364 | ).start 365 | key = (self.context.filename, pos.line, pos.column) 366 | suggestion = self.state.pyanalyze_suggestions.get(key) 367 | if suggestion is not None and not ( 368 | suggestion["imports"] and self.state.only_without_imports 369 | ): 370 | for import_line in suggestion["imports"]: 371 | if "." not in import_line: 372 | AddImportsVisitor.add_needed_import(self.context, import_line) 373 | else: 374 | mod, name = import_line.rsplit(".", maxsplit=1) 375 | AddImportsVisitor.add_needed_import(self.context, mod, name) 376 | annotation = libcst.Annotation( 377 | annotation=libcst.parse_expression(suggestion["suggested_type"]) 378 | ) 379 | return updated_node.with_changes(returns=annotation) 380 | 381 | if self.state.annotate_magics: 382 | if name in SIMPLE_MAGICS: 383 | return updated_node.with_changes( 384 | returns=libcst.Annotation( 385 | annotation=libcst.Name(value=SIMPLE_MAGICS[name]) 386 | ) 387 | ) 388 | if self.state.annotate_imprecise_magics: 389 | if name in IMPRECISE_MAGICS: 390 | module, imported_name = IMPRECISE_MAGICS[name] 391 | AddImportsVisitor.add_needed_import(self.context, module, imported_name) 392 | return updated_node.with_changes( 393 | returns=libcst.Annotation( 394 | annotation=libcst.Name(value=imported_name) 395 | ) 396 | ) 397 | 398 | if ( 399 | self.state.none_return 400 | and not seen_raise 401 | and not seen_return 402 | and (is_asynq or not seen_yield) 403 | and not is_abstractmethod 404 | and not self.is_stub() 405 | ): 406 | return updated_node.with_changes( 407 | returns=libcst.Annotation(annotation=libcst.Name(value="None")) 408 | ) 409 | 410 | if ( 411 | self.state.scalar_return 412 | and (is_asynq or not seen_yield) 413 | and len(return_types) == 1 414 | ): 415 | return_type = next(iter(return_types)) 416 | if return_type in {bool, int, float, str, bytes}: 417 | return updated_node.with_changes( 418 | returns=libcst.Annotation( 419 | annotation=libcst.Name(value=return_type.__name__) 420 | ) 421 | ) 422 | 423 | return updated_node 424 | 425 | def annotate_exit(self, node: libcst.FunctionDef) -> libcst.FunctionDef: 426 | if ( 427 | node.params.star_arg is not libcst.MaybeSentinel.DEFAULT 428 | or node.params.kwonly_params 429 | or node.params.star_kwarg 430 | ): 431 | return node 432 | # 4 for def __exit__(self, type, value, tb) 433 | if len(node.params.params) == 4: 434 | params = node.params.params 435 | is_pos_only = False 436 | elif len(node.params.posonly_params) == 4: 437 | params = node.params.posonly_params 438 | is_pos_only = True 439 | else: 440 | return node 441 | new_params = [params[0]] 442 | 443 | # type 444 | if params[1].annotation: 445 | new_params.append(params[1]) 446 | else: 447 | AddImportsVisitor.add_needed_import(self.context, "typing", "Optional") 448 | AddImportsVisitor.add_needed_import(self.context, "typing", "Type") 449 | type_anno = libcst.Subscript( 450 | value=libcst.Name(value="Type"), 451 | slice=[ 452 | libcst.SubscriptElement( 453 | slice=libcst.Index(value=libcst.Name(value="BaseException")) 454 | ) 455 | ], 456 | ) 457 | anno = libcst.Subscript( 458 | value=libcst.Name(value="Optional"), 459 | slice=[libcst.SubscriptElement(slice=libcst.Index(value=type_anno))], 460 | ) 461 | new_params.append( 462 | params[1].with_changes(annotation=libcst.Annotation(annotation=anno)) 463 | ) 464 | 465 | # value 466 | if params[2].annotation: 467 | new_params.append(params[2]) 468 | else: 469 | AddImportsVisitor.add_needed_import(self.context, "typing", "Optional") 470 | anno = libcst.Subscript( 471 | value=libcst.Name(value="Optional"), 472 | slice=[ 473 | libcst.SubscriptElement( 474 | slice=libcst.Index(value=libcst.Name(value="BaseException")) 475 | ) 476 | ], 477 | ) 478 | new_params.append( 479 | params[2].with_changes(annotation=libcst.Annotation(annotation=anno)) 480 | ) 481 | 482 | # tb 483 | if params[3].annotation: 484 | new_params.append(params[3]) 485 | else: 486 | AddImportsVisitor.add_needed_import(self.context, "types", "TracebackType") 487 | anno = libcst.Subscript( 488 | value=libcst.Name(value="Optional"), 489 | slice=[ 490 | libcst.SubscriptElement( 491 | slice=libcst.Index(value=libcst.Name(value="TracebackType")) 492 | ) 493 | ], 494 | ) 495 | new_params.append( 496 | params[3].with_changes(annotation=libcst.Annotation(annotation=anno)) 497 | ) 498 | 499 | if is_pos_only: 500 | new_parameters = node.params.with_changes(posonly_params=new_params) 501 | else: 502 | new_parameters = node.params.with_changes(params=new_params) 503 | return node.with_changes(params=new_parameters) 504 | 505 | def leave_Param( 506 | self, original_node: libcst.Param, updated_node: libcst.Param 507 | ) -> libcst.CSTNode: 508 | if self.state.in_lambda: 509 | # Lambdas can't have annotations 510 | return updated_node 511 | # don't modify if there's already any annotations set 512 | if original_node.annotation is not None: 513 | return updated_node 514 | # pyanalyze suggestions 515 | if self.state.pyanalyze_suggestions and self.context.filename: 516 | pos = self.get_metadata( 517 | PositionProvider, original_node, _DEFAULT_CODE_RANGE 518 | ).start 519 | key = (self.context.filename, pos.line, pos.column) 520 | suggestion = self.state.pyanalyze_suggestions.get(key) 521 | if suggestion is not None and not ( 522 | suggestion["imports"] and self.state.only_without_imports 523 | ): 524 | for import_line in suggestion["imports"]: 525 | if "." not in import_line: 526 | AddImportsVisitor.add_needed_import(self.context, import_line) 527 | else: 528 | mod, name = import_line.rsplit(".", maxsplit=1) 529 | AddImportsVisitor.add_needed_import(self.context, mod, name) 530 | annotation = libcst.Annotation( 531 | annotation=libcst.parse_expression(suggestion["suggested_type"]) 532 | ) 533 | return updated_node.with_changes(annotation=annotation) 534 | 535 | # infer from default non-None value 536 | if original_node.default is not None: 537 | default_type = type_of_expression(original_node.default) 538 | if default_type is not None and default_type in self.state.param_types: 539 | return updated_node.with_changes( 540 | annotation=libcst.Annotation( 541 | annotation=libcst.Name(value=default_type.__name__) 542 | ) 543 | ) 544 | 545 | parameter_name = original_node.name.value 546 | default_is_none = ( 547 | original_node.default is not None 548 | and isinstance(original_node.default, libcst.Name) 549 | and original_node.default.value == "None" 550 | ) 551 | # default value is None, i.e. `def foo(bar=None)` 552 | if default_is_none: 553 | # check if user has explicitly specified a type for this name 554 | for anno_optional in self.state.annotate_optionals: 555 | if parameter_name == anno_optional.name: 556 | return self._annotate_param( 557 | anno_optional, updated_node, containers=["Optional"] 558 | ) 559 | 560 | # no default value, i.e. `def foo(bar)` 561 | elif original_node.default is None: 562 | # check if user has explicitly specified a type for this name 563 | for anno_named_param in self.state.annotate_named_params: 564 | if original_node.name.value == anno_named_param.name: 565 | return self._annotate_param(anno_named_param, updated_node, []) 566 | 567 | # guess type from name 568 | if self.state.guess_common_names: 569 | guessed_type, containers = guess_type_from_argname(parameter_name) 570 | if guessed_type is not None: 571 | if default_is_none: 572 | containers += ["Optional"] 573 | return self._annotate_param( 574 | NamedParam("", None, guessed_type), updated_node, containers 575 | ) 576 | 577 | return updated_node 578 | 579 | def _annotate_param( 580 | self, param: NamedParam, updated_node: libcst.Param, containers: List[str] 581 | ) -> libcst.Param: 582 | if param.module is not None: 583 | AddImportsVisitor.add_needed_import( 584 | self.context, param.module, param.type_name 585 | ) 586 | anno = libcst.Name(value=param.type_name) 587 | 588 | for container in containers: 589 | # Should be updated when python <3.9 support is dropped 590 | AddImportsVisitor.add_needed_import(self.context, "typing", container) 591 | 592 | anno = libcst.Subscript( 593 | value=libcst.Name(value=container), 594 | slice=[libcst.SubscriptElement(slice=libcst.Index(value=anno))], 595 | ) 596 | return updated_node.with_changes(annotation=libcst.Annotation(annotation=anno)) 597 | 598 | 599 | def type_of_expression(expr: libcst.BaseExpression) -> Optional[Type[object]]: 600 | """Very simple type inference for expressions. 601 | 602 | Return None if the type cannot be inferred. 603 | 604 | """ 605 | if isinstance(expr, libcst.Float): 606 | return float 607 | elif isinstance(expr, libcst.Integer): 608 | return int 609 | elif isinstance(expr, libcst.Imaginary): 610 | return complex 611 | elif isinstance(expr, libcst.FormattedString): 612 | return str # f-strings can only be str, not bytes 613 | elif isinstance(expr, libcst.SimpleString): 614 | if "b" in expr.prefix: 615 | return bytes 616 | else: 617 | return str 618 | elif isinstance(expr, libcst.ConcatenatedString): 619 | left = type_of_expression(expr.left) 620 | right = type_of_expression(expr.right) 621 | if left == right: 622 | return left 623 | else: 624 | return None 625 | elif isinstance(expr, libcst.Name) and expr.value in ("True", "False"): 626 | return bool 627 | elif isinstance(expr, libcst.UnaryOperation) and isinstance( 628 | expr.operator, libcst.Not 629 | ): 630 | return bool 631 | elif isinstance(expr, libcst.BinaryOperation): 632 | left = type_of_expression(expr.left) 633 | if left in (str, bytes) and isinstance(expr.operator, libcst.Modulo): 634 | return left 635 | return None 636 | elif isinstance(expr, libcst.BooleanOperation): 637 | left = type_of_expression(expr.left) 638 | right = type_of_expression(expr.right) 639 | # For AND and OR, if both types are the same, we can infer that type. 640 | if left == right: 641 | return left 642 | else: 643 | return None 644 | elif isinstance(expr, libcst.Comparison): 645 | types = {type(comp.operator) for comp in expr.comparisons} 646 | # Only these are actually guaranteed to return bool 647 | if types <= {libcst.In, libcst.Is, libcst.IsNot, libcst.NotIn}: 648 | return bool 649 | return None 650 | elif ( 651 | isinstance(expr, libcst.Call) 652 | and isinstance(expr.func, libcst.Attribute) 653 | and isinstance(expr.func.value, libcst.BaseString) 654 | and expr.func.attr.value in ("format", "lower", "upper", "title") 655 | ): 656 | return str 657 | else: 658 | return None 659 | 660 | 661 | def get_decorator_kind(dec: libcst.Decorator) -> Optional[DecoratorKind]: 662 | """Is this @asynq() or @abstractmethod?""" 663 | if isinstance(dec.decorator, libcst.Call): 664 | call = dec.decorator 665 | if not isinstance(call.func, libcst.Name): 666 | return None 667 | if call.func.value != "asynq": 668 | return None 669 | if call.args: 670 | # @asynq() with custom arguments may do something unexpected 671 | return None 672 | return DecoratorKind.asynq 673 | elif isinstance(dec.decorator, libcst.Name): 674 | if dec.decorator.value == "abstractmethod": 675 | return DecoratorKind.abstractmethod 676 | elif isinstance(dec.decorator, libcst.Attribute): 677 | if ( 678 | dec.decorator.attr.value == "abstractmethod" 679 | and isinstance(dec.decorator.value, libcst.Name) 680 | and dec.decorator.value.value == "abc" 681 | ): 682 | return DecoratorKind.abstractmethod 683 | return None 684 | -------------------------------------------------------------------------------- /autotyping/guess_type.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, List 2 | import re 3 | 4 | 5 | # strategy heavily inspired by 6 | # https://github.com/Zac-HD/hypothesis/blob/07ff885edaa0c11f480a8639a75101c6fe14844f/hypothesis-python/src/hypothesis/extra/ghostwriter.py#L319 7 | def guess_type_from_argname(name: str) -> Tuple[Optional[str], List[str]]: 8 | """ 9 | If all else fails, we try guessing a strategy based on common argument names. 10 | 11 | A "good guess" is _usually correct_, and _a reasonable mistake_ if not. 12 | The logic below is therefore based on a manual reading of the builtins and 13 | some standard-library docs, plus the analysis of about three hundred million 14 | arguments in https://github.com/HypothesisWorks/hypothesis/issues/3311 15 | """ 16 | 17 | containers = "deque|list|set|iterator|tuple|iter|iterable" 18 | # not using 'sequence', 'counter' or 'collection' due to likely false alarms 19 | 20 | # (container)_(int|float|str|bool)s? 21 | # e.g. list_ints => List[int] 22 | # only check for built-in types to avoid false alarms, e.g. list_create, list_length 23 | if m := re.fullmatch( 24 | rf"(?P{containers})_(?Pint|float|str|bool)s?", name 25 | ): 26 | container_type = m.group("container").capitalize() 27 | if container_type == "Iter": 28 | container_type = "Iterable" 29 | return m.group("elems"), [container_type] 30 | 31 | # s?_(container) 32 | # e.g. latitude_list => List[float] 33 | # (container)_of_(s) 34 | # e.g. set_of_widths => Set[int] 35 | if m := re.fullmatch( 36 | rf"(?P\w+?)_?(?P{containers})", name 37 | ) or re.fullmatch(rf"(?P{containers})_of_(?P\w+)", name): 38 | # only do a simple container match 39 | # and don't check all of BOOL_NAMES to not trigger on stuff like "save_list" 40 | elems = m.group("elems") 41 | for names, name_type in ( 42 | (("bool", "boolean"), "bool"), 43 | # don't trigger on `real_list` 44 | (FLOAT_NAMES - {"real"}, "float"), 45 | (INTEGER_NAMES, "int"), 46 | (STRING_NAMES | {"string", "str"}, "str"), 47 | ): 48 | if elems in names or (elems[-1] == "s" and elems[:-1] in names): 49 | return name_type, [m.group("container").capitalize()] 50 | 51 | # Names which imply the value is a boolean 52 | if name.startswith("is_") or name in BOOL_NAMES: 53 | return "bool", [] 54 | 55 | if ( 56 | name.endswith("_size") 57 | or (name.endswith("size") and "_" not in name) 58 | or re.fullmatch(r"n(um)?_[a-z_]*s", name) 59 | or name in INTEGER_NAMES 60 | ): 61 | return "int", [] 62 | 63 | if name in FLOAT_NAMES: 64 | return "float", [] 65 | 66 | if ( 67 | "file" in name 68 | or "path" in name 69 | or name.endswith("_dir") 70 | or name in ("fname", "dir", "dirname", "directory", "folder") 71 | ): 72 | # Common names for filesystem paths: these are usually strings, but we 73 | # don't want to make strings more convenient than pathlib.Path. 74 | return None, [] 75 | 76 | if ( 77 | name.endswith("_name") 78 | or (name.endswith("name") and "_" not in name) 79 | or ("string" in name and "as" not in name) 80 | or name.endswith("label") 81 | or name in STRING_NAMES 82 | ): 83 | return "str", [] 84 | 85 | # Last clever idea: maybe we're looking a plural, and know the singular: 86 | # don't trigger on multiple ending "s" to avoid nested calls 87 | if re.fullmatch(r"\w*[^s]s", name): 88 | elems, container = guess_type_from_argname(name[:-1]) 89 | if elems is not None and not container: 90 | return elems, ["Sequence"] 91 | 92 | return None, [] 93 | 94 | 95 | BOOL_NAMES = { 96 | "keepdims", 97 | "verbose", 98 | "debug", 99 | "force", 100 | "train", 101 | "training", 102 | "trainable", 103 | "bias", 104 | "shuffle", 105 | "show", 106 | "load", 107 | "pretrained", 108 | "save", 109 | "overwrite", 110 | "normalize", 111 | "reverse", 112 | "success", 113 | "enabled", 114 | "strict", 115 | "copy", 116 | "quiet", 117 | "required", 118 | "inplace", 119 | "recursive", 120 | "enable", 121 | "active", 122 | "create", 123 | "validate", 124 | "refresh", 125 | "use_bias", 126 | } 127 | INTEGER_NAMES = { 128 | "width", 129 | "size", 130 | "length", 131 | "limit", 132 | "idx", 133 | "stride", 134 | "epoch", 135 | "epochs", 136 | "depth", 137 | "pid", 138 | "steps", 139 | "iteration", 140 | "iterations", 141 | "vocab_size", 142 | "ttl", 143 | "count", 144 | "offset", 145 | "seed", 146 | "dim", 147 | "total", 148 | "priority", 149 | "port", 150 | "number", 151 | "num", 152 | "int", 153 | } 154 | FLOAT_NAMES = { 155 | "real", 156 | "imag", 157 | "alpha", 158 | "theta", 159 | "beta", 160 | "sigma", 161 | "gamma", 162 | "angle", 163 | "reward", 164 | "learning_rate", 165 | "dropout", 166 | "dropout_rate", 167 | "epsilon", 168 | "eps", 169 | "prob", 170 | "tau", 171 | "temperature", 172 | "lat", 173 | "latitude", 174 | "lon", 175 | "longitude", 176 | "radius", 177 | "tol", 178 | "tolerance", 179 | "rate", 180 | "treshold", 181 | "float", 182 | } 183 | STRING_NAMES = { 184 | "text", 185 | "txt", 186 | "password", 187 | "label", 188 | "prefix", 189 | "suffix", 190 | "desc", 191 | "description", 192 | "str", 193 | "pattern", 194 | "subject", 195 | "reason", 196 | "comment", 197 | "prompt", 198 | "sentence", 199 | "sep", 200 | "host", 201 | "hostname", 202 | "email", 203 | "word", 204 | "slug", 205 | "api_key", 206 | "char", 207 | "character", 208 | } 209 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "autotyping" 3 | version = "24.9.0" 4 | description = "A tool for autoadding simple type annotations." 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | authors = [ 8 | { name = "Jelle Zijlstra", email = "jelle.zijlstra@gmail.com" }, 9 | ] 10 | keywords = [ 11 | "annotations", 12 | "typing", 13 | ] 14 | classifiers = [ 15 | "Development Status :: 3 - Alpha", 16 | "Environment :: Console", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | "Programming Language :: Python", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | "Programming Language :: Python :: 3.13", 27 | "Topic :: Software Development", 28 | ] 29 | dependencies = [ 30 | "libcst", 31 | "typing-extensions", 32 | ] 33 | 34 | [project.scripts] 35 | autotyping = "autotyping.__main__:main" 36 | 37 | [project.urls] 38 | "Bug Tracker" = "https://github.com/JelleZijlstra/autotyping/issues" 39 | Homepage = "https://github.com/JelleZijlstra/autotyping" 40 | 41 | [build-system] 42 | requires = ["setuptools>=69", "wheel"] 43 | build-backend = "setuptools.build_meta" 44 | 45 | [tool.black] 46 | target_version = ['py38', 'py39', 'py310', 'py311', 'py312', 'py313'] 47 | include = '\.pyi?$' 48 | 49 | exclude = ''' 50 | /( 51 | \.git 52 | | \.tox 53 | | \.venv 54 | )/ 55 | ''' 56 | preview = true 57 | skip_magic_trailing_comma = true 58 | 59 | [tool.pyanalyze] 60 | missing_parameter_annotation = true 61 | missing_return_annotation = true 62 | incompatible_override = true 63 | -------------------------------------------------------------------------------- /tests/test_codemod.py: -------------------------------------------------------------------------------- 1 | from libcst.codemod import CodemodTest, CodemodContext 2 | from autotyping.autotyping import AutotypeCommand 3 | 4 | 5 | class TestAutotype(CodemodTest): 6 | TRANSFORM = AutotypeCommand 7 | 8 | def test_noop(self) -> None: 9 | before = """ 10 | def f(): 11 | pass 12 | """ 13 | after = """ 14 | def f(): 15 | pass 16 | """ 17 | 18 | # By default, we do nothing. 19 | self.assertCodemod(before, after) 20 | 21 | def test_none_return(self) -> None: 22 | before = """ 23 | def foo(): 24 | pass 25 | 26 | def bar(): 27 | return 1 28 | 29 | def baz(): 30 | return 31 | 32 | @abstractmethod 33 | def very_abstract(): 34 | pass 35 | 36 | @abc.abstractmethod 37 | def very_abstract_without_import_from(): 38 | pass 39 | """ 40 | after = """ 41 | def foo() -> None: 42 | pass 43 | 44 | def bar(): 45 | return 1 46 | 47 | def baz() -> None: 48 | return 49 | 50 | @abstractmethod 51 | def very_abstract(): 52 | pass 53 | 54 | @abc.abstractmethod 55 | def very_abstract_without_import_from(): 56 | pass 57 | """ 58 | self.assertCodemod(before, after, none_return=True) 59 | 60 | def test_none_return_stub(self) -> None: 61 | before = """ 62 | def foo(): 63 | pass 64 | """ 65 | after = """ 66 | def foo(): 67 | pass 68 | """ 69 | self.assertCodemod( 70 | before, 71 | after, 72 | none_return=True, 73 | context_override=CodemodContext(filename="stub.pyi"), 74 | ) 75 | 76 | def test_scalar_return(self) -> None: 77 | before = """ 78 | def foo(): 79 | pass 80 | 81 | def bar(): 82 | return 1 83 | 84 | def return_not(x): 85 | return not x 86 | 87 | def formatter(x): 88 | return "{}".format(x) 89 | 90 | def old_school_formatter(x): 91 | return "%s" % x 92 | 93 | def bytes_formatter(x): 94 | return b"%s" % x 95 | 96 | def boolean_return(x, y, z): 97 | return x is y or x in z 98 | 99 | def baz() -> float: 100 | return "not a float" 101 | 102 | def two_returns(condition): 103 | if condition: 104 | return 42 105 | else: 106 | return 107 | """ 108 | after = """ 109 | def foo(): 110 | pass 111 | 112 | def bar() -> int: 113 | return 1 114 | 115 | def return_not(x) -> bool: 116 | return not x 117 | 118 | def formatter(x) -> str: 119 | return "{}".format(x) 120 | 121 | def old_school_formatter(x) -> str: 122 | return "%s" % x 123 | 124 | def bytes_formatter(x) -> bytes: 125 | return b"%s" % x 126 | 127 | def boolean_return(x, y, z) -> bool: 128 | return x is y or x in z 129 | 130 | def baz() -> float: 131 | return "not a float" 132 | 133 | def two_returns(condition): 134 | if condition: 135 | return 42 136 | else: 137 | return 138 | """ 139 | self.assertCodemod(before, after, scalar_return=True) 140 | 141 | def test_asynq_return(self) -> None: 142 | before = """ 143 | from asynq import asynq 144 | 145 | @asynq() 146 | def ret_none(): 147 | yield bar.asynq() 148 | 149 | @asynq() 150 | def ret_int(): 151 | yield bar.asynq() 152 | return 3 153 | 154 | @asink() 155 | def not_asynq(): 156 | yield bar.asynq() 157 | """ 158 | after = """ 159 | from asynq import asynq 160 | 161 | @asynq() 162 | def ret_none() -> None: 163 | yield bar.asynq() 164 | 165 | @asynq() 166 | def ret_int() -> int: 167 | yield bar.asynq() 168 | return 3 169 | 170 | @asink() 171 | def not_asynq(): 172 | yield bar.asynq() 173 | """ 174 | self.assertCodemod(before, after, scalar_return=True, none_return=True) 175 | 176 | def test_bool_param(self) -> None: 177 | before = """ 178 | def foo(x = False, y = 0, z: int = False): 179 | lambda x=False: None 180 | 181 | lambda x=False: None 182 | """ 183 | after = """ 184 | def foo(x: bool = False, y = 0, z: int = False): 185 | lambda x=False: None 186 | 187 | lambda x=False: None 188 | """ 189 | self.assertCodemod(before, after, bool_param=True) 190 | 191 | def test_typed_params(self) -> None: 192 | before = """ 193 | def foo(x=0, y=0.0, z=f"x", alpha="", beta="b" "a", gamma=b"a"): 194 | pass 195 | """ 196 | after = """ 197 | def foo(x: int=0, y: float=0.0, z: str=f"x", alpha: str="", beta: str="b" "a", gamma: bytes=b"a"): 198 | pass 199 | """ 200 | self.assertCodemod( 201 | before, 202 | after, 203 | str_param=True, 204 | bytes_param=True, 205 | float_param=True, 206 | int_param=True, 207 | ) 208 | 209 | def test_annotate_optional(self) -> None: 210 | before = """ 211 | def foo(uid=None, qid=None): 212 | pass 213 | 214 | def bar(uid): 215 | pass 216 | """ 217 | after = """ 218 | from my_types import Uid 219 | from typing import Optional 220 | 221 | def foo(uid: Optional[Uid]=None, qid=None): 222 | pass 223 | 224 | def bar(uid): 225 | pass 226 | """ 227 | self.assertCodemod(before, after, annotate_optional=["uid:my_types.Uid"]) 228 | 229 | def test_annotate_optional_with_builtin(self) -> None: 230 | before = """ 231 | def foo(number=None): 232 | pass 233 | 234 | def bar(number): 235 | pass 236 | """ 237 | after = """ 238 | from typing import Optional 239 | 240 | def foo(number: Optional[int]=None): 241 | pass 242 | 243 | def bar(number): 244 | pass 245 | """ 246 | self.assertCodemod(before, after, annotate_optional=["number:int"]) 247 | 248 | def test_annotate_named_param(self) -> None: 249 | before = """ 250 | def foo(uid, qid): 251 | pass 252 | 253 | def bar(uid=1): 254 | pass 255 | """ 256 | after = """ 257 | from my_types import Uid 258 | 259 | def foo(uid: Uid, qid): 260 | pass 261 | 262 | def bar(uid=1): 263 | pass 264 | """ 265 | self.assertCodemod(before, after, annotate_named_param=["uid:my_types.Uid"]) 266 | 267 | def test_annotate_magics(self) -> None: 268 | before = """ 269 | def __str__(): 270 | pass 271 | 272 | def __not_str__(): 273 | pass 274 | """ 275 | after = """ 276 | def __str__() -> str: 277 | pass 278 | 279 | def __not_str__(): 280 | pass 281 | """ 282 | self.assertCodemod(before, after, annotate_magics=True) 283 | 284 | def test_annotate_magics_len(self) -> None: 285 | before = """ 286 | def __len__(): 287 | pass 288 | 289 | def __length_hint__(): 290 | pass 291 | """ 292 | after = """ 293 | def __len__() -> int: 294 | pass 295 | 296 | def __length_hint__() -> int: 297 | pass 298 | """ 299 | self.assertCodemod(before, after, annotate_magics=True) 300 | 301 | def test_exit(self) -> None: 302 | before = """ 303 | def __exit__(self, typ, value, tb): 304 | pass 305 | 306 | def __aexit__(self, typ, value, tb): 307 | pass 308 | 309 | def __exit__(self, *args): 310 | pass 311 | 312 | def __exit__(self, a, b, c, d): 313 | pass 314 | """ 315 | after = """ 316 | from types import TracebackType 317 | from typing import Optional, Type 318 | 319 | def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException], tb: Optional[TracebackType]): 320 | pass 321 | 322 | def __aexit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException], tb: Optional[TracebackType]): 323 | pass 324 | 325 | def __exit__(self, *args): 326 | pass 327 | 328 | def __exit__(self, a, b, c, d): 329 | pass 330 | """ 331 | self.assertCodemod(before, after, annotate_magics=True) 332 | 333 | def test_empty_elems(self) -> None: 334 | before = """ 335 | def foo(iterables): 336 | ... 337 | """ 338 | after = """ 339 | def foo(iterables) -> None: 340 | ... 341 | """ 342 | self.assertCodemod(before, after, none_return=True, guess_common_names=True) 343 | 344 | def test_annotate_imprecise_magics(self) -> None: 345 | before = """ 346 | def __iter__(): 347 | pass 348 | 349 | def __not_iter__(): 350 | pass 351 | """ 352 | after = """ 353 | from typing import Iterator 354 | 355 | def __iter__() -> Iterator: 356 | pass 357 | 358 | def __not_iter__(): 359 | pass 360 | """ 361 | self.assertCodemod(before, after, annotate_imprecise_magics=True) 362 | 363 | def test_guessed_name(self) -> None: 364 | before = """ 365 | def foo(name): 366 | pass 367 | """ 368 | after = """ 369 | def foo(name: str): 370 | pass 371 | """ 372 | self.assertCodemod(before, after, guess_common_names=True) 373 | 374 | def test_guessed_name_optional(self) -> None: 375 | before = """ 376 | def foo(name=None): 377 | pass 378 | """ 379 | after = """ 380 | from typing import Optional 381 | 382 | def foo(name: Optional[str]=None): 383 | pass 384 | """ 385 | self.assertCodemod(before, after, guess_common_names=True) 386 | 387 | def test_guessed_and_named_param(self) -> None: 388 | before = """ 389 | def foo(uid, qid): 390 | pass 391 | def bar(name): 392 | pass 393 | """ 394 | after = """ 395 | from my_types import Uid 396 | 397 | def foo(uid: Uid, qid): 398 | pass 399 | def bar(name: str): 400 | pass 401 | """ 402 | self.assertCodemod( 403 | before, 404 | after, 405 | guess_common_names=True, 406 | annotate_named_param=["uid:my_types.Uid"], 407 | ) 408 | 409 | def test_optional_guessed_and_annotate_optional(self) -> None: 410 | before = """ 411 | def foo(real=None, qid=None): 412 | pass 413 | def bar(name=None): 414 | pass 415 | """ 416 | after = """ 417 | from my_types import Uid 418 | from typing import Optional 419 | 420 | def foo(real: Optional[Uid]=None, qid=None): 421 | pass 422 | def bar(name: Optional[str]=None): 423 | pass 424 | """ 425 | self.assertCodemod( 426 | before, 427 | after, 428 | guess_common_names=True, 429 | annotate_optional=["real:my_types.Uid"], 430 | ) 431 | 432 | def test_guess_type_from_argname1(self) -> None: 433 | before = """ 434 | def foo(list_int, set_ints): 435 | ... 436 | def bar(iterator_bool, deque_float): 437 | ... 438 | def no_guess(list_widths, container_int): 439 | ... 440 | """ 441 | after = """ 442 | from typing import Deque, Iterator, List, Set 443 | 444 | def foo(list_int: List[int], set_ints: Set[int]): 445 | ... 446 | def bar(iterator_bool: Iterator[bool], deque_float: Deque[float]): 447 | ... 448 | def no_guess(list_widths, container_int): 449 | ... 450 | """ 451 | self.assertCodemod(before, after, guess_common_names=True) 452 | 453 | def test_guess_type_from_argname2(self) -> None: 454 | before = """ 455 | def foo(int_list, ints_list, intslist): 456 | ... 457 | def bar(width_list, words_list, bool_list): 458 | ... 459 | def no_guess(save_list, real_list, int_lists): 460 | ... 461 | """ 462 | after = """ 463 | from typing import List 464 | 465 | def foo(int_list: List[int], ints_list: List[int], intslist: List[int]): 466 | ... 467 | def bar(width_list: List[int], words_list: List[str], bool_list: List[bool]): 468 | ... 469 | def no_guess(save_list, real_list, int_lists): 470 | ... 471 | """ 472 | self.maxDiff = None 473 | self.assertCodemod(before, after, guess_common_names=True) 474 | 475 | def test_guess_type_from_argname3(self) -> None: 476 | before = """ 477 | def foo(list_of_int, tuple_of_ints): 478 | ... 479 | def bar(deque_of_alphas, list_of_string): 480 | ... 481 | def no_guess(list_of_save, list_of_reals, list_of_list_of_int): 482 | ... 483 | """ 484 | after = """ 485 | from typing import Deque, List, Tuple 486 | 487 | def foo(list_of_int: List[int], tuple_of_ints: Tuple[int]): 488 | ... 489 | def bar(deque_of_alphas: Deque[float], list_of_string: List[str]): 490 | ... 491 | def no_guess(list_of_save, list_of_reals, list_of_list_of_int): 492 | ... 493 | """ 494 | self.assertCodemod(before, after, guess_common_names=True) 495 | 496 | def test_guess_type_from_argname4(self) -> None: 497 | before = """ 498 | def foo(reals, texts): 499 | ... 500 | def bar(shuffles, saves): 501 | ... 502 | def no_guess(radiuss, radius_s): 503 | ... 504 | """ 505 | after = """ 506 | from typing import Sequence 507 | 508 | def foo(reals: Sequence[float], texts: Sequence[str]): 509 | ... 510 | def bar(shuffles: Sequence[bool], saves: Sequence[bool]): 511 | ... 512 | def no_guess(radiuss, radius_s): 513 | ... 514 | """ 515 | self.assertCodemod(before, after, guess_common_names=True) 516 | 517 | def test_guess_type_from_argname_optional(self) -> None: 518 | before = """ 519 | def foo(reals = None, set_of_int = None): 520 | ... 521 | def foo2(int_tuple = None, iterator_int = None): 522 | ... 523 | """ 524 | after = """ 525 | from typing import Iterator, Optional, Sequence, Set, Tuple 526 | 527 | def foo(reals: Optional[Sequence[float]] = None, set_of_int: Optional[Set[int]] = None): 528 | ... 529 | def foo2(int_tuple: Optional[Tuple[int]] = None, iterator_int: Optional[Iterator[int]] = None): 530 | ... 531 | """ 532 | self.assertCodemod(before, after, guess_common_names=True) 533 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion=2.3.1 3 | envlist = py38,py39,py310,py311,py312,black,pyanalyze 4 | 5 | [testenv] 6 | deps = 7 | libcst 8 | pytest 9 | commands = 10 | pytest tests/test_codemod.py 11 | 12 | [testenv:black] 13 | deps = 14 | black == 24.8.0 15 | commands = 16 | black --check . 17 | 18 | [testenv:pyanalyze] 19 | deps = 20 | pyanalyze == 0.13.1 21 | commands = 22 | python -m pyanalyze --config pyproject.toml -v autotyping 23 | 24 | [gh-actions] 25 | python = 26 | 3.8: py38 27 | 3.9: py39 28 | 3.10: py310 29 | 3.11: py311 30 | 3.12: py312, black, pyanalyze 31 | --------------------------------------------------------------------------------